File size: 65,625 Bytes
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4322ea0
 
 
a02e949
 
 
5b7ea5e
 
 
 
 
 
 
 
8313ca8
38fd260
5b7ea5e
 
 
 
 
 
e96b9d3
5b7ea5e
 
 
e96b9d3
5b7ea5e
 
 
 
 
 
 
d789de8
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d789de8
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
d789de8
 
 
 
 
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1451cc6
 
 
 
 
 
 
 
 
 
 
 
6c7b430
 
1451cc6
 
 
 
 
 
 
 
6c7b430
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8313ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7ea5e
 
 
 
 
 
 
 
 
38fd260
6c7b430
38fd260
1451cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38fd260
 
5b7ea5e
 
 
 
 
 
 
 
 
 
38fd260
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b7ca0e
 
 
 
 
 
 
 
 
5b7ea5e
 
 
6b7ca0e
5b7ea5e
 
6b7ca0e
5b7ea5e
 
6b7ca0e
5b7ea5e
 
 
 
6b7ca0e
5b7ea5e
 
6b7ca0e
5b7ea5e
 
 
6b7ca0e
 
5b7ea5e
 
6b7ca0e
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb0306
5b7ea5e
 
2fb0306
 
5b7ea5e
 
2fb0306
5b7ea5e
 
 
 
 
2fb0306
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb0306
 
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
2fb0306
 
 
 
 
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb0306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362e9ea
 
5b7ea5e
362e9ea
 
 
5b7ea5e
 
 
 
362e9ea
5b7ea5e
4322ea0
 
 
362e9ea
5b7ea5e
 
 
 
 
 
 
 
362e9ea
 
 
5b7ea5e
362e9ea
5b7ea5e
 
 
 
 
 
 
fb048e4
 
5b7ea5e
 
 
 
fb048e4
5b7ea5e
 
 
fb048e4
 
 
5b7ea5e
fb048e4
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d789de8
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e96b9d3
4322ea0
e96b9d3
 
 
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d789de8
5b7ea5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
"""LLM Loss Debugging & Optimization Framework.

A systematic 5-level debugging framework for diagnosing training issues.
Always start from Level 1 β€” fixing lower-level bugs before tuning
hyperparameters saves time.

Levels:
  0. Status Diagnosis   β€” classify current training health
  1. Data/Implementation β€” most common cause (70% of issues)
  2. Numerical Stability β€” dtype, normalization, gradient health
  3. Hyperparameters     β€” LR, batch size, warmup
  4. Fitting Diagnosis   β€” overfitting vs underfitting
  5. Architecture        β€” initialization, component checks
"""

import copy
import math
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from llm_lab.config import TrainConfig


# ═══════════════════════════════════════════════════════════════════
# Constants
# ═══════════════════════════════════════════════════════════════════

# Approximate convergence ranges for a 1B model trained on ~10B tokens.
# Estimated from GPT-2 scaling benchmarks (Radford et al. 2019) and
# Chinchilla scaling laws (Hoffmann et al. 2022). Not dataset-specific.
_EXPECTED_TRAIN_LOSS = (2.5, 3.3)
_EXPECTED_VAL_LOSS = (2.7, 3.6)
_EXPECTED_VAL_PPL = (15, 37)

# Status labels
STATUS_NORMAL = "NORMAL"
STATUS_NO_DECREASE = "NO_DECREASE"
STATUS_DIVERGING = "DIVERGING"
STATUS_PLATEAU = "PLATEAU"
STATUS_OVERFITTING = "OVERFITTING"
STATUS_UNSTABLE = "UNSTABLE"
STATUS_NAN_DETECTED = "NAN_DETECTED"
STATUS_LOSS_BOUNCE = "LOSS_BOUNCE"

# GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
# (param_count, recommended_lr, batch_tokens_str)
_GPT3_LR_REFERENCE = [
    (125e6, 6e-4, "0.5M"),
    (350e6, 3e-4, "0.5M"),
    (760e6, 2.5e-4, "0.5M"),
    (1.3e9, 2e-4, "1M"),
    (2.7e9, 1.6e-4, "1M"),
    (6.7e9, 1.2e-4, "2M"),
    (13e9, 1e-4, "2M"),
    (175e9, 6e-5, "3.2M"),
]

# Known LLM training references
_LLM_TRAINING_REFS = {
    "TinyLlama-1.1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
    "LLaMA-7B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
    "Pythia-1B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.01},
    "OLMo-1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1},
}

# Recommended Ξ²β‚‚ for LLM training
_RECOMMENDED_BETA2 = 0.95
_DEFAULT_PYTORCH_BETA2 = 0.999


def _header(title: str) -> str:
    return f"\n{'=' * 60}\n{title}\n{'=' * 60}"


def _check_result(name: str, passed: bool, detail: str = "") -> Dict[str, Any]:
    return {"name": name, "passed": passed, "detail": detail}


# ═══════════════════════════════════════════════════════════════════
# LossDebugger
# ═══════════════════════════════════════════════════════════════════


class LossDebugger:
    """5-level loss debugging framework for LLM training.

    Usage::

        from llm_lab.training.debugger import LossDebugger

        # Quick status check
        status = LossDebugger.diagnose_status(vocab_size=32000,
                                               metrics_history=trainer.metrics.history)

        # Full diagnostics
        report = LossDebugger.run_diagnostics(
            model=model, dataloader=train_dl, tokenizer=tok,
            train_config=train_cfg, metrics_history=trainer.metrics.history,
            device=device, dtype=torch.bfloat16,
        )
    """

    # ───────────────────────────────────────────────────────────────
    # Level 0: Status Diagnosis
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def diagnose_status(
        vocab_size: int,
        metrics_history: Dict[str, list],
    ) -> Dict[str, Any]:
        """Classify current training health from metrics history.

        Args:
            vocab_size: model vocabulary size (e.g. 32000)
            metrics_history: dict with keys 'train_loss', 'val_loss', etc.

        Returns:
            dict with 'status', 'severity', 'details', 'recommended_levels'
        """
        print(_header("Level 0: Training Status Diagnosis"))

        expected_initial = math.log(vocab_size)
        print(f"  Expected initial loss (random weights): ln({vocab_size}) = {expected_initial:.2f}")
        print(f"  Normal convergence range (1B, 10B tokens):")
        print(f"    Train Loss: {_EXPECTED_TRAIN_LOSS[0]} ~ {_EXPECTED_TRAIN_LOSS[1]}")
        print(f"    Val Loss:   {_EXPECTED_VAL_LOSS[0]} ~ {_EXPECTED_VAL_LOSS[1]}")
        print(f"    Val PPL:    {_EXPECTED_VAL_PPL[0]} ~ {_EXPECTED_VAL_PPL[1]}")

        raw_train_losses = metrics_history.get("train_loss", [])
        train_losses = [l for l in raw_train_losses if not math.isnan(l)]
        val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]

        if len(train_losses) < 2:
            print("\n  [!] Not enough training data to diagnose. Run more steps first.")
            return {
                "status": "INSUFFICIENT_DATA",
                "severity": "unknown",
                "details": "Need at least 2 logged train loss values.",
                "recommended_levels": [1],
            }

        # Detect NaN presence before filtering
        has_nan = len(train_losses) < len(raw_train_losses)
        if has_nan:
            nan_count = len(raw_train_losses) - len(train_losses)
            print(f"\n  ⚠ {nan_count} NaN values detected in train_loss β€” filtered for analysis")

        first_loss = train_losses[0]
        last_loss = train_losses[-1]
        loss_change = first_loss - last_loss

        # Split into halves for trend analysis
        mid = len(train_losses) // 2
        first_half_avg = sum(train_losses[:mid]) / mid
        second_half_avg = sum(train_losses[mid:]) / (len(train_losses) - mid)

        # Recent window for spike detection
        recent_n = min(50, len(train_losses))
        recent = train_losses[-recent_n:]
        recent_mean = sum(recent) / len(recent)
        recent_var = sum((x - recent_mean) ** 2 for x in recent) / len(recent)
        recent_std = recent_var ** 0.5

        # Val trend
        val_trend = "unknown"
        if len(val_losses) >= 2:
            val_mid = len(val_losses) // 2
            val_first_avg = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
            val_second_avg = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
            if val_second_avg < val_first_avg - 0.05:
                val_trend = "decreasing"
            elif val_second_avg > val_first_avg + 0.1:
                val_trend = "increasing"
            else:
                val_trend = "flat"

        # Pre-compute bounce detection using moving-average minimum
        # to avoid false positives from single noisy data points
        _ma_window = max(1, len(train_losses) // 20)  # 5% window
        _ma_losses = [
            sum(train_losses[max(0, i - _ma_window + 1):i + 1])
            / (i - max(0, i - _ma_window + 1) + 1)
            for i in range(len(train_losses))
        ]
        _min_ma_loss = min(_ma_losses)
        _min_ma_idx = _ma_losses.index(_min_ma_loss)
        _last_ma_loss = _ma_losses[-1]
        _bounce_amount = _last_ma_loss - _min_ma_loss
        _has_bounce = (
            loss_change > 0.1
            and _min_ma_idx < len(train_losses) * 0.85
            and _bounce_amount > _min_ma_loss * 0.05
        )
        # Downgrade bounce severity when val loss is still improving
        _val_improving = (
            val_trend == "decreasing"
            or (len(val_losses) >= 4
                and val_losses[-1] <= min(val_losses[:len(val_losses) // 2]))
        )

        # ── Classify ──
        status = STATUS_NORMAL
        severity = "green"
        details = ""
        recommended_levels: List[int] = []

        # Check 1: No decrease at all
        if loss_change < 0.1 and first_loss > expected_initial - 2.0:
            status = STATUS_NO_DECREASE
            severity = "red"
            details = (
                f"Loss barely changed: {first_loss:.4f} -> {last_loss:.4f} "
                f"(delta={loss_change:.4f}). Likely a data or implementation bug."
            )
            recommended_levels = [1, 2]

        # Check 2: Diverging
        elif last_loss > expected_initial + 1.0:
            status = STATUS_DIVERGING
            severity = "red"
            details = (
                f"Loss ({last_loss:.4f}) exceeds initial value ({expected_initial:.2f}). "
                f"Training is diverging β€” check LR, data, or numerical issues."
            )
            recommended_levels = [1, 2, 3]

        # Check 3: NaN detected in training loss
        elif has_nan:
            nan_count = len(raw_train_losses) - len(train_losses)
            nan_idx = next(i for i, l in enumerate(raw_train_losses) if math.isnan(l))
            status = STATUS_NAN_DETECTED
            severity = "red"
            details = (
                f"NaN detected in train_loss: {nan_count} NaN values "
                f"(first at step ~{nan_idx}). "
                f"Before NaN: {first_loss:.4f} -> {last_loss:.4f}. "
                f"Check gradient norms, LR schedule, and numerical precision."
            )
            recommended_levels = [2, 3]

        # Check 4: Unstable (large spikes)
        elif recent_std > 0.5 * recent_mean:
            status = STATUS_UNSTABLE
            severity = "yellow"
            details = (
                f"High loss variance: std={recent_std:.4f}, mean={recent_mean:.4f}. "
                f"Training is unstable β€” likely LR too high or batch too small."
            )
            recommended_levels = [3, 2]

        # Check 5: Loss bounce (decreased then increased again)
        elif _has_bounce:
            status = STATUS_LOSS_BOUNCE
            if _val_improving:
                severity = "green"
                details = (
                    f"Train loss bounced (moving-avg): "
                    f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} "
                    f"(bounce={_bounce_amount:.4f}), but val loss is still improving "
                    f"({val_losses[0]:.4f} -> {val_losses[-1]:.4f}). "
                    f"Likely data distribution variation, not a real issue."
                )
                recommended_levels = []
            else:
                severity = "yellow"
                details = (
                    f"Train loss bounced (moving-avg): "
                    f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} "
                    f"(bounce={_bounce_amount:.4f}). "
                    f"Possible LR too high, data issue, or overfitting."
                )
                recommended_levels = [3, 4]

        # Check 6: Overfitting
        elif val_trend == "increasing" and second_half_avg < first_half_avg:
            status = STATUS_OVERFITTING
            severity = "yellow"
            details = (
                f"Train loss decreasing but val loss increasing. "
                f"Train trend: {first_half_avg:.4f} -> {second_half_avg:.4f}, "
                f"Val trend: {val_trend}."
            )
            recommended_levels = [4]

        # Check 7: Plateau
        elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
            status = STATUS_PLATEAU
            severity = "yellow"
            details = (
                f"Loss has plateaued: first half avg={first_half_avg:.4f}, "
                f"second half avg={second_half_avg:.4f}. "
                f"Current loss ({last_loss:.4f}) is above expected range."
            )
            recommended_levels = [3, 4, 5]

        # Normal
        else:
            status = STATUS_NORMAL
            severity = "green"
            details = (
                f"Training looks healthy: {first_loss:.4f} -> {last_loss:.4f} "
                f"(delta={loss_change:.4f}). Val trend: {val_trend}."
            )
            recommended_levels = []

        # ── Print ──
        icons = {"red": "πŸ”΄", "yellow": "🟑", "green": "🟒"}
        icon = icons.get(severity, "βšͺ")
        print(f"\n  {icon} Status: {status}")
        print(f"  {details}")
        if recommended_levels:
            print(f"  Recommended: check Level(s) {recommended_levels}")

        return {
            "status": status,
            "severity": severity,
            "details": details,
            "recommended_levels": recommended_levels,
        }

    # ───────────────────────────────────────────────────────────────
    # Level 1: Data / Implementation Bug Checks
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def check_data_pipeline(
        model: nn.Module,
        dataloader: DataLoader,
        tokenizer: Any,
        vocab_size: int,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ) -> Dict[str, Any]:
        """Run 6 data/implementation checks (Level 1).

        This is the most important level β€” 70% of loss issues are data bugs.

        Checks:
          1. Shift relationship (targets[t] == input_ids[t+1])
          2. Token range (0 <= ids < vocab_size)
          3. Initial loss (β‰ˆ ln(vocab_size) for random weights)
          4. Single-batch overfit (loss β†’ ~0 in 200 steps)
          5. Tokenizer roundtrip (encode→decode preserves text)
          6. Data quality sampling (visual inspection)
        """
        print(_header("Level 1: Data / Implementation Bug Checks"))
        print("  (70% of loss issues come from data pipeline bugs)\n")

        results: List[Dict[str, Any]] = []
        batch = next(iter(dataloader))
        input_ids = batch["input_ids"]
        targets = batch["targets"]

        # ── Check 1: Shift relationship ──
        shift_match = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
        passed = shift_match > 0.99
        detail = f"Shift consistency: {shift_match * 100:.1f}% (should be ~100%)"
        results.append(_check_result("Shift relationship", passed, detail))
        icon = "βœ…" if passed else "❌"
        print(f"  {icon} Check 1: {detail}")

        # ── Check 2: Token range ──
        min_id = input_ids.min().item()
        max_id = input_ids.max().item()
        range_ok = min_id >= 0 and max_id < vocab_size
        detail = f"Token range: [{min_id}, {max_id}], vocab_size={vocab_size}"
        results.append(_check_result("Token range", range_ok, detail))
        icon = "βœ…" if range_ok else "❌"
        print(f"  {icon} Check 2: {detail}")

        # ── Check 3: Initial loss ──
        expected_loss = math.log(vocab_size)
        model_copy = copy.deepcopy(model)
        model_copy._init_weights()  # re-initialize to random
        model_copy.to(device)
        model_copy.eval()
        with torch.no_grad():
            with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
                _, initial_loss = model_copy(
                    input_ids.to(device),
                    targets.to(device),
                )
        initial_loss_val = initial_loss.item()
        loss_diff = abs(initial_loss_val - expected_loss)
        loss_ok = loss_diff < 1.0
        detail = (
            f"Initial loss: {initial_loss_val:.4f} vs expected {expected_loss:.2f} "
            f"(diff={loss_diff:.4f})"
        )
        results.append(_check_result("Initial loss", loss_ok, detail))
        icon = "βœ…" if loss_ok else "❌"
        print(f"  {icon} Check 3: {detail}")
        if initial_loss_val > expected_loss + 1.0:
            print(f"       Hint: loss >> ln(V) suggests label mismatch or loss function bug")
        elif initial_loss_val < expected_loss - 2.0:
            print(f"       Hint: loss << ln(V) suggests data leakage")
        del model_copy
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # ── Check 4: Single-batch overfit test ──
        # Scale LR and steps based on model size to avoid instability
        num_params = sum(p.numel() for p in model.parameters())
        if num_params > 500e6:
            overfit_lr, overfit_steps = 1e-4, 400
        elif num_params > 50e6:
            overfit_lr, overfit_steps = 3e-4, 300
        else:
            overfit_lr, overfit_steps = 1e-3, 200
        print(f"\n  ⏳ Check 4: Single-batch overfit test ({overfit_steps} steps, lr={overfit_lr:.0e})...")
        overfit_model = copy.deepcopy(model)
        overfit_model.to(device)
        overfit_model.train()
        overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=overfit_lr)
        single_input = input_ids[:1].to(device)  # single sample
        single_target = targets[:1].to(device)
        log_interval = max(overfit_steps // 4, 1)

        overfit_losses = []
        for step in range(overfit_steps):
            overfit_optimizer.zero_grad()
            with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
                _, loss = overfit_model(single_input, single_target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(overfit_model.parameters(), 1.0)
            overfit_optimizer.step()
            overfit_losses.append(loss.item())
            if (step + 1) % log_interval == 0:
                print(f"       Step {step + 1}: Loss = {loss.item():.4f}")

        final_overfit_loss = overfit_losses[-1]
        min_overfit_loss = min(overfit_losses)
        overfit_ok = min_overfit_loss < 0.5
        detail = (
            f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
            f"(min={min_overfit_loss:.4f}, target < 0.5)"
        )
        results.append(_check_result("Single-batch overfit", overfit_ok, detail))
        icon = "βœ…" if overfit_ok else "❌"
        print(f"  {icon} Check 4: {detail}")
        if not overfit_ok:
            print(f"       CRITICAL: Model cannot memorize a single batch!")
            print(f"       This means the model or loss function has a bug.")
        del overfit_model, overfit_optimizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # ── Check 5: Tokenizer roundtrip ──
        test_text = "The quick brown fox jumps over the lazy dog."
        encoded = tokenizer.encode(test_text)
        decoded = tokenizer.decode(encoded)
        roundtrip_ok = test_text.strip() in decoded.strip()
        detail = f"Roundtrip: '{test_text}' -> '{decoded.strip()}'"
        results.append(_check_result("Tokenizer roundtrip", roundtrip_ok, detail))
        icon = "βœ…" if roundtrip_ok else "❌"
        print(f"  {icon} Check 5: {detail}")

        # ── Check 6: Data quality sampling ──
        print(f"\n  πŸ“‹ Check 6: Data quality sampling (visual inspection)")
        for i in range(min(3, input_ids.shape[0])):
            sample_tokens = input_ids[i][:100].tolist()
            decoded_text = tokenizer.decode(sample_tokens)
            preview = decoded_text[:200].replace("\n", "\\n")
            print(f"     Sample {i}: {preview}...")

        passed_count = sum(1 for r in results if r["passed"])
        total_count = len(results)
        print(f"\n  Result: {passed_count}/{total_count} checks passed")

        return {
            "level": 1,
            "checks": results,
            "passed": [r for r in results if r["passed"]],
            "failed": [r for r in results if not r["passed"]],
        }

    # ───────────────────────────────────────────────────────────────
    # Level 2: Numerical Stability
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def check_numerical_stability(
        model: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ) -> Dict[str, Any]:
        """Check for NaN/Inf in gradients, activations, and logits (Level 2).

        Checks:
          - Mixed precision config (RMSNorm fp32 upcast, loss dtype)
          - NaN/Inf gradients β†’ softmax overflow, bad data
          - Inf gradients β†’ log(0) in loss, missing ignore_index
          - Large activations growing per layer β†’ initialization or norm bug
          - Logit scale β†’ should be < 1000
        """
        print(_header("Level 2: Numerical Stability Checks"))

        batch = next(iter(dataloader))
        input_ids = batch["input_ids"].to(device)
        targets = batch["targets"].to(device)

        results: List[Dict[str, Any]] = []
        activation_stats: List[Dict[str, Any]] = []

        # ── Mixed Precision Configuration Check ──
        print("\n  Mixed Precision Config:")
        print(f"    Training dtype: {dtype}")

        # Check RMSNorm fp32 upcast
        norm_fp32_ok = True
        checked_norm_classes: set = set()
        for name, module in model.named_modules():
            cls_name = module.__class__.__name__
            if "Norm" in cls_name and cls_name not in checked_norm_classes:
                checked_norm_classes.add(cls_name)
                import inspect
                try:
                    src = inspect.getsource(type(module).forward)
                    has_upcast = ".float()" in src or "float32" in src
                except (TypeError, OSError):
                    has_upcast = True  # assume ok if can't inspect
                if not has_upcast:
                    norm_fp32_ok = False
                    print(f"    πŸ”΄ {cls_name}: no fp32 upcast detected!")
        if norm_fp32_ok:
            print(f"    βœ… Norm layers use fp32 upcast (safe)")

        results.append(_check_result(
            "Norm fp32 upcast", norm_fp32_ok,
            "Norm computes in fp32" if norm_fp32_ok else "Norm may lose precision in half dtype",
        ))

        # Check loss computation dtype
        if dtype in (torch.bfloat16, torch.float16):
            print(f"    ℹ️  Best practice: compute loss in fp32 when using {dtype}")
            print(f"       logits_fp32 = logits.float()")
            print(f"       loss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))")

        # Common numerical issues reference
        print("\n  Common Numerical Issues Reference:")
        print("    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
        print("    β”‚ Symptom              β”‚ Likely Cause             β”‚ Solution                β”‚")
        print("    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
        print("    β”‚ Loss β†’ NaN           β”‚ Large logits β†’ softmax   β”‚ Check init, logit scale β”‚")
        print("    β”‚ Loss β†’ Inf           β”‚ log(0) in CE loss        β”‚ Add eps, ignore_index   β”‚")
        print("    β”‚ Loss oscillation     β”‚ fp16 gradient underflow  β”‚ Switch to bf16 / scaler β”‚")
        print("    β”‚ Late-training NaN    β”‚ Activation growth        β”‚ Check RMSNorm, wd       β”‚")
        print("    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")

        # ── Activation monitoring via hooks ──
        hooks = []

        def make_hook(name: str):
            def hook_fn(module, input, output):
                if isinstance(output, torch.Tensor):
                    out_f = output.float()
                    stats = {
                        "name": name,
                        "mean": out_f.mean().item(),
                        "std": out_f.std().item(),
                        "max": out_f.abs().max().item(),
                        "has_nan": bool(torch.isnan(output).any()),
                        "has_inf": bool(torch.isinf(output).any()),
                    }
                    activation_stats.append(stats)
            return hook_fn

        # Register hooks on transformer layers
        for i, layer in enumerate(model.layers):
            h = layer.register_forward_hook(make_hook(f"layer_{i}"))
            hooks.append(h)

        # ── Forward + Backward ──
        model.train()
        model.zero_grad(set_to_none=True)
        use_scaler = dtype == torch.float16 and torch.cuda.is_available()
        scaler = torch.amp.GradScaler() if use_scaler else None

        with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
            logits, loss = model(input_ids, targets)

        loss_val = loss.item()
        loss_ok = not (math.isnan(loss_val) or math.isinf(loss_val))
        results.append(_check_result(
            "Loss value",
            loss_ok,
            f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
        ))

        if scaler is not None:
            scaler.scale(loss).backward()
            _temp_opt = torch.optim.SGD(model.parameters(), lr=0)
            scaler.unscale_(_temp_opt)
        else:
            loss.backward()

        # Remove hooks
        for h in hooks:
            h.remove()

        # ── Gradient checks ──
        print("\n  Gradient Health:")
        grad_issues = []
        for name, param in model.named_parameters():
            if param.grad is None:
                continue
            grad = param.grad
            if torch.isnan(grad).any():
                grad_issues.append(f"πŸ”΄ NaN gradient: {name}")
            if torch.isinf(grad).any():
                grad_issues.append(f"πŸ”΄ Inf gradient: {name}")
            if grad.abs().max().item() > 100:
                grad_issues.append(
                    f"🟑 Large gradient: {name} max={grad.abs().max().item():.1f}"
                )

        grad_ok = len(grad_issues) == 0
        if grad_ok:
            print("    βœ… All gradients are healthy (no NaN/Inf/large values)")
        else:
            for issue in grad_issues[:10]:  # limit output
                print(f"    {issue}")
            if len(grad_issues) > 10:
                print(f"    ... and {len(grad_issues) - 10} more issues")

        results.append(_check_result(
            "Gradient health",
            grad_ok,
            f"{len(grad_issues)} issues found" if not grad_ok else "All healthy",
        ))

        # ── Activation checks ──
        print("\n  Activation Stats (per transformer layer):")
        act_nan_count = 0
        for stats in activation_stats:
            icon = "πŸ”΄" if stats["has_nan"] or stats["has_inf"] else "  "
            if stats["has_nan"] or stats["has_inf"]:
                act_nan_count += 1
            print(
                f"    {icon} {stats['name']}: "
                f"mean={stats['mean']:.4f}, "
                f"std={stats['std']:.4f}, "
                f"max={stats['max']:.4f}"
                + (" [NaN!]" if stats["has_nan"] else "")
                + (" [Inf!]" if stats["has_inf"] else "")
            )

        act_ok = act_nan_count == 0
        results.append(_check_result(
            "Activation health",
            act_ok,
            f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
        ))

        # ── Activation growth trend ──
        if len(activation_stats) >= 2:
            stds = [s["std"] for s in activation_stats]
            if stds[0] > 1e-8:
                growth_ratio = stds[-1] / stds[0]
                growth_ok = growth_ratio < 10
                detail = (
                    f"Activation std ratio (last/first): {growth_ratio:.1f}x "
                    f"(layer_0={stds[0]:.4f}, last={stds[-1]:.4f})"
                )
                results.append(_check_result("Activation growth", growth_ok, detail))
                icon = "βœ…" if growth_ok else "🟑"
                print(f"    {icon} {detail}")
                if not growth_ok:
                    print(f"       Possible initialization or normalization issue")

        # ── Logit scale check ──
        logit_max = logits.float().abs().max().item()
        logit_ok = logit_max < 1000
        detail = f"Logit max abs value: {logit_max:.1f} (should be < 1000)"
        results.append(_check_result("Logit scale", logit_ok, detail))
        icon = "βœ…" if logit_ok else "πŸ”΄"
        print(f"\n  {icon} Logit scale: {detail}")

        model.zero_grad(set_to_none=True)

        passed_count = sum(1 for r in results if r["passed"])
        print(f"\n  Result: {passed_count}/{len(results)} checks passed")

        return {
            "level": 2,
            "checks": results,
            "activation_stats": activation_stats,
            "grad_issues": grad_issues,
        }

    # ───────────────────────────────────────────────────────────────
    # Level 3: Hyperparameter Diagnosis
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def diagnose_hyperparameters(
        metrics_history: Dict[str, list],
        config: TrainConfig,
    ) -> Dict[str, Any]:
        """Analyze hyperparameter health from training metrics (Level 3).

        Checks:
          - LR: too high (grad_norm hitting clip limit) or too low (grad_norm tiny)
          - Batch size: loss variance indicates batch too small
          - Warmup: spikes in early steps indicate warmup too short
        """
        print(_header("Level 3: Hyperparameter Diagnosis"))

        findings: List[Dict[str, str]] = []
        grad_norms = metrics_history.get("grad_norm", [])
        train_losses = metrics_history.get("train_loss", [])

        # ── LR diagnosis ──
        print("\n  Learning Rate Analysis:")
        print(f"    Peak LR: {config.learning_rate:.2e}")
        print(f"    Min LR:  {config.min_learning_rate:.2e}")

        if grad_norms:
            avg_grad = sum(grad_norms) / len(grad_norms)
            # Ref: PyTorch clip_grad_norm_ clips when total_norm > max_norm
            clip_count = sum(1 for g in grad_norms if g >= config.grad_clip)
            clip_rate = clip_count / len(grad_norms)
            # Relative threshold: < 1% of clip limit (model-size independent)
            tiny_threshold = config.grad_clip * 0.01
            tiny_count = sum(1 for g in grad_norms if g < tiny_threshold)
            tiny_rate = tiny_count / len(grad_norms)

            print(f"    Avg grad norm:  {avg_grad:.4f}")
            print(f"    Clip rate:      {clip_rate * 100:.1f}% (hitting max_norm={config.grad_clip})")
            print(f"    Tiny grad rate: {tiny_rate * 100:.1f}% (< {tiny_threshold:.4f})")

            # Heuristic: >50% clipping means most steps are capped, so the
            # effective LR is lower than configured. Practitioners generally
            # treat this as a sign that peak LR is too high.
            if clip_rate > 0.5:
                findings.append({
                    "issue": "LR may be too high",
                    "evidence": f"Grad norm hits clip limit {clip_rate * 100:.0f}% of the time",
                    "action": f"Try LR = {config.learning_rate / 2:.2e} (Γ·2)",
                })
                print(f"    🟑 Grad clipping frequent ({clip_rate * 100:.0f}%) β†’ LR may be too high")
            elif tiny_rate > 0.5:
                findings.append({
                    "issue": "Possible vanishing gradients",
                    "evidence": f"Grad norm < {tiny_threshold:.4f} in {tiny_rate * 100:.0f}% of steps",
                    "action": "Check weight initialization, layer norms, and model depth",
                })
                print(f"    🟑 Grad norm too small ({tiny_rate * 100:.0f}% < {tiny_threshold:.4f}) β†’ possible vanishing gradients")
            else:
                print(f"    βœ… LR looks appropriate")

        # ── Batch size diagnosis ──
        print("\n  Batch Size Analysis:")
        print(f"    Effective batch: {config.effective_batch_size}")

        if len(train_losses) >= 50:
            recent_losses = train_losses[-50:]
            loss_mean = sum(recent_losses) / len(recent_losses)
            loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses)
            loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8)

            print(f"    Recent loss CV: {loss_cv:.4f} (coefficient of variation, last 50 steps)")

            if loss_cv > 0.1:
                findings.append({
                    "issue": "Training loss has high variance",
                    "evidence": f"Loss CV = {loss_cv:.4f} over last 50 steps",
                    "action": "Check: (1) LR may be too high, (2) increase gradient_accumulation_steps, (3) inspect data quality",
                })
                print(f"    🟑 High loss variance β†’ check LR, batch size, or data quality")
            else:
                print(f"    βœ… Loss variance is acceptable")

        # ── Ξ²β‚‚ diagnosis ──
        print("\n  Ξ²β‚‚ (Adam second momentum) Analysis:")
        print(f"    Current Ξ²β‚‚: {config.beta2}")
        if config.beta2 >= _DEFAULT_PYTORCH_BETA2:
            findings.append({
                "issue": "Ξ²β‚‚ may be too high for LLM training",
                "evidence": (
                    f"Ξ²β‚‚={config.beta2} (PyTorch default). "
                    f"LLM standard is {_RECOMMENDED_BETA2}"
                ),
                "action": f"Set beta2={_RECOMMENDED_BETA2} (used by LLaMA, TinyLlama, OLMo)",
            })
            print(f"    🟑 Ξ²β‚‚={config.beta2} is PyTorch default β†’ "
                  f"LLM training standard is {_RECOMMENDED_BETA2}")
            print(f"       Why: Ξ²β‚‚=0.999 averages ~1000 steps of gradient stats,")
            print(f"       Ξ²β‚‚=0.95 averages ~20 steps β†’ faster adaptation to changing data")
            print(f"       (Cattaneo & Shigida 2025, 'Tuning Adam(W)')")
        else:
            print(f"    βœ… Ξ²β‚‚={config.beta2} is within LLM standard range")

        # ── Weight Decay diagnosis ──
        print("\n  Weight Decay Analysis:")
        print(f"    Current weight_decay: {config.weight_decay}")
        if config.weight_decay == 0:
            findings.append({
                "issue": "Weight decay is disabled",
                "evidence": "weight_decay=0 increases overfitting risk",
                "action": "Set weight_decay=0.1 (standard for LLaMA, TinyLlama, GPT-3, OLMo)",
            })
            print(f"    🟑 weight_decay=0 β†’ overfitting risk. Standard is 0.1")
        elif config.weight_decay > 0.3:
            findings.append({
                "issue": "Weight decay may be too high",
                "evidence": f"weight_decay={config.weight_decay} (unusually high)",
                "action": "Try weight_decay=0.1 (standard value)",
            })
            print(f"    🟑 weight_decay={config.weight_decay} is unusually high (standard: 0.1)")
        else:
            print(f"    βœ… weight_decay={config.weight_decay} is within normal range")

        # ── Model-size LR reference ──
        print("\n  GPT-3 LR Reference (Brown et al. 2020):")
        print("    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
        print("    β”‚ Model    β”‚ Peak LR   β”‚ Batch Tokens β”‚")
        print("    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
        for params, lr, batch_tok in _GPT3_LR_REFERENCE:
            label = f"{params / 1e9:.1f}B" if params >= 1e9 else f"{params / 1e6:.0f}M"
            marker = " ←" if abs(params - 1.1e9) < 0.5e9 else ""
            print(f"    β”‚ {label:<8} β”‚ {lr:.1e}  β”‚ {batch_tok:<12} β”‚{marker}")
        print("    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
        print("    β†’ Larger models need lower LR and larger batch")

        # ── Batch-LR scaling guidance ──
        print("\n  Batch-LR Scaling Rules:")
        print("    β€’ Batch Γ—2 β†’ LR Γ—βˆš2 (square root scaling, recommended for Adam)")
        print("      (Malladi et al. NeurIPS 2022, 'On the SDEs and Scaling Rules for Adaptive Gradient Algorithms')")
        print("    β€’ Batch Γ—2 β†’ LR Γ—2   (linear scaling, Goyal et al. 2017, mainly SGD)")
        print("    β€’ 1B model: ~1K-2K sequences (~2-4M tokens) is typical")
        print("      (Pythia-1B: ~2M tokens, TinyLlama: ~2M, OLMo-1B: ~4M)")

        # ── Warmup diagnosis ──
        print("\n  Warmup Analysis:")
        print(f"    Warmup steps: {config.warmup_steps} "
              f"({config.warmup_steps / config.total_steps * 100:.1f}% of total)")

        if len(train_losses) >= 10:
            early_losses = train_losses[:min(50, len(train_losses))]
            # Detect spikes in early training
            spike_count = 0
            for i in range(1, len(early_losses)):
                if early_losses[i] > early_losses[i - 1] * 1.5:
                    spike_count += 1

            if spike_count > 3:
                findings.append({
                    "issue": "Warmup may be too short",
                    "evidence": f"{spike_count} loss spikes in first {len(early_losses)} steps",
                    "action": f"Try warmup_steps = {config.warmup_steps * 2}",
                })
                print(f"    🟑 {spike_count} spikes in early training β†’ warmup may be too short")
            else:
                print(f"    βœ… Early training is stable")

        # ── Summary ──
        if not findings:
            print("\n  βœ… No hyperparameter issues detected")
        else:
            print(f"\n  Found {len(findings)} potential issue(s):")
            for f in findings:
                print(f"    β€’ {f['issue']}: {f['action']}")

        # ── Warmup reference from real projects ──
        print("\n  Warmup Reference (real projects):")
        print("    β€’ TinyLlama 1.1B (3T tokens): 2,000 steps β‰ˆ 0.1% of total")
        print("    β€’ GPT-3 175B:                  375M warmup tokens β‰ˆ 117 steps")
        print("    β€’ General range: 0.1% ~ 5% of total steps")
        print("    β€’ Smaller experiments: 5~10% is also reasonable")

        print("\n  Tuning priority (high β†’ low):")
        print("    1. Learning Rate  ← tune first (10x impact)")
        print("    2. Batch Size     ← adjust with LR")
        print("    3. Warmup Steps   ← early stability")
        print("    4. Weight Decay   ← if overfitting (typically 0.1)")
        print("    5. β₁, Ξ²β‚‚ (Adam) ← see Ξ²β‚‚ analysis above")
        print("    6. Gradient Clip  ← usually keep at 1.0")

        return {
            "level": 3,
            "findings": findings,
            "config_summary": {
                "learning_rate": config.learning_rate,
                "effective_batch": config.effective_batch_size,
                "warmup_steps": config.warmup_steps,
                "total_steps": config.total_steps,
                "grad_clip": config.grad_clip,
            },
        }

    @staticmethod
    def lr_range_test(
        model: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        lr_start: float = 1e-7,
        lr_end: float = 1e-1,
        steps: int = 300,
    ) -> Dict[str, Any]:
        """Run an LR range test to find the optimal learning rate (Level 3 bonus).

        Sweeps LR from lr_start to lr_end exponentially, recording loss.
        The optimal LR is where loss decreases fastest (steepest slope),
        divided by 3~10 for stability.

        WARNING: This modifies a copy of the model. The original is untouched.
        """
        print(_header("Level 3 Bonus: LR Range Test"))
        print(f"  Sweeping LR from {lr_start:.1e} to {lr_end:.1e} over {steps} steps...\n")

        test_model = copy.deepcopy(model)
        test_model.to(device)
        test_model.train()
        optimizer = torch.optim.AdamW(test_model.parameters(), lr=lr_start)

        lr_mult = (lr_end / lr_start) ** (1 / steps)
        lr = lr_start

        lrs: List[float] = []
        losses: List[float] = []
        data_iter = iter(dataloader)

        for step in range(steps):
            for pg in optimizer.param_groups:
                pg["lr"] = lr

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(dataloader)
                batch = next(data_iter)

            input_ids = batch["input_ids"].to(device)
            targets_t = batch["targets"].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
                _, loss = test_model(input_ids, targets_t)
            loss.backward()
            optimizer.step()

            loss_val = loss.item()
            lrs.append(lr)
            losses.append(loss_val)

            if (step + 1) % 50 == 0:
                print(f"    Step {step + 1}: LR = {lr:.2e}, Loss = {loss_val:.4f}")

            # Stop if loss explodes
            if len(losses) > 1 and loss_val > losses[0] * 4:
                print(f"    Loss exploded at LR = {lr:.2e}, stopping.")
                break

            lr *= lr_mult

        del test_model, optimizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Find steepest descent
        best_lr = lr_start
        if len(losses) > 10:
            # Smooth losses and find steepest negative slope
            window = 5
            smoothed = []
            for i in range(len(losses) - window):
                smoothed.append(sum(losses[i:i + window]) / window)

            min_slope = 0
            min_idx = 0
            for i in range(1, len(smoothed)):
                slope = smoothed[i] - smoothed[i - 1]
                if slope < min_slope:
                    min_slope = slope
                    min_idx = i

            best_lr = lrs[min_idx]
            suggested_lr = best_lr / 3  # conservative choice

            print(f"\n  Steepest descent at LR = {best_lr:.2e}")
            print(f"  Suggested peak LR:  {suggested_lr:.2e} (Γ·3 for stability)")
            print(f"  Conservative range: [{best_lr / 10:.2e}, {best_lr / 3:.2e}]")
        else:
            suggested_lr = 3e-4
            print(f"\n  Not enough data points. Using default LR = {suggested_lr:.2e}")

        return {
            "lrs": lrs,
            "losses": losses,
            "best_lr": best_lr,
            "suggested_lr": suggested_lr,
        }

    # ───────────────────────────────────────────────────────────────
    # Level 4: Overfitting vs Underfitting Diagnosis
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def diagnose_fitting(
        metrics_history: Dict[str, list],
        model_params: Optional[int] = None,
        total_tokens: Optional[int] = None,
    ) -> Dict[str, Any]:
        """Diagnose overfitting vs underfitting from metrics (Level 4).

        Cases:
          1. Both high, decreasing β†’ Normal (still training)
          2. Both high, plateau    β†’ Underfitting
          3. Train↓ Valβ†’ or Val↑   β†’ Overfitting
          4. Both low, plateau     β†’ Converged (or at limit)
        """
        print(_header("Level 4: Overfitting vs Underfitting Diagnosis"))

        train_losses = metrics_history.get("train_loss", [])
        val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]

        if len(train_losses) < 10 or len(val_losses) < 2:
            print("  [!] Not enough data. Need more training steps with eval.")
            return {"level": 4, "case": "insufficient_data", "recommendations": []}

        # Recent train trend
        recent_n = min(50, len(train_losses))
        train_recent = train_losses[-recent_n:]
        train_mid = len(train_recent) // 2
        train_first = sum(train_recent[:train_mid]) / max(train_mid, 1)
        train_second = sum(train_recent[train_mid:]) / max(len(train_recent) - train_mid, 1)
        train_decreasing = train_second < train_first - 0.02

        # Val trend
        val_mid = len(val_losses) // 2
        val_first = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
        val_second = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
        val_decreasing = val_second < val_first - 0.02
        val_increasing = val_second > val_first + 0.05

        # Train-Val gap
        last_train = train_losses[-1]
        last_val = val_losses[-1]
        gap = last_train - last_val  # negative means val > train (typical)

        print(f"  Train loss (recent): {train_first:.4f} β†’ {train_second:.4f} "
              f"({'↓' if train_decreasing else 'β†’'})")
        print(f"  Val loss:            {val_first:.4f} β†’ {val_second:.4f} "
              f"({'↓' if val_decreasing else '↑' if val_increasing else 'β†’'})")
        print(f"  Train-Val gap:       {abs(gap):.4f}")

        # ── Classify ──
        case = ""
        recommendations: List[str] = []

        if train_decreasing and val_decreasing:
            case = "Case 1: Normal β€” both decreasing"
            recommendations.append("Training is progressing normally. Continue.")
            if model_params and total_tokens:
                ratio = total_tokens / model_params
                chinchilla = 20  # Chinchilla optimal: 20 tokens per param
                if ratio < chinchilla:
                    recommendations.append(
                        f"Token/param ratio = {ratio:.1f}x "
                        f"(Chinchilla optimal β‰ˆ {chinchilla}x). "
                        f"Model may benefit from more data."
                    )
            print(f"\n  🟒 {case}")

        elif not train_decreasing and not val_decreasing and last_train > _EXPECTED_TRAIN_LOSS[1]:
            case = "Case 2: Underfitting β€” both plateaued at high loss"
            recommendations = [
                "Diagnosis priority (check in order):",
                "1) Training insufficient? β†’ check if loss curve still has downward slope",
                "   - Chinchilla: 1B model needs ~20B tokens minimum",
                "   - TinyLlama trains 1.1B on 3T tokens (inference-optimal)",
                "2) LR too low? β†’ try LR Γ—2, see if loss drops faster",
                "3) Model capacity too small? β†’ train 2x larger model on same data",
                "   - If larger model gets lower loss β†’ capacity was the limit",
                "4) Data quality? β†’ sample and read training data manually",
                "   - Noisy/low-quality data raises the achievable loss floor",
            ]
            if model_params and total_tokens:
                ratio = total_tokens / model_params
                if ratio < 10:
                    recommendations.insert(0,
                        f"⚠ Token/param ratio = {ratio:.1f}x β€” "
                        f"very likely undertrained. Chinchilla recommends β‰₯20x."
                    )
                elif ratio < 20:
                    recommendations.insert(0,
                        f"β„Ή Token/param ratio = {ratio:.1f}x β€” "
                        f"below Chinchilla optimal (20x). More tokens may help."
                    )
            print(f"\n  🟑 {case}")

        elif train_decreasing and (val_increasing or not val_decreasing):
            case = "Case 3: Overfitting β€” train↓ but valβ†’/↑"
            recommendations = [
                "Diagnosis priority (check in order):",
                "1) Data repetition? (most common cause in pretraining)",
                "   - Check: total tokens vs unique tokens",
                "   - Epoch > 1 dramatically increases overfitting risk",
                "   - Solution: add more data, stay within 1 epoch",
                "2) Weight decay too low?",
                "   - Check: weight_decay value (standard: 0.1)",
                "   - LLaMA, TinyLlama, OLMo, GPT-3 all use 0.1",
                "   - Experiment: 0.01 / 0.05 / 0.1 / 0.3",
                "3) Data diversity?",
                "   - Single-domain data overfits faster",
                "   - Mix: web, books, code, wiki, etc.",
                "",
                "Note on Dropout in LLM pretraining:",
                "  - Modern LLMs do NOT use dropout in pretraining",
                "    (Pythia, TinyLlama, OLMo, LLaMA all use dropout=0)",
                "  - Sufficient data is the best regularization",
                "  - Dropout is useful for fine-tuning on small datasets",
            ]
            print(f"\n  🟑 {case}")

        else:
            case = "Case 4: Converged β€” loss is low and stable"
            recommendations = [
                "Training has converged (or reached the data/model limit).",
                "To push further: add more data or increase model size.",
            ]
            print(f"\n  🟒 {case}")

        for rec in recommendations:
            print(f"    {rec}")

        return {
            "level": 4,
            "case": case,
            "train_trend": "decreasing" if train_decreasing else "flat",
            "val_trend": "decreasing" if val_decreasing else ("increasing" if val_increasing else "flat"),
            "gap": abs(gap),
            "recommendations": recommendations,
        }

    # ───────────────────────────────────────────────────────────────
    # Level 5: Architecture Checks
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def check_architecture(
        model: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
    ) -> Dict[str, Any]:
        """Check weight initialization and per-layer activation health (Level 5).

        Healthy initialization:
          - All layers: std β‰ˆ 1.0, mean β‰ˆ 0.0
        Problems:
          - std increasing per layer β†’ activation explosion (init scale too large)
          - std decreasing per layer β†’ activation vanishing (init scale too small)
          - Sudden change at specific layer β†’ implementation bug in that layer
        """
        print(_header("Level 5: Architecture / Initialization Check"))

        batch = next(iter(dataloader))
        sample_input = batch["input_ids"][:1].to(device)

        model.eval()
        layer_stats: List[Dict[str, Any]] = []

        with torch.no_grad():
            h = model.token_embedding(sample_input)
            emb_std = h.float().std().item()
            print(f"\n  Embedding: std={emb_std:.4f}")

            for i, layer in enumerate(model.layers):
                h = layer(h, mask=None, position_offset=0)
                h_f = h.float()
                stats = {
                    "layer": i,
                    "mean": h_f.mean().item(),
                    "std": h_f.std().item(),
                    "max": h_f.abs().max().item(),
                }
                layer_stats.append(stats)

        # Print stats
        print(f"\n  Layer-by-layer activation statistics:")
        print(f"  {'Layer':<8} {'Mean':>10} {'Std':>10} {'Max':>10}")
        print(f"  {'-' * 38}")
        for s in layer_stats:
            print(f"  {s['layer']:<8} {s['mean']:>10.4f} {s['std']:>10.4f} {s['max']:>10.4f}")

        # ── Weight initialization distribution check ──
        print(f"\n  Weight Initialization Distribution:")
        print(f"  {'Parameter':<40} {'Mean':>10} {'Std':>10} {'Shape'}")
        print(f"  {'-' * 75}")
        weight_issues = []
        for name, param in model.named_parameters():
            if param.ndim < 2:
                continue  # skip biases, norm weights
            p_f = param.float()
            p_mean = p_f.mean().item()
            p_std = p_f.std().item()
            # Expected: std β‰ˆ 0.02 for most layers, smaller for residual projections
            shape_str = str(list(param.shape))
            is_residual = "o_proj" in name or "down_proj" in name
            expected_std = 0.02  # GPT-2 style
            if p_std > expected_std * 5:
                weight_issues.append(f"Large std: {name} (std={p_std:.4f})")
                print(f"  🟑 {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
            elif p_std < expected_std * 0.1:
                weight_issues.append(f"Tiny std: {name} (std={p_std:.6f})")
                print(f"  🟑 {name:<38} {p_mean:>10.4f} {p_std:>10.6f} {shape_str}")
            else:
                print(f"     {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")

        if weight_issues:
            print(f"\n  ⚠ {len(weight_issues)} weight distribution issue(s) found")
            for issue in weight_issues[:5]:
                print(f"    β€’ {issue}")
        else:
            print(f"\n  βœ… All weight distributions look normal (std β‰ˆ 0.02)")

        print(f"\n  Expected init pattern:")
        print(f"    β€’ General Linear: N(0, 0.02)")
        print(f"    β€’ Residual proj (o_proj, down_proj): N(0, 0.02/√(2Γ—layers))")
        print(f"    β€’ Embedding: N(0, 0.02)")

        # ── Ablation study guidance ──
        print(f"\n  Component Ablation Reference:")
        print("    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
        print("    β”‚ Experiment           β”‚ Expected Outcome                   β”‚")
        print("    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
        print("    β”‚ RMSNorm β†’ LayerNorm  β”‚ Minimal loss diff β†’ OK            β”‚")
        print("    β”‚ RoPE β†’ Absolute PE   β”‚ Similar on short seq (<512)       β”‚")
        print("    β”‚ SwiGLU β†’ ReLU FFN    β”‚ Loss +0.05~0.15 β†’ SwiGLU working β”‚")
        print("    β”‚ GQA β†’ MHA            β”‚ Same loss, less memory β†’ OK       β”‚")
        print("    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
        print("    If any replacement shows unexpected results, check that component.")

        # Analyze trends
        stds = [s["std"] for s in layer_stats]
        diagnosis = "healthy"
        detail = ""

        if len(stds) >= 3:
            # Check for monotonic increase/decrease
            first_third = sum(stds[:len(stds) // 3]) / (len(stds) // 3)
            last_third = sum(stds[-(len(stds) // 3):]) / (len(stds) // 3)
            ratio = last_third / max(first_third, 1e-8)

            if ratio > 5:
                diagnosis = "exploding"
                detail = (
                    f"Activation std grows {ratio:.1f}x from early to late layers. "
                    f"Init scale may be too large."
                )
            elif ratio < 0.2:
                diagnosis = "vanishing"
                detail = (
                    f"Activation std shrinks to {ratio:.1f}x from early to late layers. "
                    f"Init scale may be too small."
                )
            else:
                detail = f"Std ratio (last/first third) = {ratio:.2f} β€” within normal range."

            # Check for sudden jumps
            for i in range(1, len(stds)):
                jump = stds[i] / max(stds[i - 1], 1e-8)
                if jump > 10 or jump < 0.1:
                    diagnosis = "anomaly"
                    detail = (
                        f"Sudden activation change at layer {i}: "
                        f"std {stds[i - 1]:.4f} β†’ {stds[i]:.4f}. "
                        f"Possible implementation bug in that layer."
                    )
                    break

        icon = {"healthy": "βœ…", "exploding": "πŸ”΄", "vanishing": "🟑", "anomaly": "πŸ”΄"}
        print(f"\n  {icon.get(diagnosis, 'βšͺ')} Diagnosis: {diagnosis}")
        print(f"  {detail}")

        return {
            "level": 5,
            "diagnosis": diagnosis,
            "detail": detail,
            "layer_stats": layer_stats,
            "weight_issues": weight_issues,
        }

    # ───────────────────────────────────────────────────────────────
    # Main Entry Point
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def run_diagnostics(
        model: nn.Module,
        dataloader: DataLoader,
        tokenizer: Any,
        train_config: TrainConfig,
        metrics_history: Dict[str, list],
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        vocab_size: int = 32000,
        levels: Optional[List[int]] = None,
    ) -> Dict[str, Any]:
        """Run the full 5-level debugging framework.

        Args:
            model: the LLM model
            dataloader: training dataloader
            tokenizer: tokenizer with encode/decode methods
            train_config: TrainConfig instance
            metrics_history: dict from MetricsTracker.history
            device: torch device
            dtype: mixed precision dtype
            vocab_size: model vocabulary size
            levels: which levels to run (default: all [0,1,2,3,4,5])

        Returns:
            Full diagnostic report dict.
        """
        if levels is None:
            levels = [0, 1, 2, 3, 4, 5]

        print("\n" + "═" * 60)
        print("  LLM Loss Debugging Framework")
        print("  Levels to run: " + ", ".join(str(l) for l in levels))
        print("═" * 60)

        report: Dict[str, Any] = {}

        if 0 in levels:
            report["level_0"] = LossDebugger.diagnose_status(vocab_size, metrics_history)
            # If status is normal and only level 0 was explicitly requested, skip rest
            if (
                report["level_0"]["status"] == STATUS_NORMAL
                and levels == [0]
            ):
                print("\n  Training is healthy β€” no further debugging needed.")
                return report

        if 1 in levels:
            report["level_1"] = LossDebugger.check_data_pipeline(
                model, dataloader, tokenizer, vocab_size, device, dtype,
            )

        if 2 in levels:
            report["level_2"] = LossDebugger.check_numerical_stability(
                model, dataloader, device, dtype,
            )

        if 3 in levels:
            report["level_3"] = LossDebugger.diagnose_hyperparameters(
                metrics_history, train_config,
            )

        if 4 in levels:
            model_params = sum(p.numel() for p in model.parameters())
            total_tokens = len(metrics_history.get("train_loss", [])) * train_config.tokens_per_step
            report["level_4"] = LossDebugger.diagnose_fitting(
                metrics_history, model_params, total_tokens,
            )

        if 5 in levels:
            report["level_5"] = LossDebugger.check_architecture(
                model, dataloader, device,
            )

        # Final summary
        print("\n" + "═" * 60)
        print("  Diagnostics Complete")
        print("═" * 60)

        return report

    # ───────────────────────────────────────────────────────────────
    # Study Roadmap
    # ───────────────────────────────────────────────────────────────

    @staticmethod
    def print_study_roadmap() -> None:
        """Print the recommended study roadmap for LLM training optimization."""
        print(_header("Study Roadmap β€” LLM Training Optimization"))

        print("""
  ⭐⭐⭐ Top Priority: Optimization Fundamentals
  ─────────────────────────────────────────────
  1. SGD β†’ Momentum β†’ Adam β†’ AdamW progression
     - Why Adam > SGD? Why decouple weight decay in AdamW?
     - β₁, Ξ²β‚‚ intuition (1st / 2nd momentum)
     - Ref: Loshchilov & Hutter 2019 (AdamW)
     - Ref: Karpathy "A Recipe for Training Neural Networks"

  2. Loss Landscape
     - Why large LR diverges, small LR stalls
     - Batch size effect on landscape exploration
     - Ref: Li et al. 2018 "Visualizing the Loss Landscape"
     - Ref: McCandlish et al. 2018 "Large-Batch Training"

  3. Chinchilla Scaling Law
     - Loss = f(N, D) relationship
     - Compute-optimal model size vs data allocation
     - Ref: Hoffmann et al. 2022 (original)
     - Ref: Kaplan et al. 2020 (predecessor)
     - Ref: Besiroglu et al. 2024 (replication/verification)

  ⭐⭐ Important: Training Stability
  ──────────────────────────────────
  4. Gradient Flow: vanishing/exploding, residual as gradient highway
  5. Weight Init: Xavier / Kaiming / GPT-2 style
  6. Normalization: BatchNorm β†’ LayerNorm β†’ RMSNorm
  7. Weight Decay: L2 vs decoupled, why exclude embed/norm

  ⭐ Advanced: Optimization Techniques
  ─────────────────────────────────────
  8. LR Schedules: cosine vs linear vs step, warmup/cooldown
  9. Gradient Accumulation & Large Batch Training
  10. ΞΌP (Maximal Update Parameterization): transfer HP across scales

  Recommended Experiments (in order):
  ───────────────────────────────────
  1. Single-batch overfit         (30 min)  β†’ basic sanity
  2. LR Range Test                (1 hour)  β†’ optimal LR range
  3. 10M model quick train        (2-3 hrs) β†’ pipeline validation
  4. Ablation (remove components) (1 day)   β†’ component contribution
  5. 100M model + 5B tokens       (1-2 days)β†’ mid-scale dynamics
  6. 1B model full training       (2-3 days)β†’ scaling law verification
  7. LR / batch size comparison   (1 day)   β†’ HP sensitivity

  Key References:
  ───────────────
  ⭐⭐⭐ Karpathy "Recipe for Training NNs"    β€” debugging mindset
  ⭐⭐⭐ Hoffmann et al. 2022 (Chinchilla)      β€” scaling law
  ⭐⭐  Touvron et al. 2023 (LLaMA)            β€” 1B+ training details
  ⭐⭐  Biderman et al. 2023 (Pythia)           β€” open training logs
  ⭐⭐  Zhang et al. 2024 (TinyLlama)           β€” 1.1B on 3T tokens
  ⭐⭐  Groeneveld et al. 2024 (OLMo)           β€” fully open LLM
  ⭐⭐  Li et al. 2018 (Loss Landscape)         β€” loss terrain intuition
  ⭐⭐  Loshchilov & Hutter 2019 (AdamW)        β€” optimizer basics
  ⭐   Yang et al. 2022 (ΞΌP)                    β€” HP transfer
  ⭐   McCandlish et al. 2018 (Batch size)      β€” critical batch size
""")