amarorn / scripts /hyperparams.py
beAnalytic's picture
feat: sync main with feature/superbet-live-inplay
16c19b8 verified
Raw
History Blame Contribute Delete
4.8 kB
#!/usr/bin/env python3
"""Validação walk-forward em lote por edição da Copa."""
from __future__ import annotations
import argparse
import json
import math
from collections import Counter, defaultdict
from pathlib import Path
from ingest.fixtures.world_cup import edition_label, load_wc_fixtures
from models.wc_artifact import load_artifact
from pipelines.wc_validate import list_edition_matches, validate_historical_match
def brier(probs: dict[str, float], label: str) -> float:
target = {"1": 0.0, "X": 0.0, "2": 0.0}
target[label] = 1.0
return sum((probs[k] - target[k]) ** 2 for k in "1X2")
def log_loss(probs: dict[str, float], label: str, eps: float = 1e-12) -> float:
return -math.log(max(probs.get(label, eps), eps))
def validate_edition(predictor, fixtures, season: int) -> dict:
matches = list_edition_matches(fixtures, season)
rows: list[dict] = []
by_phase: dict[str, dict] = defaultdict(
lambda: {"n": 0, "correct": 0, "brier": 0.0, "log_loss": 0.0}
)
pred_dist: Counter[str] = Counter()
actual_dist: Counter[str] = Counter()
for m in matches:
r = validate_historical_match(
predictor,
fixtures,
season,
match_id=m["match_id"],
)
actual = r["match"]["actual_result"]
pred = r["prediction"]
probs = {"1": r["prob_home"], "X": r["prob_draw"], "2": r["prob_away"]}
phase = r["match"]["phase"]
br = brier(probs, actual)
ll = log_loss(probs, actual)
rows.append(
{
"home": m["home_team"],
"away": m["away_team"],
"phase": phase,
"actual": actual,
"pred": pred,
"correct": pred == actual,
"brier": br,
"log_loss": ll,
"confidence": r["confidence"],
}
)
pred_dist[pred] += 1
actual_dist[actual] += 1
by_phase[phase]["n"] += 1
by_phase[phase]["correct"] += int(pred == actual)
by_phase[phase]["brier"] += br
by_phase[phase]["log_loss"] += ll
n = len(rows)
return {
"label": edition_label(season),
"matches": n,
"accuracy": round(sum(r["correct"] for r in rows) / n, 4),
"brier": round(sum(r["brier"] for r in rows) / n, 4),
"log_loss": round(sum(r["log_loss"] for r in rows) / n, 4),
"baseline_always_home": round(sum(1 for r in rows if r["actual"] == "1") / n, 4),
"baseline_majority_class": round(max(actual_dist.values()) / n, 4),
"pred_distribution": dict(pred_dist),
"actual_distribution": dict(actual_dist),
"by_phase": {
phase: {
"matches": stats["n"],
"accuracy": round(stats["correct"] / stats["n"], 4),
"brier": round(stats["brier"] / stats["n"], 4),
}
for phase, stats in sorted(by_phase.items())
},
"wrong_high_confidence": sorted(
[
{
"match": f"{r['home']} x {r['away']}",
"actual": r["actual"],
"pred": r["pred"],
"confidence": round(r["confidence"], 3),
}
for r in rows
if not r["correct"]
],
key=lambda item: -item["confidence"],
)[:5],
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--seasons", nargs="+", type=int, default=[2018, 2022])
parser.add_argument(
"--out",
type=Path,
default=Path("data/lake/artifacts/wc_predictor/batch_validation.json"),
)
args = parser.parse_args()
predictor = load_artifact()
if predictor is None:
raise SystemExit("Artefato WC inválido — rode: train-wc --force")
fixtures = load_wc_fixtures()
manifest_path = Path("data/lake/artifacts/wc_predictor/manifest.json")
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
report = {
"artifact_created_at": manifest.get("created_at"),
"ensemble_weights": manifest.get("ensemble_weights"),
"holdout_training": manifest.get("training_metrics", {}).get("holdout_accuracy"),
"fixture_rows": len(fixtures),
"editions": {},
}
for season in args.seasons:
print(f"Validando {season}…", flush=True)
report["editions"][str(season)] = validate_edition(predictor, fixtures, season)
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()