Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import List, Dict, Tuple, Optional | |
| import pandas as pd | |
| import logging | |
| from data import fetch_fama_french_factors, fetch_data, fetch_risk_free_rate, fetch_risk_free_series, build_monthly_returns | |
| from database import get_pg_engine | |
| from core_engine import build_spread_map, load_portfolio_state_dict | |
| from core_types import PortfolioState | |
| from config import AppConfig | |
| logger = logging.getLogger('portfolio_engine') | |
| class DataSnapshot: | |
| opt_tickers: List[str] | |
| opt_returns_df: pd.DataFrame | |
| bench_rets_monthly: pd.Series | |
| opt_ff_df: Optional[pd.DataFrame] | |
| rfr: float | |
| vol_raw: Optional[pd.Series] | |
| display_df: pd.DataFrame | |
| bench_display: pd.Series | |
| master_state: PortfolioState | |
| spread_map: Dict[str, float] | |
| train_yrs: float | |
| test_yrs: float | |
| # Full-resolution data needed by the pipeline orchestrator | |
| returns_df: pd.DataFrame = None | |
| bench_rets: pd.Series = None | |
| raw: Dict = None | |
| prices: Dict = None | |
| eq_bench: str = "SPY" | |
| vol_bench: str = "^VIX" | |
| rfr_bench: str = "^TNX" | |
| class DataRepository: | |
| """ | |
| Repository layer responsible for fetching, cleaning, and assembling | |
| all market data, benchmarks, and portfolio state required by the engine. | |
| """ | |
| def __init__(self, cfg: AppConfig): | |
| self.cfg = cfg | |
| self.trading_days = self.cfg.get("trading_days_per_year", 252) | |
| def fetch_all(self, input_tickers: List[str], model_id: int) -> DataSnapshot: | |
| b = self.cfg.get("benchmarks", {}) | |
| eq_bench = b.get("equity", "SPY") | |
| vol_bench = b.get("volatility", "^VIX") | |
| rfr_bench = b.get("risk_free", "^TNX") | |
| ff_df = fetch_fama_french_factors() if model_id in [4, 5] else None | |
| valid_tickers = fetch_data(input_tickers, b) | |
| self.cfg["risk_free_rate"] = fetch_risk_free_rate(rfr_bench, self.cfg.get("risk_free_rate", 0.05)) | |
| rfr_series = fetch_risk_free_series(rfr_bench) | |
| final_rfr = rfr_series if not rfr_series.empty else self.cfg["risk_free_rate"] | |
| macro_tickers = {eq_bench, vol_bench, rfr_bench, "^IRX"} | |
| opt_tickers = [t for t in valid_tickers if t not in macro_tickers or t in input_tickers] | |
| if not opt_tickers: | |
| raise SystemExit("No Usable Tickers found.") | |
| legacy_state_dict = load_portfolio_state_dict() if self.cfg.get('_use_saved_basis', False) else {} | |
| pg_engine = get_pg_engine() | |
| raw, prices = {}, {} | |
| all_tks = list(dict.fromkeys(opt_tickers + list(macro_tickers))) | |
| if all_tks: | |
| ph = ",".join(["?"] * len(all_tks)) if pg_engine.name == 'sqlite' else ",".join(["%s"] * len(all_tks)) | |
| query = f"SELECT ticker, date, close_price FROM daily_prices WHERE ticker IN ({ph}) ORDER BY date" | |
| try: | |
| df_all = pd.read_sql(query, pg_engine, params=tuple(all_tks)) | |
| if not df_all.empty: | |
| df_all['date'] = pd.to_datetime(df_all['date']) | |
| for t, group in df_all.groupby('ticker'): | |
| group = group.drop_duplicates(subset=['date'], keep='last') | |
| prices[t] = group.iloc[-1]['close_price'] | |
| raw[t] = group.set_index('date')['close_price'] | |
| except Exception as e: | |
| logger.warning(f"DB read failed, returning empty context: {e}") | |
| bench_rets = raw[eq_bench].pct_change().dropna() if eq_bench in raw else pd.Series(dtype=float) | |
| vol_raw = raw.get(vol_bench, None) | |
| MIN_DAYS_SHORT = self.trading_days * 2 | |
| all_rets = {} | |
| for t in opt_tickers: | |
| if t not in raw: continue | |
| s = raw[t].pct_change().dropna(how='all') | |
| if len(s) >= MIN_DAYS_SHORT: all_rets[t] = s | |
| if not all_rets: raise SystemExit("No usable tickers with enough history.") | |
| returns_df = pd.DataFrame(all_rets) | |
| # Pad delisted assets | |
| for col in returns_df.columns: | |
| if returns_df[col].min() <= -0.99: | |
| dead_indices = returns_df[col][returns_df[col] <= -0.99].index | |
| if len(dead_indices) > 0: | |
| dead_idx = dead_indices[0] | |
| returns_df.loc[dead_idx:, col] = returns_df.loc[dead_idx:, col].fillna(0.0) | |
| returns_df = returns_df.dropna() | |
| final_opt_tickers = [t for t in opt_tickers if t in returns_df.columns] | |
| spread_map = build_spread_map(final_opt_tickers, self.cfg.get("sector_map", {})) | |
| if self.cfg.get('return_frequency', 'daily') == 'monthly': | |
| opt_returns_df = build_monthly_returns(returns_df) | |
| bench_rets_monthly = build_monthly_returns(pd.DataFrame({eq_bench: bench_rets}))[eq_bench] if not bench_rets.empty else bench_rets | |
| opt_ff_df = build_monthly_returns(ff_df) if ff_df is not None else None | |
| self.cfg['_trading_periods'] = 12 | |
| else: | |
| opt_returns_df, bench_rets_monthly, opt_ff_df = returns_df, bench_rets, ff_df | |
| self.cfg['_trading_periods'] = self.trading_days | |
| DISPLAY_DAYS = self.trading_days * 6 | |
| display_df = returns_df.iloc[-DISPLAY_DAYS:] if len(returns_df) > DISPLAY_DAYS else returns_df | |
| bench_display = bench_rets.reindex(display_df.index).dropna() | |
| final_tickers = list(returns_df.columns) | |
| master_state = PortfolioState.build(final_tickers, prices, legacy_state_dict, self.cfg) | |
| OOS_TEST_DAYS = self.trading_days | |
| total_days = len(returns_df) | |
| OOS_TRAIN_DAYS = max(100, total_days - OOS_TEST_DAYS) | |
| train_yrs = OOS_TRAIN_DAYS / self.trading_days | |
| test_yrs = OOS_TEST_DAYS / self.trading_days | |
| return DataSnapshot( | |
| opt_tickers=final_opt_tickers, | |
| opt_returns_df=opt_returns_df, | |
| bench_rets_monthly=bench_rets_monthly, | |
| opt_ff_df=opt_ff_df, | |
| rfr=final_rfr, | |
| vol_raw=vol_raw, | |
| display_df=display_df, | |
| bench_display=bench_display, | |
| master_state=master_state, | |
| spread_map=spread_map, | |
| train_yrs=train_yrs, | |
| test_yrs=test_yrs, | |
| returns_df=returns_df, | |
| bench_rets=bench_rets, | |
| raw=raw, | |
| prices=prices, | |
| eq_bench=eq_bench, | |
| vol_bench=vol_bench, | |
| rfr_bench=rfr_bench | |
| ) | |