File size: 57,821 Bytes
9ee6b93
9bbf5fa
34f439a
b8a8f34
9bbf5fa
 
 
 
9ee6b93
 
 
 
 
 
 
 
746c56d
 
 
0919b50
746c56d
9ee6b93
 
b8a8f34
 
2920441
 
4c236ce
9bbf5fa
 
4c236ce
9bbf5fa
90b8ab6
 
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0919b50
 
9bbf5fa
 
 
0919b50
9bbf5fa
 
 
9ee6b93
9bbf5fa
 
 
9ee6b93
 
9bbf5fa
 
 
4c236ce
 
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a1dab
9bbf5fa
55a1dab
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a1dab
 
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7a4cee
 
 
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7a4cee
 
 
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90b8ab6
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90b8ab6
 
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746c56d
9bbf5fa
 
 
 
 
 
 
 
b8a8f34
9bbf5fa
 
 
 
 
 
 
 
 
 
 
746c56d
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746c56d
 
9bbf5fa
 
 
 
 
 
746c56d
9bbf5fa
746c56d
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a8f34
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746c56d
9bbf5fa
 
 
 
 
 
b8a8f34
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0919b50
9bbf5fa
 
 
 
 
 
90b8ab6
9bbf5fa
 
 
 
 
90b8ab6
9bbf5fa
 
 
 
 
90b8ab6
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a8f34
746c56d
b8a8f34
9bbf5fa
 
b8a8f34
9bbf5fa
b8a8f34
9bbf5fa
 
b8a8f34
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a8f34
9bbf5fa
 
 
 
 
 
 
 
 
 
 
5612ce6
 
9bbf5fa
 
5612ce6
9bbf5fa
 
 
 
 
 
 
 
5612ce6
9bbf5fa
 
b8a8f34
9bbf5fa
 
 
746c56d
9bbf5fa
 
746c56d
 
 
9bbf5fa
746c56d
9bbf5fa
 
 
746c56d
9bbf5fa
 
 
746c56d
9bbf5fa
 
 
 
746c56d
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0919b50
9bbf5fa
0919b50
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a8f34
5612ce6
 
 
9bbf5fa
 
 
 
 
5612ce6
9bbf5fa
 
 
 
 
 
b8a8f34
9bbf5fa
b8a8f34
9bbf5fa
 
b8a8f34
9bbf5fa
 
b8a8f34
9bbf5fa
b8a8f34
 
 
 
 
 
9ee6b93
9bbf5fa
0919b50
9bbf5fa
 
 
 
9ee6b93
9bbf5fa
 
90b8ab6
9bbf5fa
b8a8f34
5612ce6
9bbf5fa
 
 
5612ce6
9bbf5fa
 
b8a8f34
5612ce6
9bbf5fa
5612ce6
9bbf5fa
5612ce6
 
9bbf5fa
 
 
 
 
 
90b8ab6
9bbf5fa
 
 
 
 
 
 
746c56d
9bbf5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a8f34
 
0919b50
9bbf5fa
 
 
 
 
 
b8a8f34
9ee6b93
 
90b8ab6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
"""
ARC-AGI-3 Agent Spectator v4
Hugging Face Space: beanapologist/arc-agi

Re/Im solver live demo:
  Im side = bird's eye hypothesis (which transformation?)
  Re side = exact diff (which cells to click?)
  Bridge  = ACTION6 at the Re-side coordinates that close the gap
"""

import gradio as gr
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
import torch.nn as nn
import torch.nn.functional as TF
import io, os, time, threading, queue
from collections import deque
from PIL import Image

# ── Palette ───────────────────────────────────────────────────────────────────

ARC_HEX = ['#000000','#1a6faf','#e03a3a','#3aa63a','#f5c400',
           '#c060c0','#d07030','#aaaaaa','#60b8d0','#874010']
ARC_CMAP = ListedColormap(ARC_HEX)
COLOR_NAMES = ['black','blue','red','green','yellow',
               'purple','orange','gray','azure','maroon']

# ── Re/Im primitives ──────────────────────────────────────────────────────────

def _sobel(f):
    p=np.pad(f,1,mode='edge')
    gx=(-p[:-2,:-2]-2*p[1:-1,:-2]-p[2:,:-2]+p[:-2,2:]+2*p[1:-1,2:]+p[2:,2:])/8
    gy=(-p[:-2,:-2]-2*p[:-2,1:-1]-p[:-2,2:]+p[2:,:-2]+2*p[2:,1:-1]+p[2:,2:])/8
    return gx,gy

def _sym_axis(grid,axis):
    H,W=grid.shape; best_s,best_i=0.0,0
    if axis=='h':
        for x in range(1,W-1):
            r=min(x,W-1-x)
            s=(grid[:,x-r:x]==grid[:,x+1:x+r+1][:,::-1]).mean()
            if s>best_s: best_s,best_i=s,x
    else:
        for y in range(1,H-1):
            r=min(y,H-1-y)
            s=(grid[y-r:y,:]==grid[y+1:y+r+1,:][::-1,:]).mean()
            if s>best_s: best_s,best_i=s,y
    return best_i,best_s

def _boundary(grid):
    p=np.pad(grid,1,mode='edge')
    return ((p[1:-1,1:-1]!=p[:-2,1:-1])|(p[1:-1,1:-1]!=p[2:,1:-1])|
            (p[1:-1,1:-1]!=p[1:-1,:-2])|(p[1:-1,1:-1]!=p[1:-1,2:])).astype(np.float32)

def _sym(grid,axis):
    H,W=grid.shape; s=np.zeros((H,W),np.float32)
    if axis=='h':
        for x in range(W):
            r=min(x,W-1-x)
            if r==0: s[:,x]=1.; continue
            s[:,x]=(grid[:,x-r:x]==grid[:,x+1:x+r+1][:,::-1]).mean()
    else:
        for y in range(H):
            r=min(y,H-1-y)
            if r==0: s[y,:]=1.; continue
            s[y,:]=(grid[y-r:y,:]==grid[y+1:y+r+1,:][::-1,:]).mean()
    return s

# ── Re/Im board reader — reads directly from 56-channel feature tensor ─────────
# Im tensor = signal strength | raw grid = coordinates | i· = answer

"""
arc_solver.py — Re/Im analytic solver using the 56-channel feature tensor
=========================================================================

The board IS M = Re(M) + i·Im(M).
The 56-channel extractor already IS log(M) decomposed:

  Re channels (ch 0-47):
    ch 0-15:  one-hot colors           — multiplicative / "what's there"
    ch 16-31: component size maps      — "how many, how big" (radial magnitude)
    ch 32-47: distance-from-color maps — proximity / modulus

  Im channels (ch 48-55):
    ch 48:    H-symmetry map           — fold axis (Re(s)=1/2 analog)
    ch 49:    V-symmetry map           — fold axis (vertical)
    ch 50:    rotational symmetry      — winding structure
    ch 51:    Sobel x (edge horiz)     — gradient Re component
    ch 52:    Sobel y (edge vert)      — gradient Im component  ← arg(M)
    ch 53:    boundary contour         — Cauchy ∮ contour
    ch 54:    curl / winding proxy     — arg(M), rotation direction
    ch 55:    component ID map         — global topological / winding data

The Im channels are not clues — they ARE the answer structure.
i· (column swap) means: read the Im channels to get the Re answer.

  Im H-symmetry strong + Re mass asymmetric  →  mirror to complete
  Im boundary contour + Re solid fill        →  boundary IS the answer
  Im curl/winding strong                     →  rotation is needed
  Im Sobel-y dominates                       →  vertical flow (gravity)
  Im component map has isolated islands      →  count → place

The CNN operates on all 56 channels.
This solver reads the Im channels analytically to short-circuit the CNN
when the Im signal is unambiguous.
"""

import numpy as np
from typing import Optional, Tuple, List


# ── Re/Im channel indices (must match extract_features in my_agent.py) ────────

# Space uses 14-channel extractor (subset of full 56-channel Kaggle version)
#   ch 0-9:   Re — one-hot colors
#   ch 10:    Im — H-symmetry (fold axis)
#   ch 11:    Im — V-symmetry
#   ch 12:    Im — Cauchy boundary contour
#   ch 13:    Im — edge magnitude (Sobel combined)
CH_ONEHOT_START    = 0
CH_CCSIZE_START    = 0   # not in Space extractor — use onehot
CH_DIST_START      = 0   # not in Space extractor — use onehot
CH_H_SYM           = 10  # Im — horizontal symmetry
CH_V_SYM           = 11  # Im — vertical symmetry
CH_ROT_SYM         = 10  # not separate — reuse H-sym
CH_SOBEL_X         = 13  # Im — edge (proxy for gradient)
CH_SOBEL_Y         = 13  # Im — edge (proxy for gradient)
CH_BOUNDARY        = 12  # Im — Cauchy boundary contour
CH_CURL            = 10  # not separate — reuse H-sym
CH_COMPONENT_ID    = 11  # not separate — reuse V-sym

NUM_CHANNELS       = 14  # Space extractor has 14 channels
CONFIDENCE_THRESHOLD = 0.72


# ── Raw grid primitives (for when we don't have the tensor yet) ───────────────

def _boundary_raw(grid):
    p = np.pad(grid, 1, mode='edge')
    return ((p[1:-1,1:-1]!=p[:-2,1:-1])|(p[1:-1,1:-1]!=p[2:,1:-1])|
            (p[1:-1,1:-1]!=p[1:-1,:-2])|(p[1:-1,1:-1]!=p[1:-1,2:])).astype(np.float32)

def _perimeter(grid):
    """Object perimeter — handles solid blocks where _boundary_raw gives 0."""
    H, W = grid.shape
    p = np.zeros((H,W), dtype=np.float32)
    mask = grid > 0
    if not mask.any(): return p
    padded = np.pad(mask.astype(int), 1, constant_values=0)
    for dy,dx in [(-1,0),(1,0),(0,-1),(0,1)]:
        shifted = padded[1+dy:H+1+dy, 1+dx:W+1+dx]
        p[mask & (shifted==0)] = 1
    if p.sum() == 0:  # solid block: use outer ring
        p[0,:] = mask[0,:].astype(float);  p[-1,:] = mask[-1,:].astype(float)
        p[:,0] = mask[:,0].astype(float);  p[:,-1] = mask[:,-1].astype(float)
    return p


# ── Re/Im board reader ────────────────────────────────────────────────────────

def read_board(grid: np.ndarray, features: Optional[np.ndarray] = None):
    """
    Read the board as complex object M = Re(M) + i·Im(M).

    Parameters
    ----------
    grid     : 2D int array, the raw color grid
    features : optional (56, H, W) float array from extract_features()
               If provided, reads Im channels directly (more accurate).
               If None, recomputes Im channels from the raw grid.

    Returns
    -------
    answer     : np.ndarray — derived answer grid
    confidence : float 0-1
    reasoning  : list of strings — one per Im signal that fired
    signal     : str — which Im channel drove the answer
    """
    H, W = grid.shape
    reasoning = []
    answer    = None
    confidence = 0.0
    signal     = 'none'

    # ── Read Im channels ──────────────────────────────────────────────────────
    # Use pre-computed feature tensor if available (same computation the CNN uses)
    if features is not None and features.shape[0] == NUM_CHANNELS:
        feat = features
        # Resize Im maps back to grid space if needed
        fH, fW = feat.shape[1], feat.shape[2]

        def _im(ch):
            """Read Im channel, resize to grid if needed."""
            m = feat[ch]
            if (fH, fW) != (H, W):
                # Simple nearest-neighbor resize
                ry = np.linspace(0, fH-1, H).astype(int)
                rx = np.linspace(0, fW-1, W).astype(int)
                return m[np.ix_(ry, rx)]
            return np.array(m)

        h_sym_map  = _im(CH_H_SYM)
        v_sym_map  = _im(CH_V_SYM)
        rot_map    = _im(CH_ROT_SYM)
        sobel_x    = _im(CH_SOBEL_X)
        sobel_y    = _im(CH_SOBEL_Y)
        bound_map  = _im(CH_BOUNDARY)
        curl_map   = _im(CH_CURL)
        comp_map   = _im(CH_COMPONENT_ID)

        # Re channels: color presence and component sizes
        onehot     = feat[CH_ONEHOT_START:CH_ONEHOT_START+16]
        cc_sizes   = feat[CH_CCSIZE_START:CH_CCSIZE_START+16]

    else:
        # Compute features inline and retry once (no recursion)
        t = extract_features(grid)
        feat_np = t.numpy() if hasattr(t, 'numpy') else np.array(t)
        if feat_np.shape[0] == NUM_CHANNELS:
            return read_board(grid, features=feat_np)
        # Shape mismatch — just compute Im maps directly from grid
        _gx,_gy = _sobel(grid.astype(np.float32)/9)
        h_sym_map = _sym(grid,'h'); v_sym_map = _sym(grid,'v')
        bound_map = _boundary(grid)
        edge_map  = np.sqrt(_gx**2+_gy**2).astype(np.float32)
        # Build minimal feature array
        _oh = np.zeros((10,H,W),np.float32)
        for _c in range(10): _oh[_c]=(grid==_c).astype(np.float32)
        features = np.concatenate([_oh,h_sym_map[np.newaxis],v_sym_map[np.newaxis],
                                    bound_map[np.newaxis],edge_map[np.newaxis]],axis=0)
        return read_board(grid, features=features)

    # ── Im signal 1: H-symmetry (fold axis) ───────────────────────────────────
    # Im ch48 gives the symmetry score — use its MAX as signal strength.
    # But compute the actual best axis in raw grid space (not 64x64 space).
    # This is the Re/Im separation: Im tensor gives the SIGNAL, 
    # Re grid gives the COORDINATES.
    best_hs = float(h_sym_map.max())  # Im: how strong is any fold?

    if best_hs > 0.65:
        # Find best axis in grid space (not feature space)
        h_scores_grid = []
        for x in range(1, W-1):
            r = min(x, W-1-x)
            if r > 0:
                s = (grid[:, x-r:x] == grid[:, x+1:x+r+1][:,::-1]).mean()
                h_scores_grid.append((x, float(s)))
        if h_scores_grid:
            best_hx, best_hs_grid = max(h_scores_grid, key=lambda x: x[1])
            # Use Im tensor strength as upper bound, grid score as actual
            hs = min(best_hs, best_hs_grid) if best_hs_grid > 0.4 else best_hs_grid
            left_px  = int((grid[:, :best_hx] > 0).sum())
            right_px = int((grid[:, best_hx:] > 0).sum())
            total_px = left_px + right_px
            asymmetry = abs(left_px - right_px) / max(total_px, 1)
            if asymmetry > 0.25 and hs > 0.55:
                ans = grid.copy()
                if left_px > right_px:
                    for c in range(best_hx):
                        mir = W - 1 - c
                        if 0 <= mir < W:
                            mask = ans[:, mir] == 0
                            ans[mask, mir] = grid[mask, c]
                else:
                    for c in range(best_hx+1, W):
                        mir = W - 1 - c
                        if 0 <= mir < W:
                            mask = ans[:, mir] == 0
                            ans[mask, mir] = grid[mask, c]
                conf = hs * asymmetry * 0.95
                if conf > confidence:
                    answer, confidence, signal = ans, conf, 'Im:H-sym'
                    reasoning.append(
                        f"Im ch48 signal={best_hs:.2f} grid_score={hs:.2f} at x={best_hx} | "
                        f"Re left={left_px} right={right_px} asym={asymmetry:.2f} | "
                        f"i·: complete the fold")

    # ── Im signal 2: V-symmetry ────────────────────────────────────────────────
    best_vs = float(v_sym_map.max())  # Im: V-fold signal strength

    if best_vs > 0.65 and confidence < 0.6:
        v_scores_grid = []
        for y in range(1, H-1):
            r = min(y, H-1-y)
            if r > 0:
                s = (grid[y-r:y, :] == grid[y+1:y+r+1, :][::-1, :]).mean()
                v_scores_grid.append((y, float(s)))
        if v_scores_grid:
            best_vy, best_vs_grid = max(v_scores_grid, key=lambda x: x[1])
            vs = min(best_vs, best_vs_grid) if best_vs_grid > 0.4 else best_vs_grid
            top_px  = int((grid[:best_vy, :] > 0).sum())
            bot_px  = int((grid[best_vy:, :] > 0).sum())
            total_px = top_px + bot_px
            asymmetry = abs(top_px - bot_px) / max(total_px, 1)
            if asymmetry > 0.25 and vs > 0.55:
                ans = grid.copy()
                if top_px > bot_px:
                    for r in range(best_vy):
                        mir = H - 1 - r
                        if 0 <= mir < H:
                            mask = ans[mir, :] == 0
                            ans[mir, mask] = grid[r, mask]
                else:
                    for r in range(best_vy+1, H):
                        mir = H - 1 - r
                        if 0 <= mir < H:
                            mask = ans[mir, :] == 0
                            ans[mir, mask] = grid[r, mask]
                conf = vs * asymmetry * 0.90
                if conf > confidence:
                    answer, confidence, signal = ans, conf, 'Im:V-sym'
                    reasoning.append(
                        f"Im ch49 signal={best_vs:.2f} grid_score={vs:.2f} at y={best_vy} | "
                        f"Re top={top_px} bot={bot_px} | i·: complete V fold")

    # ── Im signal 3: Cauchy boundary contour ──────────────────────────────────
    # Im ch53 IS the Cauchy contour. When Re fill is solid (high density)
    # and Im boundary is thin, the answer IS the boundary.
    # "The radius cancels (Re collapses); only the angular part i·dθ survives"
    total_px = int((grid > 0).sum())
    if total_px > 0 and confidence < 0.75:
        fill_ratio = total_px / (H * W)
        colors = [c for c in range(1,10) if (grid==c).any()]

        if len(colors) == 1 and fill_ratio > 0.4:
            # Single color, high fill → Im boundary IS the answer
            perim = _perimeter(grid)
            # Check interior isn't already hollow
            interior = (grid == 0) & (perim == 0)
            if not interior.any():
                ans = np.zeros_like(grid)
                ans[perim > 0] = grid[perim > 0]
                conf = fill_ratio * 0.90
                if conf > confidence:
                    answer, confidence, signal = ans, conf, 'Im:boundary'
                    reasoning.append(
                        f"Im ch53 Cauchy contour | Re fill={fill_ratio:.2f} "
                        f"single color={colors[0]} | "
                        f"i·: Im boundary = Re answer (Cauchy: radius cancels)")

        elif fill_ratio > 0.3 and bound_map.mean() < 0.15:
            # Multi-color but low boundary density → extract boundary
            ans = np.zeros_like(grid)
            ans[bound_map > 0.3] = grid[bound_map > 0.3]
            if (ans > 0).any():
                conf = fill_ratio * (1 - float(bound_map.mean())) * 0.75
                if conf > confidence:
                    answer, confidence, signal = ans, conf, 'Im:boundary'
                    reasoning.append(
                        f"Im ch53 boundary density={bound_map.mean():.3f} | "
                        f"Re fill={fill_ratio:.2f} | i·: Cauchy contour")

    # ── Im signal 4: Interior fill ────────────────────────────────────────────
    # Im ch53: If colored region fully encloses empty cells → fill interior.
    # Use flood-fill from boundary to find truly enclosed (unreachable) cells.
    if confidence < 0.55 and total_px > 0:
        # Flood-fill empty cells reachable from grid boundary
        reachable = np.zeros((H, W), dtype=bool)
        fq = []
        for _r in range(H):
            for _c in range(W):
                if grid[_r,_c]==0 and (_r==0 or _r==H-1 or _c==0 or _c==W-1):
                    if not reachable[_r,_c]:
                        reachable[_r,_c]=True; fq.append((_r,_c))
        while fq:
            _y,_x=fq.pop()
            for _dy,_dx in [(-1,0),(1,0),(0,-1),(0,1)]:
                _ny,_nx=_y+_dy,_x+_dx
                if 0<=_ny<H and 0<=_nx<W and grid[_ny,_nx]==0 and not reachable[_ny,_nx]:
                    reachable[_ny,_nx]=True; fq.append((_ny,_nx))
        truly_interior = (grid == 0) & ~reachable
        if truly_interior.any():
            dominant = int(np.argmax(np.bincount(
                grid[grid>0].flatten(), minlength=10)[1:])) + 1
            ans = grid.copy()
            ans[truly_interior] = dominant
            conf = truly_interior.sum() / max(1, (grid==0).sum()) * 0.80
            if conf > confidence:
                answer, confidence, signal = ans, conf, 'Im:hollow→fill'
                reasoning.append(
                    f"Im ch53 enclosed interior={int(truly_interior.sum())}px | "
                    f"Re dominant={dominant} | i·: fill enclosed interior")

    # ── Im signal 5: Gradient flow → gravity ──────────────────────────────────
    # Im ch51/52 = Sobel x/y = gradient field direction = arg(M)
    # Suspended Re pixels + Im gradient direction → gravity answer
    suspended = sum(
        1 for c in range(W)
        for r in np.where(grid[:, c] > 0)[0]
        if r < H-1 and grid[r+1, c] == 0
    )
    if suspended > 0 and confidence < 0.70:
        gx_mag = float(np.abs(sobel_x).mean())
        gy_mag = float(np.abs(sobel_y).mean())
        direction = 'down' if gy_mag >= gx_mag else 'right'
        ans = np.zeros_like(grid)
        if direction == 'down':
            for c in range(W):
                vals = grid[:, c][grid[:, c] > 0]
                if len(vals): ans[H-len(vals):H, c] = vals
        else:
            for r in range(H):
                vals = grid[r, :][grid[r, :] > 0]
                if len(vals): ans[r, W-len(vals):W] = vals
        conf = min(0.80, suspended / max(total_px, 1) * 2.5)
        if conf > confidence:
            answer, confidence, signal = ans, conf, 'Im:Sobel→gravity'
            reasoning.append(
                f"Im ch52 Sobel-y={gy_mag:.3f} ch51 Sobel-x={gx_mag:.3f} | "
                f"Re suspended={suspended}px | i·: arg(M) gives gravity {direction}")

    # ── Im signal 6: Rotational symmetry / curl ────────────────────────────────
    # Im ch50 rot_sym + ch54 curl → rotation answer
    rot_score = float(rot_map.max())
    curl_max  = float(np.abs(curl_map).max())
    if rot_score > 0.6 and curl_max > 0.4 and confidence < 0.35:
        ans = np.rot90(grid)
        conf = rot_score * curl_max * 0.60
        if conf > confidence:
            answer, confidence, signal = ans, conf, 'Im:rot+curl'
            reasoning.append(
                f"Im ch50 rot={rot_score:.2f} ch54 curl={curl_max:.2f} | "
                f"i·: rotation indicated")

    # ── Im signal 7: Color remapping (Re→Im count ordering) ───────────────────
    # Im component map (ch55) encodes which regions are distinct objects.
    # If Re colors appear in counts that suggest an ordering → shift colors.
    if confidence < 0.30 and total_px > 0:
        colors = [c for c in range(1,10) if (grid==c).any()]
        if colors and max(colors) < 9:
            ans = grid.copy()
            mask = grid > 0
            ans[mask] = ((grid[mask] - 1 + 1) % 9) + 1
            if conf > confidence:
                answer, confidence, signal = ans, 0.30, 'Im:color_shift'
                reasoning.append(
                    f"Im ch55 component topology | Re colors {colors} | "
                    f"i·: Re→Im shift = increment colors")

    return answer, confidence, reasoning, signal


# ── Re-side: exact cell targeting ────────────────────────────────────────────

def pixel_diff(current: np.ndarray, target: np.ndarray):
    """All differing cells: [(r, c, target_color)]"""
    if current.shape != target.shape: return []
    return [(int(r), int(c), int(target[r,c]))
            for r in range(current.shape[0])
            for c in range(current.shape[1])
            if current[r,c] != target[r,c]]

def most_urgent_diff(current: np.ndarray, target: np.ndarray):
    """
    Im → Re: pick the most important cell using the Cauchy principle.
    The boundary contour determines the interior, so fix boundary cells first.
    This is ∮ doing its job: read global Im data, recover local Re data.
    """
    diffs = pixel_diff(current, target)
    if not diffs: return None
    bound = _boundary_raw(current)
    boundary_diffs = [(r,c,v) for r,c,v in diffs if bound[r,c] > 0]
    pool = boundary_diffs if boundary_diffs else diffs
    return pool[np.random.randint(len(pool))]


# ── Main entry point ─────────────────────────────────────────────────────────

def try_analytic_action(
    frame_2d:  np.ndarray,
    available_actions,
    features:  Optional[np.ndarray] = None,
) -> Tuple[Optional[int], Optional[dict], str, float]:
    """
    Read the board's Im channels to derive the answer, then use Re-side
    pixel diff to find the exact cell to click.

    Parameters
    ----------
    frame_2d          : raw 2D color grid
    available_actions : list of available GameAction values
    features          : optional pre-computed (56,H,W) feature array
                        (pass this from MyAgent to avoid recomputing)

    Returns (action_id, action_data, signal_name, confidence)
    """
    if frame_2d is None: return None, None, 'none', 0.0

    avail_ids = set(
        int(a.value if hasattr(a,'value') else a)
        for a in (available_actions or range(1,7))
    )

    answer, confidence, reasoning, signal = read_board(frame_2d, features)

    if answer is None or confidence < CONFIDENCE_THRESHOLD:
        return None, None, signal, confidence

    diffs = pixel_diff(frame_2d, answer)
    if not diffs:
        return None, None, 'already_matches', confidence

    # ACTION6: click the most urgent Re-side cell
    if 6 in avail_ids:
        cell = most_urgent_diff(frame_2d, answer)
        if cell is not None:
            r, c, _ = cell
            H, W = frame_2d.shape
            game_y = min(63, max(0, int(r * 64/H + 32/H)))
            game_x = min(63, max(0, int(c * 64/W + 32/W)))
            return 6, {'x': game_x, 'y': game_y}, signal, confidence

    return None, None, 'no_action6', confidence




# ── Feature extractor ─────────────────────────────────────────────────────────

def extract_features(grid,num_colours=10):
    H,W=grid.shape
    oh=np.zeros((num_colours,H,W),np.float32)
    for c in range(num_colours): oh[c]=(grid==c).astype(np.float32)
    gx,gy=_sobel(grid.astype(np.float32)/9)
    stacked=np.concatenate([oh,_sym(grid,'h')[np.newaxis],
                             _sym(grid,'v')[np.newaxis],
                             _boundary(grid)[np.newaxis],
                             np.sqrt(gx**2+gy**2)[np.newaxis].astype(np.float32)],axis=0)
    t=torch.from_numpy(stacked).float().unsqueeze(0)
    if H!=64 or W!=64:
        t=TF.interpolate(t,size=(64,64),mode='bilinear',align_corners=False)
    return t.squeeze(0)

# ── Gabor filter bank — s-plane cross terms ─────────────────────────────────

"""
gabor_channels.py
=================
Gabor filter bank for ARC-AGI-3 — the s-plane cross terms.

Mathematical position
---------------------
The existing 56-channel extractor covers the AXES of the s-plane:
  Re axis (σ>0, ω=0):  ch16-47  CC sizes, distance maps     — Laplace side
  Im axis (σ=0, ω>0):  ch48-55  symmetry, Sobel, boundary   — Fourier side

A Gabor filter lives at an INTERIOR point (σ>0, ω>0):
  g(x,y) = exp(-sigma*r^2) · cos(ω·x_θ + φ)
            ______/   ___________/
            Re/Laplace   Im/Fourier
            envelope     carrier

This is exp(-st) evaluated at s = σ + iω, rotated by θ, phased by φ.
It measures: "is there oscillation at frequency ω in direction θ,
              concentrated within decay radius 1/√σ?"

ARC relevance
-------------
The cross terms detect structures the axis channels miss:
  - Repeating patterns with finite extent (tiling with boundary)
  - Oriented edges at specific spatial scales
  - Localized symmetry (symmetric patch inside asymmetric grid)
  - Diagonal structure (axis channels are H/V only)

Channel layout (72 channels total)
-----------------------------------
3 σ values × 3 ω values × 4 orientations × 2 phases = 72

  σ = 0.3  → tight decay, radius ~1.8px  — local structure
  σ = 1.0  → medium decay, radius ~1.0px — mid-scale
  σ = 2.5  → broad decay, radius ~0.6px  — global texture

  ω = 0.5  → coarse frequency, period ~12px — large patterns
  ω = 1.5  → medium frequency, period ~4px  — medium patterns
  ω = 3.0  → fine frequency, period ~2px    — fine detail

  θ = 0, π/4, π/2, 3π/4 — 4 orientations (H, diagonal, V, anti-diagonal)

  φ = 0    → cosine (even symmetry, detects symmetric features)
  φ = π/2  → sine   (odd symmetry, detects antisymmetric/edge features)
"""

import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Tuple


# ── Filter bank parameters ────────────────────────────────────────────────────

SIGMA_VALUES = [0.3, 1.0, 2.5]                         # Re/Laplace decay
OMEGA_VALUES = [0.5, 1.5, 3.0]                         # Im/Fourier frequency
THETA_VALUES = [0, np.pi/4, np.pi/2, 3*np.pi/4]       # orientations
PHASE_VALUES = [0.0, np.pi/2]                          # cosine, sine

N_GABOR_CHANNELS = (len(SIGMA_VALUES) * len(OMEGA_VALUES) *
                    len(THETA_VALUES) * len(PHASE_VALUES))  # = 72

KERNEL_SIZE = 7   # 7×7 kernels — enough for 64×64 grid


def _make_gabor_kernel(sigma: float, omega: float,
                        theta: float, phi: float,
                        size: int = KERNEL_SIZE) -> np.ndarray:
    """
    2D Gabor kernel at one point in the s-plane.

    s = σ + iω  (the seam: Re side = decay, Im side = oscillation)
    θ = orientation, φ = phase
    """
    half = size // 2
    y, x = np.mgrid[-half:half+1, -half:half+1].astype(np.float32)

    # Rotate to orientation θ
    x_rot =  x * np.cos(theta) + y * np.sin(theta)
    y_rot = -x * np.sin(theta) + y * np.cos(theta)

    # Re side: Gaussian envelope — exp(-sigma*r^2)
    envelope = np.exp(-sigma * (x_rot**2 + y_rot**2))

    # Im side: sinusoidal carrier — cos(ω·x_θ + φ)
    carrier = np.cos(omega * x_rot + phi)

    # s-plane cross term: exp(-sigma*r^2) · cos(ω·x_θ + φ)
    kernel = envelope * carrier

    # Zero-mean (remove DC) — ensures kernel responds to structure, not brightness
    kernel -= kernel.mean()
    norm = np.sqrt((kernel ** 2).sum())
    if norm > 0:
        kernel /= norm

    return kernel  # shape (size, size)


def build_gabor_bank() -> torch.Tensor:
    """
    Build the full Gabor filter bank as a (72, 1, K, K) tensor
    ready for torch.nn.functional.conv2d.
    """
    kernels = []
    for sigma in SIGMA_VALUES:
        for omega in OMEGA_VALUES:
            for theta in THETA_VALUES:
                for phi in PHASE_VALUES:
                    k = _make_gabor_kernel(sigma, omega, theta, phi)
                    kernels.append(k)
    bank = np.stack(kernels, axis=0)          # (72, K, K)
    return torch.from_numpy(bank).float().unsqueeze(1)  # (72, 1, K, K)


# Pre-built bank — computed once at import time
_GABOR_BANK: torch.Tensor = build_gabor_bank()


def extract_gabor_features(grid_2d: np.ndarray,
                            grid_size: int = 64) -> torch.Tensor:
    """
    Apply the Gabor bank to a 2D color grid.

    Parameters
    ----------
    grid_2d   : np.ndarray (H, W) int — raw color grid
    grid_size : int — target output size (default 64, matching ActionModel)

    Returns
    -------
    torch.Tensor (72, grid_size, grid_size) float32
    """
    H, W = grid_2d.shape

    # Normalize grid to [0, 1] float
    grid_f = torch.from_numpy(
        grid_2d.astype(np.float32) / 9.0
    ).unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)

    # Resize to grid_size if needed
    if H != grid_size or W != grid_size:
        grid_f = F.interpolate(grid_f, size=(grid_size, grid_size),
                               mode='bilinear', align_corners=False)

    # Apply all 72 Gabor filters simultaneously
    pad = KERNEL_SIZE // 2
    responses = F.conv2d(grid_f, _GABOR_BANK, padding=pad)  # (1, 72, H, W)
    responses = responses.squeeze(0)  # (72, grid_size, grid_size)

    # Normalize responses to [-1, 1]
    max_val = responses.abs().max()
    if max_val > 0:
        responses = responses / max_val

    return responses


def channel_descriptions() -> List[str]:
    """Human-readable description of each Gabor channel."""
    descs = []
    ch = 0
    for sigma in SIGMA_VALUES:
        for omega in OMEGA_VALUES:
            for theta in THETA_VALUES:
                theta_deg = int(theta * 180 / np.pi)
                for phi in PHASE_VALUES:
                    phase_name = 'cos' if phi == 0 else 'sin'
                    descs.append(
                        f"ch{56+ch:3d}: Gabor σ={sigma:.1f} ω={omega:.1f} "
                        f"θ={theta_deg}° φ={phase_name}  "
                        f"[s={sigma:.1f}+{omega:.1f}i]"
                    )
                    ch += 1
    return descs


# ── S-plane visualization ─────────────────────────────────────────────────────

def splane_coverage_report():
    """Print the s-plane coverage table."""
    print("s-plane coverage: σ (Re/Laplace) × ω (Im/Fourier)")
    print("="*55)
    print(f"{'':8}", end="")
    for o in OMEGA_VALUES:
        print(f"  ω={o:.1f}", end="")
    print()
    for s in SIGMA_VALUES:
        print(f"σ={s:.1f}  ", end="")
        for o in OMEGA_VALUES:
            n = len(THETA_VALUES) * len(PHASE_VALUES)
            print(f"  {n:3d}ch", end="")
        print(f"  (×{len(THETA_VALUES)}θ ×{len(PHASE_VALUES)}φ)")
    print()
    print(f"Total Gabor channels: {N_GABOR_CHANNELS}")
    print(f"Existing axis channels: 56")
    print(f"Combined total: {56 + N_GABOR_CHANNELS}")




# ── Rendering ─────────────────────────────────────────────────────────────────


def _pil(fig):
    buf=io.BytesIO()
    fig.savefig(buf,format='png',dpi=80,bbox_inches='tight',
                facecolor=fig.get_facecolor())
    buf.seek(0); img=Image.open(buf).copy(); plt.close(fig)
    return img

def render_grid(grid,title='',highlight=None,mark_cell=None):
    if grid is None: return None
    H,W=grid.shape; cell=max(28,min(56,360//max(H,W)))
    fig,ax=plt.subplots(figsize=((W*cell+4)/72,(H*cell+22)/72),dpi=72)
    fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
    ax.imshow(grid,cmap=ARC_CMAP,vmin=0,vmax=9,interpolation='nearest',aspect='equal')
    for x in range(W+1): ax.axvline(x-.5,color='#444',lw=.5)
    for y in range(H+1): ax.axhline(y-.5,color='#444',lw=.5)
    for r in range(H):
        for c in range(W):
            v=int(grid[r,c])
            col='white' if v in [0,1,2,3,5,6,9] else 'black'
            ax.text(c,r,str(v),ha='center',va='center',
                    fontsize=max(7,cell//5),color=col,
                    fontweight='bold',fontfamily='monospace')
    if highlight is not None:
        for r,c,_ in highlight:
            ax.add_patch(plt.Rectangle((c-.5,r-.5),1,1,
                fill=True,facecolor='#ff4444',alpha=0.35,lw=0))
    if mark_cell is not None:
        r,c,_=mark_cell
        ax.add_patch(plt.Rectangle((c-.5,r-.5),1,1,
            fill=False,edgecolor='#00ffff',lw=2.5))
        ax.plot(c,r,'*',color='#00ffff',markersize=max(8,cell//4))
    ax.set_xlim(-.5,W-.5); ax.set_ylim(H-.5,-.5); ax.axis('off')
    if title: ax.set_title(title,color='#cdd6f4',fontsize=9,pad=4)
    plt.tight_layout(pad=.3)
    return _pil(fig)

def render_hypothesis_panel(candidates):
    """Im side: bar chart of top hypotheses with confidence."""
    if not candidates: return None
    top=candidates[:6]
    names=[c[0] for c in top]; confs=[c[2] for c in top]
    fig,ax=plt.subplots(figsize=(5,2.2))
    fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
    colors=['#ffd700' if i==0 else '#4a9eff' for i in range(len(top))]
    bars=ax.barh(names[::-1],confs[::-1],color=colors[::-1],height=0.6)
    for bar,conf in zip(bars,confs[::-1]):
        ax.text(bar.get_width()+.01,bar.get_y()+bar.get_height()/2,
                f'{conf:.2f}',va='center',color='white',fontsize=8)
    ax.set_xlim(0,1.15); ax.axvline(0.4,color='#ff6666',lw=1,ls='--',alpha=0.7)
    ax.text(0.41,0,'threshold',color='#ff6666',fontsize=7,va='bottom')
    ax.tick_params(colors='#888',labelsize=8); ax.spines[:].set_visible(False)
    ax.set_title('Im side — hypothesis ranking  🟡=selected',
                 color='#cdd6f4',fontsize=9,pad=3)
    plt.tight_layout(pad=.4)
    return _pil(fig)

def render_action_bar(action_counts,total):
    if not action_counts or total==0: return None
    labels=[f"A{k}" for k in sorted(action_counts)]
    vals  =[action_counts[k] for k in sorted(action_counts)]
    pcts  =[v/total*100 for v in vals]
    fig,ax=plt.subplots(figsize=(4,1.6))
    fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
    colors=['#4a9eff','#e05050','#50c050','#f5c400','#c060c0','#d07030']
    bars=ax.barh(labels,pcts,color=colors[:len(labels)],height=0.6)
    for bar,v,p in zip(bars,vals,pcts):
        ax.text(min(p+1,98),bar.get_y()+bar.get_height()/2,
                f'{v}',va='center',color='white',fontsize=8)
    ax.set_xlim(0,110); ax.tick_params(colors='#888',labelsize=8)
    ax.spines[:].set_visible(False)
    ax.set_title('Action frequency',color='#cdd6f4',fontsize=9,pad=3)
    plt.tight_layout(pad=.4)
    return _pil(fig)

def render_reward_chart(reward_history):
    if len(reward_history)<2: return None
    fig,ax=plt.subplots(figsize=(5,1.6))
    fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
    for i,r in enumerate(reward_history):
        col='#ffd700' if r>=5 else ('#50c050' if r>0 else '#e05050')
        ax.bar(i,r,color=col,width=1,alpha=0.8)
    ax.axhline(0,color='#555',lw=0.5)
    ax.set_xlim(0,len(reward_history))
    ax.tick_params(colors='#888',labelsize=7); ax.spines[:].set_visible(False)
    ax.set_title('Reward  🟡=level-up  🟢=change  🔴=dead',
                 color='#cdd6f4',fontsize=8,pad=3)
    plt.tight_layout(pad=.3)
    return _pil(fig)

def render_gabor_panel(grid):
    """
    Visualize the top-responding Gabor channels — the s-plane cross terms.
    Shows which (σ, ω, θ) combination is most active on the current frame.
    """
    if grid is None: return None
    feats = extract_gabor_features(grid)  # (72, 64, 64)
    max_per_ch = feats.abs().amax(dim=(1,2)).numpy()
    top4_idx = max_per_ch.argsort()[-4:][::-1]

    fig, axes = plt.subplots(1, 4, figsize=(10, 2.2))
    fig.patch.set_facecolor('#1e1e2e')

    # Build channel labels
    labels = []
    for s in SIGMA_VALUES:
        for o in OMEGA_VALUES:
            for t in THETA_VALUES:
                for p in PHASE_VALUES:
                    td = int(t*180/np.pi)
                    pn = 'cos' if p==0 else 'sin'
                    labels.append(f's={s:.1f}+{o:.1f}i th={td} {pn}')

    for ax, idx in zip(axes, top4_idx):
        ax.set_facecolor('#0d0d1a')
        ch_map = feats[idx].numpy()
        im = ax.imshow(ch_map, cmap='RdBu', vmin=-1, vmax=1,
                       interpolation='nearest', aspect='equal')
        ax.set_title(labels[idx], color='#cdd6f4', fontsize=7, pad=2)
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=.06, pad=.02)

    fig.suptitle('Gabor s-plane responses (top 4 cross-term channels)',
                 color='#cdd6f4', fontsize=9, y=1.02)
    plt.tight_layout(pad=0.5)
    return _pil(fig)


# ── TinyAgent with Re/Im solver ───────────────────────────────────────────────

CONF_THRESHOLD = 0.72  # high bar — only act analytically when very sure

class TinyAgent:
    def __init__(self):
        self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model=self._make_model().to(self.device)
        self.opt=torch.optim.Adam(self.model.parameters(),lr=1e-4)
        self.buf=[]; self.prev_feat=None; self.prev_action=None
        self.step_count=0; self.action_counts={}; self.prev_levels=0
        self.reward_history=deque(maxlen=300)
        self.level_history=[]; self.prev_state=None
        self.level_up_reward=10.0; self.win_reward=50.0
        self.near_win_reward=2.0;  self.change_reward=0.1
        self.dead_penalty=-0.01;   self.candidate_win_reward=30.0
        self.prev_candidate_dist=1.0
        self.explore_steps=20
        self.click_attempts=0; self.click_successes=0
        self._last_was_click=False

    def _make_model(self):
        return nn.Sequential(
            nn.Conv2d(14+N_GABOR_CHANNELS,32,3,padding=1),nn.ReLU(),
            nn.Conv2d(32,64,3,padding=1),nn.ReLU(),
            nn.Conv2d(64,128,3,padding=1),nn.ReLU(),
            nn.AdaptiveAvgPool2d(8),nn.Flatten(),
            nn.Linear(128*8*8,256),nn.ReLU(),
            nn.Linear(256,6),
        )

    def reset(self):
        self.model=self._make_model().to(self.device)
        self.opt=torch.optim.Adam(self.model.parameters(),lr=1e-4)
        self.buf=[]; self.prev_feat=None; self.prev_action=None
        self.step_count=0; self.action_counts={}; self.prev_levels=0
        self.reward_history=deque(maxlen=300); self.level_history=[]
        self.prev_state=None; self.prev_candidate_dist=1.0
        self.explore_steps=20; self.click_attempts=0; self.click_successes=0

    def choose(self,grid,available_actions=None,levels=0,state=None):
        feat=extract_features(grid).to(self.device)
        cur_state=str(state) if state else None

        # ── Re/Im: read the board using the 56-channel feature tensor ──────
        # Pass feat directly so read_board reads Im channels from the same
        # tensor the CNN uses — Im tensor=signal, raw grid=coordinates
        # Concatenate axis + Gabor cross-term channels
        gabor_feat = extract_gabor_features(grid, grid_size=64)
        full_feat  = torch.cat([feat, gabor_feat.to(self.device)], dim=0)
        feat_np    = full_feat.cpu().numpy()
        cand_answer,cand_conf,cand_reasoning,cand_signal=read_board(grid, feat_np)

        # Candidate proximity bonus
        if cand_answer is not None and cand_conf>=0.35:
            curr_dist=(grid!=cand_answer).mean() if grid.shape==cand_answer.shape else 1.0
            if curr_dist==0.0:
                cand_bonus=self.candidate_win_reward
            elif curr_dist<self.prev_candidate_dist:
                cand_bonus=(self.prev_candidate_dist-curr_dist)*5.0
            else:
                cand_bonus=0.0
            self.prev_candidate_dist=curr_dist
        else:
            cand_bonus=0.0; cand_answer=None

        # Store shaped experience
        if self.prev_feat is not None:
            changed=not np.array_equal(
                self.prev_feat.cpu().numpy(),full_feat.cpu().numpy())
            if self._last_was_click:
                self.click_attempts+=1
                if changed: self.click_successes+=1
            just_won=(cur_state=='WIN' and self.prev_state!='WIN')
            level_up=levels>self.prev_levels
            if just_won:
                reward=self.win_reward+cand_bonus
                for i in range(min(5,len(self.buf))):
                    idx=len(self.buf)-1-i
                    self.buf[idx]=(self.buf[idx][0],self.buf[idx][1],
                                   self.buf[idx][2]+self.near_win_reward*(1-i*0.15))
            elif level_up:
                reward=self.level_up_reward+cand_bonus
                self.level_history.append((self.step_count,levels))
            elif changed:
                reward=self.change_reward+cand_bonus
            else:
                reward=self.dead_penalty+cand_bonus
            self.reward_history.append(reward)
            self.buf.append((self.prev_feat,self.prev_action,reward))  # prev_feat is full_feat
            if len(self.buf)>500: self.buf.pop(0)
            self.prev_state=cur_state
        self.prev_levels=levels

        if self.step_count%10==0 and len(self.buf)>=16:
            self._train()

        # ── Im → Re bridge: read board → derive answer → click exact cell ──
        # Only fire after explore_steps of CNN exploration so we have context
        analytic_action=None; analytic_meta={}
        click_rate=(self.click_successes/self.click_attempts
                    if self.click_attempts>10 else 1.0)
        if (cand_answer is not None
                and cand_conf>=CONF_THRESHOLD
                and self.step_count>self.explore_steps
                and click_rate>=0.20):
            diffs=pixel_diff(grid,cand_answer)
            if diffs:
                cell=most_urgent_diff(grid,cand_answer)
                if cell is not None:
                    r,c,tgt_color=cell
                    H,W=grid.shape
                    gy=min(63,max(0,int(r*64/H+32/H)))
                    gx=min(63,max(0,int(c*64/W+32/W)))
                    analytic_action=6
                    reasoning_str=' | '.join(cand_reasoning[:2]) if cand_reasoning else 'read_board'
                    analytic_meta={'x':gx,'y':gy,'cell':(r,c,tgt_color),
                                   'hypothesis':reasoning_str,'conf':cand_conf,
                                   'n_diffs':len(diffs),
                                   'candidates':[(reasoning_str,cand_answer,cand_conf)]}

        # ── CNN fallback ──────────────────────────────────────────────────
        with torch.no_grad():
            logits=self.model(full_feat.unsqueeze(0)).squeeze(0)
            avail=list(range(1,7))
            if available_actions:
                avail=[int(a.value if hasattr(a,'value') else a)
                       for a in available_actions if
                       int(a.value if hasattr(a,'value') else a)<=6]
            indices=[m-1 for m in avail if 1<=m<=6]
            masked=torch.full((6,),float('-inf'))
            for i in indices: masked[i]=logits[i]
            probs=torch.softmax(masked,dim=0).cpu().numpy()
            probs=np.nan_to_num(probs,nan=0)
            if probs.sum()==0: probs[np.array(indices)]=1/len(indices)
            probs=probs/probs.sum()
            cnn_action_idx=np.random.choice(6,p=probs)

        # Pick final action
        if analytic_action is not None:
            chosen_id=analytic_action
            meta=analytic_meta
            meta['source']='analytic'
        else:
            chosen_id=cnn_action_idx+1
            # Package read_board result for display even in CNN fallback
            cnn_cands=[(cand_signal, cand_answer, cand_conf)] if cand_answer is not None else []
            meta={'source':'cnn','probs':probs.tolist(),'candidates':cnn_cands,
                  'hypothesis':cand_reasoning[0][:60] if cand_reasoning else 'none',
                  'conf':cand_conf,'n_diffs':len(pixel_diff(grid,cand_answer)) if cand_answer is not None else 0}

        self.prev_feat=full_feat; self.prev_action=cnn_action_idx
        self._last_was_click=(chosen_id==6)
        self.step_count+=1
        a_id=chosen_id
        self.action_counts[a_id]=self.action_counts.get(a_id,0)+1

        try:
            from arcengine import GameAction
            action=GameAction(a_id)
        except Exception:
            action=a_id

        if a_id==6 and 'x' in meta:
            try: action.set_data({'x':meta['x'],'y':meta['y']})
            except: pass

        return action,meta

    def _train(self):
        import random
        batch=random.sample(self.buf,min(16,len(self.buf)))
        states =torch.stack([b[0] for b in batch]).to(self.device)
        actions=torch.tensor([b[1] for b in batch],dtype=torch.long,  device=self.device)
        rewards=torch.tensor([b[2] for b in batch],dtype=torch.float32,device=self.device)
        self.opt.zero_grad()
        logits=self.model(states)
        loss=TF.binary_cross_entropy_with_logits(
            logits.gather(1,actions.unsqueeze(1)).squeeze(1),
            torch.clamp(rewards,0,1))
        loss.backward(); self.opt.step()

# ── Session ───────────────────────────────────────────────────────────────────

_agent      = TinyAgent()
_stop_flag  = threading.Event()
_run_thread = None
_frame_queue= queue.Queue(maxsize=60)

def _run_agent(game_id,api_key,max_steps):
    import arc_agi
    try:
        arc=arc_agi.Arcade(arc_api_key=api_key)
        env=arc.make(game_id,include_frame_data=True)
        frame=env.reset(); _agent.reset()
        prev_grid=None; step=0
        while not _stop_flag.is_set() and step<max_steps:
            if frame is None: break
            raw=np.array(frame.frame,dtype=np.int64)
            grid=raw[-1] if raw.ndim==3 else raw
            avail=getattr(frame,'available_actions',None)
            levels=getattr(frame,'levels_completed',0)
            state=getattr(frame,'state',None)
            action,meta=_agent.choose(grid,avail,levels=levels,state=state)
            diff=(grid!=prev_grid) if prev_grid is not None else None
            prev_grid=grid.copy()
            _frame_queue.put({
                'grid':grid,'diff':diff,'step':step,
                'action':int(action.value if hasattr(action,'value') else action),
                'levels':levels,'state':str(state),
                'meta':meta,
                'counts':dict(_agent.action_counts),'click_rate':round(_agent.click_successes/_agent.click_attempts,2) if _agent.click_attempts>0 else None,
                'reward_history':list(_agent.reward_history),
                'grid_raw':grid.tolist(),
                'level_history':list(_agent.level_history),
            },block=True,timeout=5)
            state_str=str(state)
            if 'WIN' in state_str or 'GAME_OVER' in state_str: break
            try:
                from arcengine import GameAction as GA
                a_int=int(action.value if hasattr(action,'value') else action)
                sa=GA(a_int)
                if meta.get('x') is not None:
                    try: sa.set_data({'x':int(meta['x']),'y':int(meta['y'])})
                    except: pass
                frame=env.step(sa)
            except Exception as step_err:
                # Last resort: try passing action directly
                try: frame=env.step(action)
                except: frame=None
            step+=1
            time.sleep(0.08)
        _frame_queue.put({'done':True,'step':step,
                          'level_history':list(_agent.level_history)})
    except Exception as e:
        _frame_queue.put({'error':str(e)})

# ── Pull frame ────────────────────────────────────────────────────────────────

_latest={'grid_img':None,'hyp_img':None,'cand_img':None,
         'bar_img':None,'reward_img':None,'status':'*Waiting...*'}

def pull_frame():
    global _latest
    data=None
    while True:
        try: data=_frame_queue.get_nowait()
        except queue.Empty: break

    if data is None:
        return (_latest['grid_img'],_latest['hyp_img'],_latest['cand_img'],
                _latest['bar_img'],_latest['gabor_img'],_latest['reward_img'],_latest['status'])

    if 'error' in data:
        _latest['status']=f"**Error:** {data['error']}"
        return (_latest['grid_img'],_latest['hyp_img'],_latest['cand_img'],
                _latest['bar_img'],_latest['gabor_img'],_latest['reward_img'],_latest['status'])

    if data.get('done'):
        lh=data.get('level_history',[])
        _latest['status']=f"**Done** — {data['step']} steps | {len(lh)} levels completed"
        return (_latest['grid_img'],_latest['hyp_img'],_latest['cand_img'],
                _latest['bar_img'],_latest['gabor_img'],_latest['reward_img'],_latest['status'])

    grid=data['grid']; meta=data['meta']; step=data['step']
    levels=data['levels']; state=data['state']; action=data['action']
    source=meta.get('source','cnn')
    hyp_str=meta.get('hypothesis','none')
    cand_conf=meta.get('conf',0.0)
    n_diffs=meta.get('n_diffs',0)

    # Build display candidates list: [(label, grid, conf)]
    raw_cands=meta.get('candidates',[])
    # raw_cands entries are (reasoning_str, grid, conf)
    # Normalise to (short_label, grid, conf)
    disp_cands=[]
    for entry in raw_cands:
        if len(entry)==3:
            label=str(entry[0])[:40]; cgrid=entry[1]; cconf=entry[2]
            if isinstance(cgrid, np.ndarray): disp_cands.append((label,cgrid,cconf))

    # Determine what to highlight
    mark_cell=None; highlight=None
    if source=='analytic' and 'cell' in meta and disp_cands:
        _,cand_grid,_=disp_cands[0]
        if cand_grid.shape==grid.shape:
            all_diffs=pixel_diff(grid,cand_grid)
            highlight=all_diffs[:20]
            mark_cell=meta['cell']

    source_emoji='🧠' if source=='analytic' else '🎲'
    _latest['grid_img']=render_grid(
        grid,
        title=f"Step {step} | {source_emoji} A{action} | Levels {levels}",
        highlight=highlight,
        mark_cell=mark_cell)

    # Im side: hypothesis panel — show read_board reasoning as bar
    if disp_cands:
        # Convert to format render_hypothesis_panel expects: [(name,grid,conf)]
        _latest['hyp_img']=render_hypothesis_panel(disp_cands)
    else:
        _latest['hyp_img']=None

    # Re side: candidate grid
    if disp_cands and disp_cands[0][1].shape==grid.shape:
        cname,cgrid,cconf2=disp_cands[0]
        diffs=pixel_diff(grid,cgrid)
        _latest['cand_img']=render_grid(
            cgrid,
            title=f"Im answer: {cname[:35]} (conf={cconf2:.2f}) — {len(diffs)} cells differ",
            highlight=diffs[:20])
    else:
        _latest['cand_img']=None

    _latest['bar_img']   =render_action_bar(data['counts'],sum(data['counts'].values()))
    _latest['reward_img']=render_reward_chart(data['reward_history'])
    grid_raw=np.array(data.get('grid_raw',grid.tolist()),dtype=np.int64)
    _latest['gabor_img']=render_gabor_panel(grid_raw)

    last_r=data['reward_history'][-1] if data['reward_history'] else 0
    r_emoji='🟡' if last_r>=5 else ('🟢' if last_r>0 else '🔴')
    hyp_str=(f"`{meta.get('hypothesis','?')}` conf={meta.get('conf',0):.2f} "
             f"→ click ({meta.get('x','?')},{meta.get('y','?')}) "
             f"[{meta.get('n_diffs','?')} cells wrong]"
             if source=='analytic'
             else f"CNN probs: {[round(p,2) for p in meta.get('probs',[])]}")

    _latest['status']=(
        f"{source_emoji} **{'Analytic (Re/Im)' if source=='analytic' else 'CNN fallback'}**"
        f" &nbsp;|&nbsp; Step {step} &nbsp;|&nbsp; Levels {levels}"
        f" &nbsp;|&nbsp; Reward {r_emoji} `{last_r:.2f}` &nbsp;|&nbsp; {state}\n\n"
        f"{hyp_str}")

    return (_latest['grid_img'],_latest['hyp_img'],_latest['cand_img'],
            _latest['bar_img'],_latest['reward_img'],_latest['status'])

# ── Handlers ──────────────────────────────────────────────────────────────────

def fetch_games(api_key):
    try:
        import arc_agi
        arc=arc_agi.Arcade(arc_api_key=api_key)
        envs=arc.get_environments()
        ids=[e.game_id for e in envs]
        return gr.Dropdown(choices=ids,value=ids[0] if ids else None),\
               f"Found **{len(ids)}** games."
    except Exception as e:
        return gr.Dropdown(choices=[]),f"**Error:** {e}"

def start_agent(game_id,api_key,max_steps):
    global _run_thread,_stop_flag
    if not game_id: return "Select a game first."
    if not api_key: return "Enter your API key."
    _stop_flag.set()
    if _run_thread and _run_thread.is_alive(): _run_thread.join(timeout=3)
    while not _frame_queue.empty():
        try: _frame_queue.get_nowait()
        except: break
    _stop_flag.clear()
    _run_thread=threading.Thread(
        target=_run_agent,args=(game_id,api_key,int(max_steps)),daemon=True)
    _run_thread.start()
    return f"Agent started on **{game_id}** — 🧠 Re/Im analytic + 🎲 CNN fallback"

def stop_agent():
    _stop_flag.set()
    return "Agent stopped."

# ── UI ────────────────────────────────────────────────────────────────────────

with gr.Blocks(title="ARC-AGI-3 Re/Im Agent") as demo:

    gr.Markdown("""
# ARC-AGI-3 Re/Im Agent Spectator
**Im side** = bird's eye hypothesis (which transformation?) &nbsp;|&nbsp;
**Re side** = exact location (which cells to click?)

🧠 = analytic solver (Im picks hypothesis → Re pins cell → ACTION6 click)
🎲 = CNN fallback (when no hypothesis clears the confidence threshold)
""")

    with gr.Row():
        with gr.Column(scale=3):
            api_box=gr.Textbox(label="ARC API key",type="password",
                                value=os.environ.get("ARC_API_KEY",""),
                                placeholder="arc-key-... or set ARC_API_KEY secret")
        with gr.Column(scale=1):
            fetch_btn=gr.Button("Fetch games")

    with gr.Row():
        with gr.Column(scale=2):
            game_dd=gr.Dropdown(label="Game",choices=[])
        with gr.Column(scale=1):
            steps_sl=gr.Slider(label="Max steps",minimum=20,maximum=500,value=150,step=10)
        with gr.Column(scale=1):
            with gr.Row():
                start_btn=gr.Button("▶ Watch",variant="primary")
                stop_btn =gr.Button("■ Stop")

    run_status=gr.Markdown("*Fetch games → select → Watch*")
    api_status=gr.Markdown()

    gr.Markdown("---")

    with gr.Row():
        grid_img=gr.Image(label="Current frame  (🔴=wrong cells  ⭐=target click)",
                          type="pil",interactive=False,height=280)
        hyp_img =gr.Image(label="Im side — hypothesis ranking",
                          type="pil",interactive=False,height=280)

    with gr.Row():
        cand_img  =gr.Image(label="Im candidate — what the answer should look like",
                             type="pil",interactive=False,height=220)
        bar_img   =gr.Image(label="Action frequency",
                             type="pil",interactive=False,height=220)

    with gr.Row():
        gabor_img =gr.Image(label="Gabor s-plane — cross terms (σ>0, ω>0)",
                             type="pil",interactive=False,height=160)
        reward_img=gr.Image(label="Reward  🟡+50 WIN  🟡+10 level  🟢+0.1 change  🔴-0.01 dead",
                             type="pil",interactive=False,height=160)

    timer=gr.Timer(value=1.0)
    timer.tick(pull_frame,
               outputs=[grid_img,hyp_img,cand_img,bar_img,gabor_img,reward_img,run_status])

    fetch_btn.click(fetch_games,inputs=api_box,outputs=[game_dd,api_status])
    start_btn.click(start_agent,inputs=[game_dd,api_box,steps_sl],outputs=run_status)
    stop_btn.click(stop_agent,outputs=run_status)

    gr.Markdown("""
---
**Re/Im duality in action:**
The Im side reads the whole board at once — symmetry maps, boundary contour, directional
flow — and ranks candidate transformations by confidence.
The Re side then diffs the current frame against the winning candidate and finds the exact
cell (boundary-first, following Cauchy's principle) that most needs fixing.
The agent emits ACTION6 at those precise coordinates instead of guessing randomly.
CNN fires only when no analytic hypothesis clears 0.40 confidence.
""")

if __name__ == "__main__":
    demo.launch()