File size: 69,247 Bytes
a328241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import time
import pytorch_lightning as pl
from torch.optim import AdamW
from torchmetrics import MeanSquaredError, PearsonCorrCoef, SpearmanCorrCoef, R2Score

logger = logging.getLogger(__name__)

# ===================== VERSION STRING FOR CLUSTER VERIFICATION =====================
ARCH_VERSION = "2024-12-24-stability-fix"
print(f"[ARCH] architectures.py loaded: {ARCH_VERSION}")
# ====================================================================================

class Interp1d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y, xnew):
        is_flat = {}
        vals = {'x': x, 'y': y, 'xnew': xnew}
        for name, arr in vals.items():
            is_flat[name] = (arr.dim() == 1)
            if is_flat[name]:
                vals[name] = arr.unsqueeze(0)
        x_2d, y_2d, xnew_2d = vals['x'], vals['y'], vals['xnew']
        B, Nx = x_2d.shape

        # SAFETY: Handle edge case where sequence length is < 5
        if Nx < 5:
            # Return constant interpolation (repeat/average the values)
            ynew_2d = y_2d.mean(dim=1, keepdim=True).expand(-1, xnew_2d.shape[1])
            ctx.save_for_backward(x_2d, y_2d, xnew_2d, 
                                  torch.zeros_like(xnew_2d, dtype=torch.long),
                                  torch.zeros_like(xnew_2d))
            ctx.Nx_was_small = True
            if is_flat['x'] and is_flat['xnew']:
                ynew_2d = ynew_2d.squeeze(0)
            return ynew_2d
        
        ctx.Nx_was_small = False
        idx = torch.searchsorted(x_2d, xnew_2d, right=False) - 1
        idx = idx.clamp(min=0, max=Nx-2)

        xL = torch.gather(x_2d, 1, idx)
        xR = torch.gather(x_2d, 1, idx+1)
        yL = torch.gather(y_2d, 1, idx)
        yR = torch.gather(y_2d, 1, idx+1)

        denom = (xR - xL)
        denom[denom == 0] = 1e-12
        t = (xnew_2d - xL)/denom
        ynew_2d = yL + (yR - yL)*t

        ctx.save_for_backward(x_2d, y_2d, xnew_2d, idx, t)
        if is_flat['x'] and is_flat['xnew']:
            ynew_2d = ynew_2d.squeeze(0)
        return ynew_2d

    @staticmethod
    def backward(ctx, grad_out):
        x_2d, y_2d, xnew_2d, idx, t = ctx.saved_tensors
        grad_x = grad_y = grad_xnew = None
        
        # Handle edge case from forward
        if getattr(ctx, 'Nx_was_small', False):
            if ctx.needs_input_grad[1]:
                grad_y = grad_out.sum(dim=-1, keepdim=True).expand_as(y_2d)
            return grad_x, grad_y, grad_xnew
            
        if ctx.needs_input_grad[1]:
            grad_y_tmp = torch.zeros_like(y_2d)
            idxp1 = (idx + 1).clamp(max=y_2d.shape[1] - 1)  # SAFETY: clamp idxp1
            
            # Calculate gradients
            grad_yL = (1.0 - t) * grad_out
            grad_yR = t * grad_out
            
            # Ensure consistent dtype between source and destination tensors
            grad_yL = grad_yL.to(dtype=grad_y_tmp.dtype)
            grad_yR = grad_yR.to(dtype=grad_y_tmp.dtype)
            
            grad_y_tmp.scatter_add_(1, idx, grad_yL)
            grad_y_tmp.scatter_add_(1, idxp1, grad_yR)
            grad_y = grad_y_tmp
        return grad_x, grad_y, grad_xnew

def interp1d(x, y, xnew):
    return Interp1d.apply(x, y, xnew)


class SWE_Pooling(nn.Module):
    """
    Sliced-Wasserstein Embedding (SWE) Pooling.
    Maps token embeddings [B, L, d_in] => [B, num_slices].
    """
    def __init__(self, d_in, num_slices, num_ref_points, freeze_swe=False):
        super().__init__()
        self.num_slices = num_slices
        self.num_ref_points = num_ref_points

        ref = torch.linspace(-1,1,num_ref_points).unsqueeze(1).repeat(1,num_slices)
        self.reference = nn.Parameter(ref, requires_grad=not freeze_swe)

        self.theta = nn.utils.weight_norm(nn.Linear(d_in, num_slices, bias=False), dim=0)
        self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data)
        self.theta.weight_g.requires_grad=False
        nn.init.normal_(self.theta.weight_v)

        self.weight = nn.Linear(num_ref_points,1,bias=False)

        if freeze_swe:
            self.theta.weight_v.requires_grad=False
            self.reference.requires_grad=False

    def forward(self, X, mask=None):
        B, N, D = X.shape
        device = X.device

        X_slices = self.theta(X)  # => [B,N,num_slices]
        X_slices_sorted, _ = torch.sort(X_slices, dim=1)

        x_coord = torch.linspace(0,1,N,device=device).unsqueeze(0).repeat(B*self.num_slices,1)
        X_flat = X_slices_sorted.permute(0,2,1).reshape(B*self.num_slices, N)
        xnew = torch.linspace(0,1,self.num_ref_points,device=device).unsqueeze(0).repeat(B*self.num_slices,1)

        y_intp = interp1d(x_coord, X_flat, xnew)
        X_slices_sorted_interp = y_intp.view(B,self.num_slices,self.num_ref_points).permute(0,2,1)

        r_expanded = self.reference.expand_as(X_slices_sorted_interp)
        embeddings = (r_expanded - X_slices_sorted_interp).permute(0,2,1)  # => [B,num_slices,num_ref_points]
        weighted = self.weight(embeddings).sum(dim=-1)  # => [B, num_slices]
        return weighted


#############################################################################
#             Enhanced Mutation-Aware SWE_Pooling                           #
#############################################################################

class MutationAwareSWEPooling(nn.Module):
    """
    Enhanced Sliced-Wasserstein Embedding Pooling with explicit mutation position handling.
    Maps token embeddings [B, L, d_in] => [B, num_slices].
    
    - Preserves mutation position information through weighted aggregation
    - Uses mutation positions to guide the pooling process
    """
    def __init__(self, d_in, num_slices, num_ref_points, freeze_swe=False):
        super().__init__()
        self.num_slices = num_slices
        self.num_ref_points = num_ref_points
        self.d_esm = 1152  # FIXED: Hardcode to 1152 to avoid channel indexing bugs with context window

        # Standard SWE components
        ref = torch.linspace(-1, 1, num_ref_points).unsqueeze(1).repeat(1, num_slices)
        self.reference = nn.Parameter(ref, requires_grad=not freeze_swe)

        # For ESM features (without mutation channel)
        self.theta = nn.utils.weight_norm(nn.Linear(self.d_esm, num_slices, bias=False), dim=0)
        self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data)
        self.theta.weight_g.requires_grad = False
        nn.init.normal_(self.theta.weight_v)

        # Mutation-aware components
        self.mutation_importance = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, num_slices),
            nn.Sigmoid()
        )
        
        # Position-specific weighting for each slice
        self.pos_weighting = nn.Linear(1, num_slices, bias=False)

        # FIXED: Direct projection of mutation channel (the 1153rd dim)
        # Previously, this channel was only used as a multiplier, meaning if ESM 
        # features had no diff, the result was 0. Now we project it directly.
        self.mut_projection = nn.Linear(1, num_slices, bias=False)
        
        # Final weighting
        self.weight = nn.Linear(num_ref_points, 1, bias=False)

        if freeze_swe:
            self.theta.weight_v.requires_grad = False
            self.reference.requires_grad = False

    def forward(self, X, mask=None):
        """
        X: [B, L, d_in] where d_in = d_esm + 1 (mutation channel)
        mask: [B, L] boolean mask
        """
        B, N, D = X.shape
        device = X.device
        
        # Check if using context window (additional channel)
        use_context = (D > self.d_esm + 1)
        
        if use_context:
            # Split ESM features and channels
            X_esm = X[:, :, :-2]  # [B, L, d_esm]
            X_mut = X[:, :, -2:-1]  # [B, L, 1] - mutation indicator
        else:
            # Split ESM features and mutation channel
            X_esm = X[:, :, :-1]  # [B, L, d_esm]
            X_mut = X[:, :, -1:]  # [B, L, 1] - mutation indicator
        
        # Regular SWE on ESM features
        X_slices = self.theta(X_esm)  # => [B, L, num_slices]
        
        # Compute mutation importance weights
        mut_weights = self.mutation_importance(X_mut)  # [B, L, num_slices]
        
        # Create position encodings (0 to 1 for each sequence)
        pos_tensor = torch.linspace(0, 1, N, device=device).view(1, N, 1).expand(B, N, 1)
        pos_weights = self.pos_weighting(pos_tensor)  # [B, L, num_slices]
        
        # Apply mutation-aware weighting to slices
        # Use both mutation indicator and position information
        # BUGFIX: We ALSO add the projected mutation signal directly. 
        # This ensures the model 'sees' the 1.0 signal even if ESM features are identical.
        X_slices = X_slices * (1.0 + mut_weights * pos_weights) + self.mut_projection(X_mut)
        
        # Sort slices as in standard SWE
        X_slices_sorted, _ = torch.sort(X_slices, dim=1)
        
        # Continue with standard SWE interpolation
        x_coord = torch.linspace(0, 1, N, device=device).unsqueeze(0).repeat(B*self.num_slices, 1)
        X_flat = X_slices_sorted.permute(0, 2, 1).reshape(B*self.num_slices, N)
        xnew = torch.linspace(0, 1, self.num_ref_points, device=device).unsqueeze(0).repeat(B*self.num_slices, 1)
        
        y_intp = interp1d(x_coord, X_flat, xnew)
        X_slices_sorted_interp = y_intp.view(B, self.num_slices, self.num_ref_points).permute(0, 2, 1)
        
        r_expanded = self.reference.expand_as(X_slices_sorted_interp)
        embeddings = (r_expanded - X_slices_sorted_interp).permute(0, 2, 1)  # => [B, num_slices, num_ref_points]
        weighted = self.weight(embeddings).sum(dim=-1)  # => [B, num_slices]
        
        return weighted


#############################################################################
#            Mutation-Specific Cross-Attention with Gating                  #
#############################################################################

class MutationSpecificAttention(nn.Module):
    """
    Enhanced cross-attention that explicitly handles mutation positions with gating.
    - Keeps ESM embeddings (1152-dim) and mutation channel separate
    - Uses specific mutation positions to guide attention
    - Preserves position-specific information throughout the network
    - Adds gating mechanism to control information flow
    - Includes memory-efficient computation for long sequences
    """
    def __init__(self, d_model=1152, num_heads=4, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Core attention for ESM embeddings only (1152-dim)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        # Absolute position encoding
        self.pos_encoder = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, d_model)
        )
        
        # Mutation-position specific attention
        self.mut_encoder = nn.Sequential(
            nn.Linear(2, 64),  # Input: [mut_binary, position_normalized]
            nn.ReLU(),
            nn.Linear(64, num_heads)
        )
        
        # Gating mechanism to control information flow
        self.gate = nn.Sequential(
            nn.Linear(d_model*2, d_model),
            nn.Sigmoid()
        )
        
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """Split the last dimension into (heads, head_dim)"""
        batch_size, seq_len, _ = x.shape
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3)  # [batch, heads, seq_len, head_dim]
    
    def merge_heads(self, x):
        """Merge the (heads, head_dim) into d_model"""
        batch_size, _, seq_len, _ = x.shape
        x = x.permute(0, 2, 1, 3)  # [batch, seq_len, heads, head_dim]
        return x.reshape(batch_size, seq_len, self.d_model)
    
    def forward(self, q_esm, k_esm, v_esm, q_mut, k_mut, mask=None):
        """
        Inputs:
            q_esm, k_esm, v_esm: ESM embeddings [B, L, 1152]
            q_mut, k_mut: Mutation information [B, L, 1]
            mask: Optional attention mask [B, L] or [B, 1, L]
        """
        batch_size = q_esm.shape[0]
        q_len, k_len = q_esm.shape[1], k_esm.shape[1]
        
        # Create position tensors (0-1 range for each sequence)
        q_pos = torch.linspace(0, 1, q_len, device=q_esm.device).view(1, -1, 1).expand(batch_size, q_len, 1)
        k_pos = torch.linspace(0, 1, k_len, device=k_esm.device).view(1, -1, 1).expand(batch_size, k_len, 1)
        
        # Position encoding
        q_pos_enc = self.pos_encoder(q_pos)
        k_pos_enc = self.pos_encoder(k_pos)
        
        # Add position encodings to ESM features
        q_esm_pos = q_esm + q_pos_enc
        k_esm_pos = k_esm + k_pos_enc
        
        # Process core ESM embeddings with position information
        q = self.split_heads(self.query(q_esm_pos))  # [B, h, q_len, d_k]
        k = self.split_heads(self.key(k_esm_pos))    # [B, h, k_len, d_k]
        v = self.split_heads(self.value(v_esm))      # [B, h, v_len, d_v]
        
        # Concatenate mutation indicator with position
        q_mut_pos = torch.cat([q_mut, q_pos], dim=-1)  # [B, q_len, 2]
        k_mut_pos = torch.cat([k_mut, k_pos], dim=-1)  # [B, k_len, 2]
        
        # Encode position-aware mutation information
        q_mut_enc = self.mut_encoder(q_mut_pos)  # [B, q_len, num_heads]
        k_mut_enc = self.mut_encoder(k_mut_pos)  # [B, k_len, num_heads]
        
        # Standard scaled dot-product attention
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)  # [B, h, q_len, k_len]
        
        # Create mutation-position attention bias
        # This explicitly boosts attention between positions based on mutation status
        mut_attn_bias = torch.matmul(
            q_mut_enc.permute(0, 2, 1).unsqueeze(3),  # [B, h, q_len, 1]
            k_mut_enc.permute(0, 2, 1).unsqueeze(2)   # [B, h, 1, k_len]
        )  # [B, h, q_len, k_len]
        
        # Apply mutation bias to attention scores
        # This makes mutations and their surrounding context attend more to each other
        scores = scores + mut_attn_bias
        
        # Apply mask if provided
        if mask is not None:
            # Fix mask dimension to match scores
            # mask shape should be [B, L] or [B, 1, L]
            if mask.dim() == 2:  # [B, L]
                # For keys mask [B, k_len] -> [B, 1, 1, k_len]
                mask = mask.unsqueeze(1).unsqueeze(2)
            elif mask.dim() == 3 and mask.size(1) == 1:  # [B, 1, L]
                # For keys mask [B, 1, k_len] -> [B, 1, 1, k_len]
                mask = mask.unsqueeze(2)
                
            # Expand mask to match scores dimensions
            # [B, 1, 1, k_len] -> [B, h, q_len, k_len]
            mask = mask.expand(-1, scores.size(1), scores.size(2), -1)
            
            # FIXED: Use -1e4 instead of -1e9 to avoid half-precision overflow
            scores = scores.masked_fill(mask == 0, -1e4)
        
        # Apply softmax and dropout
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, v)  # [B, h, q_len, d_v]
        context = self.merge_heads(context)  # [B, q_len, d_model]
        attn_output = self.out_proj(context)
        
        # Apply gating mechanism (new addition)
        # Concatenate the original query with the attention output to determine the gate
        gate_input = torch.cat([q_esm, attn_output], dim=-1)
        gate_value = self.gate(gate_input)
        
        # Memory optimization for long sequences
        # Processing the gating operation in chunks to prevent OOM errors
        if q_len > 1000:  # Only use chunking for very long sequences
            chunk_size = 500
            output_chunks = []
            
            for i in range(0, q_len, chunk_size):
                end_idx = min(i + chunk_size, q_len)
                # Process chunks
                chunk_gate = gate_value[:, i:end_idx, :]
                chunk_attn = attn_output[:, i:end_idx, :]
                chunk_q = q_esm[:, i:end_idx, :]
                
                # Apply gating equation to this chunk
                chunk_output = chunk_gate * chunk_attn + (1 - chunk_gate) * chunk_q
                output_chunks.append(chunk_output)
            
            # Combine chunks
            output = torch.cat(output_chunks, dim=1)
        else:
            # Original operation for shorter sequences
            output = gate_value * attn_output + (1 - gate_value) * q_esm
        
        return output


class MutationSpecificCrossAttentionBlock(nn.Module):
    """
    Cross-attention block with explicit mutation position handling.
    Each block processes ESM embeddings and mutation channels separately,
    with special emphasis on mutation positions.
    """
    def __init__(self, d_model=1152, num_heads=4, ffn_dim=2048, dropout=0.1):
        super().__init__()
        # Mutation-aware cross attention
        self.attn_c12 = MutationSpecificAttention(d_model, num_heads, dropout)
        self.attn_c21 = MutationSpecificAttention(d_model, num_heads, dropout)
        
        # Layer normalization for ESM embeddings
        self.norm_c1 = nn.LayerNorm(d_model)
        self.norm_c2 = nn.LayerNorm(d_model)
        
        # FFN for ESM embeddings
        self.ffn_c1 = nn.Sequential(
            nn.Linear(d_model, ffn_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_dim, d_model)
        )
        self.ffn_c2 = nn.Sequential(
            nn.Linear(d_model, ffn_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_dim, d_model)
        )
        self.norm_ffn_c1 = nn.LayerNorm(d_model)
        self.norm_ffn_c2 = nn.LayerNorm(d_model)
        
        # Mutation importance update layer
        self.mut_update = nn.Sequential(
            nn.Linear(d_model + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, c1_esm, c1_mut, c2_esm, c2_mut, mask1=None, mask2=None):
        """
        Inputs:
            c1_esm, c2_esm: ESM embeddings [B, L, 1152]
            c1_mut, c2_mut: Mutation channels [B, L, 1]
            mask1, mask2: Optional masks
        """
        # c1->c2 cross-attention
        c1_attn = self.attn_c12(c1_esm, c2_esm, c2_esm, c1_mut, c2_mut, mask2)
        c1_out = self.norm_c1(c1_esm + c1_attn)
        
        # c2->c1 cross-attention
        c2_attn = self.attn_c21(c2_esm, c1_esm, c1_esm, c2_mut, c1_mut, mask1)
        c2_out = self.norm_c2(c2_esm + c2_attn)
        
        # Feed-forward
        c1_ffn = self.ffn_c1(c1_out)
        c1_ffn_out = self.norm_ffn_c1(c1_out + c1_ffn)
        
        c2_ffn = self.ffn_c2(c2_out)
        c2_ffn_out = self.norm_ffn_c2(c2_out + c2_ffn)
        
        # Update mutation importance based on attention output
        # This creates a feedback loop where mutation effect is refined
        c1_mut_in = torch.cat([c1_ffn_out, c1_mut], dim=-1)
        c2_mut_in = torch.cat([c2_ffn_out, c2_mut], dim=-1)
        
        # Stabilized update: convex combination ensures values stay in [0, 1]
        # Avoids exponential decay (old bug) and unbounded growth (additive bug)
        c1_mut_updated = 0.9 * c1_mut + 0.1 * self.mut_update(c1_mut_in)
        c2_mut_updated = 0.9 * c2_mut + 0.1 * self.mut_update(c2_mut_in)
        
        return c1_ffn_out, c2_ffn_out, c1_mut_updated, c2_mut_updated


class MutationSpecificCrossAttentionStack(nn.Module):
    """
    Stack of Mutation-Specific Cross-Attention blocks.
    Emphasizes mutation positions throughout the network.
    Now includes gradient checkpointing for memory efficiency.
    """
    def __init__(self, d_model=1152, num_heads=4, ffn_dim=2048, dropout=0.1, num_layers=2):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.use_checkpoint = True  # Enable gradient checkpointing by default
        
        self.blocks = nn.ModuleList([
            MutationSpecificCrossAttentionBlock(
                d_model=d_model,
                num_heads=num_heads,
                ffn_dim=ffn_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

    def forward(self, c1, c2, mask1=None, mask2=None):
        """
        Process protein chains with mutation-specific attention.
        c1, c2: [B, L, D] where D can be 1153 (original) or 1154 (with context window)
        Uses gradient checkpointing when in training mode to save memory.
        """
        # Check input dimension to determine if context window is used
        d_in = c1.shape[2]
        use_context = (d_in > 1153)
        
        if use_context:
            # Split ESM embeddings from mutation+context channels
            c1_esm, c1_channels = c1[:, :, :-2], c1[:, :, -2:]  # [B, L, 1152], [B, L, 2]
            c2_esm, c2_channels = c2[:, :, :-2], c2[:, :, -2:]  # [B, L, 1152], [B, L, 2]
            
            # Extract mutation channel (first channel)
            c1_mut = c1_channels[:, :, :1]  # [B, L, 1]
            c2_mut = c2_channels[:, :, :1]  # [B, L, 1]
        else:
            # Original behavior - just split ESM and mutation
            c1_esm, c1_mut = c1[:, :, :-1], c1[:, :, -1:]  # [B, L, 1152], [B, L, 1]
            c2_esm, c2_mut = c2[:, :, :-1], c2[:, :, -1:]  # [B, L, 1152], [B, L, 1]
        
        # Process through attention blocks with optional checkpointing
        for block in self.blocks:
            # Use gradient checkpointing in training mode for memory efficiency
            if self.use_checkpoint and self.training:
                # Define helper function for checkpointing that handles None masks
                def create_checkpoint_fn(block_fn):
                    def checkpoint_fn(esm1, mut1, esm2, mut2, has_mask1, has_mask2, mask1_val, mask2_val):
                        # Conditionally use the masks based on the has_mask flags
                        m1 = mask1_val if has_mask1 else None
                        m2 = mask2_val if has_mask2 else None
                        return block_fn(esm1, mut1, esm2, mut2, m1, m2)
                    return checkpoint_fn
                
                # Convert None masks to flags and dummy tensors for checkpointing
                has_mask1 = mask1 is not None
                has_mask2 = mask2 is not None
                mask1_val = mask1 if has_mask1 else torch.zeros(1, device=c1_esm.device)
                mask2_val = mask2 if has_mask2 else torch.zeros(1, device=c1_esm.device)
                
                # Apply checkpointing
                c1_esm, c2_esm, c1_mut, c2_mut = torch.utils.checkpoint.checkpoint(
                    create_checkpoint_fn(block),
                    c1_esm, c1_mut, c2_esm, c2_mut,
                    torch.tensor(has_mask1, device=c1_esm.device),
                    torch.tensor(has_mask2, device=c1_esm.device),
                    mask1_val, mask2_val
                )
            else:
                c1_esm, c2_esm, c1_mut, c2_mut = block(c1_esm, c1_mut, c2_esm, c2_mut, mask1, mask2)
        
        # Recombine with appropriate channels
        if use_context:
            # Need to preserve the context channel
            context_channels_c1 = c1_channels[:, :, 1:]  # [B, L, 1]
            context_channels_c2 = c2_channels[:, :, 1:]  # [B, L, 1]
            c1_out = torch.cat([c1_esm, c1_mut, context_channels_c1], dim=-1)  # [B, L, 1154]
            c2_out = torch.cat([c2_esm, c2_mut, context_channels_c2], dim=-1)  # [B, L, 1154]
        else:
            # Original behavior
            c1_out = torch.cat([c1_esm, c1_mut], dim=-1)  # [B, L, 1153]
            c2_out = torch.cat([c2_esm, c2_mut], dim=-1)  # [B, L, 1153]
        
        return c1_out, c2_out


#############################################################################
#        AffinityPredictor with Improved Memory Efficiency                  #
#############################################################################

class AffinityPredictor(nn.Module):
    """
    Enhanced AffinityPredictor with explicit mutation position handling.
    embedding_method => "difference", "cosine", "cross_attention", or "cross_attention_swe".
    """
    def __init__(
        self,
        input_dim=1153,  # 1152 (ESM) + 1 (mutation)
        latent_dim=1024,
        num_slices=1024,
        num_ref_points=128,
        dropout_rate=0.2,
        freeze_swe=False,
        embedding_method="difference",
        normalize_difference=False,
        num_hidden_layers=2,
        # cross-attn
        num_cross_attn_layers=2,
        num_attention_heads=4,
        cross_ffn_dim=2048,
    ):
        super().__init__()
        self.embedding_method = embedding_method.lower()
        self.normalize_difference = normalize_difference
        self.input_dim = input_dim
        
        # ESM dimension (without mutation channel)
        self.esm_dim = input_dim - 1  # 1152
        
        # Define cross-attention stack if needed
        self.cross_stack = None
        if "cross_attention" in self.embedding_method:
            self.cross_stack = MutationSpecificCrossAttentionStack(
                d_model=self.esm_dim,  # 1152
                num_heads=num_attention_heads,
                ffn_dim=cross_ffn_dim,
                dropout=dropout_rate,
                num_layers=num_cross_attn_layers
            )
        
        # Enhanced Mutation-Aware SWE Pooling
        self.swe_pooling = None
        if self.embedding_method in ["difference", "cosine", "cross_attention_swe"]:
            # For SWE, we use the full input_dim (1153)
            self.swe_pooling = MutationAwareSWEPooling(
                d_in=input_dim,
                num_slices=num_slices,
                num_ref_points=num_ref_points,
                freeze_swe=freeze_swe
            )
        
        # Define aggregator MLP in-dimensions
        if self.embedding_method == "cosine":
            in_features = 1
        else:
            in_features = num_slices  # difference or cross_attention_swe => [B, num_slices]
        
        # Add projection layer for cross_attention to avoid dynamic creation
        self.cross_attn_projection = None
        if self.embedding_method == "cross_attention":
            cross_proj_in = input_dim  # Full dimension including mutation channel
            cross_proj_out = in_features
            self.cross_attn_projection = nn.Linear(cross_proj_in, cross_proj_out, bias=False)
        
        # Final MLP
        layers = []
        current_dim = in_features
        for _ in range(num_hidden_layers):
            layers.append(nn.Linear(current_dim, latent_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            current_dim = latent_dim
        layers.append(nn.Linear(current_dim, 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, chain1, chain1_mask, chain2, chain2_mask):
        """
        chain1, chain2 => [B, L, input_dim] (1153 or 1154 with context)
        """
        if "cross_attention" in self.embedding_method:
            # Process through mutation-specific cross-attention
            c1_out, c2_out = self.cross_stack(chain1, chain2, chain1_mask, chain2_mask)
            
            if self.embedding_method == "cross_attention_swe":
                # Apply enhanced SWE pooling
                rep1 = self.swe_pooling(c1_out, chain1_mask)  # [B, num_slices]
                rep2 = self.swe_pooling(c2_out, chain2_mask)  # [B, num_slices]
                
                # Difference aggregator
                diff = rep1 - rep2
                if self.normalize_difference:
                    diff = F.normalize(diff, p=2, dim=1)
                    
                # Final prediction
                preds = self.mlp(diff).squeeze(-1)
                
                return preds
                
            elif self.embedding_method == "cross_attention":
                # Use mutation-weighted pooling
                # Extract mutation channel to guide pooling
                d_in = c1_out.shape[2]
                use_context = (d_in > 1153)
                if use_context:
                    c1_mut = c1_out[:, :, -2:-1]  # [B, L, 1]
                    c2_mut = c2_out[:, :, -2:-1]  # [B, L, 1]
                else:
                    c1_mut = c1_out[:, :, -1:]  # [B, L, 1]
                    c2_mut = c2_out[:, :, -1:]  # [B, L, 1]
                
                # Weighted pooling - gives higher weight to mutated positions
                c1_weights = F.softmax(c1_mut * 10, dim=1)  # Sharpen weights
                c2_weights = F.softmax(c2_mut * 10, dim=1)
                
                c1_pool = torch.sum(c1_out * c1_weights, dim=1)  # [B, 1153/1154]
                c2_pool = torch.sum(c2_out * c2_weights, dim=1)  # [B, 1153/1154]
                
                # Create difference representation
                diff = c1_pool - c2_pool
                if self.normalize_difference:
                    diff = F.normalize(diff, p=2, dim=1)
                
                # Use pre-defined projection layer instead of creating one dynamically
                if self.cross_attn_projection is not None:
                    diff = self.cross_attn_projection(diff)
                
                preds = self.mlp(diff).squeeze(-1)
                
                return preds
        
        elif self.embedding_method == "cosine":
            # Enhanced SWE => [B, num_slices]
            rep1 = self.swe_pooling(chain1, chain1_mask)
            rep2 = self.swe_pooling(chain2, chain2_mask)
            sim = F.cosine_similarity(rep1, rep2, dim=1).unsqueeze(-1)
            out = self.mlp(sim).squeeze(-1)
            
            return out
        
        else:  # "difference"
            rep1 = self.swe_pooling(chain1, chain1_mask)
            rep2 = self.swe_pooling(chain2, chain2_mask)
            diff = rep1 - rep2
            if self.normalize_difference:
                diff = F.normalize(diff, p=2, dim=1)
            out = self.mlp(diff).squeeze(-1)
            
            return out


class AffinityPredictionModel(pl.LightningModule):
    """
    Lightning wrapper for training. Siamese logic in training.py
    """
    def __init__(self, predictor: AffinityPredictor, learning_rate=1e-4):
        super().__init__()
        self.predictor = predictor
        self.learning_rate = learning_rate

        self.loss_fn = nn.MSELoss()
        self.pearson_corr = PearsonCorrCoef()
        self.spearman_corr = SpearmanCorrCoef()
        self.r2_score = R2Score()
        self.mse_metric = MeanSquaredError()

    def forward(self, chain1, chain1_mask, chain2, chain2_mask):
        return self.predictor(chain1, chain1_mask, chain2, chain2_mask)

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        steps = self.trainer.estimated_stepping_batches
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.learning_rate, total_steps=steps
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}


# Enhanced AffinityPredictor with Two-Head Architecture
# Add this to architectures.py

class DualHeadAffinityPredictor(nn.Module):
    """
    Enhanced AffinityPredictor with explicit two-head architecture.
    Simultaneously processes mutant and wildtype proteins to predict both ΔG and ΔΔG.
    
    embedding_method => "difference", "cosine", "cross_attention", or "cross_attention_swe".
    """
    def __init__(
        self,
        input_dim=1153,  # 1152 (ESM) + 1 (mutation)
        latent_dim=1024,
        num_slices=1024,
        num_ref_points=128,
        dropout_rate=0.2,
        freeze_swe=False,
        embedding_method="difference",
        normalize_difference=False,
        num_hidden_layers=2,
        # cross-attn
        num_cross_attn_layers=2,
        num_attention_heads=4,
        cross_ffn_dim=2048,
        use_dual_head=True,  # Enable dual-head by default
        ddg_signal_gain=1.0,  # Initial gain for ddG signal
        ddg_signal_multiplier=20.0, # FIXED: multiplier for ddG signal (vnew65.0)
    ):
        super().__init__()
        self.embedding_method = embedding_method.lower()
        self.normalize_difference = normalize_difference
        self.input_dim = input_dim
        self.use_dual_head = use_dual_head
        self._ddg_log_counter = 0
        
        # DEBUG: Confirm this version is running
        print(f"[MODEL INIT] DualHeadAffinityPredictor created: version={ARCH_VERSION}, dual_head={use_dual_head}, method={self.embedding_method}")
        
        # ESM dimension (without mutation channel)
        self.esm_dim = input_dim - 1  # 1152
        
        # Define cross-attention stack if needed
        self.cross_stack = None
        if "cross_attention" in self.embedding_method:
            self.cross_stack = MutationSpecificCrossAttentionStack(
                d_model=self.esm_dim,  # 1152
                num_heads=num_attention_heads,
                ffn_dim=cross_ffn_dim,
                dropout=dropout_rate,
                num_layers=num_cross_attn_layers
            )
        
        # Enhanced Mutation-Aware SWE Pooling
        self.swe_pooling = None
        if self.embedding_method in ["difference", "cosine", "cross_attention_swe"]:
            # For SWE, we use the full input_dim (1153)
            self.swe_pooling = MutationAwareSWEPooling(
                d_in=input_dim,
                num_slices=num_slices,
                num_ref_points=num_ref_points,
                freeze_swe=freeze_swe
            )
        
        # Define aggregator MLP in-dimensions
        if self.embedding_method == "cosine":
            in_features = 1
        else:
            in_features = num_slices  # difference or cross_attention_swe => [B, num_slices]
        
        # Add projection layer for cross_attention to avoid dynamic creation
        self.cross_attn_projection = None
        if self.embedding_method == "cross_attention":
            cross_proj_in = input_dim  # Full dimension including mutation channel
            cross_proj_out = in_features
            self.cross_attn_projection = nn.Linear(cross_proj_in, cross_proj_out, bias=False)
        
        # Define dG head (main prediction head)
        layers = []
        current_dim = in_features
        for _ in range(num_hidden_layers):
            layers.append(nn.Linear(current_dim, latent_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            current_dim = latent_dim
        layers.append(nn.Linear(current_dim, 1))
        self.dg_mlp = nn.Sequential(*layers)
        
        # Define ΔΔG head for direct prediction
        # FIX (vnew64.0): Shallow 2-layer MLP + residual skip connection.
        # DDGACT diagnostics from vnew62-63 showed 7-layer ReLU network causes 
        # progressive variance collapse (std: 0.1 → 7e-05). Each ReLU zeros ~50%
        # of activations, so 7 layers → 0.5^7 = 0.8% signal survival.
        # Solution: (1) 2 layers only, (2) skip connection preserves raw input signal.
        if self.use_dual_head:
            # Shallow nonlinear pathway (2 layers)
            self.ddg_hidden = nn.Sequential(
                nn.Linear(in_features, latent_dim),
                nn.GELU(),  # GELU instead of ReLU - no zero-capping, smoother gradients
                nn.Linear(latent_dim, latent_dim),
                nn.GELU(),
            )
            # Output projection
            self.ddg_out = nn.Linear(latent_dim, 1)
            # Skip connection: project input directly to output dimension
            self.ddg_skip = nn.Linear(in_features, 1)
        
        # Source type embedding for conditional inference
        # User-friendly types for inference:
        #   0 = "mutant"    - Single mutant predictions (most common)
        #   1 = "wt_pairs"  - Wildtype pairs with absolute binding affinity
        #   2 = "antibody"  - Antibody-antigen binding (CDR-focused)
        self.source_type_embedding = nn.Embedding(3, 32)  # 32-dim embedding
        self.source_type_projection = nn.Linear(in_features + 32, in_features)  # Project back to in_features

        # FIXED: Restoring the historical 'learnable gain' strategy
        # This allows the model to amplify the ddG signal early in Stage B
        self.ddg_signal_gain = nn.Parameter(torch.tensor(float(ddg_signal_gain)))
        self.ddg_signal_multiplier = float(ddg_signal_multiplier)

    def _extract_features(self, chain1, chain1_mask, chain2, chain2_mask):
        """
        Extract feature representation for a protein complex.
        Returns a vector representation suitable for prediction.
        """
        if "cross_attention" in self.embedding_method:
            # Process through mutation-specific cross-attention
            c1_out, c2_out = self.cross_stack(chain1, chain2, chain1_mask, chain2_mask)
            
            if self.embedding_method == "cross_attention_swe":
                # Apply enhanced SWE pooling
                rep1 = self.swe_pooling(c1_out, chain1_mask)  # [B, num_slices]
                rep2 = self.swe_pooling(c2_out, chain2_mask)  # [B, num_slices]
                
                # Difference aggregator
                diff = rep1 - rep2
                if self.normalize_difference:
                    diff = F.normalize(diff, p=2, dim=1)
                return diff
                
            elif self.embedding_method == "cross_attention":
                # Use mutation-weighted pooling
                # Extract mutation channel to guide pooling
                d_in = c1_out.shape[2]
                use_context = (d_in > 1153)
                if use_context:
                    c1_mut = c1_out[:, :, -2:-1]  # [B, L, 1]
                    c2_mut = c2_out[:, :, -2:-1]  # [B, L, 1]
                else:
                    c1_mut = c1_out[:, :, -1:]  # [B, L, 1]
                    c2_mut = c2_out[:, :, -1:]  # [B, L, 1]
                
                # Weighted pooling - gives higher weight to mutated positions
                c1_weights = F.softmax(c1_mut * 10, dim=1)  # Sharpen weights
                c2_weights = F.softmax(c2_mut * 10, dim=1)
                
                c1_pool = torch.sum(c1_out * c1_weights, dim=1)  # [B, 1153/1154]
                c2_pool = torch.sum(c2_out * c2_weights, dim=1)  # [B, 1153/1154]
                
                # Create difference representation
                diff = c1_pool - c2_pool
                if self.normalize_difference:
                    diff = F.normalize(diff, p=2, dim=1)
                
                # Use pre-defined projection layer instead of creating one dynamically
                if self.cross_attn_projection is not None:
                    diff = self.cross_attn_projection(diff)
                
                return diff
        
        elif self.embedding_method == "cosine":
            # Enhanced SWE => [B, num_slices]
            rep1 = self.swe_pooling(chain1, chain1_mask)
            rep2 = self.swe_pooling(chain2, chain2_mask)
            sim = F.cosine_similarity(rep1, rep2, dim=1).unsqueeze(-1)
            return sim
        
        else:  # "difference"
            rep1 = self.swe_pooling(chain1, chain1_mask)
            rep2 = self.swe_pooling(chain2, chain2_mask)
            diff = rep1 - rep2
            if self.normalize_difference:
                diff = F.normalize(diff, p=2, dim=1)
            return diff
    
    def _extract_residue_features(self, chain1, chain1_mask, chain2, chain2_mask):
        """
        Extract RESIDUE-LEVEL features (before pooling) for computing differences.
        Used for ddG to preserve mutation-specific information.
        
        Returns:
            c1_out, c2_out: [B, L1, D] and [B, L2, D] attended residue features
        """
        if "cross_attention" in self.embedding_method:
            c1_out, c2_out = self.cross_stack(chain1, chain2, chain1_mask, chain2_mask)
            return c1_out, c2_out
        else:
            # For non-cross-attention methods, return inputs directly
            return chain1, chain2
    
    def forward(self, mut_chain1, mut_chain1_mask, mut_chain2, mut_chain2_mask, 
            wt_chain1=None, wt_chain1_mask=None, wt_chain2=None, wt_chain2_mask=None,
            source_type_ids=None):
        """
        Dual-head forward method that can handle both modes:
        1. Standard mode: Just predict dG for mutant complex
        2. Dual-head mode: Predict both dG and direct ddG when wildtype is provided
        
        For ddG: Uses RESIDUE-LEVEL differences before pooling to preserve mutation info.
        ddG = ddg_mlp(pool(mut_features - wt_features)) instead of 
              ddg_mlp(pool(mut_features) - pool(wt_features))
        
        Args:
            source_type_ids: Optional[Tensor] of shape [B], values 0/1/2 for conditioning
        
        Returns:
            If wildtype inputs are None or use_dual_head=False:
                Returns mutant dG prediction only
            Else:
                Returns tuple of (mutant_dG, direct_ddG_prediction)
        """
        # ============== OPTIMIZED: Cache residue features ==============
        # Get mutant RESIDUE-LEVEL features first (used for both dG and ddG)
        mut_c1_res, mut_c2_res = self._extract_residue_features(
            mut_chain1, mut_chain1_mask, mut_chain2, mut_chain2_mask)
        
        # Pool for dG prediction (reuses cached residue features)
        if "cross_attention_swe" in self.embedding_method:
            rep1 = self.swe_pooling(mut_c1_res, mut_chain1_mask)
            rep2 = self.swe_pooling(mut_c2_res, mut_chain2_mask)
            mut_features = rep1 - rep2
            if self.normalize_difference:
                mut_features = F.normalize(mut_features, p=2, dim=1, eps=1e-8)
        else:
            # Fallback pooling
            mut_features = self._extract_features(mut_chain1, mut_chain1_mask, mut_chain2, mut_chain2_mask)
        
        # ============== SOURCE TYPE CONDITIONING ==============
        # Apply source type conditioning if provided
        if source_type_ids is not None:
            # Get source type embedding [B, 32]
            src_emb = self.source_type_embedding(source_type_ids)
            # Concatenate with features and project back
            conditioned_features = torch.cat([mut_features, src_emb], dim=-1)
            mut_features = self.source_type_projection(conditioned_features)
        
        # Predict dG for mutant
        dg_pred = self.dg_mlp(mut_features).squeeze(-1)
        
        # If no wildtype or dual head is disabled, just return mutant dG
        if not self.use_dual_head or wt_chain1 is None or wt_chain2 is None:
            return dg_pred
        
        # ============== RESIDUE-LEVEL ddG COMPUTATION ==============
        # Get wildtype RESIDUE-LEVEL features (mutant already cached above)
        wt_c1_res, wt_c2_res = self._extract_residue_features(
            wt_chain1, wt_chain1_mask, wt_chain2, wt_chain2_mask)
        
        
        # (Debug logging removed - was polluting training output)
        
        # Compute RESIDUE-LEVEL differences BEFORE pooling
        # This preserves mutation-specific changes at each position
        # Handle sequence length differences by taking minimum length
        L_c1 = min(mut_c1_res.shape[1], wt_c1_res.shape[1])
        L_c2 = min(mut_c2_res.shape[1], wt_c2_res.shape[1])
        
        c1_diff = mut_c1_res[:, :L_c1, :] - wt_c1_res[:, :L_c1, :]  # [B, L1, D]
        c2_diff = mut_c2_res[:, :L_c2, :] - wt_c2_res[:, :L_c2, :]  # [B, L2, D]
        
        # Update masks for truncated length
        c1_diff_mask = mut_chain1_mask[:, :L_c1] if mut_chain1_mask is not None else None
        c2_diff_mask = mut_chain2_mask[:, :L_c2] if mut_chain2_mask is not None else None
        
        # [DIFF CHECK] Diagnostic Logging
        # Measure cosine similarity at the approximate mutation site to check for embedding collapse.
        if self._ddg_log_counter % 200 == 1:
            with torch.no_grad():
                # Extract center position
                mid_idx = L_c1 // 2
                
                # Vectors at midpoint
                v_mut = mut_c1_res[:, mid_idx, :]
                v_wt = wt_c1_res[:, mid_idx, :]
                
                # Cosine similarity
                cos_sim = F.cosine_similarity(v_mut, v_wt, dim=1).mean().item()
                diff_norm = (v_mut - v_wt).norm(dim=1).mean().item()
                
                logger.info(f"[DIFF CHECK] Batch {self._ddg_log_counter}: CosSim at mid={cos_sim:.5f}, DiffNorm={diff_norm:.5f}")
                
        # NOW pool the differences
        if self.swe_pooling is not None:
            # =================================================================
            # HYBRID POOLING: Global SWE + Local Mutation-Site-Centric (vnew37.0)
            # This ensures local mutation signals are not diluted by global pooling
            # =================================================================
            
            # Concatenate chain differences along sequence dimension
            combined_diff = torch.cat([c1_diff, c2_diff], dim=1)  # [B, L_comb, D=1153]
            if c1_diff_mask is not None and c2_diff_mask is not None:
                combined_mask = torch.cat([c1_diff_mask, c2_diff_mask], dim=1)
            else:
                combined_mask = None

            # A. Global Component: Standard SWE pooling (capture global stability context)
            global_diff = self.swe_pooling(combined_diff, combined_mask) # [B, num_slices]

            # B. Local Component: Mutation-Site-Centric Pooling (MSCP)
            # CRITICAL FIX (v49.0): Extract indicator from RAW INPUT chains, NOT cross-attention output!
            # The cross-attention stack applies 0.9 convex combination at each layer (5 layers = 0.9^5 = 59% decay).
            # Using mut_chain1/mut_chain2 (raw inputs) instead of mut_c1_res/mut_c2_res (diluted outputs).
            
            # Determine if we have context window (1154-dim) or standard (1153-dim)
            d_raw = mut_chain1.shape[2]
            use_context = (d_raw > 1153)
            
            # Extract RAW indicator from INPUT chains (before cross-attention!)
            if use_context:
                c1_mut_raw = mut_chain1[:, :L_c1, -2:-1]
                c2_mut_raw = mut_chain2[:, :L_c2, -2:-1]
            else:
                c1_mut_raw = mut_chain1[:, :L_c1, -1:]
                c2_mut_raw = mut_chain2[:, :L_c2, -1:]
            
            mut_indicator = torch.cat([c1_mut_raw, c2_mut_raw], dim=1)  # [B, L_comb, 1]
            
            # Stability: Clamp indicator to be non-negative and bounded.
            # In case attention stack produced weird values, we force them back to [0, 2]
            mut_indicator = mut_indicator.clamp(min=0.0, max=2.0)
            
            # Weighted average focusing ONLY on the mutation sites
            # use 0.001 epsilon to avoid nan for WT samples (mut_sum=0)
            mut_sum = mut_indicator.sum(dim=1).clamp(min=1e-3)
            
            # MSCP Calculation with additional stability guard
            mscp_esm = (combined_diff[:, :, :1152] * mut_indicator).sum(dim=1) / mut_sum # [B, 1152]
            mscp_mut = (mut_indicator * mut_indicator).sum(dim=1) / mut_sum # [B, 1] (should be ~1.0)
            
            # Project local delta into the same slice space as global features 
            # using the SHARED theta projection and mut_projection from SWE
            # This ensures consistent representation between global and local paths
            local_diff_esm = self.swe_pooling.theta(mscp_esm)
            local_diff_mut = self.swe_pooling.mut_projection(mscp_mut)
            local_diff = local_diff_esm + local_diff_mut  # [B, num_slices]

            # C. Combine: Residual-style addition + Gain
            # FIXED: Apply signal_gain AFTER normalization so it actually has effect
            # Previously, L2-norm undid the scaling entirely.
            diff_multiplier = getattr(self, "ddg_signal_multiplier", 20.0)
            diff_features = (global_diff + local_diff) * diff_multiplier
            
            # ===================================================================
            # SIGNAL FLOW LOGGING: Track local vs global contribution
            if not hasattr(self, '_ddg_log_counter'):
                self._ddg_log_counter = 0
            self._ddg_log_counter += 1
            
            should_log = (self._ddg_log_counter % 200 == 1)
            if should_log:
                g_mag = global_diff.abs().mean().item()
                l_mag = local_diff.abs().mean().item()
                logger.info(f"[DDG SIGNAL] Batch {self._ddg_log_counter}: Global_mag={g_mag:.4f}, Local_mag={l_mag:.4f}")
                #region agent log
                try:
                    # Inspect what the model thinks the mutation indicator is (both tail channels)
                    d_raw_dbg = int(mut_chain1.shape[2])
                    # indicator candidate stats on chain1/chain2 for last and second-last channels
                    def _chan_stats(x):
                        return {
                            "min": float(x.min().item()),
                            "max": float(x.max().item()),
                            "mean": float(x.float().mean().item()),
                            "std": float(x.float().std().item()),
                        }
                    c1_last = _chan_stats(mut_chain1[:, :L_c1, -1])
                    c2_last = _chan_stats(mut_chain2[:, :L_c2, -1])
                    c1_last2 = _chan_stats(mut_chain1[:, :L_c1, -2]) if d_raw_dbg >= 1154 else None
                    c2_last2 = _chan_stats(mut_chain2[:, :L_c2, -2]) if d_raw_dbg >= 1154 else None
                    mut_sum_dbg = float(mut_sum.mean().item()) if "mut_sum" in locals() else None
                    payload = {
                        "sessionId": "debug-session",
                        "runId": "pre-fix",
                        "hypothesisId": "F",
                        "location": "architectures.py:DualHeadAffinityPredictor:mscp_indicator_debug",
                        "message": "Indicator channel stats (last vs second-last) to detect double-indicator / wrong channel selection",
                        "data": {
                            "ddg_log_counter": int(self._ddg_log_counter),
                            "d_raw": d_raw_dbg,
                            "use_context_flag": bool(use_context),
                            "c1_last": c1_last,
                            "c2_last": c2_last,
                            "c1_last2": c1_last2,
                            "c2_last2": c2_last2,
                            "mut_sum_mean": mut_sum_dbg,
                        },
                        "timestamp": int(time.time() * 1000),
                    }
                    with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
                        f.write(json.dumps(payload, default=str) + "\n")
                    logger.info(f"[AGENTLOG MSCP] d_raw={d_raw_dbg} use_context={use_context} c1_last={c1_last} c1_last2={c1_last2} mut_sum_mean={mut_sum_dbg}")
                except Exception:
                    pass
                #endregion
            
            if self.normalize_difference:
                diff_features = F.normalize(diff_features, p=2, dim=1, eps=1e-8)
            
            # FIXED: Apply signal_gain AFTER normalization so it actually scales the output
            diff_features = diff_features * self.ddg_signal_gain
        else:
            # Fallback: mean pooling of differences
            if c1_diff_mask is not None:
                c1_diff = c1_diff * c1_diff_mask.unsqueeze(-1).float()
                c1_pool = c1_diff.sum(dim=1) / c1_diff_mask.sum(dim=1, keepdim=True).clamp(min=1)
            else:
                c1_pool = c1_diff.mean(dim=1)
            if c2_diff_mask is not None:
                c2_diff = c2_diff * c2_diff_mask.unsqueeze(-1).float()
                c2_pool = c2_diff.sum(dim=1) / c2_diff_mask.sum(dim=1, keepdim=True).clamp(min=1)
            else:
                c2_pool = c2_diff.mean(dim=1)
            diff_features = c1_pool - c2_pool
            if self.normalize_difference:
                diff_features = F.normalize(diff_features, p=2, dim=1)
        
        # DEBUG: Check for NaNs/Infs in diff_features
        if torch.isnan(diff_features).any() or torch.isinf(diff_features).any():
            print(f"[DEBUG MODEL] NaN/Inf in diff_features! Shape: {diff_features.shape}")
            if c1_diff_mask is not None:
                print(f"  c1_mask sum: {c1_diff_mask.sum(dim=1).min().item()}")
            print(f"  c1_diff nan: {torch.isnan(c1_diff).any().item()}")
            print(f"  c2_diff nan: {torch.isnan(c2_diff).any().item()}")

        # ============== SOURCE TYPE CONDITIONING FOR DDG ==============
        # Apply same source conditioning to diff_features for ddG prediction
        # This allows ddG head to learn different behaviors for different data sources
        if source_type_ids is not None:
            src_emb = self.source_type_embedding(source_type_ids)
            conditioned_diff = torch.cat([diff_features, src_emb], dim=-1)
            diff_features = self.source_type_projection(conditioned_diff)

        # MSCP Hybrid: Skip 20x gain since MSCP provides raw, un-diluted signal
        # vnew64.0: Shallow 2-layer GELU + skip connection to preserve variance
        ddg_hidden_out = self.ddg_hidden(diff_features)
        ddg_pred = (self.ddg_out(ddg_hidden_out) + self.ddg_skip(diff_features)).squeeze(-1)

        #region agent log
        # Diagnose "train variance but eval constant" which often indicates dropout-only variance or head collapse.
        try:
            if should_log:
                # Diff feature stats across the batch
                df = diff_features.detach()
                df_mean = float(df.mean().item()) if df.numel() else None
                df_std = float(df.std().item()) if df.numel() else None
                df_abs_mean = float(df.abs().mean().item()) if df.numel() else None
                # per-sample spread: average std over features (helps detect "all samples identical")
                df_per_sample_std = float(df.float().std(dim=1).mean().item()) if df.dim() == 2 and df.shape[0] > 0 else None

                # ddg_pred stats (across batch)
                p = ddg_pred.detach()
                p_mean = float(p.mean().item()) if p.numel() else None
                p_std = float(p.std().item()) if p.numel() else None

                # If dropout is active (training), a second forward pass should differ.
                p2_std = None
                p_diff_std = None
                if self.training:
                    p2 = (self.ddg_out(self.ddg_hidden(diff_features)) + self.ddg_skip(diff_features)).squeeze(-1).detach()
                    p2_std = float(p2.std().item()) if p2.numel() else None
                    p_diff_std = float((p2 - p).std().item()) if p2.numel() else None

                # Weight/bias norms (to detect collapse to near-zero weights or bias-only prediction)
                lin_layers = [m for m in self.ddg_hidden.modules() if isinstance(m, nn.Linear)] if hasattr(self, "ddg_hidden") else []
                w0 = lin_layers[0] if len(lin_layers) > 0 else None
                wL = lin_layers[-1] if len(lin_layers) > 0 else None
                w0_norm = float(w0.weight.detach().norm().item()) if w0 is not None else None
                wL_norm = float(wL.weight.detach().norm().item()) if wL is not None else None
                bL_norm = float(wL.bias.detach().norm().item()) if (wL is not None and wL.bias is not None) else None

                payload = {
                    "sessionId": "debug-session",
                    "runId": "pre-fix",
                    "hypothesisId": "I",
                    "location": "architectures.py:DualHeadAffinityPredictor:ddg_head_eval_vs_train",
                    "message": "ddG head collapse vs dropout-only variance diagnostics",
                    "data": {
                        "ddg_log_counter": int(self._ddg_log_counter),
                        "model_training": bool(self.training),
                        "normalize_difference": bool(getattr(self, "normalize_difference", False)),
                        "ddg_signal_gain": float(self.ddg_signal_gain.detach().item()) if hasattr(self, "ddg_signal_gain") else None,
                        "diff_features_mean": df_mean,
                        "diff_features_std": df_std,
                        "diff_features_abs_mean": df_abs_mean,
                        "diff_features_per_sample_std_mean": df_per_sample_std,
                        "ddg_pred_mean": p_mean,
                        "ddg_pred_std": p_std,
                        "ddg_pred2_std": p2_std,
                        "ddg_pred_repeat_diff_std": p_diff_std,
                        "ddg_w0_norm": w0_norm,
                        "ddg_wL_norm": wL_norm,
                        "ddg_bL_norm": bL_norm,
                    },
                    "timestamp": int(time.time() * 1000),
                }
                # Log first (before file write that may fail on cluster)
                logger.info(
                    f"[AGENTLOG DDGHEAD] train={self.training} df_std={df_std:.4f} df_ps_std={df_per_sample_std:.4f} "
                    f"pred_std={p_std:.4f} pred_mean={p_mean:.4f} rep_diff_std={p_diff_std if p_diff_std is not None else 'NA'} "
                    f"w0={w0_norm:.2f} wL={wL_norm:.2f} bL={bL_norm if bL_norm is not None else 'NA'}"
                )

                # Layerwise activation trace to locate where variance collapses (ReLU dead / dropout-only variance)
                layer_stats = []
                try:
                    with torch.no_grad():
                        x = diff_features.detach()
                        for li, layer in enumerate(self.ddg_hidden):
                            x = layer(x)
                            st = {
                                "i": int(li),
                                "t": layer.__class__.__name__,
                                "mean": float(x.mean().item()) if x.numel() else None,
                                "std": float(x.std().item()) if x.numel() else None,
                            }
                            if isinstance(layer, nn.ReLU):
                                st["zero_frac"] = float((x == 0).float().mean().item()) if x.numel() else None
                            layer_stats.append(st)
                    # Compact summary for logs (first 2 + last 2 layers)
                    compact = (layer_stats[:2] + (["..."] if len(layer_stats) > 4 else []) + layer_stats[-2:])
                    logger.info(f"[AGENTLOG DDGACT] train={self.training} layers={compact}")
                except Exception:
                    layer_stats = []
                # File write may fail on cluster - that's OK
                try:
                    payload["data"]["ddg_layer_stats"] = layer_stats
                    with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
                        f.write(json.dumps(payload, default=str) + "\n")
                except Exception:
                    pass
        except Exception:
            pass
        #endregion

        # DEBUG: Check ddg_pred
        if torch.isnan(ddg_pred).any() or torch.isinf(ddg_pred).any():
             print(f"[DEBUG MODEL] NaN/Inf in ddg_pred!")
        
        return dg_pred, ddg_pred


class DualHeadAffinityPredictionModel(pl.LightningModule):
    """
    Lightning wrapper for dual-head training.
    """
    def __init__(self, predictor: DualHeadAffinityPredictor, learning_rate=1e-4, ddg_loss_weight=1.0):
        super().__init__()
        self.predictor = predictor
        self.learning_rate = learning_rate
        self.ddg_loss_weight = ddg_loss_weight

        self.loss_fn = nn.MSELoss()
        self.pearson_corr = PearsonCorrCoef()
        self.spearman_corr = SpearmanCorrCoef()
        self.r2_score = R2Score()
        self.mse_metric = MeanSquaredError()
        
        # Save hyperparameters for checkpointing
        self.save_hyperparameters(ignore=['predictor'])

    def forward(self, mut_chain1, mut_chain1_mask, mut_chain2, mut_chain2_mask, 
                wt_chain1=None, wt_chain1_mask=None, wt_chain2=None, wt_chain2_mask=None,
                source_type_ids=None):
        return self.predictor(mut_chain1, mut_chain1_mask, mut_chain2, mut_chain2_mask,
                             wt_chain1, wt_chain1_mask, wt_chain2, wt_chain2_mask,
                             source_type_ids=source_type_ids)

    def training_step(self, batch, batch_idx):
        # Mutant data
        (c1, m1, c2, m2, y_mut) = batch["mutant"]
        
        # Wildtype data with valid mask
        (cw1, w1m, cw2, w2m, y_wt) = batch["wildtype"]
        has_wt = batch["has_wt"]
        
        if self.predictor.use_dual_head and has_wt.sum() > 0:
            # For samples with wildtype available, use dual-head prediction
            valid_samples = has_wt.bool()
            
            # Get predictions for valid samples
            dg_pred, ddg_pred = self(
                c1[valid_samples], m1[valid_samples], 
                c2[valid_samples], m2[valid_samples],
                cw1[valid_samples], w1m[valid_samples], 
                cw2[valid_samples], w2m[valid_samples]
            )
            
            # Calculate losses for valid samples
            dg_loss = self.loss_fn(dg_pred, y_mut[valid_samples])
            
            # Calculate true ddG as difference between mutant and wildtype dG
            true_ddg = y_mut[valid_samples] - y_wt[valid_samples]
            ddg_loss = self.loss_fn(ddg_pred, true_ddg)
            
            # Combined loss with weighting
            loss = dg_loss + self.ddg_loss_weight * ddg_loss
            
            # Process remaining samples (without wildtype) with standard prediction
            if (~valid_samples).sum() > 0:
                standard_dg_pred = self(
                    c1[~valid_samples], m1[~valid_samples], 
                    c2[~valid_samples], m2[~valid_samples]
                )
                standard_loss = self.loss_fn(standard_dg_pred, y_mut[~valid_samples])
                
                # Add to total loss, weighted by proportion of samples
                n_valid = valid_samples.sum()
                n_total = len(valid_samples)
                loss = (n_valid / n_total) * loss + ((n_total - n_valid) / n_total) * standard_loss
        else:
            # Standard prediction for all samples
            dg_pred = self(c1, m1, c2, m2)
            loss = self.loss_fn(dg_pred, y_mut)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Similar to training_step but with more comprehensive metrics
        (c1, m1, c2, m2, y_mut) = batch["mutant"]
        (cw1, w1m, cw2, w2m, y_wt) = batch["wildtype"]
        has_wt = batch["has_wt"]
        
        # Store predictions and targets for all samples
        all_dg_preds = []
        all_dg_targets = []
        all_ddg_preds = []
        all_ddg_targets = []
        
        if self.predictor.use_dual_head and has_wt.sum() > 0:
            valid_samples = has_wt.bool()
            
            # Dual-head prediction for samples with wildtype
            dg_pred, ddg_pred = self(
                c1[valid_samples], m1[valid_samples], 
                c2[valid_samples], m2[valid_samples],
                cw1[valid_samples], w1m[valid_samples], 
                cw2[valid_samples], w2m[valid_samples]
            )
            
            # Calculate true ddG
            true_ddg = y_mut[valid_samples] - y_wt[valid_samples]
            
            # Store predictions and targets
            all_dg_preds.append(dg_pred)
            all_dg_targets.append(y_mut[valid_samples])
            all_ddg_preds.append(ddg_pred)
            all_ddg_targets.append(true_ddg)
            
            # Process remaining samples with standard prediction
            if (~valid_samples).sum() > 0:
                standard_dg_pred = self(
                    c1[~valid_samples], m1[~valid_samples], 
                    c2[~valid_samples], m2[~valid_samples]
                )
                all_dg_preds.append(standard_dg_pred)
                all_dg_targets.append(y_mut[~valid_samples])
        else:
            # Standard prediction for all samples
            dg_pred = self(c1, m1, c2, m2)
            all_dg_preds.append(dg_pred)
            all_dg_targets.append(y_mut)
            
            # For samples with wildtype, calculate implicit ddG
            if has_wt.sum() > 0:
                valid_samples = has_wt.bool()
                wt_dg_pred = self(cw1[valid_samples], w1m[valid_samples], 
                                 cw2[valid_samples], w2m[valid_samples])
                
                implicit_ddg_pred = dg_pred[valid_samples] - wt_dg_pred
                true_ddg = y_mut[valid_samples] - y_wt[valid_samples]
                
                all_ddg_preds.append(implicit_ddg_pred)
                all_ddg_targets.append(true_ddg)
        
        # Concatenate all predictions and targets
        if all_dg_preds:
            all_dg_preds = torch.cat(all_dg_preds)
            all_dg_targets = torch.cat(all_dg_targets)
            
            # Calculate dG metrics
            dg_mse = self.mse_metric(all_dg_preds, all_dg_targets)
            dg_pearson = self.pearson_corr(all_dg_preds, all_dg_targets)
            dg_spearman = self.spearman_corr(all_dg_preds, all_dg_targets)
            dg_r2 = self.r2_score(all_dg_preds, all_dg_targets)
            
            # Log dG metrics
            self.log('val_dg_mse', dg_mse, on_epoch=True, prog_bar=True)
            self.log('val_dg_pearson', dg_pearson, on_epoch=True)
            self.log('val_dg_spearman', dg_spearman, on_epoch=True)
            self.log('val_dg_r2', dg_r2, on_epoch=True)
        
        # Calculate ddG metrics if available
        if all_ddg_preds:
            all_ddg_preds = torch.cat(all_ddg_preds)
            all_ddg_targets = torch.cat(all_ddg_targets)
            
            ddg_mse = self.mse_metric(all_ddg_preds, all_ddg_targets)
            ddg_pearson = self.pearson_corr(all_ddg_preds, all_ddg_targets)
            ddg_spearman = self.spearman_corr(all_ddg_preds, all_ddg_targets)
            ddg_r2 = self.r2_score(all_ddg_preds, all_ddg_targets)
            
            # Log ddG metrics
            self.log('val_ddg_mse', ddg_mse, on_epoch=True, prog_bar=True)
            self.log('val_ddg_pearson', ddg_pearson, on_epoch=True)
            self.log('val_ddg_spearman', ddg_spearman, on_epoch=True)
            self.log('val_ddg_r2', ddg_r2, on_epoch=True)
            
            # Combined validation metric for early stopping
            combined_metric = dg_mse + self.ddg_loss_weight * ddg_mse
            self.log('val_combined_metric', combined_metric, on_epoch=True)
        
        return {'val_dg_mse': dg_mse if 'dg_mse' in locals() else None,
                'val_ddg_mse': ddg_mse if 'ddg_mse' in locals() else None}

    def test_step(self, batch, batch_idx):
        # Similar to validation_step but returns more detailed metrics
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.learning_rate, total_steps=self.trainer.estimated_stepping_batches
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_combined_metric"}