mumble-cleanup / src /cleanup /eval /report.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
4.73 kB
# turn eval.json + training_history.json + latency_benchmark.json into
# charts and a markdown table. mirrors privacy-filter/eval/report.py.
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def plot_learning_curves(history: list[dict], out_path: Path) -> None:
train_loss = [(h["step"], h["loss"]) for h in history if "loss" in h and "eval_loss" not in h]
eval_loss = [(h["step"], h["eval_loss"]) for h in history if "eval_loss" in h]
fig, ax = plt.subplots(figsize=(8, 4))
if train_loss:
steps, losses = zip(*train_loss)
ax.plot(steps, losses, label="train", linewidth=1.5)
if eval_loss:
steps, losses = zip(*eval_loss)
ax.plot(steps, losses, label="val", linewidth=2.0)
ax.set_xlabel("step")
ax.set_ylabel("loss")
ax.set_title("learning curves")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_metrics_comparison(reports: dict, out_path: Path) -> None:
# reports: {"raw": agg, "base": agg, "fine_tuned": agg}
rows = ["raw", "base", "fine_tuned"]
metrics = ["disfluency_removal_rate", "punctuation_f1", "faithfulness_mean", "pass_rate"]
labels = ["disfluency removed", "punct f1", "faithfulness", "pass rate"]
fig, ax = plt.subplots(figsize=(9, 4.5))
x = list(range(len(metrics)))
bar_w = 0.27
for i, row in enumerate(rows):
values = []
for m in metrics:
v = reports[row].get(m)
values.append(0.0 if v is None else v)
ax.bar([xi + i * bar_w for xi in x], values, width=bar_w, label=row)
ax.set_xticks([xi + bar_w for xi in x])
ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_title("quality metrics: raw vs base vs fine-tuned")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")
fig.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def plot_latency(sweep_fp32: dict, sweep_int8: dict | None, out_path: Path) -> None:
fig, ax = plt.subplots(figsize=(8, 4))
for name, sweep in [("fp32", sweep_fp32), ("int8", sweep_int8)]:
if not sweep:
continue
lengths = sorted(int(k) for k in sweep.keys())
p50 = [sweep[str(l)]["p50_ms"] for l in lengths]
p95 = [sweep[str(l)]["p95_ms"] for l in lengths]
ax.plot(lengths, p50, marker="o", label=f"{name} p50")
ax.plot(lengths, p95, marker="o", linestyle="--", label=f"{name} p95")
ax.set_xlabel("input length (tokens)")
ax.set_ylabel("latency (ms)")
ax.set_title("cpu latency by input length")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
def write_report_markdown(
run_dir: Path,
eval_path: Path,
history_path: Path | None,
latency_path: Path | None,
out_path: Path,
) -> None:
eval_data = json.loads(eval_path.read_text())
lines: list[str] = []
lines.append("# cleanup model report")
lines.append("")
lines.append("## quality metrics")
lines.append("")
lines.append("| model | disfluency removal | punct f1 | faithfulness | length ratio | pass rate |")
lines.append("|---|---:|---:|---:|---:|---:|")
rows = ["raw", "base", "fine_tuned"]
for row in rows:
m = eval_data.get(row)
if not m:
continue
d = m.get("disfluency_removal_rate")
d_str = "n/a" if d is None else f"{d:.3f}"
lines.append(
f"| {row} | {d_str} | {m['punctuation_f1']:.3f} | "
f"{m['faithfulness_mean']:.3f} | {m['length_ratio_mean']:.3f} | "
f"{m['pass_rate']:.3f} |"
)
lines.append("")
if latency_path and latency_path.exists():
lat = json.loads(latency_path.read_text())
lines.append("## cpu latency")
lines.append("")
lines.append("| length | p50 ms | p95 ms | p99 ms | mean ms |")
lines.append("|---:|---:|---:|---:|---:|")
for length, stats in lat.get("results_by_length", {}).items():
lines.append(
f"| {length} | {stats['p50_ms']:.1f} | {stats['p95_ms']:.1f} | "
f"{stats['p99_ms']:.1f} | {stats['mean_ms']:.1f} |"
)
if "realistic_mix" in lat and lat["realistic_mix"]:
r = lat["realistic_mix"]
lines.append("")
lines.append(
f"**realistic mix** ({r['samples']} real inputs, "
f"p50 token length {r['token_length_p50']}): "
f"p50={r['p50_ms']:.1f}ms p95={r['p95_ms']:.1f}ms p99={r['p99_ms']:.1f}ms"
)
Path(out_path).write_text("\n".join(lines))