Vaishnavey commited on
Commit
49d2467
·
verified ·
1 Parent(s): 6eaad48

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +57 -0
  2. best_model_checkpoint.pt +3 -0
  3. esm3bedding.py +86 -0
  4. modules.py +1744 -0
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - protein
5
+ - binding-affinity
6
+ - deep-learning
7
+ - esm
8
+ - pytorch
9
+ language:
10
+ - en
11
+ ---
12
+
13
+ # 🧬 Protein Binding Affinity Predictor
14
+
15
+ Dual-head model for predicting protein-protein binding affinity (ΔG) and mutation effects (ΔΔG).
16
+
17
+ ## Model Performance
18
+
19
+ | Metric | Validation Score |
20
+ |--------|-----------------|
21
+ | dG Pearson | 0.51 |
22
+ | ddG Pearson | 0.70 |
23
+ | Sum PCC | 1.21 |
24
+
25
+ ## Architecture
26
+
27
+ - **Backbone**: ESM-600M (frozen embeddings)
28
+ - **Pooling**: Sliced-Wasserstein Embedding (SWE)
29
+ - **Heads**: Dual-head (dG + ddG)
30
+ - **Input**: Protein sequences (1153-dim = 1152 ESM + 1 mutation channel)
31
+
32
+ ## Usage
33
+
34
+ ```python
35
+ from huggingface_hub import hf_hub_download
36
+ import torch
37
+
38
+ # Download checkpoint
39
+ ckpt = hf_hub_download(repo_id="supanthadey1/protein-binding-affinity", filename="best_model_checkpoint.pt")
40
+ checkpoint = torch.load(ckpt, map_location='cpu')
41
+ model.load_state_dict(checkpoint['model_state_dict'])
42
+ ```
43
+
44
+ ## Predictions
45
+
46
+ - **ΔG (kcal/mol)**: Binding free energy. More negative = stronger binding.
47
+ - **ΔΔG (kcal/mol)**: Mutation effect. Negative = stabilizing, Positive = destabilizing.
48
+
49
+ ## Training Data
50
+
51
+ Trained on multiple datasets including SKEMPI, BindingGym, PDBbind, and others.
52
+
53
+ ## Citation
54
+
55
+ ```
56
+ [Citation coming soon]
57
+ ```
best_model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ac87dbc506c018fcf8b26f296d595350e8544adc2034da24cdd6cdd03e6b9a6
3
+ size 1603771034
esm3bedding.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # esm3bedding.py
2
+
3
+ import os
4
+ import torch
5
+ from esm.models.esmc import ESMC
6
+ from esm.sdk.api import ESMProtein, LogitsConfig
7
+ from huggingface_hub import login
8
+ from utils import get_logger
9
+ from base import Featurizer
10
+
11
+ logg = get_logger()
12
+
13
+ class ESM3Featurizer(Featurizer):
14
+ def __init__(self, save_dir: str, api_key: str, per_tok: bool = True):
15
+ super().__init__("ESM3", 1152, save_dir=save_dir)
16
+ self.per_tok = per_tok
17
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ self.client = None
19
+
20
+ self._login(api_key)
21
+ self._initialize_model()
22
+
23
+ def _login(self, api_key: str):
24
+ try:
25
+ login(api_key)
26
+ logg.info("Successfully logged into Hugging Face Hub.")
27
+ except Exception as e:
28
+ logg.error(f"Failed to log in to Hugging Face Hub: {e}")
29
+ raise RuntimeError("Hugging Face login failed. Check your API key.")
30
+
31
+ def _initialize_model(self):
32
+ try:
33
+ logg.info("Initializing ESMC model (esmc_600m)...")
34
+
35
+ # First try normal online loading
36
+ try:
37
+ self.client = ESMC.from_pretrained("esmc_600m")
38
+ self.client.to(self._device)
39
+ logg.info("ESMC model loaded.")
40
+ return
41
+ except Exception as online_error:
42
+ logg.warning(f"Online model loading failed: {online_error}")
43
+ logg.info("Attempting offline mode (using local cache)...")
44
+
45
+ # Fallback: Try offline mode using cached files
46
+ import os
47
+ os.environ["HF_HUB_OFFLINE"] = "1"
48
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
49
+
50
+ try:
51
+ self.client = ESMC.from_pretrained("esmc_600m", local_files_only=True)
52
+ self.client.to(self._device)
53
+ logg.info("ESMC model loaded from local cache (offline mode).")
54
+ except Exception as offline_error:
55
+ logg.error(f"Offline loading also failed: {offline_error}")
56
+ logg.error("="*60)
57
+ logg.error("ESMC MODEL NOT FOUND IN CACHE!")
58
+ logg.error("Run this on a node with internet access to cache the model:")
59
+ logg.error(" python -c \"from esm.models.esmc import ESMC; ESMC.from_pretrained('esmc_600m')\"")
60
+ logg.error("="*60)
61
+ raise RuntimeError("ESMC model not available. See error messages above.")
62
+
63
+ except Exception as e:
64
+ logg.error(f"Failed to load ESMC model: {e}")
65
+ raise RuntimeError("ESMC model initialization failed.")
66
+
67
+ def _transform(self, sequence: str) -> torch.Tensor:
68
+ try:
69
+ # REPLACE (not remove) invalid chars to preserve sequence length
70
+ valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
71
+ clean_sequence = ''.join(c if c in valid_aa else 'A' for c in sequence.upper())
72
+
73
+ protein = ESMProtein(sequence=clean_sequence)
74
+ protein_tensor = self.client.encode(protein)
75
+ logits_config = LogitsConfig(sequence=True, return_embeddings=True)
76
+ output = self.client.logits(protein_tensor, logits_config)
77
+ embeddings = output.embeddings # shape => [1, L, D] or [L, D]
78
+ if embeddings.dim() == 3 and embeddings.shape[0] == 1:
79
+ embeddings = embeddings.squeeze(0) # => [L, D]
80
+
81
+ if not self.per_tok:
82
+ embeddings = embeddings.mean(dim=0) # => [D]
83
+ return embeddings
84
+ except Exception as e:
85
+ logg.error(f"Error generating embeddings for sequence: {e}")
86
+ return None
modules.py ADDED
@@ -0,0 +1,1744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from pathlib import Path
3
+ import os
4
+ import math
5
+ import pandas as pd
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ import numpy as np
11
+ import hashlib
12
+ import json
13
+ import time
14
+ from esm3bedding import ESM3Featurizer
15
+ from utils import get_logger
16
+
17
+ logg = get_logger()
18
+
19
+ #########################################
20
+ # Source Type Mapping #
21
+ #########################################
22
+ SOURCE_TYPE_MAP = {
23
+ # Protein complexes (unique structures)
24
+ 'PDBbind': 'protein_complex',
25
+ 'PPIKB': 'protein_complex',
26
+ 'asd_biomap': 'protein_complex',
27
+ 'asd_aae': 'protein_complex',
28
+ 'asd_aatp': 'protein_complex',
29
+ 'asd_osh': 'protein_complex',
30
+ # True mutations
31
+ 'SKEMPI': 'mutation',
32
+ 'BindingGym': 'mutation',
33
+ 'asd_flab_koenig2017': 'mutation', # 1-2aa differences
34
+ 'asd_flab_warszawski2019': 'mutation', # 1-2aa differences
35
+ 'asd_flab_rosace2023': 'mutation', # 1-5aa differences
36
+ 'PEPBI': 'mutation',
37
+ # Antibody CDR variants
38
+ 'asd_abbd': 'antibody_cdr', # 3-14aa CDR differences
39
+ 'abdesign': 'antibody_cdr',
40
+ 'asd_flab_hie2022': 'antibody_cdr', # 2-17aa differences
41
+ 'asd_flab_shanehsazzadeh2023': 'antibody_cdr', # 3-18aa differences
42
+ }
43
+ SOURCE_TYPE_TO_ID = {'protein_complex': 0, 'mutation': 1, 'antibody_cdr': 2}
44
+ DEFAULT_SOURCE_TYPE = 'mutation' # Default for unknown sources
45
+
46
+
47
+ #########################################
48
+ # Collate function (Siamese) #
49
+ #########################################
50
+ def advanced_collate_fn(batch):
51
+ mut_c1_list, mut_c2_list, mut_y_list = [], [], []
52
+ wt_c1_list, wt_c2_list, wt_y_list = [], [], []
53
+ has_valid_wt_list = [] # CRITICAL: Track which samples have REAL WT embeddings (not zeros)
54
+ meta_list = []
55
+
56
+ for data, meta in batch:
57
+ (c1, c2, y, cw1, cw2, yw) = data
58
+ # mutant
59
+ mut_c1_list.append(c1)
60
+ mut_c2_list.append(c2)
61
+ mut_y_list.append(torch.tensor([y], dtype=torch.float32))
62
+ # wildtype
63
+ if cw1 is not None and cw2 is not None and yw is not None:
64
+ wt_c1_list.append(cw1)
65
+ wt_c2_list.append(cw2)
66
+ wt_y_list.append(torch.tensor([yw], dtype=torch.float32))
67
+ has_valid_wt_list.append(True) # Real WT data available
68
+ else:
69
+ # fallback if no known WT - ZEROS corrupt ddG signal!
70
+ wt_c1_list.append(torch.zeros((1, c1.shape[1])))
71
+ wt_c2_list.append(torch.zeros((1, c2.shape[1])))
72
+ wt_y_list.append(torch.tensor([0.0], dtype=torch.float32))
73
+ has_valid_wt_list.append(False) # INVALID for ddG - would compute mut-0=mut
74
+
75
+ meta_list.append(meta)
76
+
77
+ # pad mutant
78
+ c1_padded = pad_sequence(mut_c1_list, batch_first=True)
79
+ c2_padded = pad_sequence(mut_c2_list, batch_first=True)
80
+
81
+ B = c1_padded.shape[0]
82
+ N1 = c1_padded.shape[1]
83
+ N2 = c2_padded.shape[1]
84
+ c1_mask_list, c2_mask_list = [], []
85
+ for i in range(B):
86
+ l1 = mut_c1_list[i].shape[0]
87
+ l2 = mut_c2_list[i].shape[0]
88
+ m1 = [True]*l1 + [False]*(N1-l1)
89
+ m2 = [True]*l2 + [False]*(N2-l2)
90
+ c1_mask_list.append(torch.tensor(m1, dtype=torch.bool))
91
+ c2_mask_list.append(torch.tensor(m2, dtype=torch.bool))
92
+ c1_mask = torch.stack(c1_mask_list, dim=0)
93
+ c2_mask = torch.stack(c2_mask_list, dim=0)
94
+ y_mut = torch.cat(mut_y_list, dim=0)
95
+
96
+ # pad wildtype
97
+ w1_padded = pad_sequence(wt_c1_list, batch_first=True)
98
+ w2_padded = pad_sequence(wt_c2_list, batch_first=True)
99
+ N1w = w1_padded.shape[1]
100
+ N2w = w2_padded.shape[1]
101
+ w1_mask_list, w2_mask_list = [], []
102
+ for i in range(B):
103
+ l1 = wt_c1_list[i].shape[0]
104
+ l2 = wt_c2_list[i].shape[0]
105
+ m1 = [True]*l1 + [False]*(N1w-l1)
106
+ m2 = [True]*l2 + [False]*(N2w-l2)
107
+ w1_mask_list.append(torch.tensor(m1, dtype=torch.bool))
108
+ w2_mask_list.append(torch.tensor(m2, dtype=torch.bool))
109
+ w1_mask = torch.stack(w1_mask_list, dim=0)
110
+ w2_mask = torch.stack(w2_mask_list, dim=0)
111
+ y_wt = torch.cat(wt_y_list, dim=0)
112
+
113
+ has_wt_list = []
114
+ is_wt_list = [] # NEW: Track which samples ARE WT (not just have WT reference)
115
+ has_dg_list = []
116
+ has_ddg_list = [] # Track which samples have valid explicit ddG
117
+ has_inferred_ddg_list = [] # NEW: Track which samples have inferred ddG
118
+ has_both_list = []
119
+ ddg_list = []
120
+ ddg_inferred_list = [] # NEW: Inferred ddG values
121
+
122
+ # DEBUG: Track data consistency
123
+ n_has_ddg_true = 0
124
+ n_ddg_zero = 0
125
+ n_ddg_nan = 0
126
+
127
+ for i in range(B):
128
+ # from meta - use has_any_wt to include both real and inferred WT sequences
129
+ has_wt_list.append(meta_list[i].get("has_any_wt", meta_list[i].get("has_real_wt", False)))
130
+ is_wt_list.append(meta_list[i].get("is_wt", False)) # NEW: Whether sample IS a WT sample (not mutant)
131
+ has_dg_list.append(meta_list[i].get("has_dg", False)) # Default False to prevent false positives
132
+ # FIX: Include inferred ddG in has_ddg flag so validation samples with dG_mut and dG_wt are used
133
+ has_explicit_ddg = meta_list[i].get("has_ddg", False)
134
+ has_inferred_ddg_flag = meta_list[i].get("has_inferred_ddg", False)
135
+ # has_ddg should be True if we have EITHER explicit OR inferred ddG
136
+ has_ddg_flag = has_explicit_ddg or has_inferred_ddg_flag
137
+ has_ddg_list.append(has_ddg_flag)
138
+ has_inferred_ddg_list.append(has_inferred_ddg_flag)
139
+ has_both_list.append(meta_list[i].get("has_both_dg_ddg", False)) # For symmetric consistency
140
+
141
+ # FIX: Use explicit ddG if available, otherwise use inferred ddG (dG_mut - dG_wt)
142
+ ddg_val = meta_list[i].get("ddg", float('nan'))
143
+ ddg_inf_val = meta_list[i].get("ddg_inferred", float('nan'))
144
+ is_explicit_nan = ddg_val != ddg_val
145
+ is_inferred_nan = ddg_inf_val != ddg_inf_val
146
+
147
+ # DEBUG: Check for data consistency issues
148
+ if has_explicit_ddg:
149
+ n_has_ddg_true += 1
150
+ if is_explicit_nan:
151
+ n_ddg_nan += 1
152
+ elif abs(ddg_val) < 1e-8:
153
+ n_ddg_zero += 1
154
+
155
+ # Priority: explicit ddG > inferred ddG > 0.0 fallback (masked out)
156
+ if not is_explicit_nan:
157
+ ddg_list.append(ddg_val)
158
+ elif not is_inferred_nan:
159
+ ddg_list.append(ddg_inf_val) # Use inferred ddG when explicit unavailable
160
+ else:
161
+ ddg_list.append(0.0) # Fallback (will be masked by has_ddg=False)
162
+ # Collect inferred ddG values for separate tracking (already fetched above)
163
+ ddg_inferred_list.append(ddg_inf_val if not is_inferred_nan else 0.0)
164
+
165
+ # DEBUG: Log batch statistics if there are issues
166
+ if n_has_ddg_true > 0 and (n_ddg_nan > 0 or n_ddg_zero > B // 2):
167
+ print(f"[COLLATE DEBUG] Batch has_ddg stats: {n_has_ddg_true}/{B} have has_ddg=True, "
168
+ f"{n_ddg_nan} have NaN ddg (BUG!), {n_ddg_zero} have ddg≈0")
169
+
170
+ has_wt = torch.tensor(has_wt_list, dtype=torch.bool)
171
+ has_valid_wt = torch.tensor(has_valid_wt_list, dtype=torch.bool) # CRITICAL: Only True if WT is real (not zeros)
172
+ is_wt = torch.tensor(is_wt_list, dtype=torch.bool) # Sample IS a WT sample
173
+ has_dg = torch.tensor(has_dg_list, dtype=torch.bool)
174
+ has_ddg = torch.tensor(has_ddg_list, dtype=torch.bool)
175
+ has_inferred_ddg = torch.tensor(has_inferred_ddg_list, dtype=torch.bool)
176
+ has_both_dg_ddg = torch.tensor(has_both_list, dtype=torch.bool)
177
+ ddg_labels = torch.tensor(ddg_list, dtype=torch.float32)
178
+ ddg_inferred_labels = torch.tensor(ddg_inferred_list, dtype=torch.float32)
179
+
180
+ # DEBUG: Log WT validity stats for first few batches
181
+ n_valid_wt = has_valid_wt.sum().item()
182
+ n_has_wt = has_wt.sum().item()
183
+ if n_has_wt > 0 and n_valid_wt < n_has_wt:
184
+ print(f"[COLLATE DEBUG] WT validity: {n_valid_wt}/{n_has_wt} have valid WT embeddings "
185
+ f"({n_has_wt - n_valid_wt} samples have zero-fallback and will be EXCLUDED from ddG training)")
186
+
187
+ # Collect data_source for per-source metrics
188
+ data_source_list = [meta_list[i].get("data_source", "unknown") for i in range(B)]
189
+
190
+ # Collect source_type_ids for model conditioning
191
+ source_type_id_list = []
192
+ for i in range(B):
193
+ data_src = meta_list[i].get("data_source", "unknown")
194
+ source_type = SOURCE_TYPE_MAP.get(data_src, DEFAULT_SOURCE_TYPE)
195
+ source_type_id = SOURCE_TYPE_TO_ID[source_type]
196
+ source_type_id_list.append(source_type_id)
197
+ source_type_ids = torch.tensor(source_type_id_list, dtype=torch.long)
198
+
199
+ out = {
200
+ "mutant": (c1_padded, c1_mask, c2_padded, c2_mask, y_mut),
201
+ "wildtype": (w1_padded, w1_mask, w2_padded, w2_mask, y_wt),
202
+ "has_wt": has_wt,
203
+ "has_valid_wt": has_valid_wt, # CRITICAL: True only if WT embeddings are real (not zeros)
204
+ "is_wt": is_wt, # Sample IS a WT sample (for routing to dG head)
205
+ "has_dg": has_dg, # Whether samples have absolute dG values
206
+ "has_ddg": has_ddg, # Whether samples have valid explicit ddG values
207
+ "has_inferred_ddg": has_inferred_ddg, # Whether samples have inferred ddG
208
+ "has_both_dg_ddg": has_both_dg_ddg, # For symmetric consistency loss
209
+ "ddg_labels": ddg_labels, # Direct ddG labels for BindingGym-style data
210
+ "ddg_inferred_labels": ddg_inferred_labels, # Inferred ddG = dG_mut - dG_wt
211
+ "data_source": data_source_list, # For per-source validation metrics
212
+ "source_type_ids": source_type_ids, # For model conditioning (0=protein_complex, 1=mutation, 2=antibody_cdr)
213
+ "metadata": meta_list
214
+ }
215
+ return out
216
+
217
+ #########################################
218
+ # SiameseDataset (Simplified) #
219
+ #########################################
220
+ class AdvancedSiameseDataset(Dataset):
221
+ """
222
+ Dataset that handles mutation positions with a simple indicator channel.
223
+
224
+ Reads columns:
225
+ #Pdb, block1_sequence, block1_mut_positions, block1_mutations,
226
+ block2_sequence, block2_mut_positions, block2_mutations, del_g, ...
227
+ """
228
+ def __init__(self, df: pd.DataFrame, featurizer: ESM3Featurizer, embedding_dir: str,
229
+ normalize_embeddings=True, augment=False, max_len=1022,
230
+ wt_reference_df: pd.DataFrame = None):
231
+ super().__init__()
232
+
233
+ # Store WT reference DF (e.g. training set) for looking up missing WTs
234
+ # This enables Implicit ddG (dG_mut - dG_wt) even if WTs are not in the current split
235
+ self.wt_reference_df = wt_reference_df if wt_reference_df is not None else None
236
+ initial_len = len(df)
237
+
238
+ # CRITICAL FIX: Do NOT drop rows based on length because it shifts indices!
239
+ # External splits (indices) rely on the original row numbers.
240
+ # Instead, we TRUNCATE sequences that are too long to maintain alignment.
241
+
242
+ # Identify long sequences
243
+ long_mask = (df["block1_sequence"].astype(str).str.len() > max_len) | \
244
+ (df["block2_sequence"].astype(str).str.len() > max_len)
245
+ n_long = long_mask.sum()
246
+
247
+ if n_long > 0:
248
+ print(f" [Dataset] Truncating {n_long} samples with length > {max_len} to maintain index alignment (CRITICAL FIX).")
249
+ # Truncate sequences in place
250
+ # Use .copy() to avoid SettingWithCopyWarning if df is a slice
251
+ df = df.copy()
252
+ df.loc[long_mask, "block1_sequence"] = df.loc[long_mask, "block1_sequence"].astype(str).str.slice(0, max_len)
253
+ df.loc[long_mask, "block2_sequence"] = df.loc[long_mask, "block2_sequence"].astype(str).str.slice(0, max_len)
254
+
255
+ # No rows dropped, so indices remain aligned with split files
256
+ self.df = df.reset_index(drop=True)
257
+
258
+ #region agent log
259
+ try:
260
+ cols = set(self.df.columns.tolist())
261
+ need = {"block1_mut_positions", "block2_mut_positions", "Mutation(s)_PDB"}
262
+ missing = sorted(list(need - cols))
263
+ payload = {
264
+ "sessionId": "debug-session",
265
+ "runId": "pre-fix",
266
+ "hypothesisId": "G",
267
+ "location": "modules.py:AdvancedSiameseDataset:__init__",
268
+ "message": "Dataset columns presence check for mutation positions",
269
+ "data": {
270
+ "n_rows": int(len(self.df)),
271
+ "has_block1_mut_positions": "block1_mut_positions" in cols,
272
+ "has_block2_mut_positions": "block2_mut_positions" in cols,
273
+ "has_mutation_pdb": "Mutation(s)_PDB" in cols,
274
+ "missing": missing,
275
+ },
276
+ "timestamp": int(time.time() * 1000),
277
+ }
278
+ with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
279
+ f.write(json.dumps(payload, default=str) + "\n")
280
+ print(f"[AGENTLOG MUTPOSCOLS] missing={missing}")
281
+ except Exception:
282
+ pass
283
+ #endregion
284
+
285
+ #region agent log
286
+ # Disambiguate whether "0 positions" is happening for MUT embeddings or WT embeddings
287
+ try:
288
+ if not hasattr(self, "_agent_embed_call_counter"):
289
+ self._agent_embed_call_counter = 0
290
+ if self._agent_embed_call_counter < 10:
291
+ self._agent_embed_call_counter += 1
292
+ print(
293
+ f"[AGENTLOG EMBCALL] idx={idx} role=mut "
294
+ f"b1_mutpos_n={len(b1_mutpos)} b2_mutpos_n={len(b2_mutpos)} "
295
+ f"seq1_len={len(item.get('seq1',''))} seq2_len={len(item.get('seq2',''))}"
296
+ )
297
+ except Exception:
298
+ pass
299
+ #endregion
300
+
301
+ # Recover antibody WTs (ANTIBODY_MUTATION) before augmentation or indexing
302
+ self.df = self._recover_antibody_wts(self.df)
303
+
304
+ # ---------- OPTIONAL AUGMENT: reverse mutation (mut ↔ WT) ----------
305
+ # Only augment MUTANT samples (not WT) - WT samples don't benefit from reversal
306
+ # and doubling them confuses the pdb_to_wt lookup
307
+ if augment:
308
+ # Identify mutant rows (non-empty Mutation(s)_PDB)
309
+ mut_mask = self.df["Mutation(s)_PDB"].notna() & (self.df["Mutation(s)_PDB"].str.strip() != "")
310
+ mutant_df = self.df[mut_mask].copy()
311
+
312
+ if len(mutant_df) > 0:
313
+ # Create reversed copies of mutant samples only
314
+ rev_df = mutant_df.copy()
315
+ # For the reverse augmentation we invert the sign of ddg
316
+ if "ddg" in rev_df.columns:
317
+ rev_df["ddg"] = -rev_df["ddg"]
318
+ rev_df["is_reverse"] = True # flag for reversed samples
319
+
320
+ # Original samples stay as-is
321
+ self.df["is_reverse"] = False
322
+ self.df = pd.concat([self.df, rev_df], ignore_index=True)
323
+ print(f" [Dataset] Augmented: added {len(rev_df)} reversed mutant samples (antisymmetry training)")
324
+ else:
325
+ self.df["is_reverse"] = False
326
+ else:
327
+ self.df["is_reverse"] = False
328
+ # -------------------------------------------------------------------
329
+
330
+ # ---------- PAIR ID (mutant – WT) ----------------------------------
331
+ # Use PDB + cleaned‑mutation string so mutant and its WT share an ID
332
+ self.df["pair_id"] = (
333
+ self.df["#Pdb"].astype(str) + "_" +
334
+ self.df["Mutation(s)_cleaned"].fillna("") # WT rows have empty mutation
335
+ )
336
+ # -------------------------------------------------------------------
337
+
338
+
339
+ self.featurizer = featurizer
340
+ self.embedding_dir = Path(embedding_dir)
341
+ self.embedding_dir.mkdir(exist_ok=True, parents=True)
342
+ self.normalize = normalize_embeddings
343
+
344
+ self.samples = []
345
+ self._embedding_cache = {} # LRU-style cache for frequently accessed embeddings
346
+ self._cache_max_size = 20000 # Cache up to 20k embeddings (~20-40GB RAM)
347
+ self._cache_hits = 0
348
+ self._cache_misses = 0
349
+
350
+ # map each PDB to a wildtype row index if it exists
351
+ print(f" [Dataset] Building WT index for {len(self.df)} rows...")
352
+ self.pdb_to_wt = {}
353
+ for i, row in self.df.iterrows():
354
+ pdb = row["#Pdb"]
355
+ mut_str = row.get("Mutation(s)_PDB","")
356
+ is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
357
+ if is_wt and pdb not in self.pdb_to_wt:
358
+ self.pdb_to_wt[pdb] = i
359
+
360
+ # Build external WT map if reference DF is provided
361
+ self.external_pdb_to_wt = {}
362
+ if self.wt_reference_df is not None:
363
+ print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...")
364
+ for i, row in self.wt_reference_df.iterrows():
365
+ # Only index actual WTs
366
+ mut_str = row.get("Mutation(s)_PDB","")
367
+ is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
368
+ if 'is_wt' in row: # Prioritize pre-computed flag
369
+ is_wt = is_wt or row['is_wt']
370
+
371
+ pdb = row["#Pdb"]
372
+ if is_wt and pdb not in self.external_pdb_to_wt:
373
+ self.external_pdb_to_wt[pdb] = i
374
+ print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.")
375
+
376
+ # Build external WT map if reference DF is provided
377
+ self.external_pdb_to_wt = {}
378
+ if self.wt_reference_df is not None:
379
+ print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...")
380
+ for i, row in self.wt_reference_df.iterrows():
381
+ # Only index actual WTs
382
+ mut_str = row.get("Mutation(s)_PDB","")
383
+ is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
384
+ # Also check 'is_wt' column if present
385
+ if 'is_wt' in row:
386
+ is_wt = is_wt or row['is_wt']
387
+
388
+ pdb = row["#Pdb"]
389
+ if is_wt and pdb not in self.external_pdb_to_wt:
390
+ self.external_pdb_to_wt[pdb] = i
391
+ print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.")
392
+
393
+ # LAZY LOADING: Only store metadata, NOT embeddings
394
+ # Embeddings will be loaded on-demand in __getitem__
395
+ print(f" [Dataset] Building sample metadata for {len(self.df)} rows (lazy loading)...")
396
+ from tqdm import tqdm
397
+ for i, row in tqdm(self.df.iterrows(), total=len(self.df), desc=" Indexing"):
398
+ # RESET computed mutations for this row to prevent stale data from previous iterations
399
+ if hasattr(self, '_last_computed_mutpos'):
400
+ del self._last_computed_mutpos
401
+
402
+ pdb = row["#Pdb"]
403
+ seq1 = row["block1_sequence"]
404
+ seq2 = row["block2_sequence"]
405
+
406
+ # Data source for per-source validation metrics
407
+ data_source = row.get("data_source", "unknown")
408
+
409
+ # Handle missing dG values (e.g., BindingGym has only ddG)
410
+ raw_delg = row["del_g"]
411
+ delg = float(raw_delg) if pd.notna(raw_delg) and raw_delg != '' else float('nan')
412
+
413
+ # Get ddG if available (for ddG-only datasets like BindingGym)
414
+ raw_ddg = row.get("ddg", None)
415
+ ddg = float(raw_ddg) if pd.notna(raw_ddg) and raw_ddg != '' else float('nan')
416
+
417
+ # Parse mutations (just store the string, parse later)
418
+ b1_mutpos_str = row.get("block1_mut_positions","[]")
419
+ b2_mutpos_str = row.get("block2_mut_positions","[]")
420
+
421
+ # DEBUG: Print first few rows to debug disappearing mutations
422
+ if i < 5:
423
+ print(f"DEBUG ROW {i}: b1='{b1_mutpos_str}' ({type(b1_mutpos_str)}), b2='{b2_mutpos_str}' ({type(b2_mutpos_str)})")
424
+ #region agent log
425
+ try:
426
+ payload = {
427
+ "sessionId": "debug-session",
428
+ "runId": "pre-fix",
429
+ "hypothesisId": "G",
430
+ "location": "modules.py:AdvancedSiameseDataset:__init__:row0_4",
431
+ "message": "Raw mutpos strings from df row (first few)",
432
+ "data": {
433
+ "i": int(i),
434
+ "b1_mutpos_str": str(b1_mutpos_str),
435
+ "b2_mutpos_str": str(b2_mutpos_str),
436
+ "mutation_pdb": str(row.get("Mutation(s)_PDB", "")),
437
+ },
438
+ "timestamp": int(time.time() * 1000),
439
+ }
440
+ with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
441
+ f.write(json.dumps(payload, default=str) + "\n")
442
+ print(f"[AGENTLOG MUTPOSRAW] i={i} b1={b1_mutpos_str} b2={b2_mutpos_str} mut={row.get('Mutation(s)_PDB','')}")
443
+ except Exception:
444
+ pass
445
+ #endregion
446
+
447
+ # Get chain info for block assignment during WT inference
448
+ b1_chains = str(row.get("block1_chains", "")).upper()
449
+ b2_chains = str(row.get("block2_chains", "")).upper()
450
+
451
+ mut_str = row.get("Mutation(s)_PDB","")
452
+ is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
453
+ wt_idx = self.pdb_to_wt.get(pdb, None)
454
+
455
+ # Get WT info if available (Internal > External)
456
+ row_wt = None
457
+ wt_source = None
458
+
459
+ if not hasattr(self, '_wt_source_stats'):
460
+ self._wt_source_stats = {'internal': 0, 'external': 0}
461
+
462
+ if wt_idx is not None:
463
+ row_wt = self.df.iloc[wt_idx]
464
+ wt_source = 'internal'
465
+ self._wt_source_stats['internal'] += 1
466
+ elif pdb in self.external_pdb_to_wt:
467
+ ext_idx = self.external_pdb_to_wt[pdb]
468
+ row_wt = self.wt_reference_df.iloc[ext_idx]
469
+ wt_source = 'external'
470
+ self._wt_source_stats['external'] += 1
471
+
472
+ if row_wt is not None:
473
+ seq1_wt = row_wt["block1_sequence"]
474
+ seq2_wt = row_wt["block2_sequence"]
475
+ raw_delg_wt = row_wt["del_g"]
476
+ delg_wt = float(raw_delg_wt) if pd.notna(raw_delg_wt) and raw_delg_wt != '' else float('nan')
477
+ b1_wtpos_str = row_wt.get("block1_mut_positions","[]")
478
+ b2_wtpos_str = row_wt.get("block2_mut_positions","[]")
479
+
480
+ # BUGFIX: If we have WT but NO mutation positions in CSV, we MUST calculate them!
481
+ # This fixes the "0% mutation positions" issue when the CSV column is empty/missing
482
+ if not is_wt and (b1_mutpos_str in ["[]", "", "nan", "None"] and b2_mutpos_str in ["[]", "", "nan", "None"]):
483
+ # Run inference to locate mutations (side-effect: sets _last_computed_mutpos)
484
+ # We ignore the inferred WT sequence since we have the real one
485
+ # We pass "[]" to force scanning PDB positions
486
+ self._infer_wt_sequences(
487
+ seq1, seq2, mut_str, "[]", "[]",
488
+ b1_chains, b2_chains
489
+ )
490
+
491
+ # Update mutpos_str if we found mutations
492
+ if hasattr(self, '_last_computed_mutpos'):
493
+ comp_b1, comp_b2 = self._last_computed_mutpos
494
+ if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1:
495
+ b1_mutpos_str = str(comp_b1)
496
+ if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2:
497
+ b2_mutpos_str = str(comp_b2)
498
+
499
+ else:
500
+ # No WT row found - try to INFER WT sequence by reversing mutations
501
+ # This is crucial for BindingGym data which stores mutant sequences only
502
+ seq1_wt, seq2_wt = self._infer_wt_sequences(
503
+ seq1, seq2, mut_str, b1_mutpos_str, b2_mutpos_str,
504
+ b1_chains, b2_chains # Chain info for block assignment
505
+ )
506
+ delg_wt = float('nan') # No WT dG available for inferred sequences
507
+ b1_wtpos_str, b2_wtpos_str = "[]", "[]" # WT has no mutation positions
508
+
509
+ # FIX Bug #3: Use computed mutation positions from inference if original empty
510
+ if hasattr(self, '_last_computed_mutpos'):
511
+ comp_b1, comp_b2 = self._last_computed_mutpos
512
+ if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1:
513
+ b1_mutpos_str = str(comp_b1)
514
+ if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2:
515
+ b2_mutpos_str = str(comp_b2)
516
+
517
+ # Check if this sample has BOTH dG and ddG (for symmetric consistency)
518
+ has_dg = not (delg != delg) # False if NaN
519
+ has_ddg = not (ddg != ddg) # False if NaN
520
+ has_both = has_dg and has_ddg
521
+
522
+ # NEW: Compute inferred ddG for samples with dG_mut and dG_wt but no explicit ddG
523
+ # ddG_inferred = dG_mut - dG_wt (can be used as additional training signal)
524
+ has_dg_wt = not (delg_wt != delg_wt) # False if NaN
525
+ has_inferred_ddg = has_dg and has_dg_wt and (not has_ddg) # Only if no explicit ddG
526
+ if has_inferred_ddg:
527
+ ddg_inferred = delg - delg_wt # Computed from dG values
528
+ else:
529
+ ddg_inferred = float('nan')
530
+
531
+ # Track WT availability: real (from row), inferred, or none
532
+ has_real_wt = (wt_idx is not None)
533
+ has_inferred_wt = (wt_idx is None and seq1_wt is not None and seq2_wt is not None)
534
+ has_any_wt = has_real_wt or has_inferred_wt
535
+
536
+ # Store ONLY metadata - no embeddings loaded yet!
537
+ is_reverse = row.get("is_reverse", False) # Track reversed samples
538
+
539
+ # CRITICAL: Swap sequences and dG for reversed samples (antisymmetry augmentation)
540
+ if is_reverse:
541
+ # Swap sequences: New Mutant = Old WT, New WT = Old Mutant
542
+ if seq1_wt is not None and seq2_wt is not None:
543
+ seq1, seq1_wt = seq1_wt, seq1
544
+ seq2, seq2_wt = seq2_wt, seq2
545
+ # Swap dG values
546
+ delg, delg_wt = delg_wt, delg
547
+ # Negate inferred ddG (dG_new_mut - dG_new_wt = dG_old_wt - dG_old_mut = -(dG_old_mut - dG_old_wt))
548
+ if not math.isnan(ddg_inferred):
549
+ ddg_inferred = -ddg_inferred
550
+ # Note: Explicit 'ddg' is already negated in __init__ augmentation logic
551
+ # Note: We do NOT swap mutation positions because the indices of difference
552
+ # are the same for A->B vs B->A. We want the 'input' (new mutant) to have
553
+ # the indicator flags at the difference sites.
554
+
555
+ self.samples.append({
556
+ "pdb": pdb,
557
+ "is_wt": is_wt,
558
+ "is_reverse": is_reverse, # True if this is a reversed (augmented) sample
559
+ "seq1": seq1, "seq2": seq2, "delg": delg,
560
+ "seq1_wt": seq1_wt, "seq2_wt": seq2_wt, "delg_wt": delg_wt,
561
+ "ddg": ddg,
562
+ "ddg_inferred": ddg_inferred, # NEW: Computed from dG_mut - dG_wt
563
+ "has_dg": has_dg,
564
+ "has_ddg": has_ddg,
565
+ "has_inferred_ddg": has_inferred_ddg, # NEW: True if ddg_inferred is valid
566
+ "has_both_dg_ddg": has_both,
567
+ "has_real_wt": has_real_wt,
568
+ "has_inferred_wt": has_inferred_wt,
569
+ "has_any_wt": has_any_wt,
570
+ "b1_mutpos_str": b1_mutpos_str,
571
+ "b2_mutpos_str": b2_mutpos_str,
572
+ "b1_wtpos_str": b1_wtpos_str,
573
+ "b2_wtpos_str": b2_wtpos_str,
574
+ "data_source": data_source
575
+ })
576
+
577
+ # Log WT inference statistics
578
+ n_real_wt = sum(1 for s in self.samples if s["has_real_wt"])
579
+ n_inferred_wt = sum(1 for s in self.samples if s["has_inferred_wt"])
580
+ n_no_wt = len(self.samples) - n_real_wt - n_inferred_wt
581
+
582
+ # Detailed stats for Real WTs (Internal vs External)
583
+ if hasattr(self, '_wt_source_stats'):
584
+ n_internal = self._wt_source_stats.get('internal', 0)
585
+ n_external = self._wt_source_stats.get('external', 0)
586
+ source_msg = f" (Internal: {n_internal}, External: {n_external})"
587
+ else:
588
+ source_msg = ""
589
+
590
+ print(f" [Dataset] Ready! {len(self.samples)} samples indexed (embeddings loaded on-demand)")
591
+ print(f" [Dataset] WT stats: {n_real_wt} real WT{source_msg}, {n_inferred_wt} inferred WT, {n_no_wt} no WT")
592
+
593
+ # Log detailed failure breakdown (for debugging)
594
+ if hasattr(self, '_wt_inference_failures') and hasattr(self, '_wt_inference_fail_count'):
595
+ print(f" [Dataset] ⚠️ WT inference failed for {self._wt_inference_fail_count} samples:")
596
+ fail_dict = self._wt_inference_failures
597
+
598
+ # Count by category (note: these are capped sample counts, not totals)
599
+ n_no_pdb = len(fail_dict.get('no_pdb', []))
600
+ n_del_ins = len(fail_dict.get('del_ins_only', []))
601
+ n_parse = len(fail_dict.get('parse_fail', []))
602
+
603
+ if n_no_pdb > 0:
604
+ print(f" - ANTIBODY samples (no PDB structure): {self._wt_inference_fail_count} samples")
605
+ print(f" (These are antibody design samples without original PDB - only dG usable)")
606
+ elif n_del_ins > 0 or n_parse > 0:
607
+ print(f" - DEL/INS/stop-codon (can't reverse): counted")
608
+ print(f" - Parsing failed (unknown format): counted")
609
+
610
+ # Show samples for non-ANTIBODY failures
611
+ if fail_dict.get('parse_fail') and n_no_pdb == 0:
612
+ print(f" Sample parse failures:")
613
+ for mut in fail_dict['parse_fail'][:5]:
614
+ print(f" '{mut}'")
615
+
616
+ def _parse_mutpos(self, pos_str) -> List[int]:
617
+ """
618
+ pos_str might be '[]' or '[170, 172]' etc.
619
+ We'll do a simple parse.
620
+ """
621
+ # Handle NaN, None, or non-string values
622
+ if pos_str is None or (isinstance(pos_str, float) and str(pos_str) == 'nan'):
623
+ return []
624
+ if not isinstance(pos_str, str):
625
+ pos_str = str(pos_str)
626
+ pos_str = pos_str.strip()
627
+ if pos_str.startswith("[") and pos_str.endswith("]"):
628
+ inside = pos_str[1:-1].strip()
629
+ if not inside:
630
+ return []
631
+ # split by comma
632
+ arr = inside.split(",")
633
+ out = []
634
+ for x in arr:
635
+ x_ = x.strip()
636
+ if x_:
637
+ out.append(int(x_))
638
+ return out
639
+ return []
640
+
641
+ def _recover_antibody_wts(self, df: pd.DataFrame) -> pd.DataFrame:
642
+ """
643
+ Recover WT information for antibody samples (ANTIBODY_MUTATION)
644
+ by finding the closest-to-consensus sequence in each antigen group.
645
+
646
+ Strategy:
647
+ 1. Identify samples with 'ANTIBODY_MUTATION'
648
+ 2. Group by antigen (block2_sequence)
649
+ 3. Assign unique Pseudo-PDB ID to each group (e.g. ANTIBODY_GRP_xxx)
650
+ 4. For same-length groups: find sequence closest to consensus as WT
651
+ 5. For variable-length groups: fallback to best binder (lowest del_g)
652
+ 6. Mark selected sequence as WT (clear mutation string)
653
+ """
654
+ from collections import Counter
655
+
656
+ # Identify antibody mutation rows
657
+ mask = df['Mutation(s)_PDB'].astype(str).str.contains('ANTIBODY_MUTATION', na=False)
658
+
659
+ if not mask.any():
660
+ return df
661
+
662
+ print(f" [Dataset] Attempting to recover WT for {mask.sum()} antibody samples...")
663
+
664
+ recovered_count = 0
665
+ n_groups = 0
666
+ n_consensus = 0
667
+ n_median = 0
668
+ n_fallback = 0
669
+
670
+ # We need a copy to avoid SettingWithCopy warnings if df is a slice
671
+ df = df.copy()
672
+
673
+ # Add a temporary column for grouping (hash of antigen sequence)
674
+ df['temp_antigen_hash'] = df['block2_sequence'].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest())
675
+
676
+ # Get hashes for antibody rows
677
+ ab_hashes = df.loc[mask, 'temp_antigen_hash'].unique()
678
+
679
+ for h in ab_hashes:
680
+ # Get all antibody rows for this antigen
681
+ grp_mask = mask & (df['temp_antigen_hash'] == h)
682
+ grp_indices = df.index[grp_mask]
683
+
684
+ if len(grp_indices) == 0:
685
+ continue
686
+
687
+ n_groups += 1
688
+
689
+ # 1. Create unique Pseudo-PDB ID
690
+ pseudo_pdb = f"ANTIBODY_GRP_{h[:8]}"
691
+ df.loc[grp_indices, '#Pdb'] = pseudo_pdb
692
+
693
+ # 2. Select WT: closest-to-consensus (same-length) or best-binder (variable-length)
694
+ seqs = df.loc[grp_indices, 'block1_sequence'].tolist()
695
+ seq_lens = set(len(s) for s in seqs)
696
+
697
+ wt_idx = None
698
+
699
+ if len(seq_lens) == 1:
700
+ # SAME LENGTH: Use closest-to-consensus
701
+ seq_len = list(seq_lens)[0]
702
+
703
+ # Build consensus sequence
704
+ consensus = []
705
+ for pos in range(seq_len):
706
+ residues = [s[pos] for s in seqs]
707
+ counts = Counter(residues)
708
+ most_common = counts.most_common(1)[0][0]
709
+ consensus.append(most_common)
710
+ consensus_seq = ''.join(consensus)
711
+
712
+ # Find sequence with minimum Hamming distance to consensus
713
+ min_dist = float('inf')
714
+ for idx in grp_indices:
715
+ seq = df.at[idx, 'block1_sequence']
716
+ dist = sum(c1 != c2 for c1, c2 in zip(seq, consensus_seq))
717
+ if dist < min_dist:
718
+ min_dist = dist
719
+ wt_idx = idx
720
+
721
+ n_consensus += 1
722
+ else:
723
+ # VARIABLE LENGTH: Fallback to median binder (more representative than best)
724
+ if 'del_g' in df.columns:
725
+ delg_vals = pd.to_numeric(df.loc[grp_indices, 'del_g'], errors='coerce').dropna()
726
+ if len(delg_vals) > 0:
727
+ # Find index of value closest to median
728
+ median_val = delg_vals.median()
729
+ median_idx = (delg_vals - median_val).abs().idxmin()
730
+ wt_idx = median_idx
731
+ n_median += 1
732
+
733
+ # FINAL FALLBACK: Pick first sample if no other method works (e.g., all NaN dG)
734
+ if wt_idx is None and len(grp_indices) > 0:
735
+ wt_idx = grp_indices[0]
736
+ n_fallback += 1
737
+
738
+ # 3. Mark selected sequence as WT
739
+ if wt_idx is not None:
740
+ df.at[wt_idx, 'Mutation(s)_PDB'] = ""
741
+ recovered_count += len(grp_indices)
742
+
743
+ # Cleanup
744
+ df.drop(columns=['temp_antigen_hash'], inplace=True, errors='ignore')
745
+
746
+ print(f" [Dataset] Recovered {recovered_count} antibody samples ({n_groups} groups):")
747
+ print(f" - {n_consensus} groups via closest-to-consensus")
748
+ print(f" - {n_median} groups via median-binder (variable-length)")
749
+ if n_fallback > 0:
750
+ print(f" - {n_fallback} groups via first-sample fallback (no dG data)")
751
+ return df
752
+
753
+ def _infer_wt_sequences(self, mut_seq1: str, mut_seq2: str, mutation_str: str,
754
+ b1_mutpos_str: str, b2_mutpos_str: str,
755
+ b1_chains: str = "", b2_chains: str = "") -> Tuple[Optional[str], Optional[str]]:
756
+ """
757
+ Infer wildtype sequences by reversing mutations in the mutant sequences.
758
+
759
+ IMPROVED: Instead of relying on PDB positions (which don't match 0-indexed
760
+ sequence positions), this version searches for the mutant residue and
761
+ reverses it. Also computes actual mutation positions as byproduct.
762
+
763
+ Mutations are in formats like:
764
+ - BindingGym: "H:P53L" or "H:P53L,H:Y57C" (chain:WTresPOSmutres)
765
+ - SKEMPI: "HP53L" or "CA182A" (chainWTresPOSmutres)
766
+
767
+ Args:
768
+ mut_seq1: Mutant sequence for block1
769
+ mut_seq2: Mutant sequence for block2
770
+ mutation_str: Raw mutation string from data
771
+ b1_mutpos_str: Mutation positions for block1 (e.g., "[52, 56]")
772
+ b2_mutpos_str: Mutation positions for block2
773
+ b1_chains: Chain letters in block1 (e.g., "AB")
774
+ b2_chains: Chain letters in block2 (e.g., "HL")
775
+
776
+ Returns:
777
+ Tuple of (wt_seq1, wt_seq2) or (None, None) if inference fails
778
+ """
779
+ import re
780
+
781
+ if pd.isna(mutation_str) or str(mutation_str).strip() == '':
782
+ # No mutations = this IS the wildtype
783
+ return mut_seq1, mut_seq2
784
+
785
+ # FALLBACK: Handle ANTIBODY_MUTATION samples that couldn't be recovered
786
+ mutation_str_upper = str(mutation_str).strip().upper()
787
+ if 'ANTIBODY_MUTATION' in mutation_str_upper or mutation_str_upper == 'ANTIBODY_MUTATION':
788
+ if not hasattr(self, '_wt_inference_failures'):
789
+ self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'no_pdb': [], 'other': []}
790
+ self._wt_inference_fail_count = 0
791
+ self._wt_inference_fail_count += 1
792
+ if len(self._wt_inference_failures['no_pdb']) < 5:
793
+ self._wt_inference_failures['no_pdb'].append(mutation_str[:80])
794
+ return None, None
795
+
796
+ try:
797
+ # Parse mutation string to extract (chain, position, original_AA, mutant_AA)
798
+ mutations = []
799
+ mutation_str = str(mutation_str).strip()
800
+
801
+ # Split by common delimiters
802
+ parts = re.split(r'[,;]', mutation_str)
803
+
804
+ for part in parts:
805
+ part = part.strip().strip('"\'')
806
+ if not part:
807
+ continue
808
+
809
+ # Skip deletion/insertion markers - can't reverse these
810
+ if 'DEL' in part.upper() or 'INS' in part.upper() or '*' in part:
811
+ continue
812
+
813
+ # BindingGym format: "H:P53L" or "L:K103R"
814
+ if ':' in part:
815
+ chain_mut = part.split(':')
816
+ if len(chain_mut) >= 2:
817
+ chain = chain_mut[0].strip().upper()
818
+ for mut_part in chain_mut[1:]:
819
+ mut_part = mut_part.strip()
820
+ if not mut_part:
821
+ continue
822
+ match = re.match(r'([A-Z])(\d+)([A-Z])', mut_part)
823
+ if match:
824
+ wt_aa = match.group(1)
825
+ pos = int(match.group(2)) # PDB-numbered (1-indexed)
826
+ mut_aa = match.group(3)
827
+ mutations.append((chain, pos, wt_aa, mut_aa))
828
+ else:
829
+ # SKEMPI format: "CA182A" = C(WTresidue) + A(chain) + 182(pos) + A(mutant)
830
+ # Format: WTresidue + ChainID + Position[insertcode] + MutResidue
831
+ # Example: CA182A means Cysteine at chain A position 182 mutated to Alanine
832
+ match = re.match(r'([A-Z])([A-Z])(-?\d+[a-z]?)([A-Z])', part)
833
+ if match:
834
+ wt_aa = match.group(1) # First char is WT residue
835
+ chain = match.group(2).upper() # Second char is chain ID
836
+ pos_str = match.group(3)
837
+ pos = int(re.match(r'-?\d+', pos_str).group())
838
+ mut_aa = match.group(4) # Last char is mutant residue
839
+ mutations.append((chain, pos, wt_aa, mut_aa))
840
+ else:
841
+ # Simple format without chain: "F139A" (used by PEPBI)
842
+ # Format: WTresidue + Position + MutResidue
843
+ match = re.match(r'([A-Z])(\d+)([A-Z])', part)
844
+ if match:
845
+ wt_aa = match.group(1)
846
+ pos = int(match.group(2))
847
+ mut_aa = match.group(3)
848
+ # No chain info - will try both blocks
849
+ mutations.append(('?', pos, wt_aa, mut_aa))
850
+
851
+ if not mutations:
852
+ if not hasattr(self, '_wt_inference_failures'):
853
+ self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'other': []}
854
+ self._wt_inference_fail_count = 0
855
+ self._wt_inference_fail_count += 1
856
+
857
+ if 'DEL' in mutation_str.upper() or 'INS' in mutation_str.upper() or '*' in mutation_str:
858
+ category = 'del_ins_only'
859
+ else:
860
+ category = 'parse_fail'
861
+
862
+ if len(self._wt_inference_failures.get(category, [])) < 10:
863
+ self._wt_inference_failures.setdefault(category, []).append(mutation_str[:80])
864
+
865
+ return None, None
866
+
867
+ # Convert sequences to lists for mutation
868
+ wt_seq1_list = list(mut_seq1) if mut_seq1 else []
869
+ wt_seq2_list = list(mut_seq2) if mut_seq2 else []
870
+
871
+ # Build chain sets for block assignment
872
+ b1_chain_set = set(b1_chains.upper()) if b1_chains else set()
873
+ b2_chain_set = set(b2_chains.upper()) if b2_chains else set()
874
+
875
+ # Parse PRECOMPUTED mutation positions (these are correct 0-indexed seq positions)
876
+ # PDB residue numbers often don't match sequence indices due to numbering offsets
877
+ precomputed_b1_positions = self._parse_mutpos(b1_mutpos_str)
878
+ precomputed_b2_positions = self._parse_mutpos(b2_mutpos_str)
879
+
880
+ # Track reversal success
881
+ if not hasattr(self, '_wt_inference_stats'):
882
+ self._wt_inference_stats = {'reversed': 0, 'not_found': 0, 'total': 0}
883
+
884
+ # Also track actual mutation positions found
885
+ found_positions_b1 = []
886
+ found_positions_b2 = []
887
+
888
+ # STRATEGY 1: Use precomputed positions if available (MOST RELIABLE)
889
+ # These were computed during preprocessing with correct PDB-to-sequence mapping
890
+ if precomputed_b1_positions or precomputed_b2_positions:
891
+ pos_idx = 0
892
+ for chain, pdb_pos, wt_aa, mut_aa in mutations:
893
+ self._wt_inference_stats['total'] += 1
894
+ reversed_this = False
895
+
896
+ # Determine which block based on chain
897
+ if chain in b2_chain_set:
898
+ # Use precomputed block2 positions
899
+ if pos_idx < len(precomputed_b2_positions):
900
+ seq_idx = precomputed_b2_positions[pos_idx]
901
+ if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa:
902
+ wt_seq2_list[seq_idx] = wt_aa
903
+ reversed_this = True
904
+ found_positions_b2.append(seq_idx)
905
+ elif chain in b1_chain_set:
906
+ # Use precomputed block1 positions
907
+ if pos_idx < len(precomputed_b1_positions):
908
+ seq_idx = precomputed_b1_positions[pos_idx]
909
+ if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa:
910
+ wt_seq1_list[seq_idx] = wt_aa
911
+ reversed_this = True
912
+ found_positions_b1.append(seq_idx)
913
+ else:
914
+ # Chain unknown - try both precomputed positions
915
+ if pos_idx < len(precomputed_b1_positions):
916
+ seq_idx = precomputed_b1_positions[pos_idx]
917
+ if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa:
918
+ wt_seq1_list[seq_idx] = wt_aa
919
+ reversed_this = True
920
+ found_positions_b1.append(seq_idx)
921
+ if not reversed_this and pos_idx < len(precomputed_b2_positions):
922
+ seq_idx = precomputed_b2_positions[pos_idx]
923
+ if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa:
924
+ wt_seq2_list[seq_idx] = wt_aa
925
+ reversed_this = True
926
+ found_positions_b2.append(seq_idx)
927
+
928
+ if reversed_this:
929
+ self._wt_inference_stats['reversed'] += 1
930
+ else:
931
+ self._wt_inference_stats['not_found'] += 1
932
+ pos_idx += 1
933
+
934
+ self._last_computed_mutpos = (found_positions_b1, found_positions_b2)
935
+ return ''.join(wt_seq1_list), ''.join(wt_seq2_list)
936
+
937
+ # STRATEGY 2: Fall back to PDB position-based search (less reliable)
938
+ for chain, pdb_pos, wt_aa, mut_aa in mutations:
939
+ self._wt_inference_stats['total'] += 1
940
+ reversed_this = False
941
+ found_idx = None
942
+
943
+ # Determine which block(s) to search based on chain
944
+ chain_known = chain in b1_chain_set or chain in b2_chain_set
945
+
946
+ if chain in b1_chain_set:
947
+ blocks_to_try = [(wt_seq1_list, True, found_positions_b1)]
948
+ elif chain in b2_chain_set:
949
+ blocks_to_try = [(wt_seq2_list, False, found_positions_b2)]
950
+ else:
951
+ # Chain info unavailable - try BOTH blocks
952
+ blocks_to_try = [
953
+ (wt_seq1_list, True, found_positions_b1),
954
+ (wt_seq2_list, False, found_positions_b2)
955
+ ]
956
+
957
+ for target_seq, is_block1, pos_list in blocks_to_try:
958
+ if reversed_this:
959
+ break # Already found in previous block
960
+
961
+ guess_idx = pdb_pos - 1 # Convert to 0-indexed
962
+
963
+ # Strategy 1: Try exact position if in bounds
964
+ if 0 <= guess_idx < len(target_seq) and target_seq[guess_idx] == mut_aa:
965
+ found_idx = guess_idx
966
+ else:
967
+ # Strategy 2: Search ±50 window around expected position
968
+ search_start = max(0, pdb_pos - 50)
969
+ search_end = min(len(target_seq), pdb_pos + 50)
970
+ for idx in range(search_start, search_end):
971
+ if target_seq[idx] == mut_aa:
972
+ found_idx = idx
973
+ break
974
+
975
+ # Strategy 3: If position was out of bounds AND chain unknown,
976
+ # search the ENTIRE sequence as last resort
977
+ if found_idx is None and not chain_known:
978
+ if guess_idx >= len(target_seq) or guess_idx < 0:
979
+ # Position was out of bounds - search entire sequence
980
+ for idx in range(len(target_seq)):
981
+ if target_seq[idx] == mut_aa:
982
+ found_idx = idx
983
+ break
984
+
985
+ if found_idx is not None:
986
+ target_seq[found_idx] = wt_aa # Reverse the mutation!
987
+ reversed_this = True
988
+ pos_list.append(found_idx)
989
+
990
+ if reversed_this:
991
+ self._wt_inference_stats['reversed'] += 1
992
+ else:
993
+ self._wt_inference_stats['not_found'] += 1
994
+
995
+ # Store computed mutation positions for later use (helps with Bug #3)
996
+ # These are the ACTUAL 0-indexed positions in the sequence
997
+ self._last_computed_mutpos = (found_positions_b1, found_positions_b2)
998
+
999
+ return ''.join(wt_seq1_list), ''.join(wt_seq2_list)
1000
+
1001
+ except Exception as e:
1002
+ # On any error, return None to indicate inference failed
1003
+ return None, None
1004
+
1005
+ def _get_embedding(self, seq: str, mut_positions: List[int]) -> torch.Tensor:
1006
+ """
1007
+ Basic embedding with mutation position indicator channel.
1008
+
1009
+ Args:
1010
+ seq: The protein sequence
1011
+ mut_positions: List of positions that are mutated (0-indexed)
1012
+ """
1013
+ # Get base ESM embedding (already ensures min length of 2)
1014
+ base_emb = self._get_or_create_embedding(seq) # => [L, 1152]
1015
+ base_emb = base_emb.cpu()
1016
+
1017
+ # Get sequence length and embedding dimension
1018
+ L, D = base_emb.shape
1019
+
1020
+ #region agent log
1021
+ try:
1022
+ if not hasattr(self, "_agent_log_counter"):
1023
+ self._agent_log_counter = 0
1024
+ if self._agent_log_counter < 5:
1025
+ self._agent_log_counter += 1
1026
+ last1_stats = None
1027
+ last2_stats = None
1028
+ if D >= 1153:
1029
+ v1 = base_emb[:, -1]
1030
+ last1_stats = {
1031
+ "min": float(v1.min().item()),
1032
+ "max": float(v1.max().item()),
1033
+ "mean": float(v1.float().mean().item()),
1034
+ "std": float(v1.float().std().item()),
1035
+ }
1036
+ if D >= 1154:
1037
+ v2 = base_emb[:, -2]
1038
+ last2_stats = {
1039
+ "min": float(v2.min().item()),
1040
+ "max": float(v2.max().item()),
1041
+ "mean": float(v2.float().mean().item()),
1042
+ "std": float(v2.float().std().item()),
1043
+ }
1044
+ payload = {
1045
+ "sessionId": "debug-session",
1046
+ "runId": "pre-fix",
1047
+ "hypothesisId": "F",
1048
+ "location": "modules.py:AdvancedSiameseDataset:_get_embedding",
1049
+ "message": "Base embedding shape + tail-channel stats before appending mutation indicator",
1050
+ "data": {
1051
+ "L": int(L),
1052
+ "D": int(D),
1053
+ "mut_positions_n": int(len(mut_positions) if mut_positions is not None else -1),
1054
+ "mut_positions_first5": (mut_positions[:5] if mut_positions else []),
1055
+ "base_last1": last1_stats,
1056
+ "base_last2": last2_stats,
1057
+ },
1058
+ "timestamp": int(time.time() * 1000),
1059
+ }
1060
+ with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
1061
+ f.write(json.dumps(payload, default=str) + "\n")
1062
+ # Also emit a concise line to stdout/logs (useful on cluster runs)
1063
+ print(f"[AGENTLOG EMB] D={D} mut_n={len(mut_positions) if mut_positions else 0} last1={last1_stats} last2={last2_stats}")
1064
+ except Exception:
1065
+ pass
1066
+ #endregion
1067
+
1068
+ # Create mutation indicator channel (just one channel)
1069
+ # FIX FOR DOUBLE-INDICATOR BUG: Check if base_emb already has indicator (D=1153)
1070
+ # If D=1153, the cached embedding already has an old indicator channel - OVERWRITE it
1071
+ # If D=1152, this is a fresh ESM embedding - APPEND indicator channel
1072
+ D = base_emb.shape[-1]
1073
+ L = base_emb.shape[0]
1074
+
1075
+ if D == 1153:
1076
+ # Already has indicator channel (from cache) - overwrite it with correct mutation positions
1077
+ new_emb = base_emb.clone()
1078
+ new_emb[:, -1] = 0.0 # Reset old indicator
1079
+ for pos in mut_positions:
1080
+ if isinstance(pos, int) and 0 <= pos < L:
1081
+ new_emb[pos, -1] = 1.0
1082
+ print(f"[AGENTLOG INDICATOR-FIX] D=1153 OVERWRITING last channel with {len(mut_positions)} positions")
1083
+ else:
1084
+ # Fresh ESM embedding (D=1152) - append indicator channel
1085
+ chan = torch.zeros((L, 1), dtype=base_emb.dtype, device=base_emb.device)
1086
+ for pos in mut_positions:
1087
+ if isinstance(pos, int) and 0 <= pos < L:
1088
+ chan[pos, 0] = 1.0
1089
+ new_emb = torch.cat([base_emb, chan], dim=-1)
1090
+ print(f"[AGENTLOG INDICATOR-FIX] D={D} APPENDING indicator channel with {len(mut_positions)} positions")
1091
+
1092
+ return new_emb
1093
+
1094
+ def _get_or_create_embedding(self, seq: str) -> torch.Tensor:
1095
+ # Check LRU cache first (limited size to control memory)
1096
+ if seq in self._embedding_cache:
1097
+ self._cache_hits += 1
1098
+ return self._embedding_cache[seq].clone()
1099
+
1100
+ seq_hash = hashlib.md5(seq.encode()).hexdigest()
1101
+ pt_file = self.embedding_dir / f"{seq_hash}.pt"
1102
+ npy_file = self.embedding_dir / f"{seq_hash}.npy"
1103
+
1104
+ emb = None
1105
+ load_source = None # Track where embedding came from
1106
+
1107
+ # Try .npy first (pre-computed), then .pt
1108
+ if npy_file.is_file():
1109
+ try:
1110
+ import numpy as np
1111
+ emb = torch.from_numpy(np.load(npy_file))
1112
+ load_source = "npy"
1113
+ except Exception:
1114
+ pass
1115
+ if emb is None and pt_file.is_file():
1116
+ try:
1117
+ emb = torch.load(pt_file, map_location="cpu")
1118
+ load_source = "pt"
1119
+ except Exception:
1120
+ pt_file.unlink(missing_ok=True) # Delete corrupted file
1121
+ if emb is None:
1122
+ # On-the-fly embedding generation for missing sequences (e.g., inferred WT)
1123
+ # This is slower but ensures accurate embeddings
1124
+ try:
1125
+ emb = self.featurizer.transform(seq) # [L, 1152]
1126
+ # Save for future use
1127
+ torch.save(emb, pt_file)
1128
+ load_source = "generated"
1129
+
1130
+ # Track on-the-fly generation stats
1131
+ if not hasattr(self, '_on_the_fly_count'):
1132
+ self._on_the_fly_count = 0
1133
+ self._on_the_fly_count += 1
1134
+
1135
+ # Log first few on-the-fly generations
1136
+ if self._on_the_fly_count <= 5:
1137
+ print(f"[EMBEDDING] Generated on-the-fly #{self._on_the_fly_count}: len={len(seq)}, saved to {pt_file.name}")
1138
+ elif self._on_the_fly_count == 6:
1139
+ print(f"[EMBEDDING] Generated 5+ embeddings on-the-fly (suppressing further logs)")
1140
+
1141
+ except Exception as e:
1142
+ raise RuntimeError(
1143
+ f"Embedding not found and on-the-fly generation failed for sequence (len={len(seq)}): {e}"
1144
+ )
1145
+
1146
+ #region agent log
1147
+ try:
1148
+ if not hasattr(self, "_agent_embload_counter"):
1149
+ self._agent_embload_counter = 0
1150
+ if self._agent_embload_counter < 8:
1151
+ self._agent_embload_counter += 1
1152
+ shape = tuple(int(x) for x in emb.shape)
1153
+ D = int(shape[1]) if len(shape) == 2 else None
1154
+ payload = {
1155
+ "sessionId": "debug-session",
1156
+ "runId": "pre-fix",
1157
+ "hypothesisId": "A",
1158
+ "location": "modules.py:AdvancedSiameseDataset:_get_or_create_embedding",
1159
+ "message": "Loaded embedding tensor (source + shape) before any indicator is appended",
1160
+ "data": {
1161
+ "load_source": load_source,
1162
+ "seq_len": int(len(seq)),
1163
+ "shape": shape,
1164
+ "D": D,
1165
+ "looks_like_has_indicator": bool(D is not None and D >= 1153),
1166
+ "file_pt_exists": bool(pt_file.is_file()),
1167
+ "file_npy_exists": bool(npy_file.is_file()),
1168
+ },
1169
+ "timestamp": int(time.time() * 1000),
1170
+ }
1171
+ with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
1172
+ f.write(json.dumps(payload, default=str) + "\n")
1173
+ print(f"[AGENTLOG EMBLOAD] src={load_source} shape={shape} D={D}")
1174
+ except Exception:
1175
+ pass
1176
+ #endregion
1177
+
1178
+ # SAFETY: Ensure embedding has valid shape (at least 5 residues for interpolation)
1179
+ if emb.shape[0] < 5:
1180
+ # Pad to minimum length of 5 by repeating
1181
+ repeats = (5 // emb.shape[0]) + 1
1182
+ emb = emb.repeat(repeats, 1)[:5] # Ensure exactly 5 rows
1183
+
1184
+ # Track cache miss
1185
+ self._cache_misses += 1
1186
+
1187
+ # Add to LRU cache (evict oldest if full)
1188
+ if len(self._embedding_cache) >= self._cache_max_size:
1189
+ # Remove oldest entry (first key in dict)
1190
+ oldest_key = next(iter(self._embedding_cache))
1191
+ del self._embedding_cache[oldest_key]
1192
+ self._embedding_cache[seq] = emb
1193
+
1194
+ return emb.clone() # Return clone to avoid mutation issues
1195
+
1196
+ def get_cache_stats(self):
1197
+ """Return cache statistics."""
1198
+ total = self._cache_hits + self._cache_misses
1199
+ hit_rate = (self._cache_hits / total * 100) if total > 0 else 0
1200
+ on_the_fly = getattr(self, '_on_the_fly_count', 0)
1201
+ wt_missing = getattr(self, '_wt_missing_count', 0)
1202
+ return {
1203
+ "hits": self._cache_hits,
1204
+ "misses": self._cache_misses,
1205
+ "total": total,
1206
+ "hit_rate": hit_rate,
1207
+ "cache_size": len(self._embedding_cache),
1208
+ "cache_max": self._cache_max_size,
1209
+ "on_the_fly_generated": on_the_fly,
1210
+ "wt_embedding_failed": wt_missing
1211
+ }
1212
+
1213
+ def print_cache_stats(self):
1214
+ """Print cache statistics."""
1215
+ stats = self.get_cache_stats()
1216
+ print(f" [Cache] Hits: {stats['hits']:,} | Misses: {stats['misses']:,} | "
1217
+ f"Hit Rate: {stats['hit_rate']:.1f}% | Size: {stats['cache_size']:,}/{stats['cache_max']:,}")
1218
+ if stats['on_the_fly_generated'] > 0:
1219
+ print(f" [Cache] On-the-fly generated: {stats['on_the_fly_generated']:,} embeddings")
1220
+ if stats['wt_embedding_failed'] > 0:
1221
+ print(f" [Cache] ⚠️ WT embedding failures: {stats['wt_embedding_failed']:,} (excluded from ddG training)")
1222
+
1223
+ def __len__(self):
1224
+ return len(self.samples)
1225
+
1226
+ def __getitem__(self, idx):
1227
+ item = self.samples[idx]
1228
+
1229
+ # DEBUG: Track sequence difference statistics
1230
+ if not hasattr(self, '_seq_diff_stats'):
1231
+ self._seq_diff_stats = {'same': 0, 'different': 0, 'no_wt': 0}
1232
+ if not hasattr(self, '_mutpos_stats'):
1233
+ self._mutpos_stats = {'has_mutpos': 0, 'no_mutpos': 0}
1234
+
1235
+ # LAZY LOADING: Load embeddings on-demand
1236
+ b1_mutpos = self._parse_mutpos(item["b1_mutpos_str"])
1237
+ b2_mutpos = self._parse_mutpos(item["b2_mutpos_str"])
1238
+
1239
+ #region agent log
1240
+ try:
1241
+ if not hasattr(self, "_agent_mutpos_getitem_counter"):
1242
+ self._agent_mutpos_getitem_counter = 0
1243
+ if self._agent_mutpos_getitem_counter < 20:
1244
+ self._agent_mutpos_getitem_counter += 1
1245
+ payload = {
1246
+ "sessionId": "debug-session",
1247
+ "runId": "pre-fix",
1248
+ "hypothesisId": "G",
1249
+ "location": "modules.py:AdvancedSiameseDataset:__getitem__",
1250
+ "message": "Parsed mut_positions passed to _get_embedding",
1251
+ "data": {
1252
+ "idx": int(idx),
1253
+ "pdb": str(item.get("pdb")),
1254
+ "is_wt": bool(item.get("is_wt")),
1255
+ "b1_mutpos_str": str(item.get("b1_mutpos_str")),
1256
+ "b2_mutpos_str": str(item.get("b2_mutpos_str")),
1257
+ "b1_mutpos_n": int(len(b1_mutpos)),
1258
+ "b2_mutpos_n": int(len(b2_mutpos)),
1259
+ "b1_mutpos_first5": b1_mutpos[:5],
1260
+ "b2_mutpos_first5": b2_mutpos[:5],
1261
+ },
1262
+ "timestamp": int(time.time() * 1000),
1263
+ }
1264
+ with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
1265
+ f.write(json.dumps(payload, default=str) + "\n")
1266
+ print(f"[AGENTLOG MUTPOSGET] idx={idx} b1n={len(b1_mutpos)} b2n={len(b2_mutpos)} b1str={item.get('b1_mutpos_str')} b2str={item.get('b2_mutpos_str')}")
1267
+ except Exception:
1268
+ pass
1269
+ #endregion
1270
+
1271
+ # Track mutation position statistics
1272
+ if len(b1_mutpos) > 0 or len(b2_mutpos) > 0:
1273
+ self._mutpos_stats['has_mutpos'] += 1
1274
+ else:
1275
+ self._mutpos_stats['no_mutpos'] += 1
1276
+
1277
+ # Log mutation position stats periodically
1278
+ total = sum(self._mutpos_stats.values())
1279
+ if total in [100, 1000, 10000]:
1280
+ has_mp = self._mutpos_stats['has_mutpos']
1281
+ no_mp = self._mutpos_stats['no_mutpos']
1282
+ print(f" [MUTPOS] After {total} samples: {has_mp} have mutation positions ({100*has_mp/total:.1f}%), "
1283
+ f"{no_mp} have NO mutation positions ({100*no_mp/total:.1f}%)")
1284
+
1285
+ c1_emb = self._get_embedding(item["seq1"], b1_mutpos)
1286
+ c2_emb = self._get_embedding(item["seq2"], b2_mutpos)
1287
+
1288
+ if self.normalize:
1289
+ c1_emb[:, :-1] = torch.nn.functional.normalize(c1_emb[:, :-1], p=2, dim=-1)
1290
+ c2_emb[:, :-1] = torch.nn.functional.normalize(c2_emb[:, :-1], p=2, dim=-1)
1291
+
1292
+ # Load WT embeddings if available
1293
+ if item["seq1_wt"] is not None:
1294
+ # DEBUG: Track sequence differences
1295
+ seq1_same = (item["seq1"] == item["seq1_wt"])
1296
+ seq2_same = (item["seq2"] == item["seq2_wt"])
1297
+ if seq1_same and seq2_same:
1298
+ self._seq_diff_stats['same'] += 1
1299
+ else:
1300
+ self._seq_diff_stats['different'] += 1
1301
+
1302
+ # Periodic logging
1303
+ total_samples = sum(self._seq_diff_stats.values())
1304
+ if total_samples in [100, 1000, 10000, 50000]:
1305
+ same = self._seq_diff_stats['same']
1306
+ diff = self._seq_diff_stats['different']
1307
+ no_wt = self._seq_diff_stats['no_wt']
1308
+ print(f" [SEQ DIFF] After {total_samples} samples: {same} same seq ({100*same/total_samples:.1f}%), "
1309
+ f"{diff} different ({100*diff/total_samples:.1f}%), {no_wt} no WT")
1310
+
1311
+ b1_wtpos = self._parse_mutpos(item["b1_wtpos_str"])
1312
+ b2_wtpos = self._parse_mutpos(item["b2_wtpos_str"])
1313
+
1314
+ #region agent log
1315
+ try:
1316
+ if not hasattr(self, "_agent_embed_call_counter_wt"):
1317
+ self._agent_embed_call_counter_wt = 0
1318
+ if self._agent_embed_call_counter_wt < 10:
1319
+ self._agent_embed_call_counter_wt += 1
1320
+ print(
1321
+ f"[AGENTLOG EMBCALL] idx={idx} role=wt "
1322
+ f"b1_wtpos_n={len(b1_wtpos)} b2_wtpos_n={len(b2_wtpos)} "
1323
+ f"seq1_wt_len={len(item.get('seq1_wt','') or '')} seq2_wt_len={len(item.get('seq2_wt','') or '')}"
1324
+ )
1325
+ except Exception:
1326
+ pass
1327
+ #endregion
1328
+
1329
+ try:
1330
+ cw1 = self._get_embedding(item["seq1_wt"], b1_wtpos)
1331
+ cw2 = self._get_embedding(item["seq2_wt"], b2_wtpos)
1332
+ except RuntimeError as e:
1333
+ # WT embedding unavailable - mark as no WT for this sample
1334
+ # DO NOT use mutant embedding as proxy - this corrupts the mutation signal!
1335
+ # Instead, set cw1, cw2 to None and let training handle missing WT
1336
+ cw1, cw2 = None, None
1337
+ if not hasattr(self, '_wt_missing_count'):
1338
+ self._wt_missing_count = 0
1339
+ self._wt_missing_count += 1
1340
+ if self._wt_missing_count <= 3: # Only log first 3 to avoid spam
1341
+ print(f" [WARN] WT embedding missing #{self._wt_missing_count}, sample will be WT-less: {e}")
1342
+
1343
+ if cw1 is not None and self.normalize:
1344
+ cw1[:, :-1] = torch.nn.functional.normalize(cw1[:, :-1], p=2, dim=-1)
1345
+ cw2[:, :-1] = torch.nn.functional.normalize(cw2[:, :-1], p=2, dim=-1)
1346
+ else:
1347
+ cw1, cw2 = None, None
1348
+ self._seq_diff_stats['no_wt'] += 1
1349
+
1350
+ data_tuple = (c1_emb, c2_emb, item["delg"],
1351
+ cw1, cw2, item["delg_wt"])
1352
+ meta = {
1353
+ "pdb": item["pdb"],
1354
+ "is_wt": item["is_wt"],
1355
+ "has_real_wt": item["has_real_wt"],
1356
+ "has_dg": item["has_dg"],
1357
+ "has_ddg": item["has_ddg"], # Whether sample has valid explicit ddG value
1358
+ "has_inferred_ddg": item["has_inferred_ddg"], # Whether sample has inferred ddG (dG_mut - dG_wt)
1359
+ "has_both_dg_ddg": item["has_both_dg_ddg"],
1360
+ "ddg": item["ddg"],
1361
+ "ddg_inferred": item["ddg_inferred"], # Inferred ddG value (needed for Fix #1)
1362
+ "has_any_wt": item["has_any_wt"], # Include inferred WT status (CRITICAL!)
1363
+ "b1_mutpos": b1_mutpos,
1364
+ "b2_mutpos": b2_mutpos,
1365
+ "data_source": item["data_source"]
1366
+ }
1367
+ return (data_tuple, meta)
1368
+
1369
+ #########################################
1370
+ # AffinityDataModule
1371
+ #########################################
1372
+ from sklearn.model_selection import GroupKFold
1373
+
1374
+ class AffinityDataModule(pl.LightningDataModule):
1375
+ """
1376
+ Data module for protein binding affinity prediction.
1377
+
1378
+ Supports multiple splitting strategies:
1379
+ 1. split_indices_dir: Load pre-computed cluster-based splits (RECOMMENDED)
1380
+ 2. use_cluster_split: Create new cluster-based splits on the fly
1381
+ 3. split column: Use existing 'split' column in CSV (legacy)
1382
+ 4. num_folds > 1: GroupKFold on PDB IDs
1383
+ """
1384
+ def __init__(
1385
+ self,
1386
+ data_csv: str,
1387
+ protein_featurizer: ESM3Featurizer,
1388
+ embedding_dir: str = "precomputed_esm",
1389
+ batch_size: int = 32,
1390
+ num_workers: int = 4,
1391
+ shuffle: bool = True,
1392
+ num_folds: int = 1,
1393
+ fold_index: int = 0,
1394
+ # New cluster-based splitting options
1395
+ split_indices_dir: str = None, # Path to pre-computed split indices
1396
+ benchmark_indices_dir: str = None, # Path to balanced benchmark subset indices (optional override)
1397
+ use_cluster_split: bool = False, # Create cluster-based splits on the fly
1398
+ train_ratio: float = 0.70,
1399
+ val_ratio: float = 0.15,
1400
+ test_ratio: float = 0.15,
1401
+ random_state: int = 42
1402
+ ):
1403
+ super().__init__()
1404
+ self.data_csv = data_csv
1405
+ self.featurizer = protein_featurizer
1406
+ self.embedding_dir = embedding_dir
1407
+ self.batch_size = batch_size
1408
+ self.num_workers = num_workers
1409
+ self.shuffle = shuffle
1410
+ self.num_folds = num_folds
1411
+ self.fold_index = fold_index
1412
+
1413
+ # Cluster-based splitting options
1414
+ self.split_indices_dir = split_indices_dir
1415
+ self.benchmark_indices_dir = benchmark_indices_dir # Optional balanced benchmark override
1416
+ self.use_cluster_split = use_cluster_split
1417
+ self.train_ratio = train_ratio
1418
+ self.val_ratio = val_ratio
1419
+ self.test_ratio = test_ratio
1420
+ self.random_state = random_state
1421
+
1422
+ self.train_dataset = None
1423
+ self.val_dataset = None
1424
+ self.test_dataset = None
1425
+
1426
+ # Dual-split datasets (separate for dG and ddG heads)
1427
+ self.dg_train_dataset = None # WT-only training set for Stage A
1428
+ self.ddg_train_dataset = None # Mutation training set for Stage B
1429
+ self.dg_val_dataset = None
1430
+ self.dg_test_dataset = None
1431
+ self.ddg_val_dataset = None
1432
+ self.ddg_test_dataset = None
1433
+ self.use_dual_split = False
1434
+
1435
+ def prepare_data(self):
1436
+ if not os.path.exists(self.data_csv):
1437
+ raise FileNotFoundError(f"Data CSV not found => {self.data_csv}")
1438
+
1439
+ def setup(self, stage=None):
1440
+ data = pd.read_csv(self.data_csv, low_memory=False)
1441
+
1442
+ # Check if this is a dual-split directory
1443
+ dual_split_file = os.path.join(self.split_indices_dir, 'dg_val_indices.csv') if self.split_indices_dir else None
1444
+
1445
+ # Strategy 0: Load DUAL splits (separate for dG and ddG heads)
1446
+ if self.split_indices_dir and dual_split_file and os.path.exists(dual_split_file):
1447
+ from data_splitting import load_dual_splits
1448
+ print(f"\n[DataModule] Loading DUAL splits from {self.split_indices_dir}")
1449
+
1450
+ splits = load_dual_splits(self.split_indices_dir)
1451
+ self.use_dual_split = True
1452
+
1453
+ # Combined training set (union of dG and ddG train indices)
1454
+ train_idx = splits['combined_train']
1455
+ train_df = data.iloc[train_idx].reset_index(drop=True)
1456
+
1457
+ # For backward compatibility, use ddG validation as default val set
1458
+ # (since most validation is on mutation data)
1459
+ val_idx = splits['ddg']['val']
1460
+ val_df = data.iloc[val_idx].reset_index(drop=True)
1461
+ test_idx = splits['ddg']['test']
1462
+ test_df = data.iloc[test_idx].reset_index(drop=True)
1463
+
1464
+ # Create separate datasets for each head
1465
+ # CRITICAL: Create separate dG (WT-only) and ddG (MT-only) TRAINING sets
1466
+ # This fixes Stage A WT starvation where WT is diluted to 2.75% in combined_train
1467
+ dg_train_df = data.iloc[splits['dg']['train']].reset_index(drop=True)
1468
+ ddg_train_df = data.iloc[splits['ddg']['train']].reset_index(drop=True)
1469
+
1470
+ dg_val_df = data.iloc[splits['dg']['val']].reset_index(drop=True)
1471
+ dg_test_df = data.iloc[splits['dg']['test']].reset_index(drop=True)
1472
+ ddg_val_df = data.iloc[splits['ddg']['val']].reset_index(drop=True)
1473
+ ddg_test_df = data.iloc[splits['ddg']['test']].reset_index(drop=True)
1474
+
1475
+ print(f"\n[DataModule] Creating dG TRAIN dataset ({len(dg_train_df)} WT rows)...")
1476
+ self.dg_train_dataset = AdvancedSiameseDataset(dg_train_df, self.featurizer, self.embedding_dir, augment=False) # Baseline: no augment
1477
+
1478
+ print(f"[DataModule] Creating ddG TRAIN dataset ({len(ddg_train_df)} MT rows)...")
1479
+ self.ddg_train_dataset = AdvancedSiameseDataset(ddg_train_df, self.featurizer, self.embedding_dir, augment=False) # Baseline: no augment
1480
+
1481
+ # === BALANCED BENCHMARK OVERRIDE ===
1482
+ # If benchmark_indices_dir is provided, use those for ddG val/test instead
1483
+ if self.benchmark_indices_dir and os.path.exists(self.benchmark_indices_dir):
1484
+ print(f"\n[DataModule] Loading BALANCED BENCHMARK indices from {self.benchmark_indices_dir}")
1485
+
1486
+ # Load ddG benchmark val indices
1487
+ ddg_val_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_val_benchmark_indices.csv')
1488
+ if os.path.exists(ddg_val_bench_file):
1489
+ bench_val_idx = pd.read_csv(ddg_val_bench_file, header=None).iloc[:, 0].values.tolist()
1490
+ ddg_val_df = data.iloc[bench_val_idx].reset_index(drop=True)
1491
+ print(f" ddG val: {len(ddg_val_df)} rows (balanced benchmark)")
1492
+
1493
+ # Load ddG benchmark test indices
1494
+ ddg_test_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_test_benchmark_indices.csv')
1495
+ if os.path.exists(ddg_test_bench_file):
1496
+ bench_test_idx = pd.read_csv(ddg_test_bench_file, header=None).iloc[:, 0].values.tolist()
1497
+ ddg_test_df = data.iloc[bench_test_idx].reset_index(drop=True)
1498
+ print(f" ddG test: {len(ddg_test_df)} rows (balanced benchmark)")
1499
+
1500
+ print(f"\n[DataModule] Creating dG val dataset ({len(dg_val_df)} rows)...")
1501
+ # NOTE: Do NOT subsample validation - we want accurate metrics on full set
1502
+ self.dg_val_dataset = AdvancedSiameseDataset(
1503
+ dg_val_df, self.featurizer, self.embedding_dir, augment=False,
1504
+ wt_reference_df=data # FIX: Use full data for WT lookup (robust to split boundaries)
1505
+ )
1506
+
1507
+ print(f"\n[DataModule] Creating dG test dataset ({len(dg_test_df)} rows)...")
1508
+ self.dg_test_dataset = AdvancedSiameseDataset(
1509
+ dg_test_df, self.featurizer, self.embedding_dir, augment=False,
1510
+ wt_reference_df=data # FIX: Use full data for WT lookup
1511
+ )
1512
+
1513
+ print(f"\n[DataModule] Creating ddG val dataset ({len(ddg_val_df)} rows)...")
1514
+ # NOTE: Do NOT subsample validation - we want accurate metrics on full set
1515
+ # cap_k only applies to training DMS data
1516
+ self.ddg_val_dataset = AdvancedSiameseDataset(
1517
+ ddg_val_df, self.featurizer, self.embedding_dir, augment=False,
1518
+ wt_reference_df=data # FIX: Use full data for WT lookup
1519
+ )
1520
+
1521
+ print(f"\n[DataModule] Creating ddG test dataset ({len(ddg_test_df)} rows)...")
1522
+ self.ddg_test_dataset = AdvancedSiameseDataset(
1523
+ ddg_test_df, self.featurizer, self.embedding_dir, augment=False,
1524
+ wt_reference_df=data # FIX: Use full data for WT lookup
1525
+ )
1526
+
1527
+ print(f"\n[DataModule] Dual split datasets created:")
1528
+ print(f" dG train: {len(self.dg_train_dataset)} samples (WT-only for Stage A)")
1529
+ print(f" ddG train: {len(self.ddg_train_dataset)} samples (MT-only)")
1530
+ print(f" dG val: {len(self.dg_val_dataset)} samples")
1531
+ print(f" dG test: {len(self.dg_test_dataset)} samples")
1532
+ print(f" ddG val: {len(self.ddg_val_dataset)} samples")
1533
+ print(f" ddG test: {len(self.ddg_test_dataset)} samples")
1534
+
1535
+ # Strategy 1: Load pre-computed cluster-based splits (single split)
1536
+ elif self.split_indices_dir and os.path.exists(self.split_indices_dir):
1537
+ from data_splitting import load_split_indices, verify_no_leakage
1538
+ train_idx, val_idx, test_idx = load_split_indices(self.split_indices_dir)
1539
+
1540
+ train_df = data.iloc[train_idx].reset_index(drop=True)
1541
+ val_df = data.iloc[val_idx].reset_index(drop=True)
1542
+ test_df = data.iloc[test_idx].reset_index(drop=True)
1543
+
1544
+ # Verify no leakage
1545
+ verify_no_leakage(data, train_idx, val_idx, test_idx)
1546
+
1547
+ # Strategy 2: Create cluster-based splits on the fly
1548
+ elif self.use_cluster_split:
1549
+ from data_splitting import create_cluster_splits, verify_no_leakage
1550
+
1551
+ # Create splits directory if needed
1552
+ splits_dir = os.path.join(os.path.dirname(self.data_csv), 'splits')
1553
+
1554
+ train_idx, val_idx, test_idx = create_cluster_splits(
1555
+ data,
1556
+ train_ratio=self.train_ratio,
1557
+ val_ratio=self.val_ratio,
1558
+ test_ratio=self.test_ratio,
1559
+ random_state=self.random_state,
1560
+ save_dir=splits_dir
1561
+ )
1562
+
1563
+ train_df = data.iloc[train_idx].reset_index(drop=True)
1564
+ val_df = data.iloc[val_idx].reset_index(drop=True)
1565
+ test_df = data.iloc[test_idx].reset_index(drop=True)
1566
+
1567
+ # Strategy 3: Legacy - use 'split' column in CSV
1568
+ else:
1569
+ # must have block1_sequence, block1_mut_positions, block2_sequence, ...
1570
+ bench_df = data[data["split"]=="Benchmark test"].copy()
1571
+ trainval_df = data[data["split"]!="Benchmark test"].copy()
1572
+
1573
+ if self.num_folds > 1:
1574
+ gkf = GroupKFold(n_splits=self.num_folds)
1575
+ groups = trainval_df["#Pdb"].values
1576
+ folds = list(gkf.split(trainval_df, groups=groups))
1577
+ train_idx, val_idx = folds[self.fold_index]
1578
+ train_df = trainval_df.iloc[train_idx].reset_index(drop=True)
1579
+ val_df = trainval_df.iloc[val_idx].reset_index(drop=True)
1580
+ else:
1581
+ train_df = trainval_df[trainval_df["split"]=="train"].reset_index(drop=True)
1582
+ val_df = trainval_df[trainval_df["split"]=="val"].reset_index(drop=True)
1583
+
1584
+ test_df = bench_df
1585
+
1586
+ print(f"\n[DataModule] Creating TRAIN dataset ({len(train_df)} rows)...")
1587
+ self.train_dataset = AdvancedSiameseDataset(
1588
+ train_df, self.featurizer, self.embedding_dir, augment=False # Baseline: no augment (enable later for antisymmetry)
1589
+ )
1590
+ print(f"\n[DataModule] Creating VAL dataset ({len(val_df)} rows)...")
1591
+ # Subsampling disabled for v20 ablation to ensure robust Macro-PCC evaluation
1592
+ # (need full diversity of PDB families for honest reporting)
1593
+ self.val_dataset = AdvancedSiameseDataset(
1594
+ val_df, self.featurizer, self.embedding_dir, augment=False,
1595
+ wt_reference_df=train_df # Pass training set as source for WTs
1596
+ )
1597
+ print(f"\n[DataModule] Creating TEST dataset ({len(test_df)} rows)...")
1598
+ self.test_dataset = AdvancedSiameseDataset(
1599
+ test_df, self.featurizer, self.embedding_dir, augment=False,
1600
+ wt_reference_df=train_df # Pass training set as source for WTs (no leakage, WTs are known)
1601
+ )
1602
+
1603
+ # FIX: Create separate dg_test and ddg_test datasets for proper test metric logging
1604
+ # This is CRITICAL for sweep runs - without this, test metrics are never computed!
1605
+ if self.dg_test_dataset is None and self.ddg_test_dataset is None:
1606
+ # Determine WT/MT based on Mutation(s)_cleaned column
1607
+ def is_wt_row(row):
1608
+ mut_str = str(row.get('Mutation(s)_cleaned', '')).strip()
1609
+ return mut_str == '' or mut_str.lower() == 'nan' or mut_str == 'WT'
1610
+
1611
+ # Separate test_df into WT (for dG test) and MT (for ddG test)
1612
+ test_is_wt = test_df.apply(is_wt_row, axis=1)
1613
+ dg_test_df = test_df[test_is_wt].reset_index(drop=True)
1614
+ ddg_test_df = test_df[~test_is_wt].reset_index(drop=True)
1615
+
1616
+ if len(dg_test_df) > 0:
1617
+ print(f"\n[DataModule] Creating dG TEST dataset ({len(dg_test_df)} WT rows)...")
1618
+ self.dg_test_dataset = AdvancedSiameseDataset(
1619
+ dg_test_df, self.featurizer, self.embedding_dir, augment=False,
1620
+ wt_reference_df=data # Use full data for WT lookup
1621
+ )
1622
+ else:
1623
+ print(f"[DataModule] WARNING: No WT rows in test set for dG test dataset!")
1624
+
1625
+ if len(ddg_test_df) > 0:
1626
+ print(f"\n[DataModule] Creating ddG TEST dataset ({len(ddg_test_df)} MT rows)...")
1627
+ self.ddg_test_dataset = AdvancedSiameseDataset(
1628
+ ddg_test_df, self.featurizer, self.embedding_dir, augment=False,
1629
+ wt_reference_df=data # Use full data for WT lookup
1630
+ )
1631
+ else:
1632
+ print(f"[DataModule] WARNING: No MT rows in test set for ddG test dataset!")
1633
+
1634
+ # Log dataset sizes
1635
+ print(f"\nDataset sizes:")
1636
+ print(f" Train: {len(self.train_dataset)} samples")
1637
+ print(f" Val: {len(self.val_dataset)} samples")
1638
+ print(f" Test: {len(self.test_dataset)} samples")
1639
+ if self.dg_test_dataset:
1640
+ print(f" dG Test: {len(self.dg_test_dataset)} samples (WT)")
1641
+ if self.ddg_test_dataset:
1642
+ print(f" ddG Test: {len(self.ddg_test_dataset)} samples (MT)")
1643
+
1644
+
1645
+ def train_dataloader(self):
1646
+ return DataLoader(
1647
+ self.train_dataset,
1648
+ batch_size=self.batch_size,
1649
+ shuffle=self.shuffle,
1650
+ num_workers=self.num_workers,
1651
+ collate_fn=advanced_collate_fn
1652
+ )
1653
+
1654
+ def val_dataloader(self):
1655
+ return DataLoader(
1656
+ self.val_dataset,
1657
+ batch_size=self.batch_size,
1658
+ shuffle=False,
1659
+ num_workers=self.num_workers,
1660
+ collate_fn=advanced_collate_fn
1661
+ )
1662
+
1663
+ def test_dataloader(self):
1664
+ return DataLoader(
1665
+ self.test_dataset,
1666
+ batch_size=self.batch_size,
1667
+ shuffle=False,
1668
+ num_workers=self.num_workers,
1669
+ collate_fn=advanced_collate_fn
1670
+ )
1671
+
1672
+ # Dual-split training dataloaders for separate dG-only (Stage A) and ddG (Stage B) training
1673
+ def dg_train_dataloader(self):
1674
+ """Training dataloader for dG head (WT data only for Stage A pretraining)."""
1675
+ if self.dg_train_dataset is None:
1676
+ return None
1677
+ return DataLoader(
1678
+ self.dg_train_dataset,
1679
+ batch_size=self.batch_size,
1680
+ shuffle=self.shuffle,
1681
+ num_workers=self.num_workers,
1682
+ collate_fn=advanced_collate_fn
1683
+ )
1684
+
1685
+ def ddg_train_dataloader(self):
1686
+ """Training dataloader for ddG head (mutation data for Stage B training)."""
1687
+ if self.ddg_train_dataset is None:
1688
+ return None
1689
+ return DataLoader(
1690
+ self.ddg_train_dataset,
1691
+ batch_size=self.batch_size,
1692
+ shuffle=self.shuffle,
1693
+ num_workers=self.num_workers,
1694
+ collate_fn=advanced_collate_fn
1695
+ )
1696
+
1697
+ # Dual-split dataloaders for separate dG and ddG validation
1698
+ def dg_val_dataloader(self):
1699
+ """Validation dataloader for dG head (WT data only)."""
1700
+ if self.dg_val_dataset is None:
1701
+ return None
1702
+ return DataLoader(
1703
+ self.dg_val_dataset,
1704
+ batch_size=self.batch_size,
1705
+ shuffle=False,
1706
+ num_workers=self.num_workers,
1707
+ collate_fn=advanced_collate_fn
1708
+ )
1709
+
1710
+ def dg_test_dataloader(self):
1711
+ """Test dataloader for dG head (WT data only)."""
1712
+ if self.dg_test_dataset is None:
1713
+ return None
1714
+ return DataLoader(
1715
+ self.dg_test_dataset,
1716
+ batch_size=self.batch_size,
1717
+ shuffle=False,
1718
+ num_workers=self.num_workers,
1719
+ collate_fn=advanced_collate_fn
1720
+ )
1721
+
1722
+ def ddg_val_dataloader(self):
1723
+ """Validation dataloader for ddG head (mutation data including DMS)."""
1724
+ if self.ddg_val_dataset is None:
1725
+ return None
1726
+ return DataLoader(
1727
+ self.ddg_val_dataset,
1728
+ batch_size=self.batch_size,
1729
+ shuffle=False,
1730
+ num_workers=self.num_workers,
1731
+ collate_fn=advanced_collate_fn
1732
+ )
1733
+
1734
+ def ddg_test_dataloader(self):
1735
+ """Test dataloader for ddG head (mutation data including DMS)."""
1736
+ if self.ddg_test_dataset is None:
1737
+ return None
1738
+ return DataLoader(
1739
+ self.ddg_test_dataset,
1740
+ batch_size=self.batch_size,
1741
+ shuffle=False,
1742
+ num_workers=self.num_workers,
1743
+ collate_fn=advanced_collate_fn
1744
+ )