| |
| """Compare key TriMode training metrics between baseline and new logs.""" |
| from __future__ import annotations |
|
|
| import argparse |
| import ast |
| import re |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
|
|
| def parse_metrics(path: str) -> list[dict]: |
| text = Path(path).read_text(encoding="utf-8", errors="replace") |
| metrics = [] |
| for m in re.finditer(r"\{'loss':[^\n]+\}", text): |
| try: |
| metrics.append(ast.literal_eval(m.group())) |
| except (SyntaxError, ValueError): |
| pass |
| return metrics |
|
|
|
|
| def alert_counts(path: str) -> dict[str, int]: |
| text = Path(path).read_text(encoding="utf-8", errors="replace") |
| counts: dict[str, int] = defaultdict(int) |
| for code in re.findall(r"\[global_step=\d+\]\[ALERT\] (\w+)", text): |
| counts[code] += 1 |
| return dict(counts) |
|
|
|
|
| def opsd_mask_mean(path: str) -> float: |
| text = Path(path).read_text(encoding="utf-8", errors="replace") |
| probes = re.findall(r"opsd_mask_true=(\d+) \| opsd_mask_false=(\d+)", text) |
| ratios = [int(t) / (int(t) + int(f)) for t, f in probes if int(t) + int(f) > 0] |
| return sum(ratios) / len(ratios) if ratios else 0.0 |
|
|
|
|
| def routing_field_mean(path: str, field: str) -> float: |
| """Mean of routing/* fields from trainer log dicts (RLSD health metrics).""" |
| text = Path(path).read_text(encoding="utf-8", errors="replace") |
| key = f"routing/{field}" |
| values = [] |
| for m in re.finditer(r"\{'loss':[^\n]+\}", text): |
| try: |
| row = ast.literal_eval(m.group()) |
| except (SyntaxError, ValueError): |
| continue |
| if key in row: |
| values.append(float(row[key])) |
| return sum(values) / len(values) if values else 0.0 |
|
|
|
|
| def metric_at(metrics: list[dict], idx: int, key: str, default=0.0): |
| if idx >= len(metrics): |
| return default |
| return metrics[idx].get(key, default) |
|
|
|
|
| def summarize(label: str, path: str) -> dict: |
| metrics = parse_metrics(path) |
| alerts = alert_counts(path) |
| return { |
| "label": label, |
| "path": path, |
| "steps": len(metrics), |
| "step1_clip": metric_at(metrics, 1, "completions/clipped_ratio"), |
| "step1_eos": 1.0 - metric_at(metrics, 1, "completions/clipped_ratio"), |
| "logit_collapse": alerts.get("LOGIT_MODE_COLLAPSE", 0), |
| "gen_clip_collapse": alerts.get("GEN_CLIP_COLLAPSE", 0), |
| "rl_zero": alerts.get("RL_ZERO_SIGNAL", 0), |
| "opsd_mask_mean": opsd_mask_mean(path), |
| "opsd_on_correct_rate": routing_field_mean(path, "opsd_on_correct_rate"), |
| "privileged_suffix_has_gold_rate": routing_field_mean(path, "privileged_suffix_has_gold_rate"), |
| "leakage_pattern_rate": routing_field_mean(path, "leakage_pattern_rate"), |
| "late_format": metric_at(metrics, min(200, len(metrics) - 1), "rewards/format/mean"), |
| "late_mean_len": metric_at(metrics, min(200, len(metrics) - 1), "completions/mean_length"), |
| "late_acc": metric_at(metrics, min(200, len(metrics) - 1), "rewards/accuracy/mean"), |
| } |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="Compare two TriMode training logs") |
| parser.add_argument("baseline", help="Baseline log (e.g. pre-antidegen run)") |
| parser.add_argument("candidate", nargs="?", help="New log to compare (optional)") |
| args = parser.parse_args() |
|
|
| base = summarize("baseline", args.baseline) |
| print(f"# TriMode log comparison\n") |
| print(f"| metric | {base['label']} |") |
| print(f"|--------|----------|") |
| print(f"| steps | {base['steps']} |") |
| print(f"| step1 clip | {base['step1_clip']:.3f} |") |
| print(f"| LOGIT_MODE_COLLAPSE | {base['logit_collapse']} |") |
| print(f"| GEN_CLIP_COLLAPSE | {base['gen_clip_collapse']} |") |
| print(f"| opsd_mask mean | {base['opsd_mask_mean']:.3f} |") |
| print(f"| opsd_on_correct_rate | {base['opsd_on_correct_rate']:.4f} |") |
| print(f"| privileged_suffix_has_gold_rate | {base['privileged_suffix_has_gold_rate']:.4f} |") |
| print(f"| leakage_pattern_rate | {base['leakage_pattern_rate']:.4f} |") |
| print(f"| step~200 format | {base['late_format']:.3f} |") |
| print(f"| step~200 acc | {base['late_acc']:.3f} |") |
| print(f"| step~200 mean_len | {base['late_mean_len']:.1f} |") |
|
|
| if not args.candidate: |
| print("\n(Tip: pass a second log path to print delta columns.)") |
| return 0 |
|
|
| cand = summarize("candidate", args.candidate) |
| print(f"\n| metric | {base['label']} | {cand['label']} | delta |") |
| print(f"|--------|----------|-----------|-------|") |
| for key in ( |
| "steps", |
| "step1_clip", |
| "logit_collapse", |
| "gen_clip_collapse", |
| "opsd_mask_mean", |
| "opsd_on_correct_rate", |
| "privileged_suffix_has_gold_rate", |
| "leakage_pattern_rate", |
| "late_format", |
| "late_acc", |
| "late_mean_len", |
| ): |
| b, c = base[key], cand[key] |
| if isinstance(b, float): |
| delta = c - b |
| print(f"| {key} | {b:.3f} | {c:.3f} | {delta:+.3f} |") |
| else: |
| delta = c - b |
| print(f"| {key} | {b} | {c} | {delta:+d} |") |
|
|
| |
| print("\n## Antidegen success checks (candidate vs baseline)") |
| checks = [ |
| ("step1 clip < 1.0", cand["step1_clip"] < 1.0), |
| ("LOGIT_MODE_COLLAPSE down >30%", cand["logit_collapse"] < base["logit_collapse"] * 0.7), |
| ("opsd_mask mean > 8%", cand["opsd_mask_mean"] > 0.08), |
| ("opsd_mask improved", cand["opsd_mask_mean"] > base["opsd_mask_mean"]), |
| ("opsd_on_correct_rate == 0 (RLSD)", cand["opsd_on_correct_rate"] < 0.01), |
| ("no privileged gold suffix (RLSD)", cand["privileged_suffix_has_gold_rate"] < 0.01), |
| ("no leakage patterns", cand["leakage_pattern_rate"] < 0.01), |
| ] |
| for name, ok in checks: |
| print(f"- [{'x' if ok else ' '}] {name}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|