File size: 6,550 Bytes
558db1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import pandas as pd
import logging

from data import fetch_fama_french_factors, fetch_data, fetch_risk_free_rate, fetch_risk_free_series, build_monthly_returns
from database import get_pg_engine
from core_engine import build_spread_map, load_portfolio_state_dict
from core_types import PortfolioState
from config import AppConfig

logger = logging.getLogger('portfolio_engine')

@dataclass
class DataSnapshot:
    opt_tickers: List[str]
    opt_returns_df: pd.DataFrame
    bench_rets_monthly: pd.Series
    opt_ff_df: Optional[pd.DataFrame]
    rfr: float
    vol_raw: Optional[pd.Series]
    display_df: pd.DataFrame
    bench_display: pd.Series
    master_state: PortfolioState
    spread_map: Dict[str, float]
    train_yrs: float
    test_yrs: float
    # Full-resolution data needed by the pipeline orchestrator
    returns_df: pd.DataFrame = None
    bench_rets: pd.Series = None
    raw: Dict = None
    prices: Dict = None
    eq_bench: str = "SPY"
    vol_bench: str = "^VIX"
    rfr_bench: str = "^TNX"

class DataRepository:
    """
    Repository layer responsible for fetching, cleaning, and assembling 
    all market data, benchmarks, and portfolio state required by the engine.
    """
    def __init__(self, cfg: AppConfig):
        self.cfg = cfg
        self.trading_days = self.cfg.get("trading_days_per_year", 252)

    def fetch_all(self, input_tickers: List[str], model_id: int) -> DataSnapshot:
        b = self.cfg.get("benchmarks", {})
        eq_bench = b.get("equity", "SPY")
        vol_bench = b.get("volatility", "^VIX")
        rfr_bench = b.get("risk_free", "^TNX")
        
        ff_df = fetch_fama_french_factors() if model_id in [4, 5] else None
        valid_tickers = fetch_data(input_tickers, b)
        
        self.cfg["risk_free_rate"] = fetch_risk_free_rate(rfr_bench, self.cfg.get("risk_free_rate", 0.05))
        rfr_series = fetch_risk_free_series(rfr_bench)
        final_rfr = rfr_series if not rfr_series.empty else self.cfg["risk_free_rate"]
        
        macro_tickers = {eq_bench, vol_bench, rfr_bench, "^IRX"}
        opt_tickers = [t for t in valid_tickers if t not in macro_tickers or t in input_tickers]
        if not opt_tickers:
            raise SystemExit("No Usable Tickers found.")
            
        legacy_state_dict = load_portfolio_state_dict() if self.cfg.get('_use_saved_basis', False) else {}
        
        pg_engine = get_pg_engine()
        raw, prices = {}, {}
        all_tks = list(dict.fromkeys(opt_tickers + list(macro_tickers)))
        
        if all_tks:
            ph = ",".join(["?"] * len(all_tks)) if pg_engine.name == 'sqlite' else ",".join(["%s"] * len(all_tks))
            query = f"SELECT ticker, date, close_price FROM daily_prices WHERE ticker IN ({ph}) ORDER BY date"
            try:
                df_all = pd.read_sql(query, pg_engine, params=tuple(all_tks))
                if not df_all.empty:
                    df_all['date'] = pd.to_datetime(df_all['date'])
                    for t, group in df_all.groupby('ticker'):
                        group = group.drop_duplicates(subset=['date'], keep='last')
                        prices[t] = group.iloc[-1]['close_price']
                        raw[t] = group.set_index('date')['close_price']
            except Exception as e:
                logger.warning(f"DB read failed, returning empty context: {e}")
                
        bench_rets = raw[eq_bench].pct_change().dropna() if eq_bench in raw else pd.Series(dtype=float)
        vol_raw = raw.get(vol_bench, None)
        
        MIN_DAYS_SHORT = self.trading_days * 2
        all_rets = {}
        for t in opt_tickers:
            if t not in raw: continue
            s = raw[t].pct_change().dropna(how='all')
            if len(s) >= MIN_DAYS_SHORT: all_rets[t] = s
            
        if not all_rets: raise SystemExit("No usable tickers with enough history.")
        
        returns_df = pd.DataFrame(all_rets)
        # Pad delisted assets
        for col in returns_df.columns:
            if returns_df[col].min() <= -0.99:
                dead_indices = returns_df[col][returns_df[col] <= -0.99].index
                if len(dead_indices) > 0:
                    dead_idx = dead_indices[0]
                    returns_df.loc[dead_idx:, col] = returns_df.loc[dead_idx:, col].fillna(0.0)
                    
        returns_df = returns_df.dropna()
        final_opt_tickers = [t for t in opt_tickers if t in returns_df.columns]
        spread_map = build_spread_map(final_opt_tickers, self.cfg.get("sector_map", {}))
        
        if self.cfg.get('return_frequency', 'daily') == 'monthly':
            opt_returns_df = build_monthly_returns(returns_df)
            bench_rets_monthly = build_monthly_returns(pd.DataFrame({eq_bench: bench_rets}))[eq_bench] if not bench_rets.empty else bench_rets
            opt_ff_df = build_monthly_returns(ff_df) if ff_df is not None else None
            self.cfg['_trading_periods'] = 12
        else:
            opt_returns_df, bench_rets_monthly, opt_ff_df = returns_df, bench_rets, ff_df
            self.cfg['_trading_periods'] = self.trading_days
            
        DISPLAY_DAYS = self.trading_days * 6
        display_df = returns_df.iloc[-DISPLAY_DAYS:] if len(returns_df) > DISPLAY_DAYS else returns_df
        bench_display = bench_rets.reindex(display_df.index).dropna()
        
        final_tickers = list(returns_df.columns)
        master_state = PortfolioState.build(final_tickers, prices, legacy_state_dict, self.cfg)
        
        OOS_TEST_DAYS = self.trading_days
        total_days = len(returns_df)
        OOS_TRAIN_DAYS = max(100, total_days - OOS_TEST_DAYS)
        train_yrs = OOS_TRAIN_DAYS / self.trading_days
        test_yrs = OOS_TEST_DAYS / self.trading_days

        return DataSnapshot(
            opt_tickers=final_opt_tickers,
            opt_returns_df=opt_returns_df,
            bench_rets_monthly=bench_rets_monthly,
            opt_ff_df=opt_ff_df,
            rfr=final_rfr,
            vol_raw=vol_raw,
            display_df=display_df,
            bench_display=bench_display,
            master_state=master_state,
            spread_map=spread_map,
            train_yrs=train_yrs,
            test_yrs=test_yrs,
            returns_df=returns_df,
            bench_rets=bench_rets,
            raw=raw,
            prices=prices,
            eq_bench=eq_bench,
            vol_bench=vol_bench,
            rfr_bench=rfr_bench
        )