robust-AAE / scripts /run_alpha_robustness_matrix.py
PuLam's picture
Add standalone alpha robustness matrix bundle
79e6483 verified
#!/usr/bin/env python3
"""Run a standard alpha robustness matrix on top of the standalone robust backtester."""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pandas as pd
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
RUNNER_PATH = PROJECT_ROOT / "deploy" / "v2" / "jsonl_alpha_robustness.py"
@dataclass(frozen=True)
class MatrixCase:
case_id: str
case_group: str
description: str
params: dict[str, Any]
def _baseline_params() -> dict[str, Any]:
return {
"backtest_engine": "custom",
"top_k": 5,
"rebalance_freq": 5,
"custom_weight_mode": "equal",
"position_size": 1.0,
"max_pos_each_stock": 0.2,
"max_daily_volume_participation": 0.0,
"max_daily_amount_participation": 0.0,
"buy_fee": 0.0013,
"sell_fee": 0.0013,
"enforce_cash_limit": True,
"score_transform": "identity",
"score_clip": 3.0,
"universe_filter": "none",
"universe_top_n": 0,
"universe_lookback_days": 20,
"redistribute_unfilled_cash": False,
}
def _make_standard_cases() -> list[MatrixCase]:
base = _baseline_params()
alpha_score_cap20 = {
**base,
"custom_weight_mode": "alpha_score",
"redistribute_unfilled_cash": True,
}
cases = [
MatrixCase("baseline_replay", "baseline", "Baseline replay with fair equal-weight + 20% cap", base),
MatrixCase("topk_10", "top_k", "TOP_K sensitivity: 10", {**base, "top_k": 10}),
MatrixCase("topk_15", "top_k", "TOP_K sensitivity: 15", {**base, "top_k": 15}),
MatrixCase("weight_alpha_score_cap20", "weighting", "Alpha-score weighting with 20% cap", alpha_score_cap20),
MatrixCase(
"weight_alpha_score_no_cap",
"weighting",
"Alpha-score weighting with no per-name cap",
{
**alpha_score_cap20,
"max_pos_each_stock": 1.0,
},
),
MatrixCase("fee_0bps", "fee", "Fee sensitivity: 0 bps per side", {**base, "buy_fee": 0.0, "sell_fee": 0.0}),
MatrixCase("fee_10bps", "fee", "Fee sensitivity: 10 bps per side", {**base, "buy_fee": 0.0010, "sell_fee": 0.0010}),
MatrixCase("fee_20bps", "fee", "Fee sensitivity: 20 bps per side", {**base, "buy_fee": 0.0020, "sell_fee": 0.0020}),
MatrixCase("fee_30bps", "fee", "Fee sensitivity: 30 bps per side", {**base, "buy_fee": 0.0030, "sell_fee": 0.0030}),
MatrixCase("fee_50bps", "fee", "Fee sensitivity: 50 bps per side", {**base, "buy_fee": 0.0050, "sell_fee": 0.0050}),
MatrixCase("rebalance_10d", "rebalance", "Rebalance sensitivity: every 10 trading days", {**base, "rebalance_freq": 10}),
MatrixCase("rebalance_20d", "rebalance", "Rebalance sensitivity: every 20 trading days", {**base, "rebalance_freq": 20}),
MatrixCase(
"score_rank",
"score_transform",
"Score-transform robustness: rank transform under alpha-score weighting",
{
**alpha_score_cap20,
"score_transform": "rank",
},
),
MatrixCase(
"score_zscore",
"score_transform",
"Score-transform robustness: zscore transform under alpha-score weighting",
{
**alpha_score_cap20,
"score_transform": "zscore",
},
),
MatrixCase(
"score_rank_zscore",
"score_transform",
"Score-transform robustness: rank_zscore under alpha-score weighting",
{
**alpha_score_cap20,
"score_transform": "rank_zscore",
},
),
MatrixCase(
"frozen_recent_2026_ytd",
"frozen_recent",
"Frozen recent monitor on 2026 YTD",
{
**base,
"start_date": "2026-01-01",
},
),
]
return cases
def _filter_cases(cases: list[MatrixCase], case_filter: set[str] | None, case_limit: int) -> list[MatrixCase]:
filtered = [case for case in cases if not case_filter or case.case_id in case_filter]
if case_limit > 0:
filtered = filtered[:case_limit]
return filtered
def _bool_flag(enabled: bool, flag: str) -> list[str]:
return [flag] if enabled else []
def _build_case_command(
*,
jsonl_path: Path,
output_dir: Path,
period: str,
data_path: Path | None,
backtest_workers: int,
label_forward_days: int,
trade_guard_config: str | None,
capture_detail_artifacts: bool,
case: MatrixCase,
) -> list[str]:
params = case.params
cmd = [
sys.executable,
str(RUNNER_PATH),
"--jsonl",
str(jsonl_path),
"--period",
period,
"--output-dir",
str(output_dir),
"--backtest-workers",
str(backtest_workers),
"--label-forward-days",
str(label_forward_days),
"--backtest-engine",
str(params["backtest_engine"]),
"--top-k",
str(params["top_k"]),
"--rebalance-freq",
str(params["rebalance_freq"]),
"--custom-weight-mode",
str(params["custom_weight_mode"]),
"--position-size",
str(params["position_size"]),
"--max-pos-each-stock",
str(params["max_pos_each_stock"]),
"--max-daily-volume-participation",
str(params["max_daily_volume_participation"]),
"--max-daily-amount-participation",
str(params["max_daily_amount_participation"]),
"--buy-fee",
str(params["buy_fee"]),
"--sell-fee",
str(params["sell_fee"]),
"--score-transform",
str(params["score_transform"]),
"--score-clip",
str(params["score_clip"]),
"--universe-filter",
str(params["universe_filter"]),
"--universe-top-n",
str(params["universe_top_n"]),
"--universe-lookback-days",
str(params["universe_lookback_days"]),
]
if data_path is not None:
cmd.extend(["--data-path", str(data_path)])
if params.get("start_date"):
cmd.extend(["--start-date", str(params["start_date"])])
if params.get("end_date"):
cmd.extend(["--end-date", str(params["end_date"])])
if trade_guard_config:
cmd.extend(["--trade-guard-config", trade_guard_config])
cmd.extend(_bool_flag(bool(params.get("enforce_cash_limit")), "--enforce-cash-limit"))
cmd.extend(_bool_flag(bool(params.get("redistribute_unfilled_cash")), "--redistribute-unfilled-cash"))
cmd.extend(_bool_flag(capture_detail_artifacts, "--capture-detail-artifacts"))
return cmd
def _load_case_manifest(case_dir: Path) -> dict[str, Any]:
manifest_path = case_dir / "robust_manifest.json"
if not manifest_path.exists():
return {}
return json.loads(manifest_path.read_text(encoding="utf-8"))
def _read_csv(path: Path) -> pd.DataFrame:
if not path.exists():
return pd.DataFrame()
try:
return pd.read_csv(path)
except pd.errors.EmptyDataError:
return pd.DataFrame()
def _merge_case_frames(case_records: list[dict[str, Any]], filename: str) -> pd.DataFrame:
frames: list[pd.DataFrame] = []
for record in case_records:
path = Path(record["case_dir"]) / filename
df = _read_csv(path)
if df.empty:
continue
df.insert(0, "case_description", record["description"])
df.insert(0, "case_group", record["case_group"])
df.insert(0, "case_id", record["case_id"])
frames.append(df)
if not frames:
return pd.DataFrame()
return pd.concat(frames, ignore_index=True)
def _safe_bool(series: pd.Series) -> pd.Series:
return series.fillna(False).astype(bool)
def _compute_coverage_summary(case_records: list[dict[str, Any]]) -> pd.DataFrame:
rows: list[dict[str, Any]] = []
for record in case_records:
case_dir = Path(record["case_dir"])
manifest = _load_case_manifest(case_dir)
top_k = int(manifest.get("top_k") or 0)
portfolio_df = _read_csv(case_dir / "portfolio_daily.csv")
signal_df = _read_csv(case_dir / "signal_selection_daily.csv")
plan_df = _read_csv(case_dir / "rebalance_plan.csv")
seed_names: set[str] = set()
for frame in (portfolio_df, signal_df, plan_df):
if "seed_name" in frame.columns and not frame.empty:
seed_names.update(frame["seed_name"].dropna().astype(str).unique().tolist())
for seed_name in sorted(seed_names):
row: dict[str, Any] = {
"case_id": record["case_id"],
"case_group": record["case_group"],
"case_description": record["description"],
"seed_name": seed_name,
"top_k": top_k,
}
seed_port = portfolio_df[portfolio_df["seed_name"].astype(str) == seed_name].copy() if not portfolio_df.empty else pd.DataFrame()
seed_signal = signal_df[signal_df["seed_name"].astype(str) == seed_name].copy() if not signal_df.empty else pd.DataFrame()
seed_plan = plan_df[plan_df["seed_name"].astype(str) == seed_name].copy() if not plan_df.empty else pd.DataFrame()
if not seed_port.empty:
row["n_portfolio_days"] = int(len(seed_port))
row["n_rebalance_days"] = int(_safe_bool(seed_port["is_rebalance"]).sum()) if "is_rebalance" in seed_port.columns else 0
row["n_trade_days"] = int(_safe_bool(seed_port["had_trade"]).sum()) if "had_trade" in seed_port.columns else 0
row["cash_weight_mean"] = float(pd.to_numeric(seed_port.get("cash_weight"), errors="coerce").mean())
row["cash_weight_p95"] = float(pd.to_numeric(seed_port.get("cash_weight"), errors="coerce").quantile(0.95))
row["cash_weight_max"] = float(pd.to_numeric(seed_port.get("cash_weight"), errors="coerce").max())
else:
row["n_portfolio_days"] = 0
row["n_rebalance_days"] = 0
row["n_trade_days"] = 0
row["cash_weight_mean"] = None
row["cash_weight_p95"] = None
row["cash_weight_max"] = None
if not seed_plan.empty and "date" in seed_plan.columns:
per_day_plan = (
seed_plan.groupby("date", as_index=False)
.agg(
target_count_eod=("target_count_eod", "max"),
unallocated_cash_eod=("unallocated_cash_eod", "max"),
invested_value_eod=("invested_value_eod", "max"),
)
)
counts = pd.to_numeric(per_day_plan["target_count_eod"], errors="coerce")
row["mean_target_count_eod"] = float(counts.mean())
row["median_target_count_eod"] = float(counts.median())
row["min_target_count_eod"] = float(counts.min())
row["max_target_count_eod"] = float(counts.max())
row["rebalance_days_lt_topk"] = int((counts < top_k).sum()) if top_k > 0 else 0
row["pct_rebalance_days_lt_topk"] = float((counts < top_k).mean()) if top_k > 0 and len(counts) else 0.0
row["unallocated_cash_eod_mean"] = float(pd.to_numeric(per_day_plan["unallocated_cash_eod"], errors="coerce").mean())
else:
row["mean_target_count_eod"] = None
row["median_target_count_eod"] = None
row["min_target_count_eod"] = None
row["max_target_count_eod"] = None
row["rebalance_days_lt_topk"] = 0
row["pct_rebalance_days_lt_topk"] = 0.0
row["unallocated_cash_eod_mean"] = None
if not seed_signal.empty and "trade_date" in seed_signal.columns:
if "topk_by_score" in seed_signal.columns:
topk_rows = seed_signal[_safe_bool(seed_signal["topk_by_score"])].copy()
else:
topk_rows = seed_signal.copy()
if not topk_rows.empty:
per_trade_date = topk_rows.groupby("trade_date", as_index=False).agg(
topk_names=("instrument", "count"),
zero_score_names=("score", lambda s: int((pd.to_numeric(s, errors="coerce").fillna(0.0).abs() <= 1e-12).sum())),
)
counts = pd.to_numeric(per_trade_date["topk_names"], errors="coerce")
row["signal_trade_dates"] = int(len(per_trade_date))
row["mean_topk_names_per_signal_day"] = float(counts.mean())
row["min_topk_names_per_signal_day"] = float(counts.min())
row["pct_signal_days_lt_topk"] = float((counts < top_k).mean()) if top_k > 0 else 0.0
row["all_zero_score_days"] = int((per_trade_date["zero_score_names"] == per_trade_date["topk_names"]).sum())
row["pct_all_zero_score_days"] = float((per_trade_date["zero_score_names"] == per_trade_date["topk_names"]).mean())
row["zero_score_row_rate"] = float(
(
pd.to_numeric(topk_rows["score"], errors="coerce").fillna(0.0).abs() <= 1e-12
).mean()
)
else:
row["signal_trade_dates"] = 0
row["mean_topk_names_per_signal_day"] = 0.0
row["min_topk_names_per_signal_day"] = 0.0
row["pct_signal_days_lt_topk"] = 0.0
row["all_zero_score_days"] = 0
row["pct_all_zero_score_days"] = 0.0
row["zero_score_row_rate"] = 0.0
else:
row["signal_trade_dates"] = 0
row["mean_topk_names_per_signal_day"] = 0.0
row["min_topk_names_per_signal_day"] = 0.0
row["pct_signal_days_lt_topk"] = 0.0
row["all_zero_score_days"] = 0
row["pct_all_zero_score_days"] = 0.0
row["zero_score_row_rate"] = 0.0
rows.append(row)
return pd.DataFrame(rows)
def _write_outputs(output_root: Path, case_records: list[dict[str, Any]]) -> None:
matrix_cases = pd.DataFrame(case_records)
matrix_cases.to_csv(output_root / "matrix_cases.csv", index=False)
merged_summary = _merge_case_frames(case_records, "summary.csv")
merged_trials = _merge_case_frames(case_records, "trials.csv")
merged_summary_yearly = _merge_case_frames(case_records, "summary_yearly.csv")
merged_trials_yearly = _merge_case_frames(case_records, "trials_yearly.csv")
merged_aggregate_yearly = _merge_case_frames(case_records, "aggregate_yearly.csv")
coverage_summary = _compute_coverage_summary(case_records)
merged_summary.to_csv(output_root / "merged_summary.csv", index=False)
merged_trials.to_csv(output_root / "merged_trials.csv", index=False)
merged_summary_yearly.to_csv(output_root / "merged_summary_yearly.csv", index=False)
merged_trials_yearly.to_csv(output_root / "merged_trials_yearly.csv", index=False)
merged_aggregate_yearly.to_csv(output_root / "merged_aggregate_yearly.csv", index=False)
coverage_summary.to_csv(output_root / "coverage_sparsity_summary.csv", index=False)
def _case_metadata(case: MatrixCase, case_dir: Path) -> dict[str, Any]:
return {
"case_id": case.case_id,
"case_group": case.case_group,
"description": case.description,
"case_dir": str(case_dir),
**case.params,
}
def main() -> None:
parser = argparse.ArgumentParser(description="Run the standard alpha robustness matrix")
parser.add_argument("--jsonl", required=True, help="Alpha pack JSONL file")
parser.add_argument("--output-root", required=True, help="Directory that will hold all case outputs")
parser.add_argument("--period", default="test", choices=["train", "val", "test"])
parser.add_argument("--data-path", default=None, help="Optional daily_pv.h5 path")
parser.add_argument("--backtest-workers", type=int, default=1, help="Worker count passed into each case run")
parser.add_argument("--label-forward-days", type=int, default=5)
parser.add_argument("--trade-guard-config", default=None)
parser.add_argument("--case-filter", default="", help="Comma-separated case IDs for debugging / partial runs")
parser.add_argument("--case-limit", type=int, default=0, help="Optional cap after filtering")
parser.add_argument("--capture-detail-artifacts", action="store_true", help="Capture full detail artifacts for every case")
parser.add_argument("--dry-run", action="store_true", help="Print planned commands without executing them")
args = parser.parse_args()
jsonl_path = Path(args.jsonl).expanduser().resolve()
output_root = Path(args.output_root).expanduser().resolve()
data_path = Path(args.data_path).expanduser().resolve() if args.data_path else None
case_filter = {item.strip() for item in str(args.case_filter).split(",") if item.strip()}
cases = _filter_cases(_make_standard_cases(), case_filter or None, int(args.case_limit))
if not cases:
raise SystemExit("No robustness cases selected.")
output_root.mkdir(parents=True, exist_ok=True)
case_records: list[dict[str, Any]] = []
print(f"jsonl={jsonl_path}", flush=True)
print(f"output_root={output_root}", flush=True)
print(f"n_cases={len(cases)}", flush=True)
for idx, case in enumerate(cases, start=1):
case_dir = output_root / "cases" / case.case_id
case_dir.mkdir(parents=True, exist_ok=True)
cmd = _build_case_command(
jsonl_path=jsonl_path,
output_dir=case_dir,
period=args.period,
data_path=data_path,
backtest_workers=max(int(args.backtest_workers), 1),
label_forward_days=int(args.label_forward_days),
trade_guard_config=args.trade_guard_config,
capture_detail_artifacts=bool(args.capture_detail_artifacts),
case=case,
)
print(f"\n[{idx}/{len(cases)}] {case.case_id} :: {case.description}", flush=True)
print(" ".join(cmd), flush=True)
if not args.dry_run:
subprocess.run(cmd, check=True)
case_records.append(_case_metadata(case, case_dir))
_write_outputs(output_root, case_records)
(output_root / "matrix_manifest.json").write_text(
json.dumps(
{
"jsonl": str(jsonl_path),
"period": args.period,
"data_path": str(data_path) if data_path else None,
"capture_detail_artifacts": bool(args.capture_detail_artifacts),
"cases": case_records,
},
ensure_ascii=False,
indent=2,
)
+ "\n",
encoding="utf-8",
)
print("\nSaved matrix outputs:", flush=True)
print(output_root / "matrix_cases.csv", flush=True)
print(output_root / "merged_summary.csv", flush=True)
print(output_root / "merged_summary_yearly.csv", flush=True)
print(output_root / "coverage_sparsity_summary.csv", flush=True)
if __name__ == "__main__":
main()