Spaces:
Sleeping
Sleeping
| """ | |
| Episode logger and metrics tracker. | |
| Records every step and episode so you can: | |
| - See exactly what the model chose vs what was optimal | |
| - Analyze failure patterns across episodes | |
| - Export training data for offline analysis | |
| - Feed live stats to the /metrics endpoint and viz | |
| """ | |
| import json | |
| import os | |
| import time | |
| from collections import deque | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Optional | |
| class StepLog: | |
| step: int | |
| action: str | |
| result: str | |
| reward: float | |
| cumulative_reward: float | |
| valid_actions: list[str] | |
| oracle_action: Optional[str] # what scripted policy would do | |
| chose_oracle: Optional[bool] # did model match oracle? | |
| holding: Optional[str] | |
| n_failures_so_far: int | |
| n_subgoals_done: int | |
| class EpisodeLog: | |
| episode_id: int | |
| instruction: str | |
| difficulty: str | |
| n_objects: int | |
| n_blockers: int | |
| n_targets: int | |
| had_mid_task_change: bool | |
| steps: list[StepLog] = field(default_factory=list) | |
| # Outcome | |
| success: bool = False | |
| total_reward: float = 0.0 | |
| total_steps: int = 0 | |
| failure_types: list[str] = field(default_factory=list) # unique failure result codes | |
| repeated_failures: int = 0 | |
| oracle_agreement_rate: float = 0.0 # fraction of steps where model == oracle | |
| # Timing | |
| start_time: float = field(default_factory=time.time) | |
| end_time: Optional[float] = None | |
| def finish(self, success: bool): | |
| self.success = success | |
| self.total_steps = len(self.steps) | |
| self.total_reward = sum(s.reward for s in self.steps) | |
| self.end_time = time.time() | |
| self.failure_types = list({s.result for s in self.steps if not s.result.startswith("SUCCESS")}) | |
| seen = set() | |
| rf = 0 | |
| for s in self.steps: | |
| if s.result in seen: | |
| rf += 1 | |
| seen.add(s.result) | |
| self.repeated_failures = rf | |
| oracle_steps = [s for s in self.steps if s.oracle_action is not None] | |
| if oracle_steps: | |
| self.oracle_agreement_rate = sum(1 for s in oracle_steps if s.chose_oracle) / len(oracle_steps) | |
| def to_jsonl(self) -> str: | |
| d = asdict(self) | |
| return json.dumps(d) | |
| class MetricsTracker: | |
| """ | |
| Rolling statistics across episodes. | |
| Feeds the /metrics endpoint and the curriculum manager. | |
| """ | |
| def __init__(self, window: int = 20, max_history: int = 200): | |
| self.window = window | |
| self._history: deque[EpisodeLog] = deque(maxlen=max_history) | |
| self._episode_count = 0 | |
| self._current_difficulty = "easy" | |
| def record(self, ep: EpisodeLog): | |
| self._history.append(ep) | |
| self._episode_count += 1 | |
| def rolling_success_rate(self) -> float: | |
| recent = list(self._history)[-self.window:] | |
| if not recent: | |
| return 0.0 | |
| return sum(1 for e in recent if e.success) / len(recent) | |
| def rolling_avg_reward(self) -> float: | |
| recent = list(self._history)[-self.window:] | |
| if not recent: | |
| return 0.0 | |
| return sum(e.total_reward for e in recent) / len(recent) | |
| def rolling_avg_steps(self) -> float: | |
| recent = list(self._history)[-self.window:] | |
| if not recent: | |
| return 0.0 | |
| return sum(e.total_steps for e in recent) / len(recent) | |
| def oracle_agreement_rate(self) -> float: | |
| recent = list(self._history)[-self.window:] | |
| if not recent: | |
| return 0.0 | |
| rates = [e.oracle_agreement_rate for e in recent if e.oracle_agreement_rate > 0] | |
| return sum(rates) / len(rates) if rates else 0.0 | |
| def failure_breakdown(self) -> dict[str, int]: | |
| """Count how often each failure type appears in recent episodes.""" | |
| counts: dict[str, int] = {} | |
| for ep in list(self._history)[-self.window:]: | |
| for ft in ep.failure_types: | |
| counts[ft] = counts.get(ft, 0) + 1 | |
| return dict(sorted(counts.items(), key=lambda x: -x[1])) | |
| def failure_taxonomy(self) -> dict[str, int]: | |
| tax = { | |
| "invalid": 0, | |
| "blocked": 0, | |
| "empty": 0, | |
| "slip": 0, | |
| "other": 0, | |
| } | |
| for k, v in self.failure_breakdown().items(): | |
| kk = k.upper() | |
| if "INVALID" in kk: | |
| tax["invalid"] += v | |
| elif "BLOCK" in kk: | |
| tax["blocked"] += v | |
| elif "EMPTY" in kk: | |
| tax["empty"] += v | |
| elif "SLIP" in kk: | |
| tax["slip"] += v | |
| else: | |
| tax["other"] += v | |
| return tax | |
| def reward_curve(self) -> list[float]: | |
| """Per-episode total reward for plotting.""" | |
| return [e.total_reward for e in self._history] | |
| def success_curve(self) -> list[int]: | |
| """Per-episode 0/1 for plotting.""" | |
| return [int(e.success) for e in self._history] | |
| def to_dict(self) -> dict: | |
| return { | |
| "total_episodes": self._episode_count, | |
| "current_difficulty": self._current_difficulty, | |
| "rolling_success_rate": round(self.rolling_success_rate(), 3), | |
| "rolling_avg_reward": round(self.rolling_avg_reward(), 2), | |
| "rolling_avg_steps": round(self.rolling_avg_steps(), 1), | |
| "oracle_agreement_rate": round(self.oracle_agreement_rate(), 3), | |
| "failure_breakdown": self.failure_breakdown(), | |
| "failure_taxonomy": self.failure_taxonomy(), | |
| "reward_curve": self.reward_curve()[-50:], # last 50 for the chart | |
| "success_curve": self.success_curve()[-50:], | |
| } | |
| class EpisodeLogger: | |
| """ | |
| Manages per-episode logging and writes to JSONL. | |
| """ | |
| def __init__(self, export_path: Optional[str] = None, max_history: int = 200): | |
| self.metrics = MetricsTracker(max_history=max_history) | |
| self._current: Optional[EpisodeLog] = None | |
| self._export_path = export_path | |
| if export_path: | |
| os.makedirs(os.path.dirname(export_path), exist_ok=True) | |
| def begin_episode(self, episode_id: int, instruction: str, difficulty: str, | |
| n_objects: int, n_blockers: int, n_targets: int, | |
| had_mid_task_change: bool = False): | |
| self._current = EpisodeLog( | |
| episode_id=episode_id, | |
| instruction=instruction, | |
| difficulty=difficulty, | |
| n_objects=n_objects, | |
| n_blockers=n_blockers, | |
| n_targets=n_targets, | |
| had_mid_task_change=had_mid_task_change, | |
| ) | |
| def log_step(self, step: int, action: str, result: str, reward: float, | |
| cumulative_reward: float, valid_actions: list[str], | |
| oracle_action: Optional[str], holding: Optional[str], | |
| n_failures: int, n_subgoals: int): | |
| if self._current is None: | |
| return | |
| self._current.steps.append(StepLog( | |
| step=step, | |
| action=action, | |
| result=result, | |
| reward=reward, | |
| cumulative_reward=cumulative_reward, | |
| valid_actions=valid_actions, | |
| oracle_action=oracle_action, | |
| chose_oracle=(action == oracle_action) if oracle_action else None, | |
| holding=holding, | |
| n_failures_so_far=n_failures, | |
| n_subgoals_done=n_subgoals, | |
| )) | |
| def end_episode(self, success: bool) -> EpisodeLog: | |
| if self._current is None: | |
| raise RuntimeError("No active episode") | |
| self._current.finish(success) | |
| ep = self._current | |
| self._current = None | |
| self.metrics.record(ep) | |
| if self._export_path: | |
| with open(self._export_path, "a") as f: | |
| f.write(ep.to_jsonl() + "\n") | |
| return ep | |