math-backend / data_repository.py
engineportf's picture
Upload folder using huggingface_hub
558db1e verified
Raw
History Blame Contribute Delete
6.55 kB
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import pandas as pd
import logging
from data import fetch_fama_french_factors, fetch_data, fetch_risk_free_rate, fetch_risk_free_series, build_monthly_returns
from database import get_pg_engine
from core_engine import build_spread_map, load_portfolio_state_dict
from core_types import PortfolioState
from config import AppConfig
logger = logging.getLogger('portfolio_engine')
@dataclass
class DataSnapshot:
opt_tickers: List[str]
opt_returns_df: pd.DataFrame
bench_rets_monthly: pd.Series
opt_ff_df: Optional[pd.DataFrame]
rfr: float
vol_raw: Optional[pd.Series]
display_df: pd.DataFrame
bench_display: pd.Series
master_state: PortfolioState
spread_map: Dict[str, float]
train_yrs: float
test_yrs: float
# Full-resolution data needed by the pipeline orchestrator
returns_df: pd.DataFrame = None
bench_rets: pd.Series = None
raw: Dict = None
prices: Dict = None
eq_bench: str = "SPY"
vol_bench: str = "^VIX"
rfr_bench: str = "^TNX"
class DataRepository:
"""
Repository layer responsible for fetching, cleaning, and assembling
all market data, benchmarks, and portfolio state required by the engine.
"""
def __init__(self, cfg: AppConfig):
self.cfg = cfg
self.trading_days = self.cfg.get("trading_days_per_year", 252)
def fetch_all(self, input_tickers: List[str], model_id: int) -> DataSnapshot:
b = self.cfg.get("benchmarks", {})
eq_bench = b.get("equity", "SPY")
vol_bench = b.get("volatility", "^VIX")
rfr_bench = b.get("risk_free", "^TNX")
ff_df = fetch_fama_french_factors() if model_id in [4, 5] else None
valid_tickers = fetch_data(input_tickers, b)
self.cfg["risk_free_rate"] = fetch_risk_free_rate(rfr_bench, self.cfg.get("risk_free_rate", 0.05))
rfr_series = fetch_risk_free_series(rfr_bench)
final_rfr = rfr_series if not rfr_series.empty else self.cfg["risk_free_rate"]
macro_tickers = {eq_bench, vol_bench, rfr_bench, "^IRX"}
opt_tickers = [t for t in valid_tickers if t not in macro_tickers or t in input_tickers]
if not opt_tickers:
raise SystemExit("No Usable Tickers found.")
legacy_state_dict = load_portfolio_state_dict() if self.cfg.get('_use_saved_basis', False) else {}
pg_engine = get_pg_engine()
raw, prices = {}, {}
all_tks = list(dict.fromkeys(opt_tickers + list(macro_tickers)))
if all_tks:
ph = ",".join(["?"] * len(all_tks)) if pg_engine.name == 'sqlite' else ",".join(["%s"] * len(all_tks))
query = f"SELECT ticker, date, close_price FROM daily_prices WHERE ticker IN ({ph}) ORDER BY date"
try:
df_all = pd.read_sql(query, pg_engine, params=tuple(all_tks))
if not df_all.empty:
df_all['date'] = pd.to_datetime(df_all['date'])
for t, group in df_all.groupby('ticker'):
group = group.drop_duplicates(subset=['date'], keep='last')
prices[t] = group.iloc[-1]['close_price']
raw[t] = group.set_index('date')['close_price']
except Exception as e:
logger.warning(f"DB read failed, returning empty context: {e}")
bench_rets = raw[eq_bench].pct_change().dropna() if eq_bench in raw else pd.Series(dtype=float)
vol_raw = raw.get(vol_bench, None)
MIN_DAYS_SHORT = self.trading_days * 2
all_rets = {}
for t in opt_tickers:
if t not in raw: continue
s = raw[t].pct_change().dropna(how='all')
if len(s) >= MIN_DAYS_SHORT: all_rets[t] = s
if not all_rets: raise SystemExit("No usable tickers with enough history.")
returns_df = pd.DataFrame(all_rets)
# Pad delisted assets
for col in returns_df.columns:
if returns_df[col].min() <= -0.99:
dead_indices = returns_df[col][returns_df[col] <= -0.99].index
if len(dead_indices) > 0:
dead_idx = dead_indices[0]
returns_df.loc[dead_idx:, col] = returns_df.loc[dead_idx:, col].fillna(0.0)
returns_df = returns_df.dropna()
final_opt_tickers = [t for t in opt_tickers if t in returns_df.columns]
spread_map = build_spread_map(final_opt_tickers, self.cfg.get("sector_map", {}))
if self.cfg.get('return_frequency', 'daily') == 'monthly':
opt_returns_df = build_monthly_returns(returns_df)
bench_rets_monthly = build_monthly_returns(pd.DataFrame({eq_bench: bench_rets}))[eq_bench] if not bench_rets.empty else bench_rets
opt_ff_df = build_monthly_returns(ff_df) if ff_df is not None else None
self.cfg['_trading_periods'] = 12
else:
opt_returns_df, bench_rets_monthly, opt_ff_df = returns_df, bench_rets, ff_df
self.cfg['_trading_periods'] = self.trading_days
DISPLAY_DAYS = self.trading_days * 6
display_df = returns_df.iloc[-DISPLAY_DAYS:] if len(returns_df) > DISPLAY_DAYS else returns_df
bench_display = bench_rets.reindex(display_df.index).dropna()
final_tickers = list(returns_df.columns)
master_state = PortfolioState.build(final_tickers, prices, legacy_state_dict, self.cfg)
OOS_TEST_DAYS = self.trading_days
total_days = len(returns_df)
OOS_TRAIN_DAYS = max(100, total_days - OOS_TEST_DAYS)
train_yrs = OOS_TRAIN_DAYS / self.trading_days
test_yrs = OOS_TEST_DAYS / self.trading_days
return DataSnapshot(
opt_tickers=final_opt_tickers,
opt_returns_df=opt_returns_df,
bench_rets_monthly=bench_rets_monthly,
opt_ff_df=opt_ff_df,
rfr=final_rfr,
vol_raw=vol_raw,
display_df=display_df,
bench_display=bench_display,
master_state=master_state,
spread_map=spread_map,
train_yrs=train_yrs,
test_yrs=test_yrs,
returns_df=returns_df,
bench_rets=bench_rets,
raw=raw,
prices=prices,
eq_bench=eq_bench,
vol_bench=vol_bench,
rfr_bench=rfr_bench
)