from __future__ import annotations import json import os import time from dataclasses import asdict from dataclasses import dataclass from pathlib import Path from typing import Any @dataclass class StepMetrics: step: int solve_rate: float reward_mean: float reward_max: float health_mean: float steps_mean: float task_mix: dict[str, int] wall_seconds: float class RewardLogger: def __init__( self, output_dir: str | Path, run_name: str = "hpc_grpo", wandb_project: str | None = None, hub_repo: str | None = None, transcript_sample_every: int = 5, transcript_max_samples: int = 2, ) -> None: self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.run_name = run_name self.jsonl_path = self.output_dir / f"{run_name}.metrics.jsonl" self.transcripts_dir = self.output_dir / "transcripts" self.transcripts_dir.mkdir(parents=True, exist_ok=True) self.transcript_sample_every = max(1, int(transcript_sample_every)) self.transcript_max_samples = max(1, int(transcript_max_samples)) self._start = time.time() self._wandb = None if wandb_project: try: import wandb # type: ignore self._wandb = wandb.init( project=wandb_project, name=run_name, dir=str(self.output_dir), reinit=True, ) except Exception as exc: print(f"reward_logger wandb disabled {type(exc).__name__.lower()} {exc}") self._wandb = None self.hub_repo = hub_repo def log(self, step: int, records: list[Any]) -> StepMetrics: rewards = [float(r.reward) for r in records] health = [float(getattr(r, "best_health", 0.0) or r.grader_health) for r in records] steps = [int(r.steps) for r in records] solved = sum(1 for r in records if bool(getattr(r, "terminated", False))) mix: dict[str, int] = {} for r in records: mix[r.task_id] = mix.get(r.task_id, 0) + 1 metrics = StepMetrics( step=step, solve_rate=solved / len(records) if records else 0.0, reward_mean=(sum(rewards) / len(rewards)) if rewards else 0.0, reward_max=max(rewards) if rewards else 0.0, health_mean=(sum(health) / len(health)) if health else 0.0, steps_mean=(sum(steps) / len(steps)) if steps else 0.0, task_mix=mix, wall_seconds=time.time() - self._start, ) payload = asdict(metrics) with self.jsonl_path.open("a") as f: f.write(json.dumps(payload) + "\n") if self._wandb is not None: try: self._wandb.log(payload, step=step) except Exception as exc: print(f"reward_logger wandb log failed {type(exc).__name__.lower()} {exc}") print( f"metrics step {step} solve_rate {metrics.solve_rate:.2f} " f"reward_mean {metrics.reward_mean:.2f} health_mean {metrics.health_mean:.2f} " f"steps_mean {metrics.steps_mean:.1f} mix {mix}" ) # judges' guide: "sample outputs frequently and inspect them". write a # couple of transcripts to disk every few steps so reward hacking is # catchable by a human reviewer and so tensorboard text panels have # something to show. if step % self.transcript_sample_every == 0: self._write_transcript_sample(step, records) return metrics def _write_transcript_sample(self, step: int, records: list[Any]) -> None: if not records: return sample_path = self.transcripts_dir / f"step_{step:05d}.jsonl" with sample_path.open("w") as f: for r in records[: self.transcript_max_samples]: transcript = getattr(r, "transcript", None) or [] payload = { "task_id": getattr(r, "task_id", ""), "reward": float(getattr(r, "reward", 0.0)), "last_reward": float(getattr(r, "last_reward", 0.0)), "steps": int(getattr(r, "steps", 0)), "grader_health": float(getattr(r, "grader_health", 0.0)), "best_health": float(getattr(r, "best_health", 0.0)), "terminated": bool(getattr(r, "terminated", False)), "truncated": bool(getattr(r, "truncated", False)), "transcript": transcript, } f.write(json.dumps(payload, default=str) + "\n") def close(self) -> None: if self._wandb is not None: try: self._wandb.finish() except Exception: pass if self.hub_repo: self._push_to_hub() def _push_to_hub(self) -> None: try: from huggingface_hub import HfApi # type: ignore api = HfApi(token=os.environ.get("HF_TOKEN")) api.upload_file( path_or_fileobj=str(self.jsonl_path), path_in_repo=f"runs/{self.jsonl_path.name}", repo_id=self.hub_repo, repo_type="model", ) print(f"reward_logger pushed metrics to hub {self.hub_repo}") except Exception as exc: print(f"reward_logger hub push failed {type(exc).__name__.lower()} {exc}")