Spaces:
Paused
Paused
raj921
fix: T4-safe 4bit compute, eval/audit resilience, async inference, plot_rewards CSV/table
54c5725 | #!/usr/bin/env python3 | |
| """ | |
| Plot DriftShield training rewards from a CSV log. | |
| Usage (while training is running or after):: | |
| python plot_rewards.py # auto-find latest reward_log.csv | |
| python plot_rewards.py outputs/*/reward_log.csv # specific file | |
| python plot_rewards.py --live # refresh every 30s | |
| python plot_rewards.py --table # ASCII table, no matplotlib | |
| python plot_rewards.py --table --tail 30 # last 30 rows only | |
| Supports **two CSV schemas** (auto-detected from the header row): | |
| * **Legacy (v1)** — 6 columns:: | |
| episode, total_reward, routing_reward, reply_reward, grounding_reward, timestamp | |
| (Legacy logs named the second metric ``field_reward``; it is stored as ``routing`` | |
| here to match v2's routing column.) | |
| * **Extended (v2, current)** — 10 columns:: | |
| episode, task_id, | |
| total_reward, | |
| investigation, routing, reply_quality, groundedness, submission, | |
| penalty_total, timestamp | |
| The extended schema unlocks four stacked panels (total / components / | |
| penalty / rolling success-rate), which is what the guide asks for in §15: | |
| don't just watch the overall reward — watch each column. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| # ---------------------------------------------------------------------- | |
| # CSV loading | |
| # ---------------------------------------------------------------------- | |
| class RewardLog: | |
| schema: str # "v1" or "v2" | |
| episodes: List[int] = field(default_factory=list) | |
| task_ids: List[str] = field(default_factory=list) | |
| total: List[float] = field(default_factory=list) | |
| investigation: List[float] = field(default_factory=list) | |
| routing: List[float] = field(default_factory=list) | |
| reply: List[float] = field(default_factory=list) | |
| grounding: List[float] = field(default_factory=list) | |
| submission: List[float] = field(default_factory=list) | |
| penalty: List[float] = field(default_factory=list) | |
| def n(self) -> int: | |
| return len(self.episodes) | |
| def find_latest_csv() -> Optional[Path]: | |
| csvs = sorted( | |
| Path("outputs").glob("*/reward_log.csv"), | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True, | |
| ) | |
| if csvs: | |
| return csvs[0] | |
| if Path("reward_log.csv").exists(): | |
| return Path("reward_log.csv") | |
| return None | |
| def _safe_float(val: str, default: float = 0.0) -> float: | |
| try: | |
| return float(val) | |
| except (ValueError, TypeError): | |
| return default | |
| def _safe_int(val: str, default: int = 0) -> int: | |
| try: | |
| return int(val) | |
| except (ValueError, TypeError): | |
| return default | |
| def load_csv(path: Path) -> RewardLog: | |
| with open(path, newline="") as fh: | |
| reader = csv.reader(fh) | |
| header = next(reader, None) | |
| if header is None: | |
| return RewardLog(schema="v1") | |
| # v2 has "task_id" as the 2nd column. | |
| is_v2 = len(header) >= 10 and header[1].strip() == "task_id" | |
| log = RewardLog(schema="v2" if is_v2 else "v1") | |
| for row in reader: | |
| if not row or all(cell.strip() == "" for cell in row): | |
| continue | |
| if is_v2 and len(row) >= 10: | |
| log.episodes.append(_safe_int(row[0])) | |
| log.task_ids.append(row[1]) | |
| log.total.append(_safe_float(row[2])) | |
| log.investigation.append(_safe_float(row[3])) | |
| log.routing.append(_safe_float(row[4])) | |
| log.reply.append(_safe_float(row[5])) | |
| log.grounding.append(_safe_float(row[6])) | |
| log.submission.append(_safe_float(row[7])) | |
| log.penalty.append(_safe_float(row[8])) | |
| elif not is_v2 and len(row) >= 5: | |
| # v1: episode, total, routing, reply, grounding, timestamp | |
| log.episodes.append(_safe_int(row[0])) | |
| log.task_ids.append("") | |
| log.total.append(_safe_float(row[1])) | |
| log.investigation.append(0.0) | |
| log.routing.append(_safe_float(row[2])) | |
| log.reply.append(_safe_float(row[3])) | |
| log.grounding.append(_safe_float(row[4])) | |
| log.submission.append(0.0) | |
| log.penalty.append(0.0) | |
| return log | |
| # ---------------------------------------------------------------------- | |
| # Metrics | |
| # ---------------------------------------------------------------------- | |
| def rolling_avg(values: List[float], window: int) -> List[float]: | |
| window = max(1, min(window, len(values))) | |
| return [ | |
| sum(values[max(0, i - window + 1):i + 1]) / min(i + 1, window) | |
| for i in range(len(values)) | |
| ] | |
| def rolling_success_rate(totals: List[float], window: int, threshold: float = 0.5) -> List[float]: | |
| window = max(1, min(window, len(totals))) | |
| out: List[float] = [] | |
| for i in range(len(totals)): | |
| chunk = totals[max(0, i - window + 1):i + 1] | |
| out.append(sum(1 for v in chunk if v >= threshold) / len(chunk)) | |
| return out | |
| def _prefix_maxima(values: List[float]) -> List[float]: | |
| """M[i] = max(values[0 : i+1]) — O(n).""" | |
| out: List[float] = [] | |
| m = float("-inf") | |
| for v in values: | |
| m = max(m, v) | |
| out.append(m) | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Plot | |
| # ---------------------------------------------------------------------- | |
| def plot(path: Path, save_path: Optional[Path] = None, window: int = 10) -> None: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| log = load_csv(path) | |
| if log.n() == 0: | |
| print("No data yet.") | |
| return | |
| eps = log.episodes | |
| w = min(window, log.n()) | |
| fig, axes = plt.subplots(4, 1, figsize=(12, 14), sharex=True) | |
| ax_total, ax_comp, ax_pen, ax_succ = axes | |
| # --- Total --- | |
| ax_total.plot(eps, log.total, alpha=0.25, color="#1f77b4", label="per episode") | |
| ax_total.plot(eps, rolling_avg(log.total, w), color="#1f77b4", linewidth=2, | |
| label=f"rolling({w})") | |
| ax_total.axhline(y=0, color="gray", linestyle="--", alpha=0.5) | |
| ax_total.axhline(y=0.5, color="green", linestyle=":", alpha=0.5, label="pass threshold") | |
| ax_total.set_ylabel("Total reward") | |
| ax_total.set_title(f"DriftShield — GRPO training ({log.n()} episodes, schema={log.schema})") | |
| ax_total.legend(loc="lower right") | |
| ax_total.grid(True, alpha=0.3) | |
| # --- Components --- | |
| ax_comp.plot(eps, rolling_avg(log.investigation, w), linewidth=2, label="investigation", color="#d62728") | |
| ax_comp.plot(eps, rolling_avg(log.routing, w), linewidth=2, label="routing", color="#ff7f0e") | |
| ax_comp.plot(eps, rolling_avg(log.reply, w), linewidth=2, label="reply_quality", color="#2ca02c") | |
| ax_comp.plot(eps, rolling_avg(log.grounding, w), linewidth=2, label="groundedness", color="#9467bd") | |
| ax_comp.plot(eps, rolling_avg(log.submission, w), linewidth=2, label="submission", color="#17becf") | |
| ax_comp.set_ylabel("Component reward") | |
| ax_comp.legend(loc="lower right", ncol=3) | |
| ax_comp.grid(True, alpha=0.3) | |
| # --- Penalties (lower is better) --- | |
| ax_pen.plot(eps, log.penalty, alpha=0.25, color="#8c564b") | |
| ax_pen.plot(eps, rolling_avg(log.penalty, w), linewidth=2, color="#8c564b", | |
| label=f"penalty total rolling({w})") | |
| ax_pen.set_ylabel("Penalty total (↓)") | |
| ax_pen.legend(loc="upper right") | |
| ax_pen.grid(True, alpha=0.3) | |
| # --- Rolling success rate --- | |
| ax_succ.plot(eps, rolling_success_rate(log.total, w, threshold=0.5), | |
| linewidth=2, color="#2ca02c", label=f"rolling({w}) pass rate (total≥0.5)") | |
| ax_succ.set_xlabel("Episode") | |
| ax_succ.set_ylabel("Pass rate") | |
| ax_succ.set_ylim(-0.05, 1.05) | |
| ax_succ.legend(loc="lower right") | |
| ax_succ.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| out = save_path or path.with_suffix(".png") | |
| fmt = out.suffix.lower().lstrip(".") or "png" | |
| save_kwargs: Dict[str, object] = {"bbox_inches": "tight"} | |
| if fmt in ("png", "jpg", "jpeg", "tif", "tiff", "webp"): | |
| save_kwargs["dpi"] = 150 | |
| plt.savefig(out, format=fmt, **save_kwargs) | |
| plt.close() | |
| print(f"Plot saved to {out}") | |
| print(f"\nSchema : {log.schema}") | |
| print(f"Episodes : {log.n()}") | |
| print(f"Latest reward : {log.total[-1]:+.3f}") | |
| print(f"Avg (last 10) : {sum(log.total[-10:]) / min(10, log.n()):+.3f}") | |
| print(f"Best reward : {max(log.total):+.3f}") | |
| print(f"Pass rate (≥0.5): {sum(1 for v in log.total if v >= 0.5) / log.n():.1%}") | |
| if any(log.penalty): | |
| print(f"Avg penalty : {sum(log.penalty) / log.n():.3f}") | |
| # ---------------------------------------------------------------------- | |
| # ASCII table | |
| # ---------------------------------------------------------------------- | |
| def print_table(path: Path, tail: Optional[int] = None) -> None: | |
| log = load_csv(path) | |
| if log.n() == 0: | |
| print("No data yet.") | |
| return | |
| start = 0 | |
| if tail is not None and tail > 0 and log.n() > tail: | |
| start = log.n() - tail | |
| print(f"\n... showing last {tail} of {log.n()} rows ...") | |
| print( | |
| f"\n{'Ep':>5} | {'Task':<22} | {'Total':>7} | {'Inv':>6} | " | |
| f"{'Rout':>6} | {'Reply':>6} | {'Gnd':>6} | {'Sub':>6} | {'Pen':>6} | {'Avg10':>7}" | |
| ) | |
| print("-" * 113) | |
| prefix_max = _prefix_maxima(log.total) | |
| for i in range(start, log.n()): | |
| avg10 = sum(log.total[max(0, i - 9):i + 1]) / min(i + 1, 10) | |
| marker = " *" if log.total[i] == prefix_max[i] else "" | |
| tid = (log.task_ids[i] or "-")[:22] | |
| print( | |
| f"{log.episodes[i]:>5} | {tid:<22} | {log.total[i]:>+7.2f} | " | |
| f"{log.investigation[i]:>+6.2f} | {log.routing[i]:>+6.2f} | " | |
| f"{log.reply[i]:>+6.2f} | {log.grounding[i]:>+6.2f} | " | |
| f"{log.submission[i]:>+6.2f} | {log.penalty[i]:>+6.2f} | " | |
| f"{avg10:>+7.2f}{marker}" | |
| ) | |
| best_idx = log.total.index(max(log.total)) | |
| print(f"\nBest : {max(log.total):+.3f} (ep {log.episodes[best_idx]}" | |
| f", task={log.task_ids[best_idx] or 'n/a'})") | |
| print(f"Avg : {sum(log.total) / log.n():+.3f}") | |
| passes = sum(1 for v in log.total if v >= 0.5) | |
| print(f"Pass : {passes}/{log.n()} ({passes / log.n():.1%})") | |
| # ---------------------------------------------------------------------- | |
| # CLI | |
| # ---------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Plot DriftShield GRPO rewards") | |
| parser.add_argument("csv_path", nargs="?", help="Path to reward_log.csv") | |
| parser.add_argument("--live", action="store_true", help="Refresh every 30s") | |
| parser.add_argument("--table", action="store_true", help="Print ASCII table instead of plot") | |
| parser.add_argument("--tail", type=int, default=None, help="Show only last N rows in table mode") | |
| parser.add_argument("--window", type=int, default=10, help="Rolling-window size") | |
| parser.add_argument("--out", default=None, help="Output image path") | |
| args = parser.parse_args() | |
| path = Path(args.csv_path) if args.csv_path else find_latest_csv() | |
| if not path or not path.exists(): | |
| print("No reward_log.csv found. Run training first or specify path.") | |
| sys.exit(1) | |
| print(f"Reading: {path}") | |
| if args.table: | |
| print_table(path, tail=args.tail) | |
| return | |
| if args.live: | |
| while True: | |
| try: | |
| plot(path, Path(args.out) if args.out else None, window=args.window) | |
| print(f"[{time.strftime('%H:%M:%S')}] sleeping 30s (Ctrl+C to stop)") | |
| time.sleep(30) | |
| except KeyboardInterrupt: | |
| break | |
| else: | |
| plot(path, Path(args.out) if args.out else None, window=args.window) | |
| if __name__ == "__main__": | |
| main() | |