#!/usr/bin/env python3 """Validação walk-forward em lote por edição da Copa.""" from __future__ import annotations import argparse import json import math from collections import Counter, defaultdict from pathlib import Path from ingest.fixtures.world_cup import edition_label, load_wc_fixtures from models.wc_artifact import load_artifact from pipelines.wc_validate import list_edition_matches, validate_historical_match def brier(probs: dict[str, float], label: str) -> float: target = {"1": 0.0, "X": 0.0, "2": 0.0} target[label] = 1.0 return sum((probs[k] - target[k]) ** 2 for k in "1X2") def log_loss(probs: dict[str, float], label: str, eps: float = 1e-12) -> float: return -math.log(max(probs.get(label, eps), eps)) def validate_edition(predictor, fixtures, season: int) -> dict: matches = list_edition_matches(fixtures, season) rows: list[dict] = [] by_phase: dict[str, dict] = defaultdict( lambda: {"n": 0, "correct": 0, "brier": 0.0, "log_loss": 0.0} ) pred_dist: Counter[str] = Counter() actual_dist: Counter[str] = Counter() for m in matches: r = validate_historical_match( predictor, fixtures, season, match_id=m["match_id"], ) actual = r["match"]["actual_result"] pred = r["prediction"] probs = {"1": r["prob_home"], "X": r["prob_draw"], "2": r["prob_away"]} phase = r["match"]["phase"] br = brier(probs, actual) ll = log_loss(probs, actual) rows.append( { "home": m["home_team"], "away": m["away_team"], "phase": phase, "actual": actual, "pred": pred, "correct": pred == actual, "brier": br, "log_loss": ll, "confidence": r["confidence"], } ) pred_dist[pred] += 1 actual_dist[actual] += 1 by_phase[phase]["n"] += 1 by_phase[phase]["correct"] += int(pred == actual) by_phase[phase]["brier"] += br by_phase[phase]["log_loss"] += ll n = len(rows) return { "label": edition_label(season), "matches": n, "accuracy": round(sum(r["correct"] for r in rows) / n, 4), "brier": round(sum(r["brier"] for r in rows) / n, 4), "log_loss": round(sum(r["log_loss"] for r in rows) / n, 4), "baseline_always_home": round(sum(1 for r in rows if r["actual"] == "1") / n, 4), "baseline_majority_class": round(max(actual_dist.values()) / n, 4), "pred_distribution": dict(pred_dist), "actual_distribution": dict(actual_dist), "by_phase": { phase: { "matches": stats["n"], "accuracy": round(stats["correct"] / stats["n"], 4), "brier": round(stats["brier"] / stats["n"], 4), } for phase, stats in sorted(by_phase.items()) }, "wrong_high_confidence": sorted( [ { "match": f"{r['home']} x {r['away']}", "actual": r["actual"], "pred": r["pred"], "confidence": round(r["confidence"], 3), } for r in rows if not r["correct"] ], key=lambda item: -item["confidence"], )[:5], } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--seasons", nargs="+", type=int, default=[2018, 2022]) parser.add_argument( "--out", type=Path, default=Path("data/lake/artifacts/wc_predictor/batch_validation.json"), ) args = parser.parse_args() predictor = load_artifact() if predictor is None: raise SystemExit("Artefato WC inválido — rode: train-wc --force") fixtures = load_wc_fixtures() manifest_path = Path("data/lake/artifacts/wc_predictor/manifest.json") manifest = json.loads(manifest_path.read_text(encoding="utf-8")) report = { "artifact_created_at": manifest.get("created_at"), "ensemble_weights": manifest.get("ensemble_weights"), "holdout_training": manifest.get("training_metrics", {}).get("holdout_accuracy"), "fixture_rows": len(fixtures), "editions": {}, } for season in args.seasons: print(f"Validando {season}…", flush=True) report["editions"][str(season)] = validate_edition(predictor, fixtures, season) args.out.parent.mkdir(parents=True, exist_ok=True) args.out.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") print(json.dumps(report, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()