from typing import Tuple, List, Dict, Optional from pathlib import Path import os import math import pandas as pd import pytorch_lightning as pl import torch from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import numpy as np import hashlib import json import time from esm3bedding import ESM3Featurizer from utils import get_logger logg = get_logger() ######################################### # Source Type Mapping # ######################################### SOURCE_TYPE_MAP = { # Protein complexes (unique structures) 'PDBbind': 'protein_complex', 'PPIKB': 'protein_complex', 'asd_biomap': 'protein_complex', 'asd_aae': 'protein_complex', 'asd_aatp': 'protein_complex', 'asd_osh': 'protein_complex', # True mutations 'SKEMPI': 'mutation', 'BindingGym': 'mutation', 'asd_flab_koenig2017': 'mutation', # 1-2aa differences 'asd_flab_warszawski2019': 'mutation', # 1-2aa differences 'asd_flab_rosace2023': 'mutation', # 1-5aa differences 'PEPBI': 'mutation', # Antibody CDR variants 'asd_abbd': 'antibody_cdr', # 3-14aa CDR differences 'abdesign': 'antibody_cdr', 'asd_flab_hie2022': 'antibody_cdr', # 2-17aa differences 'asd_flab_shanehsazzadeh2023': 'antibody_cdr', # 3-18aa differences } SOURCE_TYPE_TO_ID = {'protein_complex': 0, 'mutation': 1, 'antibody_cdr': 2} DEFAULT_SOURCE_TYPE = 'mutation' # Default for unknown sources ######################################### # Collate function (Siamese) # ######################################### def advanced_collate_fn(batch): mut_c1_list, mut_c2_list, mut_y_list = [], [], [] wt_c1_list, wt_c2_list, wt_y_list = [], [], [] has_valid_wt_list = [] # CRITICAL: Track which samples have REAL WT embeddings (not zeros) meta_list = [] for data, meta in batch: (c1, c2, y, cw1, cw2, yw) = data # mutant mut_c1_list.append(c1) mut_c2_list.append(c2) mut_y_list.append(torch.tensor([y], dtype=torch.float32)) # wildtype if cw1 is not None and cw2 is not None and yw is not None: wt_c1_list.append(cw1) wt_c2_list.append(cw2) wt_y_list.append(torch.tensor([yw], dtype=torch.float32)) has_valid_wt_list.append(True) # Real WT data available else: # fallback if no known WT - ZEROS corrupt ddG signal! wt_c1_list.append(torch.zeros((1, c1.shape[1]))) wt_c2_list.append(torch.zeros((1, c2.shape[1]))) wt_y_list.append(torch.tensor([0.0], dtype=torch.float32)) has_valid_wt_list.append(False) # INVALID for ddG - would compute mut-0=mut meta_list.append(meta) # pad mutant c1_padded = pad_sequence(mut_c1_list, batch_first=True) c2_padded = pad_sequence(mut_c2_list, batch_first=True) B = c1_padded.shape[0] N1 = c1_padded.shape[1] N2 = c2_padded.shape[1] c1_mask_list, c2_mask_list = [], [] for i in range(B): l1 = mut_c1_list[i].shape[0] l2 = mut_c2_list[i].shape[0] m1 = [True]*l1 + [False]*(N1-l1) m2 = [True]*l2 + [False]*(N2-l2) c1_mask_list.append(torch.tensor(m1, dtype=torch.bool)) c2_mask_list.append(torch.tensor(m2, dtype=torch.bool)) c1_mask = torch.stack(c1_mask_list, dim=0) c2_mask = torch.stack(c2_mask_list, dim=0) y_mut = torch.cat(mut_y_list, dim=0) # pad wildtype w1_padded = pad_sequence(wt_c1_list, batch_first=True) w2_padded = pad_sequence(wt_c2_list, batch_first=True) N1w = w1_padded.shape[1] N2w = w2_padded.shape[1] w1_mask_list, w2_mask_list = [], [] for i in range(B): l1 = wt_c1_list[i].shape[0] l2 = wt_c2_list[i].shape[0] m1 = [True]*l1 + [False]*(N1w-l1) m2 = [True]*l2 + [False]*(N2w-l2) w1_mask_list.append(torch.tensor(m1, dtype=torch.bool)) w2_mask_list.append(torch.tensor(m2, dtype=torch.bool)) w1_mask = torch.stack(w1_mask_list, dim=0) w2_mask = torch.stack(w2_mask_list, dim=0) y_wt = torch.cat(wt_y_list, dim=0) has_wt_list = [] is_wt_list = [] # NEW: Track which samples ARE WT (not just have WT reference) has_dg_list = [] has_ddg_list = [] # Track which samples have valid explicit ddG has_inferred_ddg_list = [] # NEW: Track which samples have inferred ddG has_both_list = [] ddg_list = [] ddg_inferred_list = [] # NEW: Inferred ddG values # DEBUG: Track data consistency n_has_ddg_true = 0 n_ddg_zero = 0 n_ddg_nan = 0 for i in range(B): # from meta - use has_any_wt to include both real and inferred WT sequences has_wt_list.append(meta_list[i].get("has_any_wt", meta_list[i].get("has_real_wt", False))) is_wt_list.append(meta_list[i].get("is_wt", False)) # NEW: Whether sample IS a WT sample (not mutant) has_dg_list.append(meta_list[i].get("has_dg", False)) # Default False to prevent false positives # FIX: Include inferred ddG in has_ddg flag so validation samples with dG_mut and dG_wt are used has_explicit_ddg = meta_list[i].get("has_ddg", False) has_inferred_ddg_flag = meta_list[i].get("has_inferred_ddg", False) # has_ddg should be True if we have EITHER explicit OR inferred ddG has_ddg_flag = has_explicit_ddg or has_inferred_ddg_flag has_ddg_list.append(has_ddg_flag) has_inferred_ddg_list.append(has_inferred_ddg_flag) has_both_list.append(meta_list[i].get("has_both_dg_ddg", False)) # For symmetric consistency # FIX: Use explicit ddG if available, otherwise use inferred ddG (dG_mut - dG_wt) ddg_val = meta_list[i].get("ddg", float('nan')) ddg_inf_val = meta_list[i].get("ddg_inferred", float('nan')) is_explicit_nan = ddg_val != ddg_val is_inferred_nan = ddg_inf_val != ddg_inf_val # DEBUG: Check for data consistency issues if has_explicit_ddg: n_has_ddg_true += 1 if is_explicit_nan: n_ddg_nan += 1 elif abs(ddg_val) < 1e-8: n_ddg_zero += 1 # Priority: explicit ddG > inferred ddG > 0.0 fallback (masked out) if not is_explicit_nan: ddg_list.append(ddg_val) elif not is_inferred_nan: ddg_list.append(ddg_inf_val) # Use inferred ddG when explicit unavailable else: ddg_list.append(0.0) # Fallback (will be masked by has_ddg=False) # Collect inferred ddG values for separate tracking (already fetched above) ddg_inferred_list.append(ddg_inf_val if not is_inferred_nan else 0.0) # DEBUG: Log batch statistics if there are issues if n_has_ddg_true > 0 and (n_ddg_nan > 0 or n_ddg_zero > B // 2): print(f"[COLLATE DEBUG] Batch has_ddg stats: {n_has_ddg_true}/{B} have has_ddg=True, " f"{n_ddg_nan} have NaN ddg (BUG!), {n_ddg_zero} have ddg≈0") has_wt = torch.tensor(has_wt_list, dtype=torch.bool) has_valid_wt = torch.tensor(has_valid_wt_list, dtype=torch.bool) # CRITICAL: Only True if WT is real (not zeros) is_wt = torch.tensor(is_wt_list, dtype=torch.bool) # Sample IS a WT sample has_dg = torch.tensor(has_dg_list, dtype=torch.bool) has_ddg = torch.tensor(has_ddg_list, dtype=torch.bool) has_inferred_ddg = torch.tensor(has_inferred_ddg_list, dtype=torch.bool) has_both_dg_ddg = torch.tensor(has_both_list, dtype=torch.bool) ddg_labels = torch.tensor(ddg_list, dtype=torch.float32) ddg_inferred_labels = torch.tensor(ddg_inferred_list, dtype=torch.float32) # DEBUG: Log WT validity stats for first few batches n_valid_wt = has_valid_wt.sum().item() n_has_wt = has_wt.sum().item() if n_has_wt > 0 and n_valid_wt < n_has_wt: print(f"[COLLATE DEBUG] WT validity: {n_valid_wt}/{n_has_wt} have valid WT embeddings " f"({n_has_wt - n_valid_wt} samples have zero-fallback and will be EXCLUDED from ddG training)") # Collect data_source for per-source metrics data_source_list = [meta_list[i].get("data_source", "unknown") for i in range(B)] # Collect source_type_ids for model conditioning source_type_id_list = [] for i in range(B): data_src = meta_list[i].get("data_source", "unknown") source_type = SOURCE_TYPE_MAP.get(data_src, DEFAULT_SOURCE_TYPE) source_type_id = SOURCE_TYPE_TO_ID[source_type] source_type_id_list.append(source_type_id) source_type_ids = torch.tensor(source_type_id_list, dtype=torch.long) out = { "mutant": (c1_padded, c1_mask, c2_padded, c2_mask, y_mut), "wildtype": (w1_padded, w1_mask, w2_padded, w2_mask, y_wt), "has_wt": has_wt, "has_valid_wt": has_valid_wt, # CRITICAL: True only if WT embeddings are real (not zeros) "is_wt": is_wt, # Sample IS a WT sample (for routing to dG head) "has_dg": has_dg, # Whether samples have absolute dG values "has_ddg": has_ddg, # Whether samples have valid explicit ddG values "has_inferred_ddg": has_inferred_ddg, # Whether samples have inferred ddG "has_both_dg_ddg": has_both_dg_ddg, # For symmetric consistency loss "ddg_labels": ddg_labels, # Direct ddG labels for BindingGym-style data "ddg_inferred_labels": ddg_inferred_labels, # Inferred ddG = dG_mut - dG_wt "data_source": data_source_list, # For per-source validation metrics "source_type_ids": source_type_ids, # For model conditioning (0=protein_complex, 1=mutation, 2=antibody_cdr) "metadata": meta_list } return out ######################################### # SiameseDataset (Simplified) # ######################################### class AdvancedSiameseDataset(Dataset): """ Dataset that handles mutation positions with a simple indicator channel. Reads columns: #Pdb, block1_sequence, block1_mut_positions, block1_mutations, block2_sequence, block2_mut_positions, block2_mutations, del_g, ... """ def __init__(self, df: pd.DataFrame, featurizer: ESM3Featurizer, embedding_dir: str, normalize_embeddings=True, augment=False, max_len=1022, wt_reference_df: pd.DataFrame = None): super().__init__() # Store WT reference DF (e.g. training set) for looking up missing WTs # This enables Implicit ddG (dG_mut - dG_wt) even if WTs are not in the current split self.wt_reference_df = wt_reference_df if wt_reference_df is not None else None initial_len = len(df) # CRITICAL FIX: Do NOT drop rows based on length because it shifts indices! # External splits (indices) rely on the original row numbers. # Instead, we TRUNCATE sequences that are too long to maintain alignment. # Identify long sequences long_mask = (df["block1_sequence"].astype(str).str.len() > max_len) | \ (df["block2_sequence"].astype(str).str.len() > max_len) n_long = long_mask.sum() if n_long > 0: print(f" [Dataset] Truncating {n_long} samples with length > {max_len} to maintain index alignment (CRITICAL FIX).") # Truncate sequences in place # Use .copy() to avoid SettingWithCopyWarning if df is a slice df = df.copy() df.loc[long_mask, "block1_sequence"] = df.loc[long_mask, "block1_sequence"].astype(str).str.slice(0, max_len) df.loc[long_mask, "block2_sequence"] = df.loc[long_mask, "block2_sequence"].astype(str).str.slice(0, max_len) # No rows dropped, so indices remain aligned with split files self.df = df.reset_index(drop=True) #region agent log try: cols = set(self.df.columns.tolist()) need = {"block1_mut_positions", "block2_mut_positions", "Mutation(s)_PDB"} missing = sorted(list(need - cols)) payload = { "sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "G", "location": "modules.py:AdvancedSiameseDataset:__init__", "message": "Dataset columns presence check for mutation positions", "data": { "n_rows": int(len(self.df)), "has_block1_mut_positions": "block1_mut_positions" in cols, "has_block2_mut_positions": "block2_mut_positions" in cols, "has_mutation_pdb": "Mutation(s)_PDB" in cols, "missing": missing, }, "timestamp": int(time.time() * 1000), } with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: f.write(json.dumps(payload, default=str) + "\n") print(f"[AGENTLOG MUTPOSCOLS] missing={missing}") except Exception: pass #endregion #region agent log # Disambiguate whether "0 positions" is happening for MUT embeddings or WT embeddings try: if not hasattr(self, "_agent_embed_call_counter"): self._agent_embed_call_counter = 0 if self._agent_embed_call_counter < 10: self._agent_embed_call_counter += 1 print( f"[AGENTLOG EMBCALL] idx={idx} role=mut " f"b1_mutpos_n={len(b1_mutpos)} b2_mutpos_n={len(b2_mutpos)} " f"seq1_len={len(item.get('seq1',''))} seq2_len={len(item.get('seq2',''))}" ) except Exception: pass #endregion # Recover antibody WTs (ANTIBODY_MUTATION) before augmentation or indexing self.df = self._recover_antibody_wts(self.df) # ---------- OPTIONAL AUGMENT: reverse mutation (mut ↔ WT) ---------- # Only augment MUTANT samples (not WT) - WT samples don't benefit from reversal # and doubling them confuses the pdb_to_wt lookup if augment: # Identify mutant rows (non-empty Mutation(s)_PDB) mut_mask = self.df["Mutation(s)_PDB"].notna() & (self.df["Mutation(s)_PDB"].str.strip() != "") mutant_df = self.df[mut_mask].copy() if len(mutant_df) > 0: # Create reversed copies of mutant samples only rev_df = mutant_df.copy() # For the reverse augmentation we invert the sign of ddg if "ddg" in rev_df.columns: rev_df["ddg"] = -rev_df["ddg"] rev_df["is_reverse"] = True # flag for reversed samples # Original samples stay as-is self.df["is_reverse"] = False self.df = pd.concat([self.df, rev_df], ignore_index=True) print(f" [Dataset] Augmented: added {len(rev_df)} reversed mutant samples (antisymmetry training)") else: self.df["is_reverse"] = False else: self.df["is_reverse"] = False # ------------------------------------------------------------------- # ---------- PAIR ID (mutant – WT) ---------------------------------- # Use PDB + cleaned‑mutation string so mutant and its WT share an ID self.df["pair_id"] = ( self.df["#Pdb"].astype(str) + "_" + self.df["Mutation(s)_cleaned"].fillna("") # WT rows have empty mutation ) # ------------------------------------------------------------------- self.featurizer = featurizer self.embedding_dir = Path(embedding_dir) self.embedding_dir.mkdir(exist_ok=True, parents=True) self.normalize = normalize_embeddings self.samples = [] self._embedding_cache = {} # LRU-style cache for frequently accessed embeddings self._cache_max_size = 20000 # Cache up to 20k embeddings (~20-40GB RAM) self._cache_hits = 0 self._cache_misses = 0 # map each PDB to a wildtype row index if it exists print(f" [Dataset] Building WT index for {len(self.df)} rows...") self.pdb_to_wt = {} for i, row in self.df.iterrows(): pdb = row["#Pdb"] mut_str = row.get("Mutation(s)_PDB","") is_wt = (pd.isna(mut_str) or mut_str.strip()=="") if is_wt and pdb not in self.pdb_to_wt: self.pdb_to_wt[pdb] = i # Build external WT map if reference DF is provided self.external_pdb_to_wt = {} if self.wt_reference_df is not None: print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...") for i, row in self.wt_reference_df.iterrows(): # Only index actual WTs mut_str = row.get("Mutation(s)_PDB","") is_wt = (pd.isna(mut_str) or mut_str.strip()=="") if 'is_wt' in row: # Prioritize pre-computed flag is_wt = is_wt or row['is_wt'] pdb = row["#Pdb"] if is_wt and pdb not in self.external_pdb_to_wt: self.external_pdb_to_wt[pdb] = i print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.") # Build external WT map if reference DF is provided self.external_pdb_to_wt = {} if self.wt_reference_df is not None: print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...") for i, row in self.wt_reference_df.iterrows(): # Only index actual WTs mut_str = row.get("Mutation(s)_PDB","") is_wt = (pd.isna(mut_str) or mut_str.strip()=="") # Also check 'is_wt' column if present if 'is_wt' in row: is_wt = is_wt or row['is_wt'] pdb = row["#Pdb"] if is_wt and pdb not in self.external_pdb_to_wt: self.external_pdb_to_wt[pdb] = i print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.") # LAZY LOADING: Only store metadata, NOT embeddings # Embeddings will be loaded on-demand in __getitem__ print(f" [Dataset] Building sample metadata for {len(self.df)} rows (lazy loading)...") from tqdm import tqdm for i, row in tqdm(self.df.iterrows(), total=len(self.df), desc=" Indexing"): # RESET computed mutations for this row to prevent stale data from previous iterations if hasattr(self, '_last_computed_mutpos'): del self._last_computed_mutpos pdb = row["#Pdb"] seq1 = row["block1_sequence"] seq2 = row["block2_sequence"] # Data source for per-source validation metrics data_source = row.get("data_source", "unknown") # Handle missing dG values (e.g., BindingGym has only ddG) raw_delg = row["del_g"] delg = float(raw_delg) if pd.notna(raw_delg) and raw_delg != '' else float('nan') # Get ddG if available (for ddG-only datasets like BindingGym) raw_ddg = row.get("ddg", None) ddg = float(raw_ddg) if pd.notna(raw_ddg) and raw_ddg != '' else float('nan') # Parse mutations (just store the string, parse later) b1_mutpos_str = row.get("block1_mut_positions","[]") b2_mutpos_str = row.get("block2_mut_positions","[]") # DEBUG: Print first few rows to debug disappearing mutations if i < 5: print(f"DEBUG ROW {i}: b1='{b1_mutpos_str}' ({type(b1_mutpos_str)}), b2='{b2_mutpos_str}' ({type(b2_mutpos_str)})") #region agent log try: payload = { "sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "G", "location": "modules.py:AdvancedSiameseDataset:__init__:row0_4", "message": "Raw mutpos strings from df row (first few)", "data": { "i": int(i), "b1_mutpos_str": str(b1_mutpos_str), "b2_mutpos_str": str(b2_mutpos_str), "mutation_pdb": str(row.get("Mutation(s)_PDB", "")), }, "timestamp": int(time.time() * 1000), } with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: f.write(json.dumps(payload, default=str) + "\n") print(f"[AGENTLOG MUTPOSRAW] i={i} b1={b1_mutpos_str} b2={b2_mutpos_str} mut={row.get('Mutation(s)_PDB','')}") except Exception: pass #endregion # Get chain info for block assignment during WT inference b1_chains = str(row.get("block1_chains", "")).upper() b2_chains = str(row.get("block2_chains", "")).upper() mut_str = row.get("Mutation(s)_PDB","") is_wt = (pd.isna(mut_str) or mut_str.strip()=="") wt_idx = self.pdb_to_wt.get(pdb, None) # Get WT info if available (Internal > External) row_wt = None wt_source = None if not hasattr(self, '_wt_source_stats'): self._wt_source_stats = {'internal': 0, 'external': 0} if wt_idx is not None: row_wt = self.df.iloc[wt_idx] wt_source = 'internal' self._wt_source_stats['internal'] += 1 elif pdb in self.external_pdb_to_wt: ext_idx = self.external_pdb_to_wt[pdb] row_wt = self.wt_reference_df.iloc[ext_idx] wt_source = 'external' self._wt_source_stats['external'] += 1 if row_wt is not None: seq1_wt = row_wt["block1_sequence"] seq2_wt = row_wt["block2_sequence"] raw_delg_wt = row_wt["del_g"] delg_wt = float(raw_delg_wt) if pd.notna(raw_delg_wt) and raw_delg_wt != '' else float('nan') b1_wtpos_str = row_wt.get("block1_mut_positions","[]") b2_wtpos_str = row_wt.get("block2_mut_positions","[]") # BUGFIX: If we have WT but NO mutation positions in CSV, we MUST calculate them! # This fixes the "0% mutation positions" issue when the CSV column is empty/missing if not is_wt and (b1_mutpos_str in ["[]", "", "nan", "None"] and b2_mutpos_str in ["[]", "", "nan", "None"]): # Run inference to locate mutations (side-effect: sets _last_computed_mutpos) # We ignore the inferred WT sequence since we have the real one # We pass "[]" to force scanning PDB positions self._infer_wt_sequences( seq1, seq2, mut_str, "[]", "[]", b1_chains, b2_chains ) # Update mutpos_str if we found mutations if hasattr(self, '_last_computed_mutpos'): comp_b1, comp_b2 = self._last_computed_mutpos if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1: b1_mutpos_str = str(comp_b1) if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2: b2_mutpos_str = str(comp_b2) else: # No WT row found - try to INFER WT sequence by reversing mutations # This is crucial for BindingGym data which stores mutant sequences only seq1_wt, seq2_wt = self._infer_wt_sequences( seq1, seq2, mut_str, b1_mutpos_str, b2_mutpos_str, b1_chains, b2_chains # Chain info for block assignment ) delg_wt = float('nan') # No WT dG available for inferred sequences b1_wtpos_str, b2_wtpos_str = "[]", "[]" # WT has no mutation positions # FIX Bug #3: Use computed mutation positions from inference if original empty if hasattr(self, '_last_computed_mutpos'): comp_b1, comp_b2 = self._last_computed_mutpos if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1: b1_mutpos_str = str(comp_b1) if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2: b2_mutpos_str = str(comp_b2) # Check if this sample has BOTH dG and ddG (for symmetric consistency) has_dg = not (delg != delg) # False if NaN has_ddg = not (ddg != ddg) # False if NaN has_both = has_dg and has_ddg # NEW: Compute inferred ddG for samples with dG_mut and dG_wt but no explicit ddG # ddG_inferred = dG_mut - dG_wt (can be used as additional training signal) has_dg_wt = not (delg_wt != delg_wt) # False if NaN has_inferred_ddg = has_dg and has_dg_wt and (not has_ddg) # Only if no explicit ddG if has_inferred_ddg: ddg_inferred = delg - delg_wt # Computed from dG values else: ddg_inferred = float('nan') # Track WT availability: real (from row), inferred, or none has_real_wt = (wt_idx is not None) has_inferred_wt = (wt_idx is None and seq1_wt is not None and seq2_wt is not None) has_any_wt = has_real_wt or has_inferred_wt # Store ONLY metadata - no embeddings loaded yet! is_reverse = row.get("is_reverse", False) # Track reversed samples # CRITICAL: Swap sequences and dG for reversed samples (antisymmetry augmentation) if is_reverse: # Swap sequences: New Mutant = Old WT, New WT = Old Mutant if seq1_wt is not None and seq2_wt is not None: seq1, seq1_wt = seq1_wt, seq1 seq2, seq2_wt = seq2_wt, seq2 # Swap dG values delg, delg_wt = delg_wt, delg # Negate inferred ddG (dG_new_mut - dG_new_wt = dG_old_wt - dG_old_mut = -(dG_old_mut - dG_old_wt)) if not math.isnan(ddg_inferred): ddg_inferred = -ddg_inferred # Note: Explicit 'ddg' is already negated in __init__ augmentation logic # Note: We do NOT swap mutation positions because the indices of difference # are the same for A->B vs B->A. We want the 'input' (new mutant) to have # the indicator flags at the difference sites. self.samples.append({ "pdb": pdb, "is_wt": is_wt, "is_reverse": is_reverse, # True if this is a reversed (augmented) sample "seq1": seq1, "seq2": seq2, "delg": delg, "seq1_wt": seq1_wt, "seq2_wt": seq2_wt, "delg_wt": delg_wt, "ddg": ddg, "ddg_inferred": ddg_inferred, # NEW: Computed from dG_mut - dG_wt "has_dg": has_dg, "has_ddg": has_ddg, "has_inferred_ddg": has_inferred_ddg, # NEW: True if ddg_inferred is valid "has_both_dg_ddg": has_both, "has_real_wt": has_real_wt, "has_inferred_wt": has_inferred_wt, "has_any_wt": has_any_wt, "b1_mutpos_str": b1_mutpos_str, "b2_mutpos_str": b2_mutpos_str, "b1_wtpos_str": b1_wtpos_str, "b2_wtpos_str": b2_wtpos_str, "data_source": data_source }) # Log WT inference statistics n_real_wt = sum(1 for s in self.samples if s["has_real_wt"]) n_inferred_wt = sum(1 for s in self.samples if s["has_inferred_wt"]) n_no_wt = len(self.samples) - n_real_wt - n_inferred_wt # Detailed stats for Real WTs (Internal vs External) if hasattr(self, '_wt_source_stats'): n_internal = self._wt_source_stats.get('internal', 0) n_external = self._wt_source_stats.get('external', 0) source_msg = f" (Internal: {n_internal}, External: {n_external})" else: source_msg = "" print(f" [Dataset] Ready! {len(self.samples)} samples indexed (embeddings loaded on-demand)") print(f" [Dataset] WT stats: {n_real_wt} real WT{source_msg}, {n_inferred_wt} inferred WT, {n_no_wt} no WT") # Log detailed failure breakdown (for debugging) if hasattr(self, '_wt_inference_failures') and hasattr(self, '_wt_inference_fail_count'): print(f" [Dataset] ⚠️ WT inference failed for {self._wt_inference_fail_count} samples:") fail_dict = self._wt_inference_failures # Count by category (note: these are capped sample counts, not totals) n_no_pdb = len(fail_dict.get('no_pdb', [])) n_del_ins = len(fail_dict.get('del_ins_only', [])) n_parse = len(fail_dict.get('parse_fail', [])) if n_no_pdb > 0: print(f" - ANTIBODY samples (no PDB structure): {self._wt_inference_fail_count} samples") print(f" (These are antibody design samples without original PDB - only dG usable)") elif n_del_ins > 0 or n_parse > 0: print(f" - DEL/INS/stop-codon (can't reverse): counted") print(f" - Parsing failed (unknown format): counted") # Show samples for non-ANTIBODY failures if fail_dict.get('parse_fail') and n_no_pdb == 0: print(f" Sample parse failures:") for mut in fail_dict['parse_fail'][:5]: print(f" '{mut}'") def _parse_mutpos(self, pos_str) -> List[int]: """ pos_str might be '[]' or '[170, 172]' etc. We'll do a simple parse. """ # Handle NaN, None, or non-string values if pos_str is None or (isinstance(pos_str, float) and str(pos_str) == 'nan'): return [] if not isinstance(pos_str, str): pos_str = str(pos_str) pos_str = pos_str.strip() if pos_str.startswith("[") and pos_str.endswith("]"): inside = pos_str[1:-1].strip() if not inside: return [] # split by comma arr = inside.split(",") out = [] for x in arr: x_ = x.strip() if x_: out.append(int(x_)) return out return [] def _recover_antibody_wts(self, df: pd.DataFrame) -> pd.DataFrame: """ Recover WT information for antibody samples (ANTIBODY_MUTATION) by finding the closest-to-consensus sequence in each antigen group. Strategy: 1. Identify samples with 'ANTIBODY_MUTATION' 2. Group by antigen (block2_sequence) 3. Assign unique Pseudo-PDB ID to each group (e.g. ANTIBODY_GRP_xxx) 4. For same-length groups: find sequence closest to consensus as WT 5. For variable-length groups: fallback to best binder (lowest del_g) 6. Mark selected sequence as WT (clear mutation string) """ from collections import Counter # Identify antibody mutation rows mask = df['Mutation(s)_PDB'].astype(str).str.contains('ANTIBODY_MUTATION', na=False) if not mask.any(): return df print(f" [Dataset] Attempting to recover WT for {mask.sum()} antibody samples...") recovered_count = 0 n_groups = 0 n_consensus = 0 n_median = 0 n_fallback = 0 # We need a copy to avoid SettingWithCopy warnings if df is a slice df = df.copy() # Add a temporary column for grouping (hash of antigen sequence) df['temp_antigen_hash'] = df['block2_sequence'].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest()) # Get hashes for antibody rows ab_hashes = df.loc[mask, 'temp_antigen_hash'].unique() for h in ab_hashes: # Get all antibody rows for this antigen grp_mask = mask & (df['temp_antigen_hash'] == h) grp_indices = df.index[grp_mask] if len(grp_indices) == 0: continue n_groups += 1 # 1. Create unique Pseudo-PDB ID pseudo_pdb = f"ANTIBODY_GRP_{h[:8]}" df.loc[grp_indices, '#Pdb'] = pseudo_pdb # 2. Select WT: closest-to-consensus (same-length) or best-binder (variable-length) seqs = df.loc[grp_indices, 'block1_sequence'].tolist() seq_lens = set(len(s) for s in seqs) wt_idx = None if len(seq_lens) == 1: # SAME LENGTH: Use closest-to-consensus seq_len = list(seq_lens)[0] # Build consensus sequence consensus = [] for pos in range(seq_len): residues = [s[pos] for s in seqs] counts = Counter(residues) most_common = counts.most_common(1)[0][0] consensus.append(most_common) consensus_seq = ''.join(consensus) # Find sequence with minimum Hamming distance to consensus min_dist = float('inf') for idx in grp_indices: seq = df.at[idx, 'block1_sequence'] dist = sum(c1 != c2 for c1, c2 in zip(seq, consensus_seq)) if dist < min_dist: min_dist = dist wt_idx = idx n_consensus += 1 else: # VARIABLE LENGTH: Fallback to median binder (more representative than best) if 'del_g' in df.columns: delg_vals = pd.to_numeric(df.loc[grp_indices, 'del_g'], errors='coerce').dropna() if len(delg_vals) > 0: # Find index of value closest to median median_val = delg_vals.median() median_idx = (delg_vals - median_val).abs().idxmin() wt_idx = median_idx n_median += 1 # FINAL FALLBACK: Pick first sample if no other method works (e.g., all NaN dG) if wt_idx is None and len(grp_indices) > 0: wt_idx = grp_indices[0] n_fallback += 1 # 3. Mark selected sequence as WT if wt_idx is not None: df.at[wt_idx, 'Mutation(s)_PDB'] = "" recovered_count += len(grp_indices) # Cleanup df.drop(columns=['temp_antigen_hash'], inplace=True, errors='ignore') print(f" [Dataset] Recovered {recovered_count} antibody samples ({n_groups} groups):") print(f" - {n_consensus} groups via closest-to-consensus") print(f" - {n_median} groups via median-binder (variable-length)") if n_fallback > 0: print(f" - {n_fallback} groups via first-sample fallback (no dG data)") return df def _infer_wt_sequences(self, mut_seq1: str, mut_seq2: str, mutation_str: str, b1_mutpos_str: str, b2_mutpos_str: str, b1_chains: str = "", b2_chains: str = "") -> Tuple[Optional[str], Optional[str]]: """ Infer wildtype sequences by reversing mutations in the mutant sequences. IMPROVED: Instead of relying on PDB positions (which don't match 0-indexed sequence positions), this version searches for the mutant residue and reverses it. Also computes actual mutation positions as byproduct. Mutations are in formats like: - BindingGym: "H:P53L" or "H:P53L,H:Y57C" (chain:WTresPOSmutres) - SKEMPI: "HP53L" or "CA182A" (chainWTresPOSmutres) Args: mut_seq1: Mutant sequence for block1 mut_seq2: Mutant sequence for block2 mutation_str: Raw mutation string from data b1_mutpos_str: Mutation positions for block1 (e.g., "[52, 56]") b2_mutpos_str: Mutation positions for block2 b1_chains: Chain letters in block1 (e.g., "AB") b2_chains: Chain letters in block2 (e.g., "HL") Returns: Tuple of (wt_seq1, wt_seq2) or (None, None) if inference fails """ import re if pd.isna(mutation_str) or str(mutation_str).strip() == '': # No mutations = this IS the wildtype return mut_seq1, mut_seq2 # FALLBACK: Handle ANTIBODY_MUTATION samples that couldn't be recovered mutation_str_upper = str(mutation_str).strip().upper() if 'ANTIBODY_MUTATION' in mutation_str_upper or mutation_str_upper == 'ANTIBODY_MUTATION': if not hasattr(self, '_wt_inference_failures'): self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'no_pdb': [], 'other': []} self._wt_inference_fail_count = 0 self._wt_inference_fail_count += 1 if len(self._wt_inference_failures['no_pdb']) < 5: self._wt_inference_failures['no_pdb'].append(mutation_str[:80]) return None, None try: # Parse mutation string to extract (chain, position, original_AA, mutant_AA) mutations = [] mutation_str = str(mutation_str).strip() # Split by common delimiters parts = re.split(r'[,;]', mutation_str) for part in parts: part = part.strip().strip('"\'') if not part: continue # Skip deletion/insertion markers - can't reverse these if 'DEL' in part.upper() or 'INS' in part.upper() or '*' in part: continue # BindingGym format: "H:P53L" or "L:K103R" if ':' in part: chain_mut = part.split(':') if len(chain_mut) >= 2: chain = chain_mut[0].strip().upper() for mut_part in chain_mut[1:]: mut_part = mut_part.strip() if not mut_part: continue match = re.match(r'([A-Z])(\d+)([A-Z])', mut_part) if match: wt_aa = match.group(1) pos = int(match.group(2)) # PDB-numbered (1-indexed) mut_aa = match.group(3) mutations.append((chain, pos, wt_aa, mut_aa)) else: # SKEMPI format: "CA182A" = C(WTresidue) + A(chain) + 182(pos) + A(mutant) # Format: WTresidue + ChainID + Position[insertcode] + MutResidue # Example: CA182A means Cysteine at chain A position 182 mutated to Alanine match = re.match(r'([A-Z])([A-Z])(-?\d+[a-z]?)([A-Z])', part) if match: wt_aa = match.group(1) # First char is WT residue chain = match.group(2).upper() # Second char is chain ID pos_str = match.group(3) pos = int(re.match(r'-?\d+', pos_str).group()) mut_aa = match.group(4) # Last char is mutant residue mutations.append((chain, pos, wt_aa, mut_aa)) else: # Simple format without chain: "F139A" (used by PEPBI) # Format: WTresidue + Position + MutResidue match = re.match(r'([A-Z])(\d+)([A-Z])', part) if match: wt_aa = match.group(1) pos = int(match.group(2)) mut_aa = match.group(3) # No chain info - will try both blocks mutations.append(('?', pos, wt_aa, mut_aa)) if not mutations: if not hasattr(self, '_wt_inference_failures'): self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'other': []} self._wt_inference_fail_count = 0 self._wt_inference_fail_count += 1 if 'DEL' in mutation_str.upper() or 'INS' in mutation_str.upper() or '*' in mutation_str: category = 'del_ins_only' else: category = 'parse_fail' if len(self._wt_inference_failures.get(category, [])) < 10: self._wt_inference_failures.setdefault(category, []).append(mutation_str[:80]) return None, None # Convert sequences to lists for mutation wt_seq1_list = list(mut_seq1) if mut_seq1 else [] wt_seq2_list = list(mut_seq2) if mut_seq2 else [] # Build chain sets for block assignment b1_chain_set = set(b1_chains.upper()) if b1_chains else set() b2_chain_set = set(b2_chains.upper()) if b2_chains else set() # Parse PRECOMPUTED mutation positions (these are correct 0-indexed seq positions) # PDB residue numbers often don't match sequence indices due to numbering offsets precomputed_b1_positions = self._parse_mutpos(b1_mutpos_str) precomputed_b2_positions = self._parse_mutpos(b2_mutpos_str) # Track reversal success if not hasattr(self, '_wt_inference_stats'): self._wt_inference_stats = {'reversed': 0, 'not_found': 0, 'total': 0} # Also track actual mutation positions found found_positions_b1 = [] found_positions_b2 = [] # STRATEGY 1: Use precomputed positions if available (MOST RELIABLE) # These were computed during preprocessing with correct PDB-to-sequence mapping if precomputed_b1_positions or precomputed_b2_positions: pos_idx = 0 for chain, pdb_pos, wt_aa, mut_aa in mutations: self._wt_inference_stats['total'] += 1 reversed_this = False # Determine which block based on chain if chain in b2_chain_set: # Use precomputed block2 positions if pos_idx < len(precomputed_b2_positions): seq_idx = precomputed_b2_positions[pos_idx] if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa: wt_seq2_list[seq_idx] = wt_aa reversed_this = True found_positions_b2.append(seq_idx) elif chain in b1_chain_set: # Use precomputed block1 positions if pos_idx < len(precomputed_b1_positions): seq_idx = precomputed_b1_positions[pos_idx] if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa: wt_seq1_list[seq_idx] = wt_aa reversed_this = True found_positions_b1.append(seq_idx) else: # Chain unknown - try both precomputed positions if pos_idx < len(precomputed_b1_positions): seq_idx = precomputed_b1_positions[pos_idx] if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa: wt_seq1_list[seq_idx] = wt_aa reversed_this = True found_positions_b1.append(seq_idx) if not reversed_this and pos_idx < len(precomputed_b2_positions): seq_idx = precomputed_b2_positions[pos_idx] if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa: wt_seq2_list[seq_idx] = wt_aa reversed_this = True found_positions_b2.append(seq_idx) if reversed_this: self._wt_inference_stats['reversed'] += 1 else: self._wt_inference_stats['not_found'] += 1 pos_idx += 1 self._last_computed_mutpos = (found_positions_b1, found_positions_b2) return ''.join(wt_seq1_list), ''.join(wt_seq2_list) # STRATEGY 2: Fall back to PDB position-based search (less reliable) for chain, pdb_pos, wt_aa, mut_aa in mutations: self._wt_inference_stats['total'] += 1 reversed_this = False found_idx = None # Determine which block(s) to search based on chain chain_known = chain in b1_chain_set or chain in b2_chain_set if chain in b1_chain_set: blocks_to_try = [(wt_seq1_list, True, found_positions_b1)] elif chain in b2_chain_set: blocks_to_try = [(wt_seq2_list, False, found_positions_b2)] else: # Chain info unavailable - try BOTH blocks blocks_to_try = [ (wt_seq1_list, True, found_positions_b1), (wt_seq2_list, False, found_positions_b2) ] for target_seq, is_block1, pos_list in blocks_to_try: if reversed_this: break # Already found in previous block guess_idx = pdb_pos - 1 # Convert to 0-indexed # Strategy 1: Try exact position if in bounds if 0 <= guess_idx < len(target_seq) and target_seq[guess_idx] == mut_aa: found_idx = guess_idx else: # Strategy 2: Search ±50 window around expected position search_start = max(0, pdb_pos - 50) search_end = min(len(target_seq), pdb_pos + 50) for idx in range(search_start, search_end): if target_seq[idx] == mut_aa: found_idx = idx break # Strategy 3: If position was out of bounds AND chain unknown, # search the ENTIRE sequence as last resort if found_idx is None and not chain_known: if guess_idx >= len(target_seq) or guess_idx < 0: # Position was out of bounds - search entire sequence for idx in range(len(target_seq)): if target_seq[idx] == mut_aa: found_idx = idx break if found_idx is not None: target_seq[found_idx] = wt_aa # Reverse the mutation! reversed_this = True pos_list.append(found_idx) if reversed_this: self._wt_inference_stats['reversed'] += 1 else: self._wt_inference_stats['not_found'] += 1 # Store computed mutation positions for later use (helps with Bug #3) # These are the ACTUAL 0-indexed positions in the sequence self._last_computed_mutpos = (found_positions_b1, found_positions_b2) return ''.join(wt_seq1_list), ''.join(wt_seq2_list) except Exception as e: # On any error, return None to indicate inference failed return None, None def _get_embedding(self, seq: str, mut_positions: List[int]) -> torch.Tensor: """ Basic embedding with mutation position indicator channel. Args: seq: The protein sequence mut_positions: List of positions that are mutated (0-indexed) """ # Get base ESM embedding (already ensures min length of 2) base_emb = self._get_or_create_embedding(seq) # => [L, 1152] base_emb = base_emb.cpu() # Get sequence length and embedding dimension L, D = base_emb.shape #region agent log try: if not hasattr(self, "_agent_log_counter"): self._agent_log_counter = 0 if self._agent_log_counter < 5: self._agent_log_counter += 1 last1_stats = None last2_stats = None if D >= 1153: v1 = base_emb[:, -1] last1_stats = { "min": float(v1.min().item()), "max": float(v1.max().item()), "mean": float(v1.float().mean().item()), "std": float(v1.float().std().item()), } if D >= 1154: v2 = base_emb[:, -2] last2_stats = { "min": float(v2.min().item()), "max": float(v2.max().item()), "mean": float(v2.float().mean().item()), "std": float(v2.float().std().item()), } payload = { "sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "F", "location": "modules.py:AdvancedSiameseDataset:_get_embedding", "message": "Base embedding shape + tail-channel stats before appending mutation indicator", "data": { "L": int(L), "D": int(D), "mut_positions_n": int(len(mut_positions) if mut_positions is not None else -1), "mut_positions_first5": (mut_positions[:5] if mut_positions else []), "base_last1": last1_stats, "base_last2": last2_stats, }, "timestamp": int(time.time() * 1000), } with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: f.write(json.dumps(payload, default=str) + "\n") # Also emit a concise line to stdout/logs (useful on cluster runs) print(f"[AGENTLOG EMB] D={D} mut_n={len(mut_positions) if mut_positions else 0} last1={last1_stats} last2={last2_stats}") except Exception: pass #endregion # Create mutation indicator channel (just one channel) # FIX FOR DOUBLE-INDICATOR BUG: Check if base_emb already has indicator (D=1153) # If D=1153, the cached embedding already has an old indicator channel - OVERWRITE it # If D=1152, this is a fresh ESM embedding - APPEND indicator channel D = base_emb.shape[-1] L = base_emb.shape[0] if D == 1153: # Already has indicator channel (from cache) - overwrite it with correct mutation positions new_emb = base_emb.clone() new_emb[:, -1] = 0.0 # Reset old indicator for pos in mut_positions: if isinstance(pos, int) and 0 <= pos < L: new_emb[pos, -1] = 1.0 print(f"[AGENTLOG INDICATOR-FIX] D=1153 OVERWRITING last channel with {len(mut_positions)} positions") else: # Fresh ESM embedding (D=1152) - append indicator channel chan = torch.zeros((L, 1), dtype=base_emb.dtype, device=base_emb.device) for pos in mut_positions: if isinstance(pos, int) and 0 <= pos < L: chan[pos, 0] = 1.0 new_emb = torch.cat([base_emb, chan], dim=-1) print(f"[AGENTLOG INDICATOR-FIX] D={D} APPENDING indicator channel with {len(mut_positions)} positions") return new_emb def _get_or_create_embedding(self, seq: str) -> torch.Tensor: # Check LRU cache first (limited size to control memory) if seq in self._embedding_cache: self._cache_hits += 1 return self._embedding_cache[seq].clone() seq_hash = hashlib.md5(seq.encode()).hexdigest() pt_file = self.embedding_dir / f"{seq_hash}.pt" npy_file = self.embedding_dir / f"{seq_hash}.npy" emb = None load_source = None # Track where embedding came from # Try .npy first (pre-computed), then .pt if npy_file.is_file(): try: import numpy as np emb = torch.from_numpy(np.load(npy_file)) load_source = "npy" except Exception: pass if emb is None and pt_file.is_file(): try: emb = torch.load(pt_file, map_location="cpu") load_source = "pt" except Exception: pt_file.unlink(missing_ok=True) # Delete corrupted file if emb is None: # On-the-fly embedding generation for missing sequences (e.g., inferred WT) # This is slower but ensures accurate embeddings try: emb = self.featurizer.transform(seq) # [L, 1152] # Save for future use torch.save(emb, pt_file) load_source = "generated" # Track on-the-fly generation stats if not hasattr(self, '_on_the_fly_count'): self._on_the_fly_count = 0 self._on_the_fly_count += 1 # Log first few on-the-fly generations if self._on_the_fly_count <= 5: print(f"[EMBEDDING] Generated on-the-fly #{self._on_the_fly_count}: len={len(seq)}, saved to {pt_file.name}") elif self._on_the_fly_count == 6: print(f"[EMBEDDING] Generated 5+ embeddings on-the-fly (suppressing further logs)") except Exception as e: raise RuntimeError( f"Embedding not found and on-the-fly generation failed for sequence (len={len(seq)}): {e}" ) #region agent log try: if not hasattr(self, "_agent_embload_counter"): self._agent_embload_counter = 0 if self._agent_embload_counter < 8: self._agent_embload_counter += 1 shape = tuple(int(x) for x in emb.shape) D = int(shape[1]) if len(shape) == 2 else None payload = { "sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "A", "location": "modules.py:AdvancedSiameseDataset:_get_or_create_embedding", "message": "Loaded embedding tensor (source + shape) before any indicator is appended", "data": { "load_source": load_source, "seq_len": int(len(seq)), "shape": shape, "D": D, "looks_like_has_indicator": bool(D is not None and D >= 1153), "file_pt_exists": bool(pt_file.is_file()), "file_npy_exists": bool(npy_file.is_file()), }, "timestamp": int(time.time() * 1000), } with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: f.write(json.dumps(payload, default=str) + "\n") print(f"[AGENTLOG EMBLOAD] src={load_source} shape={shape} D={D}") except Exception: pass #endregion # SAFETY: Ensure embedding has valid shape (at least 5 residues for interpolation) if emb.shape[0] < 5: # Pad to minimum length of 5 by repeating repeats = (5 // emb.shape[0]) + 1 emb = emb.repeat(repeats, 1)[:5] # Ensure exactly 5 rows # Track cache miss self._cache_misses += 1 # Add to LRU cache (evict oldest if full) if len(self._embedding_cache) >= self._cache_max_size: # Remove oldest entry (first key in dict) oldest_key = next(iter(self._embedding_cache)) del self._embedding_cache[oldest_key] self._embedding_cache[seq] = emb return emb.clone() # Return clone to avoid mutation issues def get_cache_stats(self): """Return cache statistics.""" total = self._cache_hits + self._cache_misses hit_rate = (self._cache_hits / total * 100) if total > 0 else 0 on_the_fly = getattr(self, '_on_the_fly_count', 0) wt_missing = getattr(self, '_wt_missing_count', 0) return { "hits": self._cache_hits, "misses": self._cache_misses, "total": total, "hit_rate": hit_rate, "cache_size": len(self._embedding_cache), "cache_max": self._cache_max_size, "on_the_fly_generated": on_the_fly, "wt_embedding_failed": wt_missing } def print_cache_stats(self): """Print cache statistics.""" stats = self.get_cache_stats() print(f" [Cache] Hits: {stats['hits']:,} | Misses: {stats['misses']:,} | " f"Hit Rate: {stats['hit_rate']:.1f}% | Size: {stats['cache_size']:,}/{stats['cache_max']:,}") if stats['on_the_fly_generated'] > 0: print(f" [Cache] On-the-fly generated: {stats['on_the_fly_generated']:,} embeddings") if stats['wt_embedding_failed'] > 0: print(f" [Cache] ⚠️ WT embedding failures: {stats['wt_embedding_failed']:,} (excluded from ddG training)") def __len__(self): return len(self.samples) def __getitem__(self, idx): item = self.samples[idx] # DEBUG: Track sequence difference statistics if not hasattr(self, '_seq_diff_stats'): self._seq_diff_stats = {'same': 0, 'different': 0, 'no_wt': 0} if not hasattr(self, '_mutpos_stats'): self._mutpos_stats = {'has_mutpos': 0, 'no_mutpos': 0} # LAZY LOADING: Load embeddings on-demand b1_mutpos = self._parse_mutpos(item["b1_mutpos_str"]) b2_mutpos = self._parse_mutpos(item["b2_mutpos_str"]) #region agent log try: if not hasattr(self, "_agent_mutpos_getitem_counter"): self._agent_mutpos_getitem_counter = 0 if self._agent_mutpos_getitem_counter < 20: self._agent_mutpos_getitem_counter += 1 payload = { "sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "G", "location": "modules.py:AdvancedSiameseDataset:__getitem__", "message": "Parsed mut_positions passed to _get_embedding", "data": { "idx": int(idx), "pdb": str(item.get("pdb")), "is_wt": bool(item.get("is_wt")), "b1_mutpos_str": str(item.get("b1_mutpos_str")), "b2_mutpos_str": str(item.get("b2_mutpos_str")), "b1_mutpos_n": int(len(b1_mutpos)), "b2_mutpos_n": int(len(b2_mutpos)), "b1_mutpos_first5": b1_mutpos[:5], "b2_mutpos_first5": b2_mutpos[:5], }, "timestamp": int(time.time() * 1000), } with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: f.write(json.dumps(payload, default=str) + "\n") print(f"[AGENTLOG MUTPOSGET] idx={idx} b1n={len(b1_mutpos)} b2n={len(b2_mutpos)} b1str={item.get('b1_mutpos_str')} b2str={item.get('b2_mutpos_str')}") except Exception: pass #endregion # Track mutation position statistics if len(b1_mutpos) > 0 or len(b2_mutpos) > 0: self._mutpos_stats['has_mutpos'] += 1 else: self._mutpos_stats['no_mutpos'] += 1 # Log mutation position stats periodically total = sum(self._mutpos_stats.values()) if total in [100, 1000, 10000]: has_mp = self._mutpos_stats['has_mutpos'] no_mp = self._mutpos_stats['no_mutpos'] print(f" [MUTPOS] After {total} samples: {has_mp} have mutation positions ({100*has_mp/total:.1f}%), " f"{no_mp} have NO mutation positions ({100*no_mp/total:.1f}%)") c1_emb = self._get_embedding(item["seq1"], b1_mutpos) c2_emb = self._get_embedding(item["seq2"], b2_mutpos) if self.normalize: c1_emb[:, :-1] = torch.nn.functional.normalize(c1_emb[:, :-1], p=2, dim=-1) c2_emb[:, :-1] = torch.nn.functional.normalize(c2_emb[:, :-1], p=2, dim=-1) # Load WT embeddings if available if item["seq1_wt"] is not None: # DEBUG: Track sequence differences seq1_same = (item["seq1"] == item["seq1_wt"]) seq2_same = (item["seq2"] == item["seq2_wt"]) if seq1_same and seq2_same: self._seq_diff_stats['same'] += 1 else: self._seq_diff_stats['different'] += 1 # Periodic logging total_samples = sum(self._seq_diff_stats.values()) if total_samples in [100, 1000, 10000, 50000]: same = self._seq_diff_stats['same'] diff = self._seq_diff_stats['different'] no_wt = self._seq_diff_stats['no_wt'] print(f" [SEQ DIFF] After {total_samples} samples: {same} same seq ({100*same/total_samples:.1f}%), " f"{diff} different ({100*diff/total_samples:.1f}%), {no_wt} no WT") b1_wtpos = self._parse_mutpos(item["b1_wtpos_str"]) b2_wtpos = self._parse_mutpos(item["b2_wtpos_str"]) #region agent log try: if not hasattr(self, "_agent_embed_call_counter_wt"): self._agent_embed_call_counter_wt = 0 if self._agent_embed_call_counter_wt < 10: self._agent_embed_call_counter_wt += 1 print( f"[AGENTLOG EMBCALL] idx={idx} role=wt " f"b1_wtpos_n={len(b1_wtpos)} b2_wtpos_n={len(b2_wtpos)} " f"seq1_wt_len={len(item.get('seq1_wt','') or '')} seq2_wt_len={len(item.get('seq2_wt','') or '')}" ) except Exception: pass #endregion try: cw1 = self._get_embedding(item["seq1_wt"], b1_wtpos) cw2 = self._get_embedding(item["seq2_wt"], b2_wtpos) except RuntimeError as e: # WT embedding unavailable - mark as no WT for this sample # DO NOT use mutant embedding as proxy - this corrupts the mutation signal! # Instead, set cw1, cw2 to None and let training handle missing WT cw1, cw2 = None, None if not hasattr(self, '_wt_missing_count'): self._wt_missing_count = 0 self._wt_missing_count += 1 if self._wt_missing_count <= 3: # Only log first 3 to avoid spam print(f" [WARN] WT embedding missing #{self._wt_missing_count}, sample will be WT-less: {e}") if cw1 is not None and self.normalize: cw1[:, :-1] = torch.nn.functional.normalize(cw1[:, :-1], p=2, dim=-1) cw2[:, :-1] = torch.nn.functional.normalize(cw2[:, :-1], p=2, dim=-1) else: cw1, cw2 = None, None self._seq_diff_stats['no_wt'] += 1 data_tuple = (c1_emb, c2_emb, item["delg"], cw1, cw2, item["delg_wt"]) meta = { "pdb": item["pdb"], "is_wt": item["is_wt"], "has_real_wt": item["has_real_wt"], "has_dg": item["has_dg"], "has_ddg": item["has_ddg"], # Whether sample has valid explicit ddG value "has_inferred_ddg": item["has_inferred_ddg"], # Whether sample has inferred ddG (dG_mut - dG_wt) "has_both_dg_ddg": item["has_both_dg_ddg"], "ddg": item["ddg"], "ddg_inferred": item["ddg_inferred"], # Inferred ddG value (needed for Fix #1) "has_any_wt": item["has_any_wt"], # Include inferred WT status (CRITICAL!) "b1_mutpos": b1_mutpos, "b2_mutpos": b2_mutpos, "data_source": item["data_source"] } return (data_tuple, meta) ######################################### # AffinityDataModule ######################################### from sklearn.model_selection import GroupKFold class AffinityDataModule(pl.LightningDataModule): """ Data module for protein binding affinity prediction. Supports multiple splitting strategies: 1. split_indices_dir: Load pre-computed cluster-based splits (RECOMMENDED) 2. use_cluster_split: Create new cluster-based splits on the fly 3. split column: Use existing 'split' column in CSV (legacy) 4. num_folds > 1: GroupKFold on PDB IDs """ def __init__( self, data_csv: str, protein_featurizer: ESM3Featurizer, embedding_dir: str = "precomputed_esm", batch_size: int = 32, num_workers: int = 4, shuffle: bool = True, num_folds: int = 1, fold_index: int = 0, # New cluster-based splitting options split_indices_dir: str = None, # Path to pre-computed split indices benchmark_indices_dir: str = None, # Path to balanced benchmark subset indices (optional override) use_cluster_split: bool = False, # Create cluster-based splits on the fly train_ratio: float = 0.70, val_ratio: float = 0.15, test_ratio: float = 0.15, random_state: int = 42 ): super().__init__() self.data_csv = data_csv self.featurizer = protein_featurizer self.embedding_dir = embedding_dir self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle self.num_folds = num_folds self.fold_index = fold_index # Cluster-based splitting options self.split_indices_dir = split_indices_dir self.benchmark_indices_dir = benchmark_indices_dir # Optional balanced benchmark override self.use_cluster_split = use_cluster_split self.train_ratio = train_ratio self.val_ratio = val_ratio self.test_ratio = test_ratio self.random_state = random_state self.train_dataset = None self.val_dataset = None self.test_dataset = None # Dual-split datasets (separate for dG and ddG heads) self.dg_train_dataset = None # WT-only training set for Stage A self.ddg_train_dataset = None # Mutation training set for Stage B self.dg_val_dataset = None self.dg_test_dataset = None self.ddg_val_dataset = None self.ddg_test_dataset = None self.use_dual_split = False def prepare_data(self): if not os.path.exists(self.data_csv): raise FileNotFoundError(f"Data CSV not found => {self.data_csv}") def setup(self, stage=None): data = pd.read_csv(self.data_csv, low_memory=False) # Check if this is a dual-split directory dual_split_file = os.path.join(self.split_indices_dir, 'dg_val_indices.csv') if self.split_indices_dir else None # Strategy 0: Load DUAL splits (separate for dG and ddG heads) if self.split_indices_dir and dual_split_file and os.path.exists(dual_split_file): from data_splitting import load_dual_splits print(f"\n[DataModule] Loading DUAL splits from {self.split_indices_dir}") splits = load_dual_splits(self.split_indices_dir) self.use_dual_split = True # Combined training set (union of dG and ddG train indices) train_idx = splits['combined_train'] train_df = data.iloc[train_idx].reset_index(drop=True) # For backward compatibility, use ddG validation as default val set # (since most validation is on mutation data) val_idx = splits['ddg']['val'] val_df = data.iloc[val_idx].reset_index(drop=True) test_idx = splits['ddg']['test'] test_df = data.iloc[test_idx].reset_index(drop=True) # Create separate datasets for each head # CRITICAL: Create separate dG (WT-only) and ddG (MT-only) TRAINING sets # This fixes Stage A WT starvation where WT is diluted to 2.75% in combined_train dg_train_df = data.iloc[splits['dg']['train']].reset_index(drop=True) ddg_train_df = data.iloc[splits['ddg']['train']].reset_index(drop=True) dg_val_df = data.iloc[splits['dg']['val']].reset_index(drop=True) dg_test_df = data.iloc[splits['dg']['test']].reset_index(drop=True) ddg_val_df = data.iloc[splits['ddg']['val']].reset_index(drop=True) ddg_test_df = data.iloc[splits['ddg']['test']].reset_index(drop=True) print(f"\n[DataModule] Creating dG TRAIN dataset ({len(dg_train_df)} WT rows)...") self.dg_train_dataset = AdvancedSiameseDataset(dg_train_df, self.featurizer, self.embedding_dir, augment=False) # Baseline: no augment print(f"[DataModule] Creating ddG TRAIN dataset ({len(ddg_train_df)} MT rows)...") self.ddg_train_dataset = AdvancedSiameseDataset(ddg_train_df, self.featurizer, self.embedding_dir, augment=False) # Baseline: no augment # === BALANCED BENCHMARK OVERRIDE === # If benchmark_indices_dir is provided, use those for ddG val/test instead if self.benchmark_indices_dir and os.path.exists(self.benchmark_indices_dir): print(f"\n[DataModule] Loading BALANCED BENCHMARK indices from {self.benchmark_indices_dir}") # Load ddG benchmark val indices ddg_val_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_val_benchmark_indices.csv') if os.path.exists(ddg_val_bench_file): bench_val_idx = pd.read_csv(ddg_val_bench_file, header=None).iloc[:, 0].values.tolist() ddg_val_df = data.iloc[bench_val_idx].reset_index(drop=True) print(f" ddG val: {len(ddg_val_df)} rows (balanced benchmark)") # Load ddG benchmark test indices ddg_test_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_test_benchmark_indices.csv') if os.path.exists(ddg_test_bench_file): bench_test_idx = pd.read_csv(ddg_test_bench_file, header=None).iloc[:, 0].values.tolist() ddg_test_df = data.iloc[bench_test_idx].reset_index(drop=True) print(f" ddG test: {len(ddg_test_df)} rows (balanced benchmark)") print(f"\n[DataModule] Creating dG val dataset ({len(dg_val_df)} rows)...") # NOTE: Do NOT subsample validation - we want accurate metrics on full set self.dg_val_dataset = AdvancedSiameseDataset( dg_val_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=data # FIX: Use full data for WT lookup (robust to split boundaries) ) print(f"\n[DataModule] Creating dG test dataset ({len(dg_test_df)} rows)...") self.dg_test_dataset = AdvancedSiameseDataset( dg_test_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=data # FIX: Use full data for WT lookup ) print(f"\n[DataModule] Creating ddG val dataset ({len(ddg_val_df)} rows)...") # NOTE: Do NOT subsample validation - we want accurate metrics on full set # cap_k only applies to training DMS data self.ddg_val_dataset = AdvancedSiameseDataset( ddg_val_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=data # FIX: Use full data for WT lookup ) print(f"\n[DataModule] Creating ddG test dataset ({len(ddg_test_df)} rows)...") self.ddg_test_dataset = AdvancedSiameseDataset( ddg_test_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=data # FIX: Use full data for WT lookup ) print(f"\n[DataModule] Dual split datasets created:") print(f" dG train: {len(self.dg_train_dataset)} samples (WT-only for Stage A)") print(f" ddG train: {len(self.ddg_train_dataset)} samples (MT-only)") print(f" dG val: {len(self.dg_val_dataset)} samples") print(f" dG test: {len(self.dg_test_dataset)} samples") print(f" ddG val: {len(self.ddg_val_dataset)} samples") print(f" ddG test: {len(self.ddg_test_dataset)} samples") # Strategy 1: Load pre-computed cluster-based splits (single split) elif self.split_indices_dir and os.path.exists(self.split_indices_dir): from data_splitting import load_split_indices, verify_no_leakage train_idx, val_idx, test_idx = load_split_indices(self.split_indices_dir) train_df = data.iloc[train_idx].reset_index(drop=True) val_df = data.iloc[val_idx].reset_index(drop=True) test_df = data.iloc[test_idx].reset_index(drop=True) # Verify no leakage verify_no_leakage(data, train_idx, val_idx, test_idx) # Strategy 2: Create cluster-based splits on the fly elif self.use_cluster_split: from data_splitting import create_cluster_splits, verify_no_leakage # Create splits directory if needed splits_dir = os.path.join(os.path.dirname(self.data_csv), 'splits') train_idx, val_idx, test_idx = create_cluster_splits( data, train_ratio=self.train_ratio, val_ratio=self.val_ratio, test_ratio=self.test_ratio, random_state=self.random_state, save_dir=splits_dir ) train_df = data.iloc[train_idx].reset_index(drop=True) val_df = data.iloc[val_idx].reset_index(drop=True) test_df = data.iloc[test_idx].reset_index(drop=True) # Strategy 3: Legacy - use 'split' column in CSV else: # must have block1_sequence, block1_mut_positions, block2_sequence, ... bench_df = data[data["split"]=="Benchmark test"].copy() trainval_df = data[data["split"]!="Benchmark test"].copy() if self.num_folds > 1: gkf = GroupKFold(n_splits=self.num_folds) groups = trainval_df["#Pdb"].values folds = list(gkf.split(trainval_df, groups=groups)) train_idx, val_idx = folds[self.fold_index] train_df = trainval_df.iloc[train_idx].reset_index(drop=True) val_df = trainval_df.iloc[val_idx].reset_index(drop=True) else: train_df = trainval_df[trainval_df["split"]=="train"].reset_index(drop=True) val_df = trainval_df[trainval_df["split"]=="val"].reset_index(drop=True) test_df = bench_df print(f"\n[DataModule] Creating TRAIN dataset ({len(train_df)} rows)...") self.train_dataset = AdvancedSiameseDataset( train_df, self.featurizer, self.embedding_dir, augment=False # Baseline: no augment (enable later for antisymmetry) ) print(f"\n[DataModule] Creating VAL dataset ({len(val_df)} rows)...") # Subsampling disabled for v20 ablation to ensure robust Macro-PCC evaluation # (need full diversity of PDB families for honest reporting) self.val_dataset = AdvancedSiameseDataset( val_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=train_df # Pass training set as source for WTs ) print(f"\n[DataModule] Creating TEST dataset ({len(test_df)} rows)...") self.test_dataset = AdvancedSiameseDataset( test_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=train_df # Pass training set as source for WTs (no leakage, WTs are known) ) # FIX: Create separate dg_test and ddg_test datasets for proper test metric logging # This is CRITICAL for sweep runs - without this, test metrics are never computed! if self.dg_test_dataset is None and self.ddg_test_dataset is None: # Determine WT/MT based on Mutation(s)_cleaned column def is_wt_row(row): mut_str = str(row.get('Mutation(s)_cleaned', '')).strip() return mut_str == '' or mut_str.lower() == 'nan' or mut_str == 'WT' # Separate test_df into WT (for dG test) and MT (for ddG test) test_is_wt = test_df.apply(is_wt_row, axis=1) dg_test_df = test_df[test_is_wt].reset_index(drop=True) ddg_test_df = test_df[~test_is_wt].reset_index(drop=True) if len(dg_test_df) > 0: print(f"\n[DataModule] Creating dG TEST dataset ({len(dg_test_df)} WT rows)...") self.dg_test_dataset = AdvancedSiameseDataset( dg_test_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=data # Use full data for WT lookup ) else: print(f"[DataModule] WARNING: No WT rows in test set for dG test dataset!") if len(ddg_test_df) > 0: print(f"\n[DataModule] Creating ddG TEST dataset ({len(ddg_test_df)} MT rows)...") self.ddg_test_dataset = AdvancedSiameseDataset( ddg_test_df, self.featurizer, self.embedding_dir, augment=False, wt_reference_df=data # Use full data for WT lookup ) else: print(f"[DataModule] WARNING: No MT rows in test set for ddG test dataset!") # Log dataset sizes print(f"\nDataset sizes:") print(f" Train: {len(self.train_dataset)} samples") print(f" Val: {len(self.val_dataset)} samples") print(f" Test: {len(self.test_dataset)} samples") if self.dg_test_dataset: print(f" dG Test: {len(self.dg_test_dataset)} samples (WT)") if self.ddg_test_dataset: print(f" ddG Test: {len(self.ddg_test_dataset)} samples (MT)") def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) # Dual-split training dataloaders for separate dG-only (Stage A) and ddG (Stage B) training def dg_train_dataloader(self): """Training dataloader for dG head (WT data only for Stage A pretraining).""" if self.dg_train_dataset is None: return None return DataLoader( self.dg_train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) def ddg_train_dataloader(self): """Training dataloader for ddG head (mutation data for Stage B training).""" if self.ddg_train_dataset is None: return None return DataLoader( self.ddg_train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) # Dual-split dataloaders for separate dG and ddG validation def dg_val_dataloader(self): """Validation dataloader for dG head (WT data only).""" if self.dg_val_dataset is None: return None return DataLoader( self.dg_val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) def dg_test_dataloader(self): """Test dataloader for dG head (WT data only).""" if self.dg_test_dataset is None: return None return DataLoader( self.dg_test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) def ddg_val_dataloader(self): """Validation dataloader for ddG head (mutation data including DMS).""" if self.ddg_val_dataset is None: return None return DataLoader( self.ddg_val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=advanced_collate_fn ) def ddg_test_dataloader(self): """Test dataloader for ddG head (mutation data including DMS).""" if self.ddg_test_dataset is None: return None return DataLoader( self.ddg_test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=advanced_collate_fn )