agentic-rl-main / scripts /analyze_trimode_log.py
Jack04810's picture
Upload folder using huggingface_hub
7451c3f verified
Raw
History Blame Contribute Delete
12.2 kB
"""Analyze TriMode training log for routing, rewards, generation quality."""
import json
import re
import sys
from collections import defaultdict
LOG = sys.argv[1] if len(sys.argv) > 1 else "train_trimode_4gpu_20260609_175823.log"
re_step = re.compile(r"\[step=(\d+)\]\[every=10\]")
re_routing = re.compile(r"\[routing\] mode routing summary \| (.+)")
re_reward = re.compile(r"\[reward\] aggregate reward stats \| (.+)")
re_gen = re.compile(r"\[generation\] completion mask summary \| (.+)")
re_probe = re.compile(
r"global_step=(\d+) \| .*?"
r"effective_tokens_mean=([\d.]+) \| .*?"
r"paren_then_eos_count=(\d+) \| .*?"
r"one_token_count=(\d+) \| .*?"
r"eos_terminated_rate=([\d.]+)"
)
re_loss = re.compile(r"\[loss\] GRPO / OPSD loss breakdown \| (.+)")
re_per_sample = re.compile(
r"\[reward\] per_sample\[(\d+)\] \| group=(\d+) \| format=([\d.]+) \| acc=([\d.]+)"
)
re_completion = re.compile(
r"\[generation\] sample\[(\d+)\] \| group=(\d+) \| effective_tokens=(\d+) \| has_eos=(\w+) \| text='([^']*)'"
)
def parse_kv(s: str) -> dict:
d = {}
for m in re.finditer(r"(\w+)=([^|]+?)(?=\s*\||\s*$)", s):
k, v = m.group(1), m.group(2).strip()
try:
if v.startswith("{") or v.startswith("["):
d[k] = json.loads(v.replace("'", '"'))
elif re.match(r"^-?\d+\.\d+([eE][+-]?\d+)?$", v) or re.match(r"^-?\d+\.\d+$", v):
d[k] = float(v)
elif re.match(r"^-?\d+$", v):
d[k] = int(v)
else:
d[k] = v
except (json.JSONDecodeError, ValueError):
d[k] = v
return d
def main():
routing = {}
rewards = {}
gens = {}
losses_by_step = defaultdict(list)
probes = []
per_sample_rewards = defaultdict(list)
sample_completions = defaultdict(list)
first_ts = last_ts = None
grpo_routes = sft_routes = 0
total_debug_routes = 0
with open(LOG, "r", encoding="utf-8", errors="replace") as f:
for line in f:
m = re.search(r"\[(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})\]", line)
if m:
if first_ts is None:
first_ts = m.group(1)
last_ts = m.group(1)
if "completion_mode_counts=" in line and "rank=0/4" in line:
cm = re.search(r'completion_mode_counts=(\{[^}]+\})', line)
if cm:
total_debug_routes += 1
c = json.loads(cm.group(1).replace("'", '"'))
grpo_routes += c.get("GRPO", 0)
sft_routes += c.get("SFT", 0)
if "[OPSD-PROBE]" in line and "rank=0/4" in line and "[generate] raw generate summary" in line:
pm = re_probe.search(line)
if pm:
probes.append(
(
int(pm.group(1)),
float(pm.group(2)),
int(pm.group(3)),
int(pm.group(4)),
float(pm.group(5)),
)
)
continue
if "[OPSD-DETAIL]" not in line or "rank=0/4" not in line:
continue
sm = re_step.search(line)
if not sm:
continue
step = int(sm.group(1))
if "[routing]" in line:
rm = re_routing.search(line)
if rm:
routing[step] = parse_kv(rm.group(1))
elif "[reward] aggregate" in line:
rm = re_reward.search(line)
if rm:
rewards[step] = parse_kv(rm.group(1))
elif "[reward] per_sample" in line:
rm = re_per_sample.search(line)
if rm:
per_sample_rewards[step].append(
(int(rm.group(1)), float(rm.group(3)), float(rm.group(4)))
)
elif "[generation] completion mask" in line:
rm = re_gen.search(line)
if rm:
gens[step] = parse_kv(rm.group(1))
elif "[generation] sample[" in line:
rm = re_completion.search(line)
if rm and len(sample_completions[step]) < 4:
sample_completions[step].append(rm.group(5)[:120])
elif "[loss] GRPO / OPSD loss breakdown" in line:
rm = re_loss.search(line)
if rm:
losses_by_step[step].append(parse_kv(rm.group(1)))
print("=== TRAINING OVERVIEW ===")
print(f"Time range: {first_ts} -> {last_ts}")
print(f"OPSD-DETAIL routing snapshots (rank0): {len(routing)}")
if routing:
print(f"Global steps: {min(routing)} - {max(routing)}")
print(f"OPSD-PROBE generate summaries (rank0): {len(probes)}")
if probes:
print(f"Probe global steps: {min(p[0] for p in probes)} - {max(p[0] for p in probes)}")
print("\n=== ROUTING RATIOS (rank0 DETAIL, every 10 steps) ===")
print(f"{'step':>6} {'OPSD':>7} {'GRPO':>7} {'SFT':>7} {'has_correct':>16} {'mask_ratio':>10}")
steps_sorted = sorted(routing.keys())
sample_idx = list(range(0, len(steps_sorted), max(1, len(steps_sorted) // 12)))
if steps_sorted and steps_sorted[-1] not in [steps_sorted[i] for i in sample_idx]:
sample_idx.append(len(steps_sorted) - 1)
tot_opsd = tot_grpo = tot_sft = 0
any_grpo_steps = []
any_sft_steps = []
for step in steps_sorted:
r = routing[step]
counts = r.get("completion_mode_counts", {})
if isinstance(counts, str):
counts = json.loads(counts.replace("'", '"'))
total = sum(counts.values()) or 32
opsd = counts.get("OPSD", 0)
grpo = counts.get("GRPO", 0)
sft = counts.get("SFT", 0)
tot_opsd += opsd
tot_grpo += grpo
tot_sft += sft
if grpo > 0:
any_grpo_steps.append(step)
if sft > 0:
any_sft_steps.append(step)
for i in sample_idx:
step = steps_sorted[i]
r = routing[step]
counts = r.get("completion_mode_counts", {})
if isinstance(counts, str):
counts = json.loads(counts.replace("'", '"'))
total = sum(counts.values()) or 32
print(
f"{step:>6} {counts.get('OPSD', 0) / total * 100:>6.1f}%"
f" {counts.get('GRPO', 0) / total * 100:>6.1f}%"
f" {counts.get('SFT', 0) / total * 100:>6.1f}%"
f" {str(r.get('has_correct', '?')):>16}"
f" {r.get('opsd_mask_ratio', 0) * 100:>9.1f}%"
)
tot = tot_opsd + tot_grpo + tot_sft
print(
f"\nAggregate ({len(routing)} snapshots x 32 samples):"
f" OPSD={tot_opsd / tot * 100:.1f}%"
f" GRPO={tot_grpo / tot * 100:.1f}%"
f" SFT={tot_sft / tot * 100:.1f}%"
)
print(f"Steps with any GRPO on rank0: {len(any_grpo_steps)} ({any_grpo_steps[:5]}...)")
print(f"Steps with any SFT on rank0: {len(any_sft_steps)}")
# all-rank debug routing (rank0 only in DETAIL but DEBUG has all ranks)
opsd_all = total_debug_routes * 32 - grpo_routes - sft_routes
tot_all = total_debug_routes * 32
if tot_all:
print(
f"\nAll-rank DEBUG routing lines (rank0 only in file grep):"
f" GRPO samples={grpo_routes} SFT={sft_routes}"
f" (from rank=0 completion_mode_counts lines: {total_debug_routes})"
)
print("\n=== REWARD / FORMAT LEARNING (rank0) ===")
print(f"{'step':>6} {'fmt_zero':>9} {'acc_zero':>9} {'acc_sum':>8} {'fmt_sum':>8} {'w_mean':>8}")
reward_steps = sorted(rewards.keys())
r_sample = list(range(0, len(reward_steps), max(1, len(reward_steps) // 12)))
if reward_steps and reward_steps[-1] not in [reward_steps[i] for i in r_sample]:
r_sample.append(len(reward_steps) - 1)
for i in r_sample:
step = reward_steps[i]
r = rewards[step]
print(
f"{step:>6} {r.get('format_zero_rate', 0) * 100:>8.1f}%"
f" {r.get('acc_zero_rate', 0) * 100:>8.1f}%"
f" {r.get('acc_sum', 0):>8.2f}"
f" {r.get('format_sum', 0):>8.2f}"
f" {r.get('weighted_mean', 0):>8.4f}"
)
if rewards:
s0, sL = rewards[min(rewards)], rewards[max(rewards)]
print(f"\nStep {min(rewards)} -> {max(rewards)} delta:")
print(f" format_zero_rate: {s0.get('format_zero_rate', 0) * 100:.1f}% -> {sL.get('format_zero_rate', 0) * 100:.1f}%")
print(f" acc_zero_rate: {s0.get('acc_zero_rate', 0) * 100:.1f}% -> {sL.get('acc_zero_rate', 0) * 100:.1f}%")
print(f" acc_sum: {s0.get('acc_sum', 0):.2f} -> {sL.get('acc_sum', 0):.2f}")
print(f" format_sum: {s0.get('format_sum', 0):.2f} -> {sL.get('format_sum', 0):.2f}")
# format reward rate from per-sample
print("\n=== FORMAT REWARD RATE (per-sample, rank0) ===")
for step in [0, 10, 50, 100, 200, 300, 400, 500, 600, 700, 800, 850]:
if step in per_sample_rewards:
samples = per_sample_rewards[step]
fmt_ok = sum(1 for _, f, _ in samples if f > 0)
acc_ok = sum(1 for _, _, a in samples if a > 0)
print(f"step {step:>4}: format_reward>0: {fmt_ok}/{len(samples)} ({fmt_ok/len(samples)*100:.1f}%) acc>0: {acc_ok}/{len(samples)} ({acc_ok/len(samples)*100:.1f}%)")
print("\n=== SAMPLE COMPLETIONS (first 4 per step, rank0) ===")
for step in [0, 10, 100, 300, 500, max(sample_completions) if sample_completions else 0]:
if step in sample_completions:
print(f"\n--- step {step} ---")
for i, t in enumerate(sample_completions[step][:4]):
print(f" [{i}] {t}")
print("\n=== GENERATION LENGTH (rank0 DETAIL) ===")
print(f"{'step':>6} {'eff_tok':>10} {'eos_term':>9} {'at_max':>8}")
gen_steps = sorted(gens.keys())
g_sample = list(range(0, len(gen_steps), max(1, len(gen_steps) // 12)))
if gen_steps and gen_steps[-1] not in [gen_steps[i] for i in g_sample]:
g_sample.append(len(gen_steps) - 1)
for i in g_sample:
step = gen_steps[i]
g = gens[step]
print(
f"{step:>6} {g.get('effective_tokens_mean', 0):>10.1f}"
f" {g.get('eos_terminated_rate', 0) * 100:>8.1f}%"
f" {g.get('at_max_length_rate', 0) * 100:>7.1f}%"
)
print("\n=== GENERATION DEGENERATION (rank0 PROBE) ===")
if probes:
print(f"{'gstep':>6} {'eff_tok':>9} {'paren_eos':>10} {'one_tok':>8} {'eos_rate':>8}")
pick = [0, 1, 2, 3, 5, 10, 20, 50, 100, 200, 400, 600, 800, 850]
probe_by_gs = {p[0]: p for p in probes}
for gs in pick:
if gs in probe_by_gs:
p = probe_by_gs[gs]
print(f"{p[0]:>6} {p[1]:>9.1f} {p[2]:>10} {p[3]:>8} {p[4]:>7.2f}")
print(f"\nTotal paren_then_eos: {sum(p[2] for p in probes)} / {len(probes)} regenerates")
print(f"Total one_token: {sum(p[3] for p in probes)}")
early = [p for p in probes if p[0] <= 10]
late = [p for p in probes if p[0] >= 800]
if early:
print(f"Early (gs<=10) avg eff_tok={sum(p[1] for p in early)/len(early):.1f}")
if late:
print(f"Late (gs>=800) avg eff_tok={sum(p[1] for p in late)/len(late):.1f}, eos_rate={sum(p[4] for p in late)/len(late):.2f}")
print("\n=== LOSS (rank0, last micro-batch per step) ===")
loss_keys = ["grpo_loss_scalar", "opsd_loss_scalar", "opsd_loss", "opsd_active_samples", "opsd_samples"]
loss_steps = sorted(losses_by_step.keys())
l_sample = list(range(0, len(loss_steps), max(1, len(loss_steps) // 12)))
if loss_steps and loss_steps[-1] not in [loss_steps[i] for i in l_sample]:
l_sample.append(len(loss_steps) - 1)
print(f"{'step':>6} {'grpo':>10} {'opsd':>10}")
for i in l_sample:
step = loss_steps[i]
d = losses_by_step[step][-1]
opsd_l = d.get("opsd_loss_scalar", d.get("opsd_loss", 0))
print(f"{step:>6} {d.get('grpo_loss_scalar', 0):>10.4f} {opsd_l:>10.4f}")
if __name__ == "__main__":
main()