Spaces:
Paused
Paused
File size: 5,574 Bytes
bc35a94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | from __future__ import annotations
import json
import os
import time
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass
class StepMetrics:
step: int
solve_rate: float
reward_mean: float
reward_max: float
health_mean: float
steps_mean: float
task_mix: dict[str, int]
wall_seconds: float
class RewardLogger:
def __init__(
self,
output_dir: str | Path,
run_name: str = "hpc_grpo",
wandb_project: str | None = None,
hub_repo: str | None = None,
transcript_sample_every: int = 5,
transcript_max_samples: int = 2,
) -> None:
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.run_name = run_name
self.jsonl_path = self.output_dir / f"{run_name}.metrics.jsonl"
self.transcripts_dir = self.output_dir / "transcripts"
self.transcripts_dir.mkdir(parents=True, exist_ok=True)
self.transcript_sample_every = max(1, int(transcript_sample_every))
self.transcript_max_samples = max(1, int(transcript_max_samples))
self._start = time.time()
self._wandb = None
if wandb_project:
try:
import wandb # type: ignore
self._wandb = wandb.init(
project=wandb_project,
name=run_name,
dir=str(self.output_dir),
reinit=True,
)
except Exception as exc:
print(f"reward_logger wandb disabled {type(exc).__name__.lower()} {exc}")
self._wandb = None
self.hub_repo = hub_repo
def log(self, step: int, records: list[Any]) -> StepMetrics:
rewards = [float(r.reward) for r in records]
health = [float(getattr(r, "best_health", 0.0) or r.grader_health) for r in records]
steps = [int(r.steps) for r in records]
solved = sum(1 for r in records if bool(getattr(r, "terminated", False)))
mix: dict[str, int] = {}
for r in records:
mix[r.task_id] = mix.get(r.task_id, 0) + 1
metrics = StepMetrics(
step=step,
solve_rate=solved / len(records) if records else 0.0,
reward_mean=(sum(rewards) / len(rewards)) if rewards else 0.0,
reward_max=max(rewards) if rewards else 0.0,
health_mean=(sum(health) / len(health)) if health else 0.0,
steps_mean=(sum(steps) / len(steps)) if steps else 0.0,
task_mix=mix,
wall_seconds=time.time() - self._start,
)
payload = asdict(metrics)
with self.jsonl_path.open("a") as f:
f.write(json.dumps(payload) + "\n")
if self._wandb is not None:
try:
self._wandb.log(payload, step=step)
except Exception as exc:
print(f"reward_logger wandb log failed {type(exc).__name__.lower()} {exc}")
print(
f"metrics step {step} solve_rate {metrics.solve_rate:.2f} "
f"reward_mean {metrics.reward_mean:.2f} health_mean {metrics.health_mean:.2f} "
f"steps_mean {metrics.steps_mean:.1f} mix {mix}"
)
# judges' guide: "sample outputs frequently and inspect them". write a
# couple of transcripts to disk every few steps so reward hacking is
# catchable by a human reviewer and so tensorboard text panels have
# something to show.
if step % self.transcript_sample_every == 0:
self._write_transcript_sample(step, records)
return metrics
def _write_transcript_sample(self, step: int, records: list[Any]) -> None:
if not records:
return
sample_path = self.transcripts_dir / f"step_{step:05d}.jsonl"
with sample_path.open("w") as f:
for r in records[: self.transcript_max_samples]:
transcript = getattr(r, "transcript", None) or []
payload = {
"task_id": getattr(r, "task_id", ""),
"reward": float(getattr(r, "reward", 0.0)),
"last_reward": float(getattr(r, "last_reward", 0.0)),
"steps": int(getattr(r, "steps", 0)),
"grader_health": float(getattr(r, "grader_health", 0.0)),
"best_health": float(getattr(r, "best_health", 0.0)),
"terminated": bool(getattr(r, "terminated", False)),
"truncated": bool(getattr(r, "truncated", False)),
"transcript": transcript,
}
f.write(json.dumps(payload, default=str) + "\n")
def close(self) -> None:
if self._wandb is not None:
try:
self._wandb.finish()
except Exception:
pass
if self.hub_repo:
self._push_to_hub()
def _push_to_hub(self) -> None:
try:
from huggingface_hub import HfApi # type: ignore
api = HfApi(token=os.environ.get("HF_TOKEN"))
api.upload_file(
path_or_fileobj=str(self.jsonl_path),
path_in_repo=f"runs/{self.jsonl_path.name}",
repo_id=self.hub_repo,
repo_type="model",
)
print(f"reward_logger pushed metrics to hub {self.hub_repo}")
except Exception as exc:
print(f"reward_logger hub push failed {type(exc).__name__.lower()} {exc}")
|