File size: 15,009 Bytes
3550904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c065aa
 
 
3550904
2f2776e
 
 
1c065aa
2f2776e
 
1c065aa
 
2f2776e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3550904
1c065aa
3550904
1c065aa
3550904
2f2776e
 
 
 
3550904
1c065aa
 
 
 
 
 
3550904
 
1c065aa
 
3550904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f2776e
 
 
1c065aa
 
 
 
 
 
 
 
 
2f2776e
 
 
 
 
1c065aa
 
 
 
 
 
 
3550904
 
 
 
 
 
 
 
 
 
 
1c065aa
 
 
3550904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c065aa
 
 
 
 
 
 
 
 
 
3550904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
#!/usr/bin/env python3
"""
Summarize locked-stream multi-seed dropout schedule results.

MIT License

Copyright (c) 2025 Andrej Karpathy

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"""

from __future__ import annotations

import argparse
import csv
from collections import defaultdict
import json
from pathlib import Path
import statistics


DEFAULT_CONDITIONS = [
    "interaction",
    "baseabc",
    "smooth_low",
    "static_dropout_0.08",
    "static_dropout_0.12",
    "static_dropout_0.18",
]


def read_metrics(paths: list[Path], conditions: set[str]) -> list[dict]:
    rows: list[dict] = []
    for path in paths:
        for line in path.read_text(encoding="utf-8").splitlines():
            if not line.strip():
                continue
            row = json.loads(line)
            if row["condition"] not in conditions:
                continue
            rows.append(row)
    return rows


def mean(values: list[float]) -> float:
    return statistics.fmean(values)


def std(values: list[float]) -> float:
    if len(values) < 2:
        return 0.0
    return statistics.stdev(values)


def fmt(value: float) -> str:
    return f"{value:.4f}"


def grouped(rows: list[dict], *keys: str) -> dict[tuple, list[dict]]:
    out: dict[tuple, list[dict]] = defaultdict(list)
    for row in rows:
        out[tuple(row[key] for key in keys)].append(row)
    return out


def condition_kind(rows: list[dict], condition: str) -> str:
    for row in rows:
        if row["condition"] == condition:
            return row["condition_kind"]
    return ""


def dropout_path(rows: list[dict], condition: str) -> str:
    items = sorted(
        [row for row in rows if row["condition"] == condition and row["seed"] == min_seed(rows)],
        key=lambda row: row["stage"],
    )
    return " -> ".join(f"{float(row['dropout_active_final']):.2f}" for row in items)


def min_seed(rows: list[dict]) -> int:
    return min(int(row["seed"]) for row in rows)


def write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def build_summary(rows: list[dict], conditions: list[str]) -> tuple[list[dict], list[dict], list[dict]]:
    by_condition_seed = grouped(rows, "condition", "seed")
    final_by_condition: dict[str, list[float]] = defaultdict(list)
    trajectory_by_condition: dict[str, list[float]] = defaultdict(list)
    gap_by_condition: dict[str, list[float]] = defaultdict(list)

    for (condition, _seed), items in by_condition_seed.items():
        items = sorted(items, key=lambda row: row["stage"])
        trajectory_by_condition[condition].append(mean([float(row["val_eval_loss"]) for row in items]))
        final = max(items, key=lambda row: row["stage"])
        final_by_condition[condition].append(float(final["val_eval_loss"]))
        gap_by_condition[condition].append(float(final["generalization_gap"]))

    condition_rows = []
    for condition in conditions:
        final_values = final_by_condition[condition]
        trajectory_values = trajectory_by_condition[condition]
        gap_values = gap_by_condition[condition]
        condition_rows.append(
            {
                "condition": condition,
                "kind": condition_kind(rows, condition),
                "n": len(final_values),
                "mean_trajectory_val": mean(trajectory_values),
                "std_trajectory_val": std(trajectory_values),
                "mean_final_val": mean(final_values),
                "std_final_val": std(final_values),
                "mean_final_gap": mean(gap_values),
                "std_final_gap": std(gap_values),
                "dropout_path": dropout_path(rows, condition),
            }
        )
    condition_rows.sort(key=lambda row: row["mean_final_val"])

    by_condition_stage = grouped(rows, "condition", "stage")
    stage_rows = []
    for condition in conditions:
        for stage in sorted({int(row["stage"]) for row in rows}):
            items = by_condition_stage[(condition, stage)]
            if not items:
                continue
            val_values = [float(row["val_eval_loss"]) for row in items]
            train_values = [float(row["train_eval_loss"]) for row in items]
            gap_values = [float(row["generalization_gap"]) for row in items]
            stage_rows.append(
                {
                    "condition": condition,
                    "stage": stage,
                    "token_limit": int(items[0]["token_limit"]),
                    "dropout": mean([float(row["dropout_active_final"]) for row in items]),
                    "n": len(items),
                    "mean_val": mean(val_values),
                    "std_val": std(val_values),
                    "mean_train": mean(train_values),
                    "std_train": std(train_values),
                    "mean_gap": mean(gap_values),
                    "std_gap": std(gap_values),
                }
            )

    static_conditions = [condition for condition in conditions if condition.startswith("static_")]
    final_rows = [row for row in rows if int(row["stage"]) == max(int(item["stage"]) for item in rows)]
    by_seed = grouped(final_rows, "seed")
    paired_rows = []
    for (seed,), items in sorted(by_seed.items()):
        static_items = [row for row in items if row["condition"] in static_conditions]
        best_static = min(static_items, key=lambda row: float(row["val_eval_loss"]))
        for condition in conditions:
            match = [row for row in items if row["condition"] == condition]
            if not match:
                continue
            row = match[0]
            paired_rows.append(
                {
                    "seed": int(seed),
                    "condition": condition,
                    "final_val": float(row["val_eval_loss"]),
                    "best_static_condition": best_static["condition"],
                    "best_static_final_val": float(best_static["val_eval_loss"]),
                    "delta_vs_best_static": float(row["val_eval_loss"])
                    - float(best_static["val_eval_loss"]),
                }
            )

    return condition_rows, stage_rows, paired_rows


def write_report(
    path: Path,
    condition_rows: list[dict],
    stage_rows: list[dict],
    paired_rows: list[dict],
    metrics_paths: list[Path],
    title: str,
    date: str,
    context: str,
) -> None:
    seed_ids = sorted({int(row["seed"]) for row in paired_rows})
    seed_count = len(seed_ids)
    best_row = condition_rows[0]
    second_row = condition_rows[1] if len(condition_rows) > 1 else None
    static_rows = [row for row in condition_rows if row["condition"].startswith("static_")]
    best_static_row = min(static_rows, key=lambda row: row["mean_final_val"])
    first_stage_rows = [row for row in stage_rows if int(row["stage"]) == 0]
    best_first_stage = min(first_stage_rows, key=lambda row: row["mean_val"])

    paired_win_lines = []
    for row in condition_rows:
        condition = row["condition"]
        if condition.startswith("static_"):
            continue
        condition_deltas = [
            item["delta_vs_best_static"]
            for item in paired_rows
            if item["condition"] == condition
        ]
        wins = sum(delta < 0 for delta in condition_deltas)
        ties = sum(delta == 0 for delta in condition_deltas)
        worst_delta = max(condition_deltas)
        paired_win_lines.append(
            f"- `{condition}` beats the per-seed best static baseline in "
            f"{wins}/{seed_count} seeds"
            + (f" with {ties} exact ties" if ties else "")
            + f"; worst paired delta is {worst_delta:+.4f}."
        )

    lines = [
        f"# {title}",
        "",
        f"Date: {date}",
        "",
        f"This report combines {seed_count} random seeds "
        f"({', '.join(str(seed) for seed in seed_ids)}) from saved streaming runs.",
        "No additional training is performed by this script; it reads saved",
        "`metrics.jsonl` files.",
        "",
    ]
    if context:
        lines.extend([context, ""])

    lines.extend(
        [
        "## Sources",
        "",
        ]
    )
    for path_item in metrics_paths:
        lines.append(f"- `{path_item}`")

    lines.extend(
        [
            "",
            "## Condition Ranking By Final Loss",
            "",
            "| Condition | Kind | N | Mean trajectory val | Std trajectory val | Mean final val | Std final val | Mean final gap | Dropout path |",
            "|---|---|---:|---:|---:|---:|---:|---:|---|",
        ]
    )
    for row in condition_rows:
        lines.append(
            f"| `{row['condition']}` | `{row['kind']}` | {row['n']} | "
            f"{fmt(row['mean_trajectory_val'])} | {fmt(row['std_trajectory_val'])} | "
            f"{fmt(row['mean_final_val'])} | {fmt(row['std_final_val'])} | "
            f"{fmt(row['mean_final_gap'])} | `{row['dropout_path']}` |"
        )

    lines.extend(
        [
            "",
            "## Paired Final-Loss Deltas",
            "",
            "Negative `delta_vs_best_static` means the condition beat the best static",
            "baseline for that seed.",
            "",
            "| Seed | Condition | Final val | Best static | Best static final val | Delta vs best static |",
            "|---:|---|---:|---|---:|---:|",
        ]
    )
    for row in paired_rows:
        lines.append(
            f"| {row['seed']} | `{row['condition']}` | {fmt(row['final_val'])} | "
            f"`{row['best_static_condition']}` | {fmt(row['best_static_final_val'])} | "
            f"{row['delta_vs_best_static']:+.4f} |"
        )

    lines.extend(
        [
            "",
            "## Stage Trajectory",
            "",
            "| Stage | Prefix tokens | Condition | Dropout | N | Mean val | Std val | Mean train | Mean gap |",
            "|---:|---:|---|---:|---:|---:|---:|---:|---:|",
        ]
    )
    for row in sorted(stage_rows, key=lambda item: (item["stage"], item["mean_val"])):
        lines.append(
            f"| {row['stage']} | {row['token_limit']:,} | `{row['condition']}` | "
            f"{row['dropout']:.3f} | {row['n']} | {fmt(row['mean_val'])} | "
            f"{fmt(row['std_val'])} | {fmt(row['mean_train'])} | {fmt(row['mean_gap'])} |"
        )

    lines.extend(
        [
            "",
            "## Interpretation",
            "",
            f"- `{best_row['condition']}` has the best {seed_count}-seed mean final "
            f"validation loss: {fmt(best_row['mean_final_val'])} +/- "
            f"{fmt(best_row['std_final_val'])}.",
            *(
                [
                    f"- The second-best final condition is `{second_row['condition']}` at "
                    f"{fmt(second_row['mean_final_val'])} +/- "
                    f"{fmt(second_row['std_final_val'])}."
                ]
                if second_row is not None
                else []
            ),
            f"- The best static baseline by mean final loss is "
            f"`{best_static_row['condition']}` at "
            f"{fmt(best_static_row['mean_final_val'])} +/- "
            f"{fmt(best_static_row['std_final_val'])}.",
            *paired_win_lines,
            f"- The best first-stage condition is `{best_first_stage['condition']}` "
            f"at prefix {best_first_stage['token_limit']:,} with mean validation "
            f"loss {fmt(best_first_stage['mean_val'])}; compare this with the final "
            "ranking before claiming a schedule is uniformly better.",
            "- This is a saved-run streaming validation artifact. Treat it as strong",
            "  evidence only when the tested conditions, seeds, static baselines, and",
            "  stream protocol match the claim being made.",
        ]
    )
    path.write_text("\n".join(lines) + "\n", encoding="utf-8")


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.add_argument("--metrics", nargs="+", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--report", type=Path, required=True)
    parser.add_argument("--conditions", nargs="+", default=DEFAULT_CONDITIONS)
    parser.add_argument("--title", default="TinyStories Multi-Seed Streaming Validation")
    parser.add_argument("--date", default="2026-05-30")
    parser.add_argument("--context", default="")
    return parser


def main() -> None:
    args = build_parser().parse_args()
    args.output_dir.mkdir(parents=True, exist_ok=True)
    rows = read_metrics(args.metrics, set(args.conditions))
    condition_rows, stage_rows, paired_rows = build_summary(rows, args.conditions)
    write_csv(
        args.output_dir / "condition_summary.csv",
        condition_rows,
        [
            "condition",
            "kind",
            "n",
            "mean_trajectory_val",
            "std_trajectory_val",
            "mean_final_val",
            "std_final_val",
            "mean_final_gap",
            "std_final_gap",
            "dropout_path",
        ],
    )
    write_csv(
        args.output_dir / "stage_summary.csv",
        stage_rows,
        [
            "condition",
            "stage",
            "token_limit",
            "dropout",
            "n",
            "mean_val",
            "std_val",
            "mean_train",
            "std_train",
            "mean_gap",
            "std_gap",
        ],
    )
    write_csv(
        args.output_dir / "paired_final_deltas.csv",
        paired_rows,
        [
            "seed",
            "condition",
            "final_val",
            "best_static_condition",
            "best_static_final_val",
            "delta_vs_best_static",
        ],
    )
    write_report(
        args.report,
        condition_rows,
        stage_rows,
        paired_rows,
        args.metrics,
        args.title,
        args.date,
        args.context,
    )
    print(
        json.dumps(
            {
                "report": str(args.report),
                "output_dir": str(args.output_dir),
                "rows": len(rows),
                "conditions": args.conditions,
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()