File size: 70,055 Bytes
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbcdc9d
 
 
b553ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbcdc9d
b553ba1
cbcdc9d
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
 
 
 
 
 
cbcdc9d
 
 
 
 
 
 
b553ba1
 
 
 
 
 
 
cbcdc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
cbcdc9d
b553ba1
 
 
 
cbcdc9d
b553ba1
 
cbcdc9d
b553ba1
 
 
 
 
 
cbcdc9d
b553ba1
cbcdc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbcdc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
35acee3
b553ba1
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
 
35acee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
 
 
cbcdc9d
 
b553ba1
 
 
 
 
cbcdc9d
b553ba1
cbcdc9d
 
 
 
 
 
 
b553ba1
 
 
 
cbcdc9d
 
b553ba1
 
 
 
 
 
cbcdc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
b553ba1
 
 
cbcdc9d
 
 
b553ba1
cbcdc9d
 
 
 
 
 
b553ba1
 
 
 
cbcdc9d
b553ba1
 
 
 
cbcdc9d
 
b553ba1
 
 
 
 
 
 
 
 
 
 
 
cbcdc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35acee3
 
 
cbcdc9d
 
 
 
b553ba1
cbcdc9d
35acee3
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
# =============================================================================
# COMPRESSION NAVIGATOR  ·  extended + annotated edition
# =============================================================================
# An LLM is a lossy codec for text. Training compresses a corpus into weights;
# a forward pass decompresses a continuation. These five tools let you watch
# that decompression happen and poke at where facts physically live.
#
# The five tabs are not toys invented here - each one is a real mechanistic-
# interpretability technique you'll find in papers:
#
#   1. Decompress      = LOGIT LENS            (nostalgebraist, 2020)
#   2. Triangulate     = EMBEDDING NEIGHBOURS  (the geometry of the vocab)
#   3. Re-route        = ACTIVATION STEERING   (ActAdd / repr. engineering)
#   4. Diff            = CROSS-MODEL ALIGNMENT  (compare checkpoints by depth)
#   5. Causal trace    = ACTIVATION PATCHING    (ROME, Meng et al., 2022)
#
# WHY THE GLASS-BOX MODELS MATTER
# -------------------------------
# On a real model (gpt2) you never know the ground truth, so you can't tell
# whether a tool is *correct* or just producing plausible-looking output.
# This file ships two models whose internals you fully specify, so you can
# check each tool against a known answer:
#
#   "handmade"  - facts stored as a LOOKUP TABLE keyed on the prompt string.
#                 The computation happens in a side channel (string match),
#                 NOT in the residual stream. Lesson: such a model is almost
#                 invisible to residual-stream interpretability. Logit lens
#                 sees a sudden jump with no build-up; causal tracing finds
#                 nothing, because corrupting activations doesn't touch the
#                 string match. This is a real and underappreciated *limit*
#                 of these methods.
#
#   "glassbox"  - facts stored the way real transformers store them: as
#                 key->value writes into the RESIDUAL STREAM (Geva et al.'s
#                 "MLPs are key-value memories", which is exactly what ROME
#                 edits). Because the fact flows through activations, ALL five
#                 tools light up correctly - and you can verify they report
#                 the layer you actually put the fact in. This is a unit-test
#                 harness for interpretability code.
#
# Run order suggestion:  glassbox  ->  handmade  ->  gpt2
#   glassbox shows what "correct" looks like; handmade shows a failure mode;
#   gpt2 shows the fuzzy, distributed real thing.
# =============================================================================

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32
MODELS = {}                 # name -> (model, tokenizer) cache
STATE = {"name": None}      # currently loaded model name


# =============================================================================
# A tiny shared tokenizer for both glass-box models.
# Case is CANONICALISED to lowercase everywhere (this fixes a real bug in the
# original: "Paris" from a pinned fact and "paris" from the Markov table became
# two different vocab entries, so the boosted token and the *tracked* token
# silently diverged - every neighbour read cos=0.000 and every tracked prob 0).
# =============================================================================
class FakeBatchEncoding(dict):
    def to(self, device):            # let callers do tok(...).to(DEVICE) safely
        return self


class SimpleTok:
    """Whitespace tokenizer over a fixed vocab. Not 'fast' (no offset map)."""
    is_fast = False

    def __init__(self, stoi, itos):
        self.stoi, self.itos = stoi, itos
        self.eos_token_id = stoi["."]      # period doubles as end-of-sequence

    def _ids(self, text):
        words = text.lower().replace(".", " .").split()
        return [self.stoi.get(w, self.stoi["<s>"]) for w in words]

    def __call__(self, text, return_tensors=None, return_offsets_mapping=False):
        ids = self._ids(text) or [self.stoi["<s>"]]
        return FakeBatchEncoding(
            input_ids=torch.tensor([ids]),
            attention_mask=torch.ones(1, len(ids), dtype=torch.long),
        )

    def encode(self, text, add_special_tokens=False):
        return self._ids(text)

    def decode(self, ids, skip_special_tokens=False):
        out = []
        for i in ids:
            w = self.itos.get(int(i), "?")
            if skip_special_tokens and w in ("<pad>", "<s>"):
                continue
            out.append(w)
        return " ".join(out)


class _Out:
    """Mimics a HF CausalLMOutput: .logits and (optional) .hidden_states."""
    def __init__(self, logits, hidden_states):
        self.logits = logits
        self.hidden_states = hidden_states


def _greedy_generate(model, input_ids, max_new_tokens=20, pad_token_id=None, **_):
    """Minimal greedy decode so the steering tab works on the toy models too
    (the originals had no .generate, so that tab crashed on 'handmade')."""
    ids = input_ids
    for _ in range(int(max_new_tokens)):
        nxt = model(input_ids=ids).logits[0, -1].argmax().view(1, 1)
        ids = torch.cat([ids, nxt], dim=1)
        if pad_token_id is not None and int(nxt.item()) == int(pad_token_id):
            break
    return ids


# =============================================================================
# MODEL 1 - "handmade": facts as a LOOKUP TABLE (the side-channel glass box)
# -----------------------------------------------------------------------------
# Embeddings are the identity matrix (each token is its own one-hot). The two
# "layers" don't read the residual stream in a meaningful linear way:
#   - MemoryBlock matches the *decoded prompt string* and boosts the answer.
#   - MarkovBlock adds a hand-built bigram transition for the last token.
# Because MemoryBlock keys on the prompt TEXT, not on activations, this is a
# deliberate demonstration of a model that residual-stream interpretability
# cannot see. Use it as the "what failure looks like" control.
# =============================================================================
PINNED = {                              # answers are lowercase now (bug fix)
    "the capital of france is": " paris",
    "the eiffel tower is in":   " paris",
    "two plus two equals":      " four",
}
MARKOV = {
    "<s>":    {"the": 3, "i": 2, "a": 1},
    "the":    {"city": 2, "tower": 2, "answer": 1},
    "i":      {"think": 2, "am": 1},
    "a":      {"model": 2, "city": 1},
    "city":   {"of": 3, "is": 1},
    "of":     {"light": 2, "paris": 1},
    "tower":  {"is": 3},
    "is":     {"in": 2, "a": 1},
    "in":     {"paris": 2, "france": 1},
    "model":  {"is": 2},
    "think":  {"the": 2},
    "paris":  {".": 1},
    "france": {".": 1},
    "light":  {".": 1},
    "four":   {".": 1},
}


def _build_handmade_vocab():
    toks, seen = ["<pad>", "<s>", "."], {"<pad>", "<s>", "."}
    def add(w):
        if w not in seen:
            toks.append(w); seen.add(w)
    for v in PINNED.values():
        add(v.strip())
    for w, nxts in MARKOV.items():
        add(w)
        for x in nxts:
            add(x)
    for k in PINNED:
        for w in k.split():
            add(w)
    return toks


HM_VOCAB = _build_handmade_vocab()
HM_STOI = {w: i for i, w in enumerate(HM_VOCAB)}
HM_ITOS = {i: w for w, i in HM_STOI.items()}
HM_V = len(HM_VOCAB)


class _MemoryBlock(nn.Module):
    """If the decoded prompt ends with a pinned key, slam the answer logit.
    NOTE: this reads prompt_ids (the string), not x - that's the whole point."""
    def forward(self, x, prompt_ids=None):
        out = x.clone()
        if prompt_ids is not None:
            text = " ".join(HM_ITOS.get(int(i), "") for i in prompt_ids).strip()
            for key, ans in PINNED.items():
                if text.endswith(key):
                    out[0, -1, HM_STOI[ans.strip()]] += 12.0
        return (out,)


class _MarkovBlock(nn.Module):
    """Add a hand-built bigram transition row for the last token."""
    def __init__(self):
        super().__init__()
        T = torch.zeros(HM_V, HM_V)
        for w, nxts in MARKOV.items():
            if w in HM_STOI:
                tot = sum(nxts.values())
                for x, wt in nxts.items():
                    if x in HM_STOI:
                        T[HM_STOI[w], HM_STOI[x]] = wt / tot
        self.register_buffer("T", T)

    def forward(self, x, prompt_ids=None):
        out = x.clone()
        if prompt_ids:
            out[0, -1] += 4.0 * self.T[int(prompt_ids[-1])]
        return (out,)


class _HMTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.wte = nn.Embedding(HM_V, HM_V)
        with torch.no_grad():
            self.wte.weight.copy_(torch.eye(HM_V))          # one-hot embeddings
        self.h = nn.ModuleList([_MemoryBlock(), _MarkovBlock()])
        self.ln_f = nn.Identity()


class HandmadeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = _HMTransformer()
        self.head = nn.Linear(HM_V, HM_V, bias=False)
        with torch.no_grad():
            self.head.weight.copy_(torch.eye(HM_V))         # identity unembed
        self.tok = SimpleTok(HM_STOI, HM_ITOS)

    def get_input_embeddings(self):  return self.transformer.wte
    def get_output_embeddings(self): return self.head
    def generate(self, input_ids=None, attention_mask=None, **kw):
        return _greedy_generate(self, input_ids, **kw)

    def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
        ids = input_ids[0].tolist()
        x = self.transformer.wte(input_ids).float()
        hs = [x]; h = x
        for blk in self.transformer.h:
            (h,) = blk(h, prompt_ids=ids); hs.append(h)
        logits = self.head(self.transformer.ln_f(h))
        return _Out(logits, tuple(hs) if output_hidden_states else None)


# =============================================================================
# MODEL 2 - "glassbox": facts as RESIDUAL-STREAM key->value writes
# -----------------------------------------------------------------------------
# This is the model the original was missing. It stores facts the way real
# transformers do, so every tool works AND can be checked against ground truth.
#
# Vocab + structured embeddings (d=32). Country and its capital deliberately
# SHARE an embedding dimension, so the neighbours tool finds real geometry
# (paris is near france).
#
# Four layers:
#   L0  subject site   : (identity here) the residual the trace will restore
#   L1  pool/attention : copies subject signal from earlier positions -> last
#   L2  fact MLP       : key(subject+relation) -> relu -> value(answer dir)   <- ROME edits this kind of layer
#   L3  cleanup        : identity
#
# Ground truth you can verify:
#   - logit lens: the answer is INVISIBLE until L2, then appears. Compare with
#     handmade (sudden, no build-up) and gpt2 (fuzzy, spread over many layers).
#   - causal trace: corrupting the subject and restoring layer by layer peaks
#     at L0 - because L1's "attention" re-reads the restored subject. That is
#     the ROME story: the causal site is an early layer at the SUBJECT token.
#   - steering / neighbours: both operate on real directions, so both work.
# =============================================================================
GB_D = 32
GB_TOKS = ["<pad>", "<s>", ".", "the", "capital", "of", "is", "in",
           "france", "germany", "japan", "paris", "berlin", "tokyo",
           "london", "rome"]          # spare answers so edits can hit a fresh target
GB_STOI = {w: i for i, w in enumerate(GB_TOKS)}
GB_ITOS = {i: w for w, i in GB_STOI.items()}
GB_V = len(GB_TOKS)
GB_FACTS = [("france", "paris"), ("germany", "berlin"), ("japan", "tokyo")]


def _build_gb_embeddings():
    E = torch.zeros(GB_V, GB_D)
    def setd(tok, pairs):
        for d, v in pairs:
            E[GB_STOI[tok], d] = v
    # country/capital pairs share their first dim -> positive cosine (geometry!)
    setd("france", [(0, 1.0), (1, 0.6), (20, 0.5)])
    setd("paris",  [(0, 0.8), (2, 0.9), (21, 0.5)])
    setd("germany",[(3, 1.0), (4, 0.6), (22, 0.5)])
    setd("berlin", [(3, 0.8), (5, 0.9), (23, 0.5)])
    setd("japan",  [(6, 1.0), (7, 0.6), (24, 0.5)])
    setd("tokyo",  [(6, 0.8), (8, 0.9), (25, 0.5)])
    setd("london", [(27, 1.0), (28, 0.5)])                 # spare answers (own dirs)
    setd("rome",   [(29, 1.0), (30, 0.5)])
    setd("is",     [(9, 1.0), (26, 0.4)])                  # the relation marker
    for i, t in enumerate(GB_TOKS):                        # give fillers an id
        if E[i].abs().sum() == 0:
            E[i, 10 + i % 6] = 1.0
    return E / (E.norm(dim=-1, keepdim=True) + 1e-9)       # unit rows


GB_E = _build_gb_embeddings()
GB_SUBJ = torch.zeros(GB_D, GB_D)                          # projector onto subject dims 0..8
for _d in range(9):
    GB_SUBJ[_d, _d] = 1.0


class _GBIdent(nn.Module):
    def forward(self, x, prompt_ids=None):
        return (x.clone(),)


class _GBPool(nn.Module):
    """Toy 'attention': sum the subject-projected residual of all earlier
    positions into the last position. Corrupting the subject earlier shows up
    here; restoring the subject BEFORE this layer is what makes the trace
    recover - that is why the causal peak lands at L0, not L1."""
    def forward(self, x, prompt_ids=None):
        out = x.clone()
        if x.shape[1] > 1:
            pooled = (x[0, :-1] @ GB_SUBJ.T).sum(0)
            out[0, -1] = out[0, -1] + 0.9 * pooled
        return (out,)


class _GBFactMLP(nn.Module):
    """Geva-style key->value memory. W_in rows are (subject+relation) keys;
    relu gates which fact fires; W_out columns are answer unembed directions.
    This is structurally the exact layer ROME rewrites to edit a fact."""
    def __init__(self):
        super().__init__()
        Win = torch.zeros(len(GB_FACTS), GB_D)
        Wout = torch.zeros(GB_D, len(GB_FACTS))
        rel = GB_E[GB_STOI["is"]]
        for k, (s, a) in enumerate(GB_FACTS):
            key = (GB_E[GB_STOI[s]] @ GB_SUBJ.T) * 0.9 + rel
            Win[k] = key / key.norm()
            Wout[:, k] = GB_E[GB_STOI[a]]                  # write answer direction
        self.register_buffer("Win", Win)
        self.register_buffer("Wout", Wout)
        self.register_buffer("Win0", Win.clone())          # pristine backups for reset
        self.register_buffer("Wout0", Wout.clone())
        self.bias, self.gain = 0.85, 6.0                   # tuned: clean p~0.5, corrupt p~0.07

    def forward(self, x, prompt_ids=None):
        out = x.clone()
        pre = F.relu(self.Win @ out[0, -1] - self.bias)
        out[0, -1] = out[0, -1] + self.gain * (self.Wout @ pre)
        return (out,)


class _GBTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.wte = nn.Embedding(GB_V, GB_D)
        with torch.no_grad():
            self.wte.weight.copy_(GB_E)
        self.h = nn.ModuleList([_GBIdent(), _GBPool(), _GBFactMLP(), _GBIdent()])
        self.ln_f = nn.Identity()


class GlassBoxModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = _GBTransformer()
        self.head = nn.Linear(GB_D, GB_V, bias=False)
        with torch.no_grad():
            self.head.weight.copy_(GB_E)                   # tied unembed
        self.tok = SimpleTok(GB_STOI, GB_ITOS)

    # --- knowledge editing (ROME-style, exact on this key->value layer) -------
    @torch.no_grad()
    def edit_fact(self, subject, new_answer, method="rank1", strength=1.0):
        """Rewrite the value a fact-MLP key maps to. Methods:
           rank1 / surgical - the minimal update: change only this fact's value.
           broadcast        - DELIBERATELY sloppy: smear the delta across ALL
                              facts, so the verifier has real collateral to catch.
        """
        fm = self.transformer.h[2]                          # the FactMLP block
        subjects = [s for s, _ in GB_FACTS]
        if subject not in subjects:
            raise ValueError("unknown subject %r" % subject)
        if new_answer not in GB_STOI:
            raise ValueError("unknown answer token %r" % new_answer)
        k = subjects.index(subject)
        delta = (GB_E[GB_STOI[new_answer]] - fm.Wout0[:, k]) * float(strength)
        if method in ("rank1", "surgical"):
            fm.Wout[:, k] = fm.Wout0[:, k] + delta
        elif method == "broadcast":
            fm.Wout += delta.unsqueeze(1)                   # hits every fact
        else:
            raise ValueError("unknown method %r" % method)

    @torch.no_grad()
    def reset(self):
        fm = self.transformer.h[2]
        fm.Win.copy_(fm.Win0); fm.Wout.copy_(fm.Wout0)

    def get_input_embeddings(self):  return self.transformer.wte
    def get_output_embeddings(self): return self.head
    def generate(self, input_ids=None, attention_mask=None, **kw):
        return _greedy_generate(self, input_ids, **kw)

    def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
        ids = input_ids[0].tolist()
        x = self.transformer.wte(input_ids).float()
        hs = [x]; h = x
        for blk in self.transformer.h:
            (h,) = blk(h, prompt_ids=ids); hs.append(h)
        logits = self.head(self.transformer.ln_f(h))
        return _Out(logits, tuple(hs) if output_hidden_states else None)


# =============================================================================
# REAL MODELS - resolve the architecture-specific module paths
# =============================================================================
def _resolve(model, paths):
    for path in paths:
        obj, ok = model, True
        for part in path.split("."):
            if hasattr(obj, part):
                obj = getattr(obj, part)
            else:
                ok = False; break
        if ok:
            return obj
    return None


def get_blocks(model):
    blocks = _resolve(model, ["transformer.h", "model.layers",
                              "gpt_neox.layers", "model.decoder.layers"])
    if blocks is None:
        raise RuntimeError("Could not locate transformer blocks.")
    return blocks


def get_final_norm(model):
    norm = _resolve(model, ["transformer.ln_f", "model.norm",
                            "gpt_neox.final_layer_norm",
                            "model.decoder.final_layer_norm"])
    return norm if norm is not None else (lambda x: x)


def get_head(model):
    return model.get_output_embeddings()


def get_handles(name):
    if name not in MODELS:
        if name == "handmade":
            m = HandmadeModel().eval(); MODELS[name] = (m, m.tok)
        elif name == "glassbox":
            m = GlassBoxModel().eval(); MODELS[name] = (m, m.tok)
        else:
            tok = AutoTokenizer.from_pretrained(name)
            model = AutoModelForCausalLM.from_pretrained(
                name, torch_dtype=DTYPE).to(DEVICE).eval()
            MODELS[name] = (model, tok)
    return MODELS[name]


def load_model(name):
    name = name.strip()
    model, _ = get_handles(name)
    STATE["name"] = name
    return "Loaded **%s** (%d layers)." % (name, len(get_blocks(model)))


# =============================================================================
# Shared readout: project every layer's last-token residual to a vocab dist.
# =============================================================================
@torch.no_grad()
def layer_distributions(model, tok, prompt):
    inputs = tok(prompt, return_tensors="pt").to(DEVICE)
    out = model(**inputs, output_hidden_states=True)
    hs = out.hidden_states
    norm, head, n = get_final_norm(model), get_head(model), len(out.hidden_states)
    dists = []
    for i, layer_hs in enumerate(hs):
        vec = layer_hs[0, -1].to(DTYPE)
        # HF convention: the LAST hidden_states entry is already post-ln_f,
        # so skip norm there; apply ln_f to intermediates (logit-lens style).
        logits = head(vec) if i == n - 1 else head(norm(vec))
        dists.append(("embed" if i == 0 else "L%d" % i, F.softmax(logits, dim=-1)))
    return dists


def _entropy_bits(probs):
    p = probs.clamp_min(1e-12)
    return float(-(p * p.log()).sum() / math.log(2))


# =============================================================================
# TAB 1 - LOGIT LENS: watch the answer condense out of the residual stream
# =============================================================================
@torch.no_grad()
def logit_lens(prompt, top_k, track):
    if STATE["name"] is None:
        return "Load a model first."
    model, tok = get_handles(STATE["name"])
    top_k = int(top_k)
    tids = tok.encode(track, add_special_tokens=False) if track.strip() else []
    tid = tids[0] if tids else None
    dists = layer_distributions(model, tok, prompt)
    header = "layer | top tokens (prob)                       | entropy" \
             + ("   | p(%r)" % track if tid is not None else "")
    lines = ["prompt: %r" % prompt, header, "-" * len(header)]
    for label, probs in dists:
        p, idx = probs.topk(top_k)
        shown = "  ".join("%r:%.2f" % (tok.decode([t]).replace("\n", "\\n"), v)
                          for t, v in zip(idx.tolist(), p.tolist()))
        row = "%5s | %-40s | %4.1fb" % (label, shown, _entropy_bits(probs))
        if tid is not None:
            row += "   | %.3f" % probs[tid].item()
        lines.append(row)
    return "\n".join(lines)


# =============================================================================
# TAB 2 - NEIGHBOURS: the geometry of the (un)embedding space
# =============================================================================
@torch.no_grad()
def neighbors(word, top_k):
    if STATE["name"] is None:
        return "Load a model first."
    model, tok = get_handles(STATE["name"])
    top_k = int(top_k)
    ids = tok.encode(word, add_special_tokens=False)
    if not ids:
        return "Could not tokenize %r." % word
    tid = ids[0]
    W = F.normalize(get_head(model).weight.to(DTYPE), dim=-1)
    sims = W @ W[tid]
    vals, idx = sims.topk(top_k + 1)
    note = ""
    if STATE["name"] == "handmade":
        note = ("(handmade uses one-hot embeddings, so every token is "
                "orthogonal -> all cosines are 0 by construction. This is the "
                "tool telling the truth about a model with no vocab geometry.)\n")
    lines = [note + "neighbours of %r:" % word]
    for v, j in zip(vals.tolist(), idx.tolist()):
        if j != tid:
            lines.append("  %14r  cos=%.3f" % (tok.decode([j]), v))
    return "\n".join(lines[: top_k + 1])


# =============================================================================
# TAB 3 - STEERING: bend behaviour by adding a direction, no retraining
# =============================================================================
def _make_steer_hook(direction, alpha):
    d = direction * alpha
    def hook(module, inp, out):
        if isinstance(out, tuple):
            return (out[0] + d.to(out[0].dtype).to(out[0].device),) + out[1:]
        return out + d.to(out.dtype).to(out.device)
    return hook


@torch.no_grad()
def steer_generate(prompt, source, target, layer, alpha, max_new):
    if STATE["name"] is None:
        return "Load a model first.", ""
    model, tok = get_handles(STATE["name"])
    layer, max_new = int(layer), int(max_new)
    emb = model.get_input_embeddings().weight
    def first_emb(w):
        ids = tok.encode(w, add_special_tokens=False)
        return emb[ids[0]] if ids else torch.zeros(emb.shape[-1], device=DEVICE)
    direction = F.normalize((first_emb(target) - first_emb(source)).to(DTYPE), dim=-1)
    inputs = tok(prompt, return_tensors="pt").to(DEVICE)
    gk = dict(max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id)
    base = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
    blocks = get_blocks(model)
    layer = max(0, min(layer, len(blocks) - 1))
    handle = blocks[layer].register_forward_hook(_make_steer_hook(direction, alpha))
    try:
        steered = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
    finally:
        handle.remove()
    return base, "steer %r -> %r @ L%d alpha=%s\n%s" % (source, target, layer, alpha, steered)


# =============================================================================
# TAB 4 - DIFF: compare two models on one prompt, aligned by relative depth
# =============================================================================
@torch.no_grad()
def diff_models(name_a, name_b, prompt, target, top_k):
    ma, ta = get_handles(name_a.strip())
    mb, tb = get_handles(name_b.strip())
    ida = ta.encode(target, add_special_tokens=False)
    idb = tb.encode(target, add_special_tokens=False)
    if not ida or not idb:
        return "Could not tokenize target %r in both models." % target
    ida, idb = ida[0], idb[0]
    da = layer_distributions(ma, ta, prompt)
    db = layer_distributions(mb, tb, prompt)
    nA, nB = len(da) - 1, len(db) - 1
    def top1(probs, tok):
        v, i = probs.topk(1)
        return "%r:%.2f" % (tok.decode([i.item()]), v.item())
    lines = ["prompt: %r   target: %r" % (prompt, target),
             "%18s | %16s %6s | %16s %6s | %7s"
             % ("depth (A/B)", "A top1", "pA", "B top1", "pB", "dp")]
    for i in range(nA + 1):
        frac = (i / nA) if nA > 0 else 0.0
        j = max(0, min(round(frac * nB), nB)) if nB > 0 else 0
        la, pa = da[i]; lb, pb = db[j]
        a_t, b_t = pa[ida].item(), pb[idb].item()
        lines.append("%18s | %16s %6.3f | %16s %6.3f | %+7.3f"
                     % ("%3.0f%% (%s/%s)" % (frac * 100, la, lb),
                        top1(pa, ta), a_t, top1(pb, tb), b_t, b_t - a_t))
    return "\n".join(lines)


# =============================================================================
# TAB 5 - CAUSAL TRACE: corrupt the subject, restore each layer, find the site
# -----------------------------------------------------------------------------
# This is ROME's activation patching. We:
#   1. record clean activations and clean p(target)
#   2. add gaussian noise to the SUBJECT token embeddings -> corrupt p(target)
#   3. for each layer L: run corrupted, but force layer L's residual back to
#      the clean values at the subject positions. How much p(target) recovers
#      tells you how causally important layer L is. The peak is "the site".
# The glass-box gives a clean, verifiable peak; gpt2 gives a realistic band.
# =============================================================================
def _find_subject_positions(tok, input_ids, prompt, subject):
    """Locate subject token positions, with a path for slow (non-fast) toks."""
    seq_len = input_ids.shape[1]
    if getattr(tok, "is_fast", False):
        enc = tok(prompt, return_tensors="pt", return_offsets_mapping=True)
        cs = prompt.find(subject)
        if cs >= 0:
            ce = cs + len(subject)
            offs = enc["offset_mapping"][0].tolist()
            pos = [i for i, (s, e) in enumerate(offs) if e > cs and s < ce]
            if pos:
                return [p for p in pos if p != seq_len - 1], ""
    else:
        sub_ids = tok.encode(subject, add_special_tokens=False)
        seq = input_ids[0].tolist()
        pos = [i for i, t in enumerate(seq) if t in sub_ids]
        if pos:
            return [p for p in pos if p != seq_len - 1], ""
    fb = list(range(0, max(1, seq_len - 1)))[: max(1, seq_len // 2)]
    return fb, "(subject not found; using fallback window)\n"


@torch.no_grad()
def causal_trace(prompt, subject, target, noise_scale, seed):
    if STATE["name"] is None:
        return "Load a model first."
    model, tok = get_handles(STATE["name"])
    seed, noise_scale = int(seed), float(noise_scale)
    inputs = tok(prompt, return_tensors="pt").to(DEVICE)
    input_ids = inputs["input_ids"]
    positions, note = _find_subject_positions(tok, input_ids, prompt, subject)
    if not positions:
        return note + "No valid subject positions."
    target_ids = tok.encode(target, add_special_tokens=False)
    if not target_ids:
        return "Could not tokenize target %r." % target
    tid = target_ids[0]

    out_clean = model(**inputs, output_hidden_states=True)
    clean_hs = out_clean.hidden_states
    clean_p = F.softmax(out_clean.logits[0, -1].to(DTYPE), dim=-1)[tid].item()

    emb_module = model.get_input_embeddings()
    std = emb_module.weight.std().item()
    hidden = emb_module.weight.shape[-1]
    torch.manual_seed(seed)
    noise = torch.randn(len(positions), hidden, device=DEVICE) * noise_scale * std

    def corrupt_hook(module, inp, out):
        out = out.clone()
        for k, p in enumerate(positions):
            out[0, p] = out[0, p] + noise[k].to(out.dtype)
        return out

    h = emb_module.register_forward_hook(corrupt_hook)
    corrupt_p = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
    h.remove()

    blocks, rows = get_blocks(model), []
    for l in range(len(blocks)):
        clean_layer_hs = clean_hs[l + 1][0]
        def restore_hook(module, inp, out, _clean=clean_layer_hs):
            if isinstance(out, tuple):
                h0 = out[0].clone()
                for p in positions:
                    h0[0, p] = _clean[p].to(h0.dtype)
                return (h0,) + out[1:]
            h0 = out.clone()
            for p in positions:
                h0[0, p] = _clean[p].to(h0.dtype)
            return h0
        h1 = emb_module.register_forward_hook(corrupt_hook)
        h2 = blocks[l].register_forward_hook(restore_hook)
        p_r = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
        h1.remove(); h2.remove()
        rows.append((l, p_r))

    denom = clean_p - corrupt_p
    lines = [note + "prompt: %r" % prompt,
             "subject: %r   target: %r" % (subject, target),
             "clean p=%.3f   corrupt p=%.3f   noise=%sx std" % (clean_p, corrupt_p, noise_scale),
             "", "%6s | %9s | %9s" % ("layer", "p(target)", "recovery")]
    best_l, best_r = 0, -1e9
    for l, p_r in rows:
        rec = (p_r - corrupt_p) / denom if abs(denom) > 1e-6 else 0.0
        if rec > best_r:
            best_r, best_l = rec, l
        lines.append("  L%-3d | %9.3f | %8.1f%%" % (l, p_r, rec * 100))
    lines.append("")
    lines.append("# peak at L%d (%.0f%% recovery) <- the causal site" % (best_l, best_r * 100))
    if abs(denom) < 1e-6:
        lines.append("# (corruption didn't move p(target): on 'handmade' this is "
                     "EXPECTED - the fact lives in a string match, not activations.)")
    return "\n".join(lines)


# =============================================================================
# EDIT LOOP  +  VERIFICATION HARNESS  (the ROME sandbox)
# -----------------------------------------------------------------------------
# Apply a knowledge edit to the glass-box, then PROVE it was surgical:
#   efficacy    - did the target fact change to the new answer?
#   specificity - did the OTHER facts stay exactly as they were? (locality)
#   fluency     - did the output distribution stay sane (no entropy collapse)?
# Because we own the ground truth, "nothing else broke" is checkable, not vibes.
# An optional pass sends the before/after battery to Claude for an independent
# verdict - real LLM calls verifying the edit.
# =============================================================================
GB_ANSWERS = ["paris", "berlin", "tokyo", "london", "rome"]


@torch.no_grad()
def _probe_battery(model, tok):
    """Run every known fact + a neutral format probe; record what the model says."""
    rows = {}
    for country, orig in GB_FACTS:
        prompt = "the capital of %s is" % country
        probs = F.softmax(model(**tok(prompt, return_tensors="pt").to(DEVICE)
                                ).logits[0, -1].to(DTYPE), dim=-1)
        v, i = probs.topk(1)
        rows[country] = {
            "prompt": prompt, "orig": orig,
            "top1": tok.decode([i.item()]), "top1_p": v.item(),
            "p_orig": probs[GB_STOI[orig]].item(),
            "cand": {a: probs[GB_STOI[a]].item() for a in GB_ANSWERS},
            "entropy": _entropy_bits(probs),
        }
    return rows


def _verdict(before, after, subject, new_answer, drift_thresh=0.05):
    eff = after[subject]["top1"] == new_answer
    collateral, max_drift = [], 0.0
    for c in before:
        if c == subject:
            continue
        d = abs(after[c]["p_orig"] - before[c]["p_orig"])
        max_drift = max(max_drift, d)
        if after[c]["top1"] != before[c]["top1"] or d > drift_thresh:
            collateral.append(c)
    ent_blowup = any(abs(after[c]["entropy"] - before[c]["entropy"]) > 0.8 for c in before)
    surgical = eff and not collateral and not ent_blowup
    return eff, collateral, max_drift, ent_blowup, surgical


def edit_and_verify(subject, new_answer, method, strength, use_llm,
                    anthropic_key, anthropic_model, hf_token, hf_model,
                    local_url, local_model):
    model, tok = get_handles("glassbox")
    STATE["name"] = "glassbox"
    model.reset()
    before = _probe_battery(model, tok)
    try:
        model.edit_fact(subject.strip(), new_answer.strip(), method, float(strength))
    except ValueError as e:
        return "Edit failed: %s\nValid subjects: france, germany, japan. " \
               "Valid answers: %s" % (e, ", ".join(GB_ANSWERS))
    after = _probe_battery(model, tok)
    eff, collateral, max_drift, ent, surgical = _verdict(before, after, subject, new_answer)

    L = ["EDIT: %s's capital -> %r   (method=%s, strength=%s)" %
         (subject, new_answer, method, strength), "",
         "%-9s | %-22s | %-22s" % ("fact", "before (top1 / p_orig)", "after (top1 / p_orig)"),
         "-" * 60]
    for c in before:
        b, a = before[c], after[c]
        flag = "  <- TARGET" if c == subject else ("  <- COLLATERAL" if c in collateral else "")
        L.append("%-9s | %-22s | %-22s%s" % (
            c, "%s / %.2f" % (b["top1"], b["p_orig"]),
            "%s / %.2f" % (a["top1"], a["p_orig"]), flag))
    L += ["",
          "efficacy    : %s (target now says %r, p=%.2f)" %
          ("PASS" if eff else "FAIL", after[subject]["top1"], after[subject]["top1_p"]),
          "specificity : %s (max drift on other facts = %.3f%s)" %
          ("PASS" if not collateral else "FAIL: " + ", ".join(collateral),
           max_drift, "; entropy spike" if ent else ""),
          "", "VERDICT: %s" % ("SURGICAL EDIT" if surgical else "COLLATERAL DAMAGE")]
    L.append("(model is left in the edited state - inspect it in tabs 1-5, or hit Reset.)")

    llm_report = ""
    if use_llm:
        providers = [
            {"type": "anthropic", "key": anthropic_key, "model": anthropic_model},
            {"type": "hf",        "key": hf_token,       "model": hf_model},
            {"type": "local",     "url": local_url,      "model": local_model},
        ]
        llm_report = _llm_judge_chain(before, after, subject, new_answer, providers)
        L += ["", "-" * 60, "INDEPENDENT LLM REVIEW:", llm_report]

    report = "\n".join(L)
    _log_session(subject, new_answer, method, strength, before, after,
                eff, collateral, max_drift, surgical, llm_report)
    return report


def reset_glassbox():
    model, _ = get_handles("glassbox")
    model.reset()
    return "Glass-box weights restored to pristine. Re-run any tab to confirm."


# --- optional: real LLM calls to verify the edit, with a 3-tier fallback chain
# Anthropic (Claude) -> Hugging Face Inference -> local OpenAI-compatible server
# (e.g. LM Studio). Tries each in order; the first provider that's configured
# AND reachable wins. This means you're never blocked on one vendor being down
# or on not having an Anthropic key at all - your own RTX 5090 can be the judge.
def _build_judge_prompt(before, after, subject, new_answer):
    import json
    payload = {c: {"prompt": before[c]["prompt"],
                   "before_top1": before[c]["top1"], "before_p_orig": round(before[c]["p_orig"], 3),
                   "after_top1": after[c]["top1"],  "after_p_orig": round(after[c]["p_orig"], 3)}
               for c in before}
    sys = ("You audit knowledge edits to a small language model. The intended edit "
           "is: make %s's capital '%s'. Given before/after predictions for every "
           "known fact, decide if the edit was SURGICAL (target changed, all other "
           "facts unchanged) or caused COLLATERAL damage. Reply ONLY as JSON, no "
           'prose, no markdown fences: {"verdict":"surgical|collateral",'
           '"target_changed":bool,"damaged_facts":[...],"confidence":0-1,'
           '"reason":"one sentence"}.') % (subject, new_answer)
    return sys, json.dumps(payload)


def _parse_verdict_json(text, provider_label):
    import json
    clean = text.strip().strip("`")
    if clean.lower().startswith("json"):
        clean = clean[4:].strip()
    start, end = clean.find("{"), clean.rfind("}")
    if start != -1 and end != -1:
        clean = clean[start:end + 1]
    v = json.loads(clean)
    return ("[%s] verdict=%s  target_changed=%s  confidence=%s\n  damaged: %s\n  reason: %s"
            % (provider_label, v.get("verdict"), v.get("target_changed"), v.get("confidence"),
               v.get("damaged_facts") or "none", v.get("reason")))


def _try_anthropic(sys, user, cfg):
    import os, json
    key = (cfg.get("key") or "").strip() or os.environ.get("ANTHROPIC_API_KEY", "")
    if not key:
        return None, "anthropic: no key configured"
    body = {"model": (cfg.get("model") or "claude-sonnet-4-6").strip(),
            "max_tokens": 400, "system": sys, "messages": [{"role": "user", "content": user}]}
    try:
        try:
            import anthropic
            client = anthropic.Anthropic(api_key=key)
            msg = client.messages.create(**body)
            text = "".join(b.text for b in msg.content if getattr(b, "type", "") == "text")
        except ImportError:
            import urllib.request
            req = urllib.request.Request(
                "https://api.anthropic.com/v1/messages", data=json.dumps(body).encode(),
                headers={"x-api-key": key, "anthropic-version": "2023-06-01",
                         "content-type": "application/json"})
            with urllib.request.urlopen(req, timeout=30) as r:
                data = json.loads(r.read())
            text = "".join(b.get("text", "") for b in data.get("content", [])
                           if b.get("type") == "text")
        return _parse_verdict_json(text, "anthropic:" + body["model"]), None
    except Exception as e:
        return None, "anthropic failed: %s" % e


def _try_hf(sys, user, cfg):
    token = (cfg.get("key") or "").strip()
    model = (cfg.get("model") or "Qwen/Qwen2.5-7B-Instruct").strip()
    if not token:
        import os
        token = os.environ.get("HF_TOKEN", "")
    if not token:
        return None, "hf: no token configured"
    try:
        from huggingface_hub import InferenceClient
        client = InferenceClient(model=model, token=token)
        resp = client.chat_completion(
            messages=[{"role": "system", "content": sys}, {"role": "user", "content": user}],
            max_tokens=400)
        text = resp.choices[0].message.content
        return _parse_verdict_json(text, "hf:" + model), None
    except Exception as e:
        return None, "hf failed: %s" % e


def _try_local(sys, user, cfg):
    """Any OpenAI-compatible /v1/chat/completions server - LM Studio, vLLM,
    Ollama (with its OpenAI shim), text-generation-webui, etc."""
    import json, urllib.request
    url = (cfg.get("url") or "").strip().rstrip("/")
    if not url:
        return None, "local: no URL configured"
    model = (cfg.get("model") or "local-model").strip()
    body = json.dumps({"model": model, "max_tokens": 400, "temperature": 0,
                       "messages": [{"role": "system", "content": sys},
                                   {"role": "user", "content": user}]}).encode()
    try:
        req = urllib.request.Request(
            url + "/v1/chat/completions", data=body,
            headers={"content-type": "application/json"})
        with urllib.request.urlopen(req, timeout=20) as r:
            data = json.loads(r.read())
        text = data["choices"][0]["message"]["content"]
        return _parse_verdict_json(text, "local:" + model + "@" + url), None
    except Exception as e:
        return None, "local failed: %s" % e


def _llm_judge_chain(before, after, subject, new_answer, providers):
    sys, user = _build_judge_prompt(before, after, subject, new_answer)
    dispatch = {"anthropic": _try_anthropic, "hf": _try_hf, "local": _try_local}
    skipped = []
    for cfg in providers:
        fn = dispatch.get(cfg["type"])
        if fn is None:
            continue
        result, err = fn(sys, user, cfg)
        if result is not None:
            note = ("" if not skipped else
                    "(skipped: %s)\n" % "; ".join(skipped))
            return note + result
        skipped.append(err)
    return ("all providers unavailable:\n  " + "\n  ".join(skipped) +
            "\n(configure at least one: Anthropic key, HF token, or a local "
            "OpenAI-compatible server URL like http://192.168.188.25:1234)")


# --- session log: every edit+verify run is appended here as JSON, so you can
# download it, or paste the markdown block straight into a future chat with
# Claude for review ("did all work, here's the log").
SESSION_LOG = []


def _log_session(subject, new_answer, method, strength, before, after,
                 eff, collateral, max_drift, surgical, llm_report):
    import datetime
    SESSION_LOG.append({
        "ts": datetime.datetime.utcnow().isoformat() + "Z",
        "subject": subject, "new_answer": new_answer, "method": method,
        "strength": strength, "efficacy_pass": bool(eff),
        "collateral": collateral, "max_drift": round(max_drift, 4),
        "verdict": "SURGICAL" if surgical else "COLLATERAL",
        "before": {c: {"top1": before[c]["top1"], "p_orig": round(before[c]["p_orig"], 4)}
                  for c in before},
        "after": {c: {"top1": after[c]["top1"], "p_orig": round(after[c]["p_orig"], 4)}
                 for c in after},
        "llm_review": llm_report or None,
    })


def export_session_log():
    import json, os
    if not SESSION_LOG:
        return None, "No edits run yet this session - nothing to export."
    os.makedirs("/mnt/user-data/outputs", exist_ok=True)
    path = "/mnt/user-data/outputs/edit_session_log.json"
    json.dump(SESSION_LOG, open(path, "w"), indent=2)
    # also a markdown rendition meant to be pasted straight into a chat
    md = ["# Edit session log\n"]
    for i, e in enumerate(SESSION_LOG, 1):
        md.append("## Edit %d - %s (%s, %s, strength=%s)\n" %
                  (i, e["verdict"], e["subject"] + "->" + e["new_answer"],
                   e["method"], e["strength"]))
        md.append("- efficacy: %s, max collateral drift: %.4f, damaged: %s" %
                  ("pass" if e["efficacy_pass"] else "fail", e["max_drift"],
                   e["collateral"] or "none"))
        if e["llm_review"]:
            md.append("- LLM review: " + e["llm_review"].replace("\n", " "))
        md.append("")
    md_path = "/mnt/user-data/outputs/edit_session_log.md"
    open(md_path, "w").write("\n".join(md))
    return path, "Wrote %d edit(s) to %s and %s" % (len(SESSION_LOG), path, md_path)


# =============================================================================
# EXPORT  +  UPLOAD TO HUGGING FACE
# -----------------------------------------------------------------------------
# Save the glass-box as a self-contained, reloadable repo (weights + config +
# vocab + a standalone modeling file + a model card), and optionally push it -
# and/or this whole app as a Space - to the Hub.
# =============================================================================
_MODELING_PY = '''"""Standalone glass-box model - reload with no other files.

    from modeling_glassbox import load
    m, tok = load(".")            # folder containing config/weights/vocab
    print(tok.decode(m.generate(tok("the capital of france is"))[0]))
"""
import json, torch, torch.nn as nn, torch.nn.functional as F
from safetensors.torch import load_file

def load(path="."):
    cfg = json.load(open(f"{path}/config.json"))
    stoi = json.load(open(f"{path}/vocab.json")); itos = {i: w for w, i in stoi.items()}
    D, V = cfg["d_model"], len(stoi); facts = [tuple(f) for f in cfg["facts"]]
    SUBJ = torch.zeros(D, D)
    for d in range(cfg["subject_dims"]): SUBJ[d, d] = 1.0

    class Tok:
        is_fast = False
        def __init__(s): s.eos_token_id = stoi["."]
        def _ids(s, t): return [stoi.get(w, stoi["<s>"]) for w in t.lower().replace(".", " .").split()] or [stoi["<s>"]]
        def __call__(s, t, **k):
            import torch as T; return {"input_ids": T.tensor([s._ids(t)])}
        def decode(s, ids, **k): return " ".join(itos.get(int(i), "?") for i in ids)
    class Ident(nn.Module):
        def forward(s, x): return (x.clone(),)
    class Pool(nn.Module):
        def forward(s, x):
            o = x.clone()
            if x.shape[1] > 1: o[0, -1] = o[0, -1] + 0.9 * (x[0, :-1] @ SUBJ.T).sum(0)
            return (o,)
    class FactMLP(nn.Module):
        def __init__(s):
            super().__init__()
            s.register_buffer("Win", torch.zeros(len(facts), D))
            s.register_buffer("Wout", torch.zeros(D, len(facts)))
            s.bias, s.gain = cfg["bias"], cfg["gain"]
        def forward(s, x):
            o = x.clone(); pre = F.relu(s.Win @ o[0, -1] - s.bias)
            o[0, -1] = o[0, -1] + s.gain * (s.Wout @ pre); return (o,)
    class T(nn.Module):
        def __init__(s):
            super().__init__(); s.wte = nn.Embedding(V, D)
            s.h = nn.ModuleList([Ident(), Pool(), FactMLP(), Ident()]); s.ln_f = nn.Identity()
    class GlassBox(nn.Module):
        def __init__(s):
            super().__init__(); s.transformer = T(); s.head = nn.Linear(D, V, bias=False)
        def get_input_embeddings(s): return s.transformer.wte
        def forward(s, input_ids=None, **k):
            x = s.transformer.wte(input_ids)
            for b in s.transformer.h: (x,) = b(x)
            class O: pass
            o = O(); o.logits = s.head(x); return o
        @torch.no_grad()
        def generate(s, input_ids=None, max_new_tokens=12, **k):
            ids = input_ids
            for _ in range(max_new_tokens):
                ids = torch.cat([ids, s(input_ids=ids).logits[0, -1].argmax().view(1, 1)], 1)
            return ids
    m = GlassBox().eval()
    sd = load_file(f"{path}/model.safetensors")
    m.load_state_dict({k: v for k, v in sd.items() if not k.endswith("0")}, strict=False)
    return m, Tok()
'''


def export_glassbox(outdir="glassbox_export"):
    import os, json
    from safetensors.torch import save_file
    os.makedirs(outdir, exist_ok=True)
    model, _ = get_handles("glassbox")
    sd = {k: v.contiguous() for k, v in model.state_dict().items()}
    save_file(sd, os.path.join(outdir, "model.safetensors"))
    json.dump({"model_type": "glassbox", "d_model": GB_D, "vocab_size": GB_V,
               "subject_dims": 9, "bias": model.transformer.h[2].bias,
               "gain": model.transformer.h[2].gain,
               "facts": [list(f) for f in GB_FACTS]},
              open(os.path.join(outdir, "config.json"), "w"), indent=2)
    json.dump(GB_STOI, open(os.path.join(outdir, "vocab.json"), "w"), indent=2)
    open(os.path.join(outdir, "modeling_glassbox.py"), "w").write(_MODELING_PY)
    open(os.path.join(outdir, "README.md"), "w").write(
        "---\nlicense: mit\ntags: [interpretability, glass-box, rome, toy-model]\n---\n\n"
        "# Glass-box interpretability model\n\n"
        "A tiny transformer-shaped model whose facts are stored as key->value "
        "writes into the residual stream, so logit-lens, activation steering and "
        "ROME causal tracing all reproduce the *known* ground truth. Built as a "
        "verification harness for interpretability code.\n\n"
        "```python\nfrom modeling_glassbox import load\n"
        "m, tok = load('.')\n"
        "print(tok.decode(m.generate(tok('the capital of france is')['input_ids'])[0]))\n```\n\n"
        "Facts: " + ", ".join("%s->%s" % f for f in GB_FACTS) + ".\n")
    return outdir


def upload_to_hf(repo_id, token, what, app_path=__file__):
    """Push the model and/or this app (as a Space) to the Hub."""
    import os
    try:
        from huggingface_hub import HfApi
    except ImportError:
        return "huggingface_hub not installed. `pip install huggingface_hub`."
    token = (token or "").strip() or os.environ.get("HF_TOKEN", "")
    if not token:
        return "No HF token. Paste a write token or set HF_TOKEN."
    if not repo_id.strip():
        return "Enter a repo id like 'Chris4K/glassbox-interp'."
    api, logs = HfApi(token=token), []
    try:
        if what in ("model", "both"):
            d = export_glassbox()
            api.create_repo(repo_id, repo_type="model", exist_ok=True)
            api.upload_folder(folder_path=d, repo_id=repo_id, repo_type="model")
            logs.append("model -> https://huggingface.co/%s" % repo_id)
        if what in ("space", "both"):
            sid = repo_id + "-space" if what == "both" else repo_id
            api.create_repo(sid, repo_type="space", space_sdk="gradio", exist_ok=True)
            api.upload_file(path_or_fileobj=app_path, path_in_repo="app.py",
                            repo_id=sid, repo_type="space")
            req = "torch\ntransformers\ngradio\nsafetensors\nhuggingface_hub\nanthropic\n"
            api.upload_file(path_or_fileobj=req.encode(), path_in_repo="requirements.txt",
                            repo_id=sid, repo_type="space")
            logs.append("space -> https://huggingface.co/spaces/%s" % sid)
        return "Uploaded:\n  " + "\n  ".join(logs)
    except Exception as e:
        return "Upload failed: %s" % e


# --- upload a REAL model (e.g. a VINDEX-edited Llama checkpoint), not the toy.
# This does NOT load the model into memory (multi-GB Llama weights don't need
# to round-trip through Python) - it just pushes whatever's already on disk.
# Point it at the local folder produced by your save_pretrained()/VINDEX run:
# expects the usual HF layout (config.json + .safetensors shards + tokenizer
# files). Note: gated models (e.g. meta-llama/*) require the destination repo
# to either be your own namespace or one you have write access to - the Hub's
# license gate is independent of this upload step.
def upload_local_checkpoint(local_dir, repo_id, token, private, commit_message):
    import os
    try:
        from huggingface_hub import HfApi
    except ImportError:
        return "huggingface_hub not installed. `pip install huggingface_hub`."
    local_dir = (local_dir or "").strip()
    repo_id = (repo_id or "").strip()
    if not local_dir or not os.path.isdir(local_dir):
        return "local_dir %r does not exist or is not a directory." % local_dir
    if not repo_id:
        return "Enter a repo id like 'Chris4K/vindex-llama3-edited'."
    token = (token or "").strip() or os.environ.get("HF_TOKEN", "")
    if not token:
        return "No HF token. Paste a write token or set HF_TOKEN."
    has_cfg = os.path.exists(os.path.join(local_dir, "config.json"))
    has_weights = any(f.endswith((".safetensors", ".bin"))
                      for f in os.listdir(local_dir))
    warn = "" if (has_cfg and has_weights) else (
        "WARNING: folder is missing config.json or weight files - this may "
        "not be a loadable HF checkpoint. Uploading anyway.\n")
    api = HfApi(token=token)
    try:
        api.create_repo(repo_id, repo_type="model", private=bool(private), exist_ok=True)
        api.upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model",
                          commit_message=(commit_message or "upload checkpoint").strip())
        return (warn + "Uploaded %s -> https://huggingface.co/%s\n"
                "Files: %s" % (local_dir, repo_id, ", ".join(sorted(os.listdir(local_dir))[:12])))
    except Exception as e:
        return warn + "Upload failed: %s" % e


# =============================================================================
# UI
# =============================================================================
INTRO = """
# Compression Navigator
**An LLM is a lossy codec for text.** Training compresses a corpus into weights;
a forward pass decompresses a continuation. These five tools let you watch that
decompression and find where facts physically live.

Each tab is a real interpretability technique: **logit lens, embedding
neighbours, activation steering, cross-model diff, and causal tracing (ROME).**

### Three models, on purpose
| name | how it stores facts | what it teaches |
|---|---|---|
| **`glassbox`** | key→value writes into the **residual stream** (like a real transformer / what ROME edits) | the tools **work and are verifiable** against ground truth you can read in the source |
| **`handmade`** | a **lookup table** keyed on the prompt string (a side channel) | a model can be **invisible** to residual-stream interpretability — a real limitation |
| **`gpt2`** | learned, fuzzy, **distributed** over many layers | what the real, messy thing looks like |

**Suggested order:** load `glassbox` first (see "correct"), then `handmade`
(see a failure mode), then `gpt2` (see reality). Type a name below and Load.
"""

with gr.Blocks(title="Compression Navigator") as demo:
    gr.Markdown(INTRO)
    with gr.Row():
        model_name = gr.Textbox(value="glassbox", label="model name or HF id")
        load_btn = gr.Button("Load", variant="primary")
    load_status = gr.Markdown()
    load_btn.click(load_model, inputs=model_name, outputs=load_status)

    # ---- TAB 1 -------------------------------------------------------------
    with gr.Tab("1 · Decompress (logit lens)"):
        gr.Markdown("""
### Logit lens — watch the answer condense, layer by layer
**What it does:** takes the last-token residual at *every* layer and reads it
through the unembedding, as if the model had to answer right there. You see the
prediction form.

**How to read it:** each row is a layer. Watch your tracked token's probability
(right column) climb, and watch **entropy** (bits) fall as the model commits.

**Ground truth to check:**
- `glassbox` — `paris` is ~0 until **L3** (the readout right after the fact-MLP), then jumps to ~0.51. Sharp and localised because you put it there.
- `handmade` — the answer snaps to 1.00 at **L1** with zero build-up (it's a lookup, not a computation).
- `gpt2` — the answer accretes *gradually* across many middle/late layers. That smear is what "distributed representation" actually looks like.

*(Numbering note: the lens counts from the embedding, so `L1` is after the first block. The causal-trace tab counts blocks from `L0`. So the fact-MLP is lens-`L3` / trace-block-`L2`, and its causal site shows at trace-`L0`.)*
""")
        ll_prompt = gr.Textbox(value="the capital of france is", label="prompt")
        with gr.Row():
            ll_k = gr.Slider(1, 10, value=3, step=1, label="top-k per layer")
            ll_track = gr.Textbox(value="paris", label="track this token's prob")
        ll_out = gr.Textbox(label="output", lines=18)
        gr.Button("Run").click(logit_lens, [ll_prompt, ll_k, ll_track], ll_out)

    # ---- TAB 2 -------------------------------------------------------------
    with gr.Tab("2 · Triangulate (neighbours)"):
        gr.Markdown("""
### Neighbours — the geometry of the vocabulary
**What it does:** ranks tokens by cosine similarity of their unembedding rows.
Directions that point the same way are "near" in the model's compressed space.

**How to read it:** high cosine = the model treats these tokens as related.

**Ground truth to check:**
- `glassbox` — `paris` is near `france` (cos ≈ 0.48): the source deliberately makes a capital share a dimension with its country. Real geometry, by design.
- `handmade` — **every** cosine is 0. One-hot embeddings are mutually orthogonal, so there's no geometry at all. The tool is correctly reporting "nothing here."
- `gpt2` — neighbours are messy but meaningful (casing variants, plurals, semantic kin).
""")
        nb_word = gr.Textbox(value="paris", label="word")
        nb_k = gr.Slider(5, 25, value=10, step=1, label="top neighbours")
        nb_out = gr.Textbox(label="output", lines=15)
        gr.Button("Run").click(neighbors, [nb_word, nb_k], nb_out)

    # ---- TAB 3 -------------------------------------------------------------
    with gr.Tab("3 · Re-route (steering)"):
        gr.Markdown("""
### Steering — bend behaviour with a direction, no retraining
**What it does:** builds the vector `emb(target) − emb(source)` and *adds* it to
a layer's output during generation. The model drifts from `source` toward
`target`. This is the cheap cousin of fine-tuning (ActAdd / representation
engineering).

**How to read it:** compare *baseline* vs *steered*. Raise **strength** until the
output flips; too high and it turns to noise (you've knocked the residual off
the manifold).

**Tips:** on `gpt2` try `from: Paris  to: London` on the France prompt, layer
0–4, strength 6–14. On `glassbox` it works cleanly too — `from: france
to: japan` at layer 0, strength 8, flips the output from `paris` to `tokyo`
(you're pushing the residual along the subject→subject direction the fact-MLP
keys on).
""")
        st_prompt = gr.Textbox(value="the capital of france is", label="prompt")
        with gr.Row():
            st_src = gr.Textbox(value="Paris", label="from")
            st_tgt = gr.Textbox(value="London", label="to")
        with gr.Row():
            st_layer = gr.Slider(0, 11, value=2, step=1, label="layer")
            st_alpha = gr.Slider(0, 30, value=10, step=0.5, label="strength")
            st_max = gr.Slider(8, 80, value=40, step=1, label="max new tokens")
        st_base = gr.Textbox(label="baseline", lines=2)
        st_out = gr.Textbox(label="steered", lines=3)
        gr.Button("Run").click(steer_generate,
                               [st_prompt, st_src, st_tgt, st_layer, st_alpha, st_max],
                               [st_base, st_out])

    # ---- TAB 4 -------------------------------------------------------------
    with gr.Tab("4 · Diff (align by depth)"):
        gr.Markdown("""
### Diff — two models on one prompt, aligned by *relative* depth
**What it does:** runs the logit lens on model A and model B and lines their
layers up by percentage depth (0–100%), so you can compare a 2-layer toy with a
12-layer gpt2 side by side. `dp` is `p_B − p_A` for the target token.

**How to read it:** look at *where* on the depth axis each model commits to the
target. A localised model commits at one depth; a distributed one ramps up.

**Try:** A = `gpt2`, B = `glassbox`, target = `paris`. You'll see gpt2 ramp
through the middle while glassbox snaps on at its fact layer — the same fact,
two very different internal shapes.
""")
        with gr.Row():
            df_a = gr.Textbox(value="gpt2", label="model A")
            df_b = gr.Textbox(value="glassbox", label="model B")
        df_prompt = gr.Textbox(value="the capital of france is", label="prompt")
        df_target = gr.Textbox(value="paris", label="target token")
        df_k = gr.Slider(1, 5, value=1, step=1, label="top-k (display)")
        df_out = gr.Textbox(label="output", lines=16)
        gr.Button("Run").click(diff_models,
                               [df_a, df_b, df_prompt, df_target, df_k], df_out)

    # ---- TAB 5 -------------------------------------------------------------
    with gr.Tab("5 · Causal trace (ROME)"):
        gr.Markdown("""
### Causal trace — corrupt the subject, restore each layer, find the site
**What it does:** activation patching (Meng et al.'s ROME). It noises the
**subject** token, which breaks the prediction, then restores one layer at a
time and measures how much of the answer comes back. The layer that restores
the most is where the fact is *causally* computed.

**How to read it:** `recovery` ≈ 100% means "restoring this layer is enough" →
the fact is read here. The peak line names the site.

**Ground truth to check:**
- `glassbox` — peak at **L0** (≈100%). The fact is read at the early subject site, because the L1 "attention" re-reads the restored subject. You know this is right because you wrote the mechanism.
- `handmade` — `clean p` ≈ `corrupt p`, so recovery is meaningless. **Expected:** the fact is a string match, untouched by activation noise. This is the headline lesson — patching can't see lookup behaviour.
- `gpt2` — a *band* of early–middle layers at the subject token light up, exactly as in the ROME paper.
""")
        ct_prompt = gr.Textbox(value="the capital of france is", label="prompt")
        ct_subject = gr.Textbox(value="france", label="subject to corrupt")
        ct_target = gr.Textbox(value="paris", label="target token")
        with gr.Row():
            ct_noise = gr.Slider(0, 10, value=3, step=0.5, label="noise (x embed std)")
            ct_seed = gr.Slider(0, 100, value=0, step=1, label="seed")
        ct_out = gr.Textbox(label="output", lines=18)
        gr.Button("Run").click(causal_trace,
                               [ct_prompt, ct_subject, ct_target, ct_noise, ct_seed], ct_out)

    # ---- TAB 6 -------------------------------------------------------------
    with gr.Tab("6 · Edit + verify (ROME loop)"):
        gr.Markdown("""
### Edit a fact, then prove nothing else broke
**What it does:** rewrites the value one fact-MLP key maps to (the exact thing
ROME/MEMIT do on real models — this is a literal `nn.Module` weight tensor,
not a token or vocab change), then runs a verification battery over **every**
known fact to measure **efficacy** (target changed), **specificity** (others
untouched), and **fluency** (no entropy collapse).

**Two methods, on purpose:**
- `rank1` — the minimal, surgical update. Only the target fact moves → **SURGICAL**.
- `broadcast` — a deliberately sloppy edit that smears the change across all facts → the harness catches the **COLLATERAL DAMAGE**. This proves the verifier actually works, not just reports "ok" by default.

**Independent LLM review, with a fallback chain — not locked to one vendor:**
tick the box and it tries, in order: **Anthropic** (Claude, if you give a key)
→ **Hugging Face Inference** (any hosted chat model, if you give an HF token)
→ **your own local server** (LM Studio / vLLM / Ollama's OpenAI shim — anything
exposing `/v1/chat/completions`). The first one that's configured *and*
reachable answers; the rest are skipped and noted. So your own RTX 5090 can
be the judge with zero cloud calls if you just fill in the local URL.

Subjects: `france`, `germany`, `japan`. Answers: `paris, berlin, tokyo, london, rome`.
After editing, the model stays edited — go look at it in tabs 1–5 (the logit lens
will show the new answer rising; the trace still localises to L0). Hit **Reset**
to restore. Every run is appended to a session log you can download below and
paste into a future chat for review.
""")
        with gr.Row():
            ed_subj = gr.Textbox(value="france", label="subject")
            ed_new = gr.Textbox(value="london", label="new answer")
            ed_method = gr.Radio(["rank1", "broadcast"], value="rank1", label="method")
            ed_strength = gr.Slider(0.2, 2.0, value=1.0, step=0.1, label="strength")
        ed_llm = gr.Checkbox(value=False, label="also run an independent LLM review")
        with gr.Accordion("LLM review providers (tried in this order)", open=False):
            with gr.Row():
                ed_a_model = gr.Textbox(value="claude-sonnet-4-6", label="1. Anthropic model")
                ed_a_key = gr.Textbox(value="", label="Anthropic API key", type="password")
            with gr.Row():
                ed_h_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct",
                                        label="2. HF Inference model")
                ed_h_key = gr.Textbox(value="", label="HF token", type="password")
            with gr.Row():
                ed_l_url = gr.Textbox(value="http://192.168.188.25:1234",
                                      label="3. Local server URL (LM Studio etc.)")
                ed_l_model = gr.Textbox(value="local-model", label="local model name")
        ed_out = gr.Textbox(label="edit + verification report", lines=24)
        with gr.Row():
            gr.Button("Edit & verify", variant="primary").click(
                edit_and_verify,
                [ed_subj, ed_new, ed_method, ed_strength, ed_llm,
                 ed_a_key, ed_a_model, ed_h_key, ed_h_model, ed_l_url, ed_l_model],
                ed_out)
            gr.Button("Reset model").click(reset_glassbox, outputs=ed_out)
        gr.Markdown("**Session log** (every edit run above, appended):")
        with gr.Row():
            log_btn = gr.Button("Write session log to disk")
            log_file = gr.File(label="download")
        log_status = gr.Markdown()
        log_btn.click(lambda: export_session_log(), outputs=[log_file, log_status])

    # ---- TAB 7 -------------------------------------------------------------
    with gr.Tab("7 · Export / Upload to HF"):
        gr.Markdown("""
### Ship the toy glass-box
**Export** writes a self-contained, reloadable repo: weights (`safetensors`),
`config.json`, `vocab.json`, a standalone `modeling_glassbox.py` (reload with
`from modeling_glassbox import load`), and a model card.

**Upload** pushes it to the Hub. Choose `model`, `space` (this whole app,
runnable), or `both`. Paste a **write** token (or set `HF_TOKEN`).
""")
        with gr.Row():
            hf_repo = gr.Textbox(value="Chris4K/glassbox-interp", label="repo id")
            hf_what = gr.Radio(["model", "space", "both"], value="model", label="what to push")
            hf_token = gr.Textbox(value="", label="HF write token (optional)", type="password")
        hf_out = gr.Textbox(label="result", lines=6)
        with gr.Row():
            gr.Button("Export locally").click(
                lambda: "Exported to ./%s" % export_glassbox(), outputs=hf_out)
            gr.Button("Upload to HF", variant="primary").click(
                upload_to_hf, [hf_repo, hf_token, hf_what], hf_out)

        gr.Markdown("""
---
### Upload a REAL model — e.g. your VINDEX-edited Llama checkpoint
This does **not** load the model into memory and does **not** assume any
particular architecture — it just pushes whatever's already on disk at
`local_dir` (the usual `save_pretrained()` layout: `config.json` +
`*.safetensors` shards + tokenizer files) straight to a new repo. Large
weights upload fine through `upload_folder`; for very large repos consider
installing `hf_transfer` for faster throughput. If the base model is gated
(e.g. `meta-llama/*`), the gate applies to the destination repo's license
settings, not to this upload step.
""")
        with gr.Row():
            rc_dir = gr.Textbox(value="", label="local checkpoint folder (on this machine)")
            rc_repo = gr.Textbox(value="", label="destination repo id, e.g. Chris4K/vindex-llama3-edited")
        with gr.Row():
            rc_token = gr.Textbox(value="", label="HF write token (optional)", type="password")
            rc_private = gr.Checkbox(value=True, label="private repo")
            rc_msg = gr.Textbox(value="upload edited checkpoint", label="commit message")
        rc_out = gr.Textbox(label="result", lines=6)
        gr.Button("Upload real checkpoint", variant="primary").click(
            upload_local_checkpoint, [rc_dir, rc_repo, rc_token, rc_private, rc_msg], rc_out)

    gr.Markdown("""
---
### Where this goes next
- **Closing the loop (what "self-improving" would actually require):** right now a human picks every edit; the verifier just grades it. A real closed loop needs a policy that *proposes* edits on its own (e.g. scanning eval failures for wrong facts), auto-applies, and auto-commits only on a SURGICAL verdict, rolling back otherwise. The hard part — the verifier — already exists here; the proposal step doesn't yet.
- **A training-method angle worth taking seriously:** instead of accept/reject after the fact, feed the specificity battery's drift score back as a regularizer *during* the edit computation (closer to elastic weight consolidation, or the null-space projection AlphaEdit-style methods use) so collateral is penalized while solving, not caught after.
- **Real-model MEMIT:** the edit loop here is exact because the glass-box's fact layer is literally key→value. The same verify harness (efficacy / specificity / fluency + the multi-provider LLM judge) ports straight onto a gpt2/Llama MEMIT edit — the toy is the regression test you run first.
- **Multi-hop & paraphrase generalization:** add `"the currency of france is"` so two relations share a subject, and have the LLM judge auto-generate paraphrase probes to test that an edit generalizes, not just memorizes the one prompt.
- **Attribution view:** Geva-style "what does this neuron write to the vocab", per-head attention attribution.
- **It already ships:** tab 7 pushes the toy model and this whole app (as a Space) to your Hub, or a real local checkpoint folder to its own repo.
""")

    demo.load(lambda: load_model("glassbox"), outputs=load_status)

if __name__ == "__main__":
    demo.launch()