agentic-rl-main / scripts /compare_trimode_logs.py
Jack04810's picture
Upload folder using huggingface_hub
7451c3f verified
Raw
History Blame Contribute Delete
5.98 kB
#!/usr/bin/env python3
"""Compare key TriMode training metrics between baseline and new logs."""
from __future__ import annotations
import argparse
import ast
import re
import sys
from collections import defaultdict
from pathlib import Path
def parse_metrics(path: str) -> list[dict]:
text = Path(path).read_text(encoding="utf-8", errors="replace")
metrics = []
for m in re.finditer(r"\{'loss':[^\n]+\}", text):
try:
metrics.append(ast.literal_eval(m.group()))
except (SyntaxError, ValueError):
pass
return metrics
def alert_counts(path: str) -> dict[str, int]:
text = Path(path).read_text(encoding="utf-8", errors="replace")
counts: dict[str, int] = defaultdict(int)
for code in re.findall(r"\[global_step=\d+\]\[ALERT\] (\w+)", text):
counts[code] += 1
return dict(counts)
def opsd_mask_mean(path: str) -> float:
text = Path(path).read_text(encoding="utf-8", errors="replace")
probes = re.findall(r"opsd_mask_true=(\d+) \| opsd_mask_false=(\d+)", text)
ratios = [int(t) / (int(t) + int(f)) for t, f in probes if int(t) + int(f) > 0]
return sum(ratios) / len(ratios) if ratios else 0.0
def routing_field_mean(path: str, field: str) -> float:
"""Mean of routing/* fields from trainer log dicts (RLSD health metrics)."""
text = Path(path).read_text(encoding="utf-8", errors="replace")
key = f"routing/{field}"
values = []
for m in re.finditer(r"\{'loss':[^\n]+\}", text):
try:
row = ast.literal_eval(m.group())
except (SyntaxError, ValueError):
continue
if key in row:
values.append(float(row[key]))
return sum(values) / len(values) if values else 0.0
def metric_at(metrics: list[dict], idx: int, key: str, default=0.0):
if idx >= len(metrics):
return default
return metrics[idx].get(key, default)
def summarize(label: str, path: str) -> dict:
metrics = parse_metrics(path)
alerts = alert_counts(path)
return {
"label": label,
"path": path,
"steps": len(metrics),
"step1_clip": metric_at(metrics, 1, "completions/clipped_ratio"),
"step1_eos": 1.0 - metric_at(metrics, 1, "completions/clipped_ratio"), # proxy if eos not in metrics
"logit_collapse": alerts.get("LOGIT_MODE_COLLAPSE", 0),
"gen_clip_collapse": alerts.get("GEN_CLIP_COLLAPSE", 0),
"rl_zero": alerts.get("RL_ZERO_SIGNAL", 0),
"opsd_mask_mean": opsd_mask_mean(path),
"opsd_on_correct_rate": routing_field_mean(path, "opsd_on_correct_rate"),
"privileged_suffix_has_gold_rate": routing_field_mean(path, "privileged_suffix_has_gold_rate"),
"leakage_pattern_rate": routing_field_mean(path, "leakage_pattern_rate"),
"late_format": metric_at(metrics, min(200, len(metrics) - 1), "rewards/format/mean"),
"late_mean_len": metric_at(metrics, min(200, len(metrics) - 1), "completions/mean_length"),
"late_acc": metric_at(metrics, min(200, len(metrics) - 1), "rewards/accuracy/mean"),
}
def main() -> int:
parser = argparse.ArgumentParser(description="Compare two TriMode training logs")
parser.add_argument("baseline", help="Baseline log (e.g. pre-antidegen run)")
parser.add_argument("candidate", nargs="?", help="New log to compare (optional)")
args = parser.parse_args()
base = summarize("baseline", args.baseline)
print(f"# TriMode log comparison\n")
print(f"| metric | {base['label']} |")
print(f"|--------|----------|")
print(f"| steps | {base['steps']} |")
print(f"| step1 clip | {base['step1_clip']:.3f} |")
print(f"| LOGIT_MODE_COLLAPSE | {base['logit_collapse']} |")
print(f"| GEN_CLIP_COLLAPSE | {base['gen_clip_collapse']} |")
print(f"| opsd_mask mean | {base['opsd_mask_mean']:.3f} |")
print(f"| opsd_on_correct_rate | {base['opsd_on_correct_rate']:.4f} |")
print(f"| privileged_suffix_has_gold_rate | {base['privileged_suffix_has_gold_rate']:.4f} |")
print(f"| leakage_pattern_rate | {base['leakage_pattern_rate']:.4f} |")
print(f"| step~200 format | {base['late_format']:.3f} |")
print(f"| step~200 acc | {base['late_acc']:.3f} |")
print(f"| step~200 mean_len | {base['late_mean_len']:.1f} |")
if not args.candidate:
print("\n(Tip: pass a second log path to print delta columns.)")
return 0
cand = summarize("candidate", args.candidate)
print(f"\n| metric | {base['label']} | {cand['label']} | delta |")
print(f"|--------|----------|-----------|-------|")
for key in (
"steps",
"step1_clip",
"logit_collapse",
"gen_clip_collapse",
"opsd_mask_mean",
"opsd_on_correct_rate",
"privileged_suffix_has_gold_rate",
"leakage_pattern_rate",
"late_format",
"late_acc",
"late_mean_len",
):
b, c = base[key], cand[key]
if isinstance(b, float):
delta = c - b
print(f"| {key} | {b:.3f} | {c:.3f} | {delta:+.3f} |")
else:
delta = c - b
print(f"| {key} | {b} | {c} | {delta:+d} |")
# Success criteria from antidegen plan
print("\n## Antidegen success checks (candidate vs baseline)")
checks = [
("step1 clip < 1.0", cand["step1_clip"] < 1.0),
("LOGIT_MODE_COLLAPSE down >30%", cand["logit_collapse"] < base["logit_collapse"] * 0.7),
("opsd_mask mean > 8%", cand["opsd_mask_mean"] > 0.08),
("opsd_mask improved", cand["opsd_mask_mean"] > base["opsd_mask_mean"]),
("opsd_on_correct_rate == 0 (RLSD)", cand["opsd_on_correct_rate"] < 0.01),
("no privileged gold suffix (RLSD)", cand["privileged_suffix_has_gold_rate"] < 0.01),
("no leakage patterns", cand["leakage_pattern_rate"] < 0.01),
]
for name, ok in checks:
print(f"- [{'x' if ok else ' '}] {name}")
return 0
if __name__ == "__main__":
sys.exit(main())