support-ops-env / plot_rewards.py
raj921
fix: T4-safe 4bit compute, eval/audit resilience, async inference, plot_rewards CSV/table
54c5725
Raw
History Blame Contribute Delete
12 kB
#!/usr/bin/env python3
"""
Plot DriftShield training rewards from a CSV log.
Usage (while training is running or after)::
python plot_rewards.py # auto-find latest reward_log.csv
python plot_rewards.py outputs/*/reward_log.csv # specific file
python plot_rewards.py --live # refresh every 30s
python plot_rewards.py --table # ASCII table, no matplotlib
python plot_rewards.py --table --tail 30 # last 30 rows only
Supports **two CSV schemas** (auto-detected from the header row):
* **Legacy (v1)** — 6 columns::
episode, total_reward, routing_reward, reply_reward, grounding_reward, timestamp
(Legacy logs named the second metric ``field_reward``; it is stored as ``routing``
here to match v2's routing column.)
* **Extended (v2, current)** — 10 columns::
episode, task_id,
total_reward,
investigation, routing, reply_quality, groundedness, submission,
penalty_total, timestamp
The extended schema unlocks four stacked panels (total / components /
penalty / rolling success-rate), which is what the guide asks for in §15:
don't just watch the overall reward — watch each column.
"""
from __future__ import annotations
import argparse
import csv
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional
# ----------------------------------------------------------------------
# CSV loading
# ----------------------------------------------------------------------
@dataclass
class RewardLog:
schema: str # "v1" or "v2"
episodes: List[int] = field(default_factory=list)
task_ids: List[str] = field(default_factory=list)
total: List[float] = field(default_factory=list)
investigation: List[float] = field(default_factory=list)
routing: List[float] = field(default_factory=list)
reply: List[float] = field(default_factory=list)
grounding: List[float] = field(default_factory=list)
submission: List[float] = field(default_factory=list)
penalty: List[float] = field(default_factory=list)
def n(self) -> int:
return len(self.episodes)
def find_latest_csv() -> Optional[Path]:
csvs = sorted(
Path("outputs").glob("*/reward_log.csv"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
if csvs:
return csvs[0]
if Path("reward_log.csv").exists():
return Path("reward_log.csv")
return None
def _safe_float(val: str, default: float = 0.0) -> float:
try:
return float(val)
except (ValueError, TypeError):
return default
def _safe_int(val: str, default: int = 0) -> int:
try:
return int(val)
except (ValueError, TypeError):
return default
def load_csv(path: Path) -> RewardLog:
with open(path, newline="") as fh:
reader = csv.reader(fh)
header = next(reader, None)
if header is None:
return RewardLog(schema="v1")
# v2 has "task_id" as the 2nd column.
is_v2 = len(header) >= 10 and header[1].strip() == "task_id"
log = RewardLog(schema="v2" if is_v2 else "v1")
for row in reader:
if not row or all(cell.strip() == "" for cell in row):
continue
if is_v2 and len(row) >= 10:
log.episodes.append(_safe_int(row[0]))
log.task_ids.append(row[1])
log.total.append(_safe_float(row[2]))
log.investigation.append(_safe_float(row[3]))
log.routing.append(_safe_float(row[4]))
log.reply.append(_safe_float(row[5]))
log.grounding.append(_safe_float(row[6]))
log.submission.append(_safe_float(row[7]))
log.penalty.append(_safe_float(row[8]))
elif not is_v2 and len(row) >= 5:
# v1: episode, total, routing, reply, grounding, timestamp
log.episodes.append(_safe_int(row[0]))
log.task_ids.append("")
log.total.append(_safe_float(row[1]))
log.investigation.append(0.0)
log.routing.append(_safe_float(row[2]))
log.reply.append(_safe_float(row[3]))
log.grounding.append(_safe_float(row[4]))
log.submission.append(0.0)
log.penalty.append(0.0)
return log
# ----------------------------------------------------------------------
# Metrics
# ----------------------------------------------------------------------
def rolling_avg(values: List[float], window: int) -> List[float]:
window = max(1, min(window, len(values)))
return [
sum(values[max(0, i - window + 1):i + 1]) / min(i + 1, window)
for i in range(len(values))
]
def rolling_success_rate(totals: List[float], window: int, threshold: float = 0.5) -> List[float]:
window = max(1, min(window, len(totals)))
out: List[float] = []
for i in range(len(totals)):
chunk = totals[max(0, i - window + 1):i + 1]
out.append(sum(1 for v in chunk if v >= threshold) / len(chunk))
return out
def _prefix_maxima(values: List[float]) -> List[float]:
"""M[i] = max(values[0 : i+1]) — O(n)."""
out: List[float] = []
m = float("-inf")
for v in values:
m = max(m, v)
out.append(m)
return out
# ----------------------------------------------------------------------
# Plot
# ----------------------------------------------------------------------
def plot(path: Path, save_path: Optional[Path] = None, window: int = 10) -> None:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
log = load_csv(path)
if log.n() == 0:
print("No data yet.")
return
eps = log.episodes
w = min(window, log.n())
fig, axes = plt.subplots(4, 1, figsize=(12, 14), sharex=True)
ax_total, ax_comp, ax_pen, ax_succ = axes
# --- Total ---
ax_total.plot(eps, log.total, alpha=0.25, color="#1f77b4", label="per episode")
ax_total.plot(eps, rolling_avg(log.total, w), color="#1f77b4", linewidth=2,
label=f"rolling({w})")
ax_total.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
ax_total.axhline(y=0.5, color="green", linestyle=":", alpha=0.5, label="pass threshold")
ax_total.set_ylabel("Total reward")
ax_total.set_title(f"DriftShield — GRPO training ({log.n()} episodes, schema={log.schema})")
ax_total.legend(loc="lower right")
ax_total.grid(True, alpha=0.3)
# --- Components ---
ax_comp.plot(eps, rolling_avg(log.investigation, w), linewidth=2, label="investigation", color="#d62728")
ax_comp.plot(eps, rolling_avg(log.routing, w), linewidth=2, label="routing", color="#ff7f0e")
ax_comp.plot(eps, rolling_avg(log.reply, w), linewidth=2, label="reply_quality", color="#2ca02c")
ax_comp.plot(eps, rolling_avg(log.grounding, w), linewidth=2, label="groundedness", color="#9467bd")
ax_comp.plot(eps, rolling_avg(log.submission, w), linewidth=2, label="submission", color="#17becf")
ax_comp.set_ylabel("Component reward")
ax_comp.legend(loc="lower right", ncol=3)
ax_comp.grid(True, alpha=0.3)
# --- Penalties (lower is better) ---
ax_pen.plot(eps, log.penalty, alpha=0.25, color="#8c564b")
ax_pen.plot(eps, rolling_avg(log.penalty, w), linewidth=2, color="#8c564b",
label=f"penalty total rolling({w})")
ax_pen.set_ylabel("Penalty total (↓)")
ax_pen.legend(loc="upper right")
ax_pen.grid(True, alpha=0.3)
# --- Rolling success rate ---
ax_succ.plot(eps, rolling_success_rate(log.total, w, threshold=0.5),
linewidth=2, color="#2ca02c", label=f"rolling({w}) pass rate (total≥0.5)")
ax_succ.set_xlabel("Episode")
ax_succ.set_ylabel("Pass rate")
ax_succ.set_ylim(-0.05, 1.05)
ax_succ.legend(loc="lower right")
ax_succ.grid(True, alpha=0.3)
plt.tight_layout()
out = save_path or path.with_suffix(".png")
fmt = out.suffix.lower().lstrip(".") or "png"
save_kwargs: Dict[str, object] = {"bbox_inches": "tight"}
if fmt in ("png", "jpg", "jpeg", "tif", "tiff", "webp"):
save_kwargs["dpi"] = 150
plt.savefig(out, format=fmt, **save_kwargs)
plt.close()
print(f"Plot saved to {out}")
print(f"\nSchema : {log.schema}")
print(f"Episodes : {log.n()}")
print(f"Latest reward : {log.total[-1]:+.3f}")
print(f"Avg (last 10) : {sum(log.total[-10:]) / min(10, log.n()):+.3f}")
print(f"Best reward : {max(log.total):+.3f}")
print(f"Pass rate (≥0.5): {sum(1 for v in log.total if v >= 0.5) / log.n():.1%}")
if any(log.penalty):
print(f"Avg penalty : {sum(log.penalty) / log.n():.3f}")
# ----------------------------------------------------------------------
# ASCII table
# ----------------------------------------------------------------------
def print_table(path: Path, tail: Optional[int] = None) -> None:
log = load_csv(path)
if log.n() == 0:
print("No data yet.")
return
start = 0
if tail is not None and tail > 0 and log.n() > tail:
start = log.n() - tail
print(f"\n... showing last {tail} of {log.n()} rows ...")
print(
f"\n{'Ep':>5} | {'Task':<22} | {'Total':>7} | {'Inv':>6} | "
f"{'Rout':>6} | {'Reply':>6} | {'Gnd':>6} | {'Sub':>6} | {'Pen':>6} | {'Avg10':>7}"
)
print("-" * 113)
prefix_max = _prefix_maxima(log.total)
for i in range(start, log.n()):
avg10 = sum(log.total[max(0, i - 9):i + 1]) / min(i + 1, 10)
marker = " *" if log.total[i] == prefix_max[i] else ""
tid = (log.task_ids[i] or "-")[:22]
print(
f"{log.episodes[i]:>5} | {tid:<22} | {log.total[i]:>+7.2f} | "
f"{log.investigation[i]:>+6.2f} | {log.routing[i]:>+6.2f} | "
f"{log.reply[i]:>+6.2f} | {log.grounding[i]:>+6.2f} | "
f"{log.submission[i]:>+6.2f} | {log.penalty[i]:>+6.2f} | "
f"{avg10:>+7.2f}{marker}"
)
best_idx = log.total.index(max(log.total))
print(f"\nBest : {max(log.total):+.3f} (ep {log.episodes[best_idx]}"
f", task={log.task_ids[best_idx] or 'n/a'})")
print(f"Avg : {sum(log.total) / log.n():+.3f}")
passes = sum(1 for v in log.total if v >= 0.5)
print(f"Pass : {passes}/{log.n()} ({passes / log.n():.1%})")
# ----------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Plot DriftShield GRPO rewards")
parser.add_argument("csv_path", nargs="?", help="Path to reward_log.csv")
parser.add_argument("--live", action="store_true", help="Refresh every 30s")
parser.add_argument("--table", action="store_true", help="Print ASCII table instead of plot")
parser.add_argument("--tail", type=int, default=None, help="Show only last N rows in table mode")
parser.add_argument("--window", type=int, default=10, help="Rolling-window size")
parser.add_argument("--out", default=None, help="Output image path")
args = parser.parse_args()
path = Path(args.csv_path) if args.csv_path else find_latest_csv()
if not path or not path.exists():
print("No reward_log.csv found. Run training first or specify path.")
sys.exit(1)
print(f"Reading: {path}")
if args.table:
print_table(path, tail=args.tail)
return
if args.live:
while True:
try:
plot(path, Path(args.out) if args.out else None, window=args.window)
print(f"[{time.strftime('%H:%M:%S')}] sleeping 30s (Ctrl+C to stop)")
time.sleep(30)
except KeyboardInterrupt:
break
else:
plot(path, Path(args.out) if args.out else None, window=args.window)
if __name__ == "__main__":
main()