math-backend / database.py
engineportf's picture
Upload folder using huggingface_hub
558db1e verified
Raw
History Blame Contribute Delete
6.83 kB
import os
import sqlite3
import time
import pandas as pd
from typing import Optional
from sqlalchemy import create_engine, Column, String, Float, Date, UniqueConstraint, DateTime, JSON
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.dialects.postgresql import insert
from config import logger, OUTPUT_DIR
Base = declarative_base()
class DailyPrice(Base):
__tablename__ = 'daily_prices'
ticker = Column(String, primary_key=True)
date = Column(Date, primary_key=True)
close_price = Column(Float, nullable=False)
__table_args__ = (UniqueConstraint('ticker', 'date', name='uq_daily_prices_ticker_date'),)
class DailyYield(Base):
__tablename__ = 'daily_yields'
ticker = Column(String, primary_key=True)
date = Column(Date, primary_key=True)
yield_pct = Column(Float)
__table_args__ = (UniqueConstraint('ticker', 'date', name='uq_daily_yields_ticker_date'),)
class StitchMetadata(Base):
__tablename__ = 'stitch_metadata'
ticker = Column(String, primary_key=True)
date = Column(Date, primary_key=True)
source = Column(String) # 'direct', 'proxy_stitched', 'synthetic'
proxy_used = Column(String)
adjustment_factor = Column(Float)
__table_args__ = (UniqueConstraint('ticker', 'date', name='uq_stitch_metadata_ticker_date'),)
import uuid
import datetime
class AuditLog(Base):
__tablename__ = 'audit_log'
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String, nullable=True)
endpoint = Column(String, nullable=False)
request_hash = Column(String, nullable=True)
request_body = Column(JSON, nullable=True)
response_weights = Column(JSON, nullable=True)
timestamp = Column(DateTime, default=datetime.datetime.utcnow)
ip_address = Column(String, nullable=True)
class ApiKey(Base):
__tablename__ = 'api_keys'
key = Column(String, primary_key=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.utcnow)
expires_at = Column(DateTime, nullable=False)
revoked = Column(String, default="false") # SQLite boolean compat
used_at = Column(DateTime, nullable=True)
used_by_ip = Column(String, nullable=True)
_ENGINE = None
def with_db_retry(max_retries=3):
"""Decorator to retry database operations on transient failures."""
def decorator(func):
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except SQLAlchemyError as e:
if attempt == max_retries - 1:
raise
logger.warning(f"Database operation failed: {e}. Retrying ({attempt + 1}/{max_retries})...")
time.sleep(1.0 * (2 ** attempt)) # Exponential backoff
return wrapper
return decorator
def get_pg_engine():
"""
Creates and returns a singleton SQLAlchemy engine for PostgreSQL.
Expects DATABASE_URL to be set in the environment, falling back to local defaults if missing.
"""
global _ENGINE
if _ENGINE is not None:
return _ENGINE
db_url = os.getenv("DATABASE_URL")
if not db_url:
db_url = f"sqlite:///{os.path.join(OUTPUT_DIR, 'portfolio_db.sqlite3')}"
if db_url.startswith("sqlite"):
_ENGINE = create_engine(db_url, echo=False)
else:
_ENGINE = create_engine(db_url, echo=False, pool_size=10, max_overflow=20, pool_pre_ping=True, pool_recycle=3600)
return _ENGINE
def init_db():
"""Initializes the database schema (Creates tables if they don't exist)."""
engine = get_pg_engine()
Base.metadata.create_all(engine)
logger.info("PostgreSQL Database schema initialized.")
def migrate_sqlite_to_postgres(sqlite_path: Optional[str] = None):
"""
Reads the legacy SQLite finance database and bulk inserts all historical
price and yield records into the new PostgreSQL database.
"""
if sqlite_path is None:
sqlite_path = os.path.join(OUTPUT_DIR, "finance_data.db")
if not os.path.exists(sqlite_path):
logger.warning(f"Legacy SQLite database not found at {sqlite_path}. Nothing to migrate.")
return
logger.info(f"Starting migration from SQLite ({sqlite_path}) to PostgreSQL...")
# 1. Connect to SQLite
sqlite_conn = sqlite3.connect(sqlite_path)
# 2. Extract Data
try:
prices_df = pd.read_sql("SELECT ticker, date, close_price FROM daily_prices", sqlite_conn)
logger.info(f"Extracted {len(prices_df)} records from SQLite daily_prices.")
except Exception as e:
logger.warning(f"Could not read daily_prices from SQLite: {e}")
prices_df = pd.DataFrame()
try:
yields_df = pd.read_sql("SELECT ticker, date, yield_pct FROM daily_yields", sqlite_conn)
logger.info(f"Extracted {len(yields_df)} records from SQLite daily_yields.")
except Exception as e:
logger.warning(f"Could not read daily_yields from SQLite: {e}")
yields_df = pd.DataFrame()
sqlite_conn.close()
# 3. Connect to Postgres & Initialize schema
init_db()
pg_engine = get_pg_engine()
# 4. Transform and Load
Session = sessionmaker(bind=pg_engine)
session = Session()
try:
# We use pd.DataFrame.to_sql for massive bulk insert performance.
# Convert date strings to actual dates first
def insert_on_conflict_nothing(table, conn, keys, data_iter):
data = [dict(zip(keys, row)) for row in data_iter]
stmt = insert(table.table).values(data).on_conflict_do_nothing()
result = conn.execute(stmt)
return result.rowcount
if not prices_df.empty:
prices_df['date'] = pd.to_datetime(prices_df['date']).dt.date
prices_df.to_sql('daily_prices', pg_engine, if_exists='append', index=False, method=insert_on_conflict_nothing, chunksize=10000)
logger.info("Successfully migrated daily_prices to PostgreSQL.")
if not yields_df.empty:
yields_df['date'] = pd.to_datetime(yields_df['date']).dt.date
yields_df.to_sql('daily_yields', pg_engine, if_exists='append', index=False, method=insert_on_conflict_nothing, chunksize=10000)
logger.info("Successfully migrated daily_yields to PostgreSQL.")
except Exception as e:
logger.error(f"Migration failed during PostgreSQL insertion: {e}")
session.rollback()
finally:
session.close()
logger.info("Migration routine complete.")
if __name__ == "__main__":
# If run standalone, execute the migration
migrate_sqlite_to_postgres()