File size: 6,591 Bytes
099bec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python
"""Generate a side-by-side untrained-vs-trained trace demo for the blog/README.

Reads two eval JSONs produced by `run_eval.py`, picks N scenarios where the
trained model dramatically outscored the base (or where the trained model
asked-and-the-base-hallucinated), and emits a markdown file with a clean
two-column comparison + rubric breakdown.

Usage:
    python scripts/make_trace_demo.py \\
        --base outputs/eval_qwen3-1.7b_base.json \\
        --trained outputs/eval_qwen3-1.7b_trained.json \\
        --out docs/trace_demo.md \\
        --n 3
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any


def _load(path: str) -> dict[str, Any]:
    return json.loads(Path(path).read_text())


def _index(eval_dict: dict[str, Any]) -> dict[str, dict[str, Any]]:
    return {r["scenario_id"]: r for r in eval_dict.get("results", [])}


def _format_breakdown(bd: dict[str, float] | None) -> str:
    if not bd:
        return "_no breakdown_"
    parts = []
    for k in ("FieldMatch", "InfoGain", "QuestionEfficiency", "HallucinationCheck"):
        v = bd.get(k, bd.get(k.lower(), 0.0))
        try:
            parts.append(f"{k}={float(v):.2f}")
        except (TypeError, ValueError):
            parts.append(f"{k}=?")
    return " Β· ".join(parts)


def _render_messages(messages: list[dict[str, Any]] | None, max_chars: int = 400) -> str:
    if not messages:
        return "_(no messages captured)_"
    lines: list[str] = []
    for m in messages[:30]:
        role = m.get("role", "?")
        content = (m.get("content") or "").strip()
        if len(content) > max_chars:
            content = content[: max_chars - 1] + "…"
        if not content:
            tool = m.get("tool_calls") or []
            if tool:
                names = ", ".join(t.get("name", "?") for t in tool)
                content = f"_[tool: {names}]_"
            else:
                continue
        lines.append(f"**{role}**: {content}")
    return "\n\n".join(lines)


def _pick_demo_scenarios(
    base: dict[str, dict[str, Any]],
    trained: dict[str, dict[str, Any]],
    n: int,
) -> list[str]:
    common = set(base) & set(trained)
    diffs: list[tuple[str, float, dict[str, Any], dict[str, Any]]] = []
    for sid in common:
        b, t = base[sid], trained[sid]
        b_score = float(b.get("final_score", 0.0))
        t_score = float(t.get("final_score", 0.0))
        delta = t_score - b_score
        diffs.append((sid, delta, b, t))

    diffs.sort(key=lambda x: -x[1])
    seen_families: set[str] = set()
    picks: list[str] = []
    for sid, delta, b, _t in diffs:
        if delta <= 0.05:
            break
        fam = b.get("family", "?")
        if fam in seen_families:
            continue
        seen_families.add(fam)
        picks.append(sid)
        if len(picks) >= n:
            break

    if len(picks) < n:
        for sid, delta, _b, _t in diffs:
            if sid in picks:
                continue
            if delta > 0:
                picks.append(sid)
            if len(picks) >= n:
                break
    return picks


def _emit(
    out_path: Path,
    base: dict[str, Any],
    trained: dict[str, Any],
    picks: list[str],
) -> None:
    base_idx = _index(base)
    trained_idx = _index(trained)

    base_label = base.get("label", "untrained")
    trained_label = trained.get("label", "trained")

    parts: list[str] = []
    parts.append(f"# Two-trace demo β€” {base_label} vs {trained_label}")
    parts.append("")
    parts.append(f"_{len(picks)} scenarios where the trained model substantially outperformed the base._")
    parts.append("")
    parts.append("Each row shows: the ambiguous request β†’ the agent's full message trace β†’ final rubric breakdown.")
    parts.append("")

    for i, sid in enumerate(picks, 1):
        b = base_idx[sid]
        t = trained_idx[sid]

        family = b.get("family", "?")
        difficulty = b.get("difficulty", "?")
        request = (b.get("request") or t.get("request") or "_(no request captured)_").strip()
        if len(request) > 240:
            request = request[:240] + "…"

        parts.append(f"## {i}. `{sid}` β€” `{family}` (`{difficulty}`)")
        parts.append("")
        parts.append(f"**Request**: {request}")
        parts.append("")
        parts.append(
            "| Run | Score | Q's asked | Format pass | Rubric breakdown |"
        )
        parts.append("|-----|-------|-----------|-------------|------------------|")
        parts.append(
            f"| {base_label} | **{float(b.get('final_score', 0.0)):.2f}** | "
            f"{b.get('questions_asked', 0)} | "
            f"{'βœ“' if b.get('format_pass') else 'βœ—'} | "
            f"{_format_breakdown(b.get('score_breakdown'))} |"
        )
        parts.append(
            f"| {trained_label} | **{float(t.get('final_score', 0.0)):.2f}** | "
            f"{t.get('questions_asked', 0)} | "
            f"{'βœ“' if t.get('format_pass') else 'βœ—'} | "
            f"{_format_breakdown(t.get('score_breakdown'))} |"
        )
        parts.append("")
        parts.append(f"**{base_label} trace:**")
        parts.append("")
        parts.append(_render_messages(b.get("messages")))
        parts.append("")
        parts.append(f"**{trained_label} trace:**")
        parts.append("")
        parts.append(_render_messages(t.get("messages")))
        parts.append("")
        parts.append("---")
        parts.append("")

    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text("\n".join(parts))
    print(f"[ok] wrote {out_path} with {len(picks)} demo scenarios")


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("--base", required=True, help="Path to base eval JSON")
    parser.add_argument("--trained", required=True, help="Path to trained eval JSON")
    parser.add_argument("--out", default="docs/trace_demo.md", help="Output markdown path")
    parser.add_argument("--n", type=int, default=3, help="Number of demo scenarios to include")
    args = parser.parse_args()

    base = _load(args.base)
    trained = _load(args.trained)
    base_idx = _index(base)
    trained_idx = _index(trained)
    picks = _pick_demo_scenarios(base_idx, trained_idx, args.n)
    if not picks:
        print("[warn] No scenarios where trained > base by >0.05 β€” nothing to demo")
        return
    _emit(Path(args.out), base, trained, picks)


if __name__ == "__main__":
    main()