from dataclasses import dataclass from typing import List, Dict, Tuple, Optional import pandas as pd import logging from data import fetch_fama_french_factors, fetch_data, fetch_risk_free_rate, fetch_risk_free_series, build_monthly_returns from database import get_pg_engine from core_engine import build_spread_map, load_portfolio_state_dict from core_types import PortfolioState from config import AppConfig logger = logging.getLogger('portfolio_engine') @dataclass class DataSnapshot: opt_tickers: List[str] opt_returns_df: pd.DataFrame bench_rets_monthly: pd.Series opt_ff_df: Optional[pd.DataFrame] rfr: float vol_raw: Optional[pd.Series] display_df: pd.DataFrame bench_display: pd.Series master_state: PortfolioState spread_map: Dict[str, float] train_yrs: float test_yrs: float # Full-resolution data needed by the pipeline orchestrator returns_df: pd.DataFrame = None bench_rets: pd.Series = None raw: Dict = None prices: Dict = None eq_bench: str = "SPY" vol_bench: str = "^VIX" rfr_bench: str = "^TNX" class DataRepository: """ Repository layer responsible for fetching, cleaning, and assembling all market data, benchmarks, and portfolio state required by the engine. """ def __init__(self, cfg: AppConfig): self.cfg = cfg self.trading_days = self.cfg.get("trading_days_per_year", 252) def fetch_all(self, input_tickers: List[str], model_id: int) -> DataSnapshot: b = self.cfg.get("benchmarks", {}) eq_bench = b.get("equity", "SPY") vol_bench = b.get("volatility", "^VIX") rfr_bench = b.get("risk_free", "^TNX") ff_df = fetch_fama_french_factors() if model_id in [4, 5] else None valid_tickers = fetch_data(input_tickers, b) self.cfg["risk_free_rate"] = fetch_risk_free_rate(rfr_bench, self.cfg.get("risk_free_rate", 0.05)) rfr_series = fetch_risk_free_series(rfr_bench) final_rfr = rfr_series if not rfr_series.empty else self.cfg["risk_free_rate"] macro_tickers = {eq_bench, vol_bench, rfr_bench, "^IRX"} opt_tickers = [t for t in valid_tickers if t not in macro_tickers or t in input_tickers] if not opt_tickers: raise SystemExit("No Usable Tickers found.") legacy_state_dict = load_portfolio_state_dict() if self.cfg.get('_use_saved_basis', False) else {} pg_engine = get_pg_engine() raw, prices = {}, {} all_tks = list(dict.fromkeys(opt_tickers + list(macro_tickers))) if all_tks: ph = ",".join(["?"] * len(all_tks)) if pg_engine.name == 'sqlite' else ",".join(["%s"] * len(all_tks)) query = f"SELECT ticker, date, close_price FROM daily_prices WHERE ticker IN ({ph}) ORDER BY date" try: df_all = pd.read_sql(query, pg_engine, params=tuple(all_tks)) if not df_all.empty: df_all['date'] = pd.to_datetime(df_all['date']) for t, group in df_all.groupby('ticker'): group = group.drop_duplicates(subset=['date'], keep='last') prices[t] = group.iloc[-1]['close_price'] raw[t] = group.set_index('date')['close_price'] except Exception as e: logger.warning(f"DB read failed, returning empty context: {e}") bench_rets = raw[eq_bench].pct_change().dropna() if eq_bench in raw else pd.Series(dtype=float) vol_raw = raw.get(vol_bench, None) MIN_DAYS_SHORT = self.trading_days * 2 all_rets = {} for t in opt_tickers: if t not in raw: continue s = raw[t].pct_change().dropna(how='all') if len(s) >= MIN_DAYS_SHORT: all_rets[t] = s if not all_rets: raise SystemExit("No usable tickers with enough history.") returns_df = pd.DataFrame(all_rets) # Pad delisted assets for col in returns_df.columns: if returns_df[col].min() <= -0.99: dead_indices = returns_df[col][returns_df[col] <= -0.99].index if len(dead_indices) > 0: dead_idx = dead_indices[0] returns_df.loc[dead_idx:, col] = returns_df.loc[dead_idx:, col].fillna(0.0) returns_df = returns_df.dropna() final_opt_tickers = [t for t in opt_tickers if t in returns_df.columns] spread_map = build_spread_map(final_opt_tickers, self.cfg.get("sector_map", {})) if self.cfg.get('return_frequency', 'daily') == 'monthly': opt_returns_df = build_monthly_returns(returns_df) bench_rets_monthly = build_monthly_returns(pd.DataFrame({eq_bench: bench_rets}))[eq_bench] if not bench_rets.empty else bench_rets opt_ff_df = build_monthly_returns(ff_df) if ff_df is not None else None self.cfg['_trading_periods'] = 12 else: opt_returns_df, bench_rets_monthly, opt_ff_df = returns_df, bench_rets, ff_df self.cfg['_trading_periods'] = self.trading_days DISPLAY_DAYS = self.trading_days * 6 display_df = returns_df.iloc[-DISPLAY_DAYS:] if len(returns_df) > DISPLAY_DAYS else returns_df bench_display = bench_rets.reindex(display_df.index).dropna() final_tickers = list(returns_df.columns) master_state = PortfolioState.build(final_tickers, prices, legacy_state_dict, self.cfg) OOS_TEST_DAYS = self.trading_days total_days = len(returns_df) OOS_TRAIN_DAYS = max(100, total_days - OOS_TEST_DAYS) train_yrs = OOS_TRAIN_DAYS / self.trading_days test_yrs = OOS_TEST_DAYS / self.trading_days return DataSnapshot( opt_tickers=final_opt_tickers, opt_returns_df=opt_returns_df, bench_rets_monthly=bench_rets_monthly, opt_ff_df=opt_ff_df, rfr=final_rfr, vol_raw=vol_raw, display_df=display_df, bench_display=bench_display, master_state=master_state, spread_map=spread_map, train_yrs=train_yrs, test_yrs=test_yrs, returns_df=returns_df, bench_rets=bench_rets, raw=raw, prices=prices, eq_bench=eq_bench, vol_bench=vol_bench, rfr_bench=rfr_bench )