stockproject / backtesting /framework /quant_framework.py
harshisageek's picture
Upload folder using huggingface_hub
8e50444 verified
import sys, os
import numpy as np, pandas as pd
from scipy.stats import spearmanr
import warnings; warnings.filterwarnings('ignore')
try:
from v30_causal_engine import evaluate_slice, CAP
except ImportError:
from backtesting.v30_causal_engine import evaluate_slice, CAP
class QuantValidator:
"""
Generalized Quantitative Validation Framework.
Allows institutional-grade auditing of any causal trading strategy.
"""
def __init__(self, dc, spy, vf, daily_ret, strategy_fn, default_params,
signal_fn=None, audit_fn=None, txn_param_name='txn_bps',
rebal_param_name='rebal_days'):
self.dc = dc
self.spy = spy
self.vf = vf
self.daily_ret = daily_ret
self.strategy_fn = strategy_fn
self.default_params = default_params
self.signal_fn = signal_fn
self.audit_fn = audit_fn
self.txn_param_name = txn_param_name
self.rebal_param_name = rebal_param_name
def _run(self, dc_sub, spy_sub, vf_sub, ret_sub, params):
c = self.strategy_fn(dc_sub, spy_sub, vf_sub, ret_sub, **params)
if isinstance(c, dict) and 'curve' in c:
return c['curve']
return c
def _get_ew_baseline(self):
# Quick equal-weight causal baseline
nav = CAP; port_rets = []; hist = []
sma = self.spy.rolling(200).mean()
for i in range(1, len(self.dc)):
if len(port_rets) >= 21:
w = port_rets[-60:] if len(port_rets) >= 60 else port_rets[-21:]
vs = 0.18 / (np.std(w)*np.sqrt(252)+1e-8)
else: vs = 0.5
sp, sm = self.spy.values[i-1], sma.values[i-1]
if pd.isna(sm) or sp <= sm: vs *= 0.50
vs = float(np.clip(vs, 0.05, 1.0))
lr = self.daily_ret.iloc[i][[t for t in self.vf if t in self.daily_ret.columns]].dropna()
day_ret = lr.mean() * vs if len(lr) > 0 else 0.0
nav *= (1 + day_ret)
port_rets.append(day_ret)
hist.append(nav)
return pd.Series(hist, index=self.dc.index[1:len(hist)+1])
def run_all_phases(self, param_grid=None):
self.run_phase1()
self.run_phase2()
self.run_phase3()
if param_grid:
self.run_phase4(param_grid)
self.run_phase5()
def run_phase1(self):
print("\n========================================")
print(" PHASE 1: SIGNAL VALIDATION")
print("========================================")
if self.signal_fn:
print("--- Test 1.1 & 1.2: IC Analysis ---")
raw_signal = self.signal_fn(self.dc, self.spy, self.vf)
fwd = self.dc[self.vf].pct_change(60).shift(-60)
sma = self.spy.rolling(200).mean()
ic_vals, dates, regimes = [], [], []
# Calculate IC every 60 days
for i in range(200, len(self.dc)-60, 60):
sig = raw_signal.iloc[i].dropna()
common = sig.index.intersection(fwd.iloc[i].dropna().index)
if len(common) >= 15:
corr, _ = spearmanr(sig[common].values, fwd.iloc[i][common].values)
if not np.isnan(corr):
ic_vals.append(corr)
dates.append(self.dc.index[i])
regimes.append("ON" if self.spy.iloc[i] > sma.iloc[i] else "OFF")
ic_series = pd.Series(ic_vals, index=dates)
ic_mean = ic_series.mean()
t_stat = ic_mean / (ic_series.std() / np.sqrt(len(ic_series))) if len(ic_series)>0 else 0
print(f"Mean IC: {ic_mean:.4f} | t-stat: {t_stat:.2f}")
print(f"Result: {'PASS' if t_stat > 2.0 else 'FAIL'} (Threshold: t > 2.0)")
print("\nBy Regime:")
for reg in ["ON", "OFF"]:
mask = [r == reg for r in regimes]
sub = ic_series[mask]
if len(sub) > 0:
rt_stat = sub.mean() / (sub.std() / np.sqrt(len(sub)))
print(f"Risk-{reg} days: t-stat = {rt_stat:.2f} (n={len(sub)})")
else:
print("--- Test 1.1 & 1.2: Skipped (No signal_fn provided) ---")
print("\n--- Test 1.3: Baseline Comparison (Equal Weight) ---")
c_orig = self._run(self.dc, self.spy, self.vf, self.daily_ret, self.default_params)
m_orig = evaluate_slice(c_orig, "2008-01-01", "2025-12-31")
c_ew = self._get_ew_baseline()
m_ew = evaluate_slice(c_ew, "2008-01-01", "2025-12-31")
print(f"Strategy Sharpe: {m_orig['sharpe']:.4f}")
print(f"Eq-Weight Sharpe: {m_ew['sharpe']:.4f}")
diff = m_orig['sharpe'] - m_ew['sharpe']
print(f"Excess Sharpe: {diff:+.4f}")
print(f"Result: {'PASS' if diff > 0.05 else 'FAIL'} (Threshold: > +0.05)")
def run_phase2(self):
print("\n========================================")
print(" PHASE 2: BACKTEST INTEGRITY")
print("========================================")
print("--- Test 2.1: Strict Train/Test Split ---")
c = self._run(self.dc, self.spy, self.vf, self.daily_ret, self.default_params)
m_train = evaluate_slice(c, "2008-01-01", "2018-12-31")
m_test = evaluate_slice(c, "2019-01-01", "2025-12-31")
print(f"Train Sharpe (2008-2018): {m_train['sharpe']:.4f}")
print(f"Test Sharpe (2019-2025): {m_test['sharpe']:.4f}")
diff = m_test['sharpe'] - m_train['sharpe']
print(f"Difference: {diff:+.4f}")
print(f"Result: {'PASS' if diff > -0.20 else 'FAIL'} (Test decay must not exceed -0.20)")
print("\n--- Test 2.2: Start Date Sensitivity ---")
sharpes = []
rebal = self.default_params.get(self.rebal_param_name, 60)
step = 4 if rebal <= 60 else 6
offsets = list(range(0, rebal, step))
print(f"Running {len(offsets)} offsets...", end="", flush=True)
for off in offsets:
c_off = self._run(self.dc.iloc[off:], self.spy.iloc[off:], self.vf, self.daily_ret.iloc[off:], self.default_params)
m = evaluate_slice(c_off, "2008-01-01", "2025-12-31")
sharpes.append(m['sharpe'])
print(".", end="", flush=True)
print()
s_mean, s_range = np.mean(sharpes), max(sharpes) - min(sharpes)
print(f"Mean Sharpe: {s_mean:.4f} | Range: {s_range:.4f}")
print(f"Result: {'PASS' if s_range < 0.20 else 'FAIL'} (Path dependency range < 0.20)")
print("\n--- Test 2.3: Survivorship Bias (Poison Universe) ---")
c_base = self._run(self.dc, self.spy, self.vf, self.daily_ret, self.default_params)
base_cagr = evaluate_slice(c_base, "2008-01-01", "2025-12-31")['cagr']
POISON_TICKERS = ["PTON", "CLOV", "NKLA", "QS", "SPCE", "SKLZ", "GOEV", "FSR", "SOFI", "BYND", "ZM", "DOCU", "TDOC", "UPST", "AFRM", "CHPT"]
import yfinance as yf
poison_raw = yf.download(POISON_TICKERS, start="2006-01-01", end="2025-12-31", progress=False)
if isinstance(poison_raw.columns, pd.MultiIndex):
lvl0 = poison_raw.columns.get_level_values(0).unique().tolist()
poison_close = poison_raw["Close"] if "Close" in lvl0 else poison_raw
if isinstance(poison_close.columns, pd.MultiIndex): poison_close.columns = poison_close.columns.get_level_values(-1)
else: poison_close = poison_raw
valid_poison = [t for t in POISON_TICKERS if t in poison_close.columns and poison_close[t].notna().sum() > 100]
dc_poisoned = self.dc.copy()
for t in valid_poison:
if t not in dc_poisoned.columns:
dc_poisoned[t] = poison_close[t].reindex(dc_poisoned.index).ffill()
poison_vf = list(dict.fromkeys(list(self.vf) + valid_poison))
c_poison = self._run(dc_poisoned, self.spy, poison_vf, dc_poisoned.pct_change(), self.default_params)
poison_cagr = evaluate_slice(c_poison, "2008-01-01", "2025-12-31")['cagr']
print(f"Baseline CAGR: {base_cagr:.1f}% | Poison CAGR: {poison_cagr:.1f}%")
diff = base_cagr - poison_cagr
print(f"Result: {'PASS' if diff <= 3.0 else 'FAIL'} (Decay <= 3.0%)")
def run_phase3(self):
print("\n========================================")
print(" PHASE 3: ROBUSTNESS TESTING")
print("========================================")
print("--- Test 3.1: Transaction Cost Stress ---")
if self.txn_param_name not in self.default_params:
print(f"Error: Txn param '{self.txn_param_name}' not in default params. Cannot stress test.")
else:
for bps in [20, 40, 60]:
p = self.default_params.copy()
p[self.txn_param_name] = bps
c = self._run(self.dc, self.spy, self.vf, self.daily_ret, p)
m = evaluate_slice(c, "2008-01-01", "2025-12-31")
print(f"At {bps} bps: Sharpe = {m['sharpe']:.4f}")
if bps == 60:
print(f"Result: {'PASS' if m['sharpe'] >= 0.60 else 'FAIL'} (Threshold > 0.60 at 60bps)")
print("\n--- Test 3.5: Momentum Crash Autopsy ---")
c = self._run(self.dc, self.spy, self.vf, self.daily_ret, self.default_params)
crashes = {
"GFC (Lehman)": ("2008-09-01", "2009-03-31"),
"China Shock": ("2015-06-01", "2015-09-30"),
"Rate Shock": ("2022-01-01", "2022-12-31")
}
for name, (start, end) in crashes.items():
if start >= c.index[0].strftime('%Y-%m-%d') and end <= c.index[-1].strftime('%Y-%m-%d'):
c_slice = c.loc[start:end]
s_slice = self.spy.loc[start:end]
if len(c_slice) == 0: continue
c_ret = (c_slice.iloc[-1] / c_slice.iloc[0]) - 1
s_ret = (s_slice.iloc[-1] / s_slice.iloc[0]) - 1
print(f"{name:<15}: Strat {c_ret*100:>+6.1f}% | SPY {s_ret*100:>+6.1f}%")
print("Result: PASS (Quantified)")
def run_phase4(self, param_grid):
print("\n========================================")
print(" PHASE 4: PARAMETER ROBUSTNESS")
print("========================================")
print("Sweeping parameters...")
import itertools
keys = list(param_grid.keys())
combos = list(itertools.product(*(param_grid[k] for k in keys)))
res = {}
for combo in combos:
p = self.default_params.copy()
for k, v in zip(keys, combo):
p[k] = v
c = self._run(self.dc, self.spy, self.vf, self.daily_ret, p)
m = evaluate_slice(c, "2008-01-01", "2025-12-31")
res[combo] = m['sharpe']
for combo, sharpe in res.items():
print(f"Params {dict(zip(keys, combo))}: Sharpe {sharpe:.4f}")
best = max(res, key=res.get)
print(f"\nBest Params {dict(zip(keys, best))} -> Evaluating Overfit Risk...")
p = self.default_params.copy()
for k, v in zip(keys, best): p[k] = v
c = self._run(self.dc, self.spy, self.vf, self.daily_ret, p)
m_train = evaluate_slice(c, "2008-01-01", "2018-12-31")
m_test = evaluate_slice(c, "2019-01-01", "2025-12-31")
print(f"Train (2008-2018): {m_train['sharpe']:.4f}")
print(f"Test (2019-2025): {m_test['sharpe']:.4f}")
print(f"Result: {'PASS' if m_test['sharpe'] >= m_train['sharpe'] - 0.15 else 'FAIL'} (Test didn't collapse)")
def run_phase5(self):
print("\n========================================")
print(" PHASE 5: FORWARD VALIDITY")
print("========================================")
c = self._run(self.dc, self.spy, self.vf, self.daily_ret, self.default_params)
regimes = {
"2020 V-Crash": ("2020-02-01", "2020-12-31"),
"2022 Slow Bear": ("2022-01-01", "2022-12-31")
}
fails = 0
for name, (start, end) in regimes.items():
if start >= c.index[0].strftime('%Y-%m-%d') and end <= c.index[-1].strftime('%Y-%m-%d'):
c_slice, s_slice = c.loc[start:end], self.spy.loc[start:end]
c_ret = (c_slice.iloc[-1] / c_slice.iloc[0]) - 1
s_ret = (s_slice.iloc[-1] / s_slice.iloc[0]) - 1
print(f"{name}: Strat {c_ret*100:+.1f}% | SPY {s_ret*100:+.1f}%")
if name == "2022 Slow Bear" and c_ret < s_ret:
print(" -> FAIL: Lost more than SPY")
fails += 1
elif name == "2020 V-Crash" and c_ret < s_ret * 0.30:
print(" -> FAIL: Captured < 30% of SPY recovery")
fails += 1
print(f"\nResult: {'PASS' if fails == 0 else 'FAIL'} (Regime tests)")