Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import sys | |
| from dataclasses import asdict, dataclass, field | |
| from datetime import UTC, datetime | |
| from pathlib import Path | |
| from typing import Any, Protocol | |
| from config import settings | |
| TRAIN_STEPS = ( | |
| "logistic_regression", | |
| "dixon_coles", | |
| "collaborative_ensemble", | |
| "draw_model", | |
| "persist_artifact", | |
| ) | |
| STEP_LABELS: dict[str, str] = { | |
| "logistic_regression": "Regressão logística + calibração", | |
| "dixon_coles": "Dixon-Coles (rho)", | |
| "collaborative_ensemble": "Ensemble colaborativo", | |
| "draw_model": "Modelo de empate", | |
| "persist_artifact": "Salvando artefato", | |
| } | |
| class TrainProgressState: | |
| status: str | |
| step: str | None = None | |
| step_label: str | None = None | |
| step_index: int = 0 | |
| total_steps: int = len(TRAIN_STEPS) | |
| detail: str | None = None | |
| percent: float | None = None | |
| started_at: str | None = None | |
| updated_at: str | None = None | |
| elapsed_sec: float | None = None | |
| metrics: dict[str, Any] = field(default_factory=dict) | |
| error: str | None = None | |
| class TrainProgressReporter(Protocol): | |
| def start(self, fixtures: int) -> None: ... | |
| def step_start(self, step: str, detail: str | None = None) -> None: ... | |
| def step_progress(self, current: int, total: int, detail: str | None = None) -> None: ... | |
| def step_done(self, step: str, metrics: dict[str, Any] | None = None) -> None: ... | |
| def complete(self, summary: dict[str, Any]) -> None: ... | |
| def fail(self, error: str) -> None: ... | |
| def progress_path() -> Path: | |
| return settings.wc_artifact_dir / "train_progress.json" | |
| def read_train_progress() -> TrainProgressState | None: | |
| path = progress_path() | |
| if not path.exists(): | |
| return None | |
| try: | |
| raw = json.loads(path.read_text(encoding="utf-8")) | |
| return TrainProgressState(**raw) | |
| except (json.JSONDecodeError, TypeError): | |
| return None | |
| class NullTrainProgressReporter: | |
| def start(self, fixtures: int) -> None: | |
| return | |
| def step_start(self, step: str, detail: str | None = None) -> None: | |
| return | |
| def step_progress(self, current: int, total: int, detail: str | None = None) -> None: | |
| return | |
| def step_done(self, step: str, metrics: dict[str, Any] | None = None) -> None: | |
| return | |
| def complete(self, summary: dict[str, Any]) -> None: | |
| return | |
| def fail(self, error: str) -> None: | |
| return | |
| class WcTrainProgressReporter: | |
| def __init__(self, *, console: bool = False) -> None: | |
| self.console = console and sys.stdout.isatty() | |
| self._started: datetime | None = None | |
| self._state = TrainProgressState(status="idle") | |
| self._last_console_pct = -1 | |
| def _persist(self) -> None: | |
| self._state.updated_at = datetime.now(UTC).isoformat() | |
| if self._started is not None: | |
| self._state.elapsed_sec = round( | |
| (datetime.now(UTC) - self._started).total_seconds(), | |
| 1, | |
| ) | |
| path = progress_path() | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text( | |
| json.dumps(asdict(self._state), ensure_ascii=False, indent=2), | |
| encoding="utf-8", | |
| ) | |
| def _println(self, message: str) -> None: | |
| if self.console: | |
| print(message, flush=True) | |
| def start(self, fixtures: int) -> None: | |
| self._started = datetime.now(UTC) | |
| self._last_console_pct = -1 | |
| self._state = TrainProgressState( | |
| status="running", | |
| started_at=self._started.isoformat(), | |
| detail=f"{fixtures} jogos no lake", | |
| ) | |
| self._persist() | |
| self._println(f"Treino WC iniciado ({fixtures} jogos)") | |
| def step_start(self, step: str, detail: str | None = None) -> None: | |
| idx = TRAIN_STEPS.index(step) if step in TRAIN_STEPS else 0 | |
| label = STEP_LABELS.get(step, step) | |
| self._state.step = step | |
| self._state.step_label = label | |
| self._state.step_index = idx + 1 | |
| self._state.detail = detail | |
| self._state.percent = None | |
| self._last_console_pct = -1 | |
| self._persist() | |
| self._println(f"[{idx + 1}/{len(TRAIN_STEPS)}] {label}…") | |
| def step_progress(self, current: int, total: int, detail: str | None = None) -> None: | |
| pct = round((current / total) * 100, 1) if total else 0.0 | |
| self._state.percent = pct | |
| self._state.detail = detail or f"{current}/{total}" | |
| self._persist() | |
| if not self.console or total <= 0: | |
| return | |
| bucket = int(pct // 5) * 5 | |
| if bucket != self._last_console_pct or current >= total: | |
| self._last_console_pct = bucket | |
| label = self._state.step_label or self._state.step or "progresso" | |
| self._println(f" {label}: {current}/{total} ({pct:.0f}%)") | |
| def step_done(self, step: str, metrics: dict[str, Any] | None = None) -> None: | |
| if metrics: | |
| self._state.metrics.update(metrics) | |
| self._state.percent = 100.0 | |
| self._persist() | |
| label = STEP_LABELS.get(step, step) | |
| self._println(f" {label}: concluído") | |
| def complete(self, summary: dict[str, Any]) -> None: | |
| self._state.status = "completed" | |
| self._state.percent = 100.0 | |
| self._state.detail = "Treino finalizado" | |
| self._state.metrics.update(summary) | |
| self._persist() | |
| elapsed = self._state.elapsed_sec | |
| elapsed_txt = f" em {elapsed:.0f}s" if elapsed is not None else "" | |
| self._println(f"Treino WC concluído{elapsed_txt}") | |
| def fail(self, error: str) -> None: | |
| self._state.status = "failed" | |
| self._state.error = error | |
| self._persist() | |
| self._println(f"Treino WC falhou: {error}") | |