Spaces:
Running
Running
File size: 9,420 Bytes
5822cc5 70ef685 5822cc5 ed8a768 5822cc5 ed8a768 5822cc5 ed8a768 5822cc5 dd48f0c 28f4a26 dd48f0c 28f4a26 dd48f0c 28f4a26 dd48f0c 5822cc5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | # evaluate.py
# Runs backtest on test set, computes performance vs SPY/AGG benchmarks.
# Usage:
# python evaluate.py --start_year 2015 --fee_bps 10 --tsl 10 --z 1.1
import argparse
import json
import os
from datetime import datetime
import numpy as np
import pandas as pd
import config
from data_download import load_local
from features import build_features
from env import make_splits
from agent import DQNAgent
WEIGHTS_PATH = os.path.join(config.MODELS_DIR, "dqn_best.pt")
SUMMARY_PATH = os.path.join(config.MODELS_DIR, "training_summary.json")
EVAL_PATH = "evaluation_results.json"
def _sharpe(rets: np.ndarray, tbill: float = 0.036) -> float:
excess = rets - tbill / 252
return float(excess.mean() / (excess.std() + 1e-9) * np.sqrt(252))
def _max_drawdown(equity: np.ndarray) -> float:
peak = np.maximum.accumulate(equity)
dd = (equity - peak) / (peak + 1e-9)
return float(dd.min())
def _calmar(ann_ret: float, max_dd: float) -> float:
return ann_ret / (abs(max_dd) + 1e-9)
def run_backtest(start_year: int,
fee_bps: int = config.DEFAULT_FEE_BPS,
tsl_pct: float = config.DEFAULT_TSL_PCT,
z_reentry: float = config.DEFAULT_Z_REENTRY) -> dict:
print(f"\n{'='*60}")
print(f" P2-ETF-DQN-ENGINE β Evaluation")
print(f" Start year : {start_year}")
print(f"{'='*60}")
# ββ Load ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
data = load_local()
if not data:
raise RuntimeError("No local data. Run data_download.py first.")
etf_prices = data["etf_prices"]
macro = data["macro"]
feat_df = build_features(etf_prices, macro, start_year=start_year)
fee_pct = fee_bps / 10_000
_, _, test_env = make_splits(feat_df, etf_prices, macro, start_year,
fee_pct=fee_pct)
# ββ Load agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
agent = DQNAgent(state_size=test_env.observation_size)
agent.load(WEIGHTS_PATH)
# ββ Backtest with TSL βββββββββββββββββββββββββββββββββββββββββββββββββββββ
# FIX: force reset to start_idx (not random) for deterministic evaluation
test_env.current_idx = test_env.start_idx
test_env.held_action = 0
test_env.peak_equity = 1.0
test_env.equity = 1.0
test_env.is_stopped_out = False
state = test_env._get_state()
rets = []
allocations = []
q_vals_log = []
equity_curve = [1.0]
peak_equity = 1.0
is_stopped = False
done = False
while not done:
q_values = agent.q_values(state)
z_scores = _q_zscore(q_values)
if is_stopped:
# Re-enter when best z-score clears threshold
if z_scores.max() >= z_reentry:
is_stopped = False
action = int(q_values.argmax())
else:
action = 0 # stay in CASH
else:
action = int(q_values.argmax())
next_state, reward, done, info = test_env.step(action)
# TSL check
eq = test_env.equity
if action != 0:
if eq > peak_equity:
peak_equity = eq
if eq < peak_equity * (1 - tsl_pct / 100):
is_stopped = True
rets.append(info["day_ret"])
allocations.append(config.ACTIONS[action])
q_vals_log.append(q_values.tolist())
equity_curve.append(eq)
state = next_state
rets = np.array(rets)
equity = np.array(equity_curve[1:])
# ββ Benchmark returns (over same test period) βββββββββββββββββββββββββββββ
test_dates = feat_df.index[
int(len(feat_df) * (config.TRAIN_SPLIT + config.VAL_SPLIT)):
]
bench = {}
bench_ann = {}
bench_equity = {}
for b in config.BENCHMARKS:
if b in etf_prices.columns:
bp = etf_prices[b].reindex(test_dates).ffill().pct_change().dropna()
bench[b] = _sharpe(bp.values)
cum = float((1 + bp).prod())
bench_ann[b] = round(float(cum ** (252 / max(len(bp), 1)) - 1), 4)
beq = (1 + bp).cumprod().values
bench_equity[b] = [round(float(v / beq[0]), 4) for v in beq]
# ββ Metrics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
n_days = len(rets)
ann_ret = float((equity[-1]) ** (252 / n_days) - 1) if n_days > 0 else 0.0
sharpe = _sharpe(rets)
max_dd = _max_drawdown(equity)
calmar = _calmar(ann_ret, max_dd)
hit = float((rets > 0).mean())
# Allocation breakdown
alloc_counts = pd.Series(allocations).value_counts(normalize=True).to_dict()
results = dict(
start_year = start_year,
evaluated_at = datetime.now().isoformat(),
n_test_days = n_days,
ann_return = round(ann_ret, 4),
sharpe = round(sharpe, 4),
max_drawdown = round(max_dd, 4),
calmar = round(calmar, 4),
hit_ratio = round(hit, 4),
final_equity = round(float(equity[-1]), 4),
benchmark_sharpe = {k: round(v, 4) for k, v in bench.items()},
benchmark_ann = bench_ann,
benchmark_equity = bench_equity,
test_dates = [str(d.date()) for d in test_dates],
allocation_pct = {k: round(v, 4) for k, v in alloc_counts.items()},
equity_curve = [round(float(e), 4) for e in equity],
allocations = allocations,
fee_bps = fee_bps,
tsl_pct = tsl_pct,
z_reentry = z_reentry,
)
with open(EVAL_PATH, "w") as f:
json.dump(results, f, indent=2)
# ββ Write date-stamped sweep cache if this is a sweep year ββββββββββββββββ
sweep_years = [2008, 2013, 2015, 2017, 2019, 2021]
if start_year in sweep_years:
from datetime import timezone
today_str = datetime.now(timezone.utc).strftime("%Y%m%d")
# Z-score from the LAST step's Q-values (already computed during backtest)
# This is always available β no dependency on predict.py running first
last_q = np.array(q_vals_log[-1]) if q_vals_log else np.zeros(len(config.ACTIONS))
last_z_arr = _q_zscore(last_q)
last_action_idx = int(last_q.argmax())
z_val = float(last_z_arr[last_action_idx])
z_val = z_val if np.isfinite(z_val) else 0.0
# Next signal = the last day's chosen action
next_signal = config.ACTIONS[last_action_idx] if allocations else "CASH"
conviction = ("Very High" if z_val >= 2.0 else
"High" if z_val >= 1.5 else
"Moderate" if z_val >= 1.0 else "Low")
# Most held ETF (excluding CASH) for display
non_cash = {k: v for k, v in alloc_counts.items() if k != "CASH"}
top_held = max(non_cash, key=non_cash.get) if non_cash else next_signal
sweep_payload = {
"signal": next_signal,
"top_held": top_held,
"ann_return": round(ann_ret, 6),
"z_score": round(z_val, 4),
"sharpe": round(sharpe, 4),
"max_dd": round(max_dd, 6),
"conviction": conviction,
"lookback": results.get("lookback", config.LOOKBACK_WINDOW),
"start_year": start_year,
"sweep_date": today_str,
}
os.makedirs("results", exist_ok=True)
sweep_fname = os.path.join("results", f"sweep_{start_year}_{today_str}.json")
with open(sweep_fname, "w") as sf:
json.dump(sweep_payload, sf, indent=2)
print(f" Sweep cache saved β {sweep_fname}")
print(f" Sweep signal: {next_signal} z={z_val:.3f} conviction={conviction}")
print(f"\n Ann. Return : {ann_ret:.2%}")
print(f" Sharpe Ratio : {sharpe:.3f}")
print(f" Max Drawdown : {max_dd:.2%}")
print(f" Calmar Ratio : {calmar:.3f}")
print(f" Hit Ratio : {hit:.1%}")
print(f" Benchmarks : {bench}")
print(f"\n Results saved β {EVAL_PATH}")
return results
def _q_zscore(q_vals: np.ndarray) -> np.ndarray:
mu = q_vals.mean()
std = q_vals.std() + 1e-9
return (q_vals - mu) / std
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--start_year", type=int, default=config.DEFAULT_START_YEAR)
parser.add_argument("--fee_bps", type=int, default=config.DEFAULT_FEE_BPS)
parser.add_argument("--tsl", type=float, default=config.DEFAULT_TSL_PCT)
parser.add_argument("--z", type=float, default=config.DEFAULT_Z_REENTRY)
args = parser.parse_args()
run_backtest(args.start_year, args.fee_bps, args.tsl, args.z)
|