supanthadey1's picture
Upload modules.py with huggingface_hub
6ef15c4 verified
from typing import Tuple, List, Dict, Optional
from pathlib import Path
import os
import math
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import hashlib
import json
import time
from esm3bedding import ESM3Featurizer
from utils import get_logger
logg = get_logger()
#########################################
# Source Type Mapping #
#########################################
SOURCE_TYPE_MAP = {
# Protein complexes (unique structures)
'PDBbind': 'protein_complex',
'PPIKB': 'protein_complex',
'asd_biomap': 'protein_complex',
'asd_aae': 'protein_complex',
'asd_aatp': 'protein_complex',
'asd_osh': 'protein_complex',
# True mutations
'SKEMPI': 'mutation',
'BindingGym': 'mutation',
'asd_flab_koenig2017': 'mutation', # 1-2aa differences
'asd_flab_warszawski2019': 'mutation', # 1-2aa differences
'asd_flab_rosace2023': 'mutation', # 1-5aa differences
'PEPBI': 'mutation',
# Antibody CDR variants
'asd_abbd': 'antibody_cdr', # 3-14aa CDR differences
'abdesign': 'antibody_cdr',
'asd_flab_hie2022': 'antibody_cdr', # 2-17aa differences
'asd_flab_shanehsazzadeh2023': 'antibody_cdr', # 3-18aa differences
}
SOURCE_TYPE_TO_ID = {'protein_complex': 0, 'mutation': 1, 'antibody_cdr': 2}
DEFAULT_SOURCE_TYPE = 'mutation' # Default for unknown sources
#########################################
# Collate function (Siamese) #
#########################################
def advanced_collate_fn(batch):
mut_c1_list, mut_c2_list, mut_y_list = [], [], []
wt_c1_list, wt_c2_list, wt_y_list = [], [], []
has_valid_wt_list = [] # CRITICAL: Track which samples have REAL WT embeddings (not zeros)
meta_list = []
for data, meta in batch:
(c1, c2, y, cw1, cw2, yw) = data
# mutant
mut_c1_list.append(c1)
mut_c2_list.append(c2)
mut_y_list.append(torch.tensor([y], dtype=torch.float32))
# wildtype
if cw1 is not None and cw2 is not None and yw is not None:
wt_c1_list.append(cw1)
wt_c2_list.append(cw2)
wt_y_list.append(torch.tensor([yw], dtype=torch.float32))
has_valid_wt_list.append(True) # Real WT data available
else:
# fallback if no known WT - ZEROS corrupt ddG signal!
wt_c1_list.append(torch.zeros((1, c1.shape[1])))
wt_c2_list.append(torch.zeros((1, c2.shape[1])))
wt_y_list.append(torch.tensor([0.0], dtype=torch.float32))
has_valid_wt_list.append(False) # INVALID for ddG - would compute mut-0=mut
meta_list.append(meta)
# pad mutant
c1_padded = pad_sequence(mut_c1_list, batch_first=True)
c2_padded = pad_sequence(mut_c2_list, batch_first=True)
B = c1_padded.shape[0]
N1 = c1_padded.shape[1]
N2 = c2_padded.shape[1]
c1_mask_list, c2_mask_list = [], []
for i in range(B):
l1 = mut_c1_list[i].shape[0]
l2 = mut_c2_list[i].shape[0]
m1 = [True]*l1 + [False]*(N1-l1)
m2 = [True]*l2 + [False]*(N2-l2)
c1_mask_list.append(torch.tensor(m1, dtype=torch.bool))
c2_mask_list.append(torch.tensor(m2, dtype=torch.bool))
c1_mask = torch.stack(c1_mask_list, dim=0)
c2_mask = torch.stack(c2_mask_list, dim=0)
y_mut = torch.cat(mut_y_list, dim=0)
# pad wildtype
w1_padded = pad_sequence(wt_c1_list, batch_first=True)
w2_padded = pad_sequence(wt_c2_list, batch_first=True)
N1w = w1_padded.shape[1]
N2w = w2_padded.shape[1]
w1_mask_list, w2_mask_list = [], []
for i in range(B):
l1 = wt_c1_list[i].shape[0]
l2 = wt_c2_list[i].shape[0]
m1 = [True]*l1 + [False]*(N1w-l1)
m2 = [True]*l2 + [False]*(N2w-l2)
w1_mask_list.append(torch.tensor(m1, dtype=torch.bool))
w2_mask_list.append(torch.tensor(m2, dtype=torch.bool))
w1_mask = torch.stack(w1_mask_list, dim=0)
w2_mask = torch.stack(w2_mask_list, dim=0)
y_wt = torch.cat(wt_y_list, dim=0)
has_wt_list = []
is_wt_list = [] # NEW: Track which samples ARE WT (not just have WT reference)
has_dg_list = []
has_ddg_list = [] # Track which samples have valid explicit ddG
has_inferred_ddg_list = [] # NEW: Track which samples have inferred ddG
has_both_list = []
ddg_list = []
ddg_inferred_list = [] # NEW: Inferred ddG values
# DEBUG: Track data consistency
n_has_ddg_true = 0
n_ddg_zero = 0
n_ddg_nan = 0
for i in range(B):
# from meta - use has_any_wt to include both real and inferred WT sequences
has_wt_list.append(meta_list[i].get("has_any_wt", meta_list[i].get("has_real_wt", False)))
is_wt_list.append(meta_list[i].get("is_wt", False)) # NEW: Whether sample IS a WT sample (not mutant)
has_dg_list.append(meta_list[i].get("has_dg", False)) # Default False to prevent false positives
# FIX: Include inferred ddG in has_ddg flag so validation samples with dG_mut and dG_wt are used
has_explicit_ddg = meta_list[i].get("has_ddg", False)
has_inferred_ddg_flag = meta_list[i].get("has_inferred_ddg", False)
# has_ddg should be True if we have EITHER explicit OR inferred ddG
has_ddg_flag = has_explicit_ddg or has_inferred_ddg_flag
has_ddg_list.append(has_ddg_flag)
has_inferred_ddg_list.append(has_inferred_ddg_flag)
has_both_list.append(meta_list[i].get("has_both_dg_ddg", False)) # For symmetric consistency
# FIX: Use explicit ddG if available, otherwise use inferred ddG (dG_mut - dG_wt)
ddg_val = meta_list[i].get("ddg", float('nan'))
ddg_inf_val = meta_list[i].get("ddg_inferred", float('nan'))
is_explicit_nan = ddg_val != ddg_val
is_inferred_nan = ddg_inf_val != ddg_inf_val
# DEBUG: Check for data consistency issues
if has_explicit_ddg:
n_has_ddg_true += 1
if is_explicit_nan:
n_ddg_nan += 1
elif abs(ddg_val) < 1e-8:
n_ddg_zero += 1
# Priority: explicit ddG > inferred ddG > 0.0 fallback (masked out)
if not is_explicit_nan:
ddg_list.append(ddg_val)
elif not is_inferred_nan:
ddg_list.append(ddg_inf_val) # Use inferred ddG when explicit unavailable
else:
ddg_list.append(0.0) # Fallback (will be masked by has_ddg=False)
# Collect inferred ddG values for separate tracking (already fetched above)
ddg_inferred_list.append(ddg_inf_val if not is_inferred_nan else 0.0)
# DEBUG: Log batch statistics if there are issues
if n_has_ddg_true > 0 and (n_ddg_nan > 0 or n_ddg_zero > B // 2):
print(f"[COLLATE DEBUG] Batch has_ddg stats: {n_has_ddg_true}/{B} have has_ddg=True, "
f"{n_ddg_nan} have NaN ddg (BUG!), {n_ddg_zero} have ddg≈0")
has_wt = torch.tensor(has_wt_list, dtype=torch.bool)
has_valid_wt = torch.tensor(has_valid_wt_list, dtype=torch.bool) # CRITICAL: Only True if WT is real (not zeros)
is_wt = torch.tensor(is_wt_list, dtype=torch.bool) # Sample IS a WT sample
has_dg = torch.tensor(has_dg_list, dtype=torch.bool)
has_ddg = torch.tensor(has_ddg_list, dtype=torch.bool)
has_inferred_ddg = torch.tensor(has_inferred_ddg_list, dtype=torch.bool)
has_both_dg_ddg = torch.tensor(has_both_list, dtype=torch.bool)
ddg_labels = torch.tensor(ddg_list, dtype=torch.float32)
ddg_inferred_labels = torch.tensor(ddg_inferred_list, dtype=torch.float32)
# DEBUG: Log WT validity stats for first few batches
n_valid_wt = has_valid_wt.sum().item()
n_has_wt = has_wt.sum().item()
if n_has_wt > 0 and n_valid_wt < n_has_wt:
print(f"[COLLATE DEBUG] WT validity: {n_valid_wt}/{n_has_wt} have valid WT embeddings "
f"({n_has_wt - n_valid_wt} samples have zero-fallback and will be EXCLUDED from ddG training)")
# Collect data_source for per-source metrics
data_source_list = [meta_list[i].get("data_source", "unknown") for i in range(B)]
# Collect source_type_ids for model conditioning
source_type_id_list = []
for i in range(B):
data_src = meta_list[i].get("data_source", "unknown")
source_type = SOURCE_TYPE_MAP.get(data_src, DEFAULT_SOURCE_TYPE)
source_type_id = SOURCE_TYPE_TO_ID[source_type]
source_type_id_list.append(source_type_id)
source_type_ids = torch.tensor(source_type_id_list, dtype=torch.long)
out = {
"mutant": (c1_padded, c1_mask, c2_padded, c2_mask, y_mut),
"wildtype": (w1_padded, w1_mask, w2_padded, w2_mask, y_wt),
"has_wt": has_wt,
"has_valid_wt": has_valid_wt, # CRITICAL: True only if WT embeddings are real (not zeros)
"is_wt": is_wt, # Sample IS a WT sample (for routing to dG head)
"has_dg": has_dg, # Whether samples have absolute dG values
"has_ddg": has_ddg, # Whether samples have valid explicit ddG values
"has_inferred_ddg": has_inferred_ddg, # Whether samples have inferred ddG
"has_both_dg_ddg": has_both_dg_ddg, # For symmetric consistency loss
"ddg_labels": ddg_labels, # Direct ddG labels for BindingGym-style data
"ddg_inferred_labels": ddg_inferred_labels, # Inferred ddG = dG_mut - dG_wt
"data_source": data_source_list, # For per-source validation metrics
"source_type_ids": source_type_ids, # For model conditioning (0=protein_complex, 1=mutation, 2=antibody_cdr)
"metadata": meta_list
}
return out
#########################################
# SiameseDataset (Simplified) #
#########################################
class AdvancedSiameseDataset(Dataset):
"""
Dataset that handles mutation positions with a simple indicator channel.
Reads columns:
#Pdb, block1_sequence, block1_mut_positions, block1_mutations,
block2_sequence, block2_mut_positions, block2_mutations, del_g, ...
"""
def __init__(self, df: pd.DataFrame, featurizer: ESM3Featurizer, embedding_dir: str,
normalize_embeddings=True, augment=False, max_len=1022,
wt_reference_df: pd.DataFrame = None):
super().__init__()
# Store WT reference DF (e.g. training set) for looking up missing WTs
# This enables Implicit ddG (dG_mut - dG_wt) even if WTs are not in the current split
self.wt_reference_df = wt_reference_df if wt_reference_df is not None else None
initial_len = len(df)
# CRITICAL FIX: Do NOT drop rows based on length because it shifts indices!
# External splits (indices) rely on the original row numbers.
# Instead, we TRUNCATE sequences that are too long to maintain alignment.
# Identify long sequences
long_mask = (df["block1_sequence"].astype(str).str.len() > max_len) | \
(df["block2_sequence"].astype(str).str.len() > max_len)
n_long = long_mask.sum()
if n_long > 0:
print(f" [Dataset] Truncating {n_long} samples with length > {max_len} to maintain index alignment (CRITICAL FIX).")
# Truncate sequences in place
# Use .copy() to avoid SettingWithCopyWarning if df is a slice
df = df.copy()
df.loc[long_mask, "block1_sequence"] = df.loc[long_mask, "block1_sequence"].astype(str).str.slice(0, max_len)
df.loc[long_mask, "block2_sequence"] = df.loc[long_mask, "block2_sequence"].astype(str).str.slice(0, max_len)
# No rows dropped, so indices remain aligned with split files
self.df = df.reset_index(drop=True)
#region agent log
try:
cols = set(self.df.columns.tolist())
need = {"block1_mut_positions", "block2_mut_positions", "Mutation(s)_PDB"}
missing = sorted(list(need - cols))
payload = {
"sessionId": "debug-session",
"runId": "pre-fix",
"hypothesisId": "G",
"location": "modules.py:AdvancedSiameseDataset:__init__",
"message": "Dataset columns presence check for mutation positions",
"data": {
"n_rows": int(len(self.df)),
"has_block1_mut_positions": "block1_mut_positions" in cols,
"has_block2_mut_positions": "block2_mut_positions" in cols,
"has_mutation_pdb": "Mutation(s)_PDB" in cols,
"missing": missing,
},
"timestamp": int(time.time() * 1000),
}
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
f.write(json.dumps(payload, default=str) + "\n")
print(f"[AGENTLOG MUTPOSCOLS] missing={missing}")
except Exception:
pass
#endregion
#region agent log
# Disambiguate whether "0 positions" is happening for MUT embeddings or WT embeddings
try:
if not hasattr(self, "_agent_embed_call_counter"):
self._agent_embed_call_counter = 0
if self._agent_embed_call_counter < 10:
self._agent_embed_call_counter += 1
print(
f"[AGENTLOG EMBCALL] idx={idx} role=mut "
f"b1_mutpos_n={len(b1_mutpos)} b2_mutpos_n={len(b2_mutpos)} "
f"seq1_len={len(item.get('seq1',''))} seq2_len={len(item.get('seq2',''))}"
)
except Exception:
pass
#endregion
# Recover antibody WTs (ANTIBODY_MUTATION) before augmentation or indexing
self.df = self._recover_antibody_wts(self.df)
# ---------- OPTIONAL AUGMENT: reverse mutation (mut ↔ WT) ----------
# Only augment MUTANT samples (not WT) - WT samples don't benefit from reversal
# and doubling them confuses the pdb_to_wt lookup
if augment:
# Identify mutant rows (non-empty Mutation(s)_PDB)
mut_mask = self.df["Mutation(s)_PDB"].notna() & (self.df["Mutation(s)_PDB"].str.strip() != "")
mutant_df = self.df[mut_mask].copy()
if len(mutant_df) > 0:
# Create reversed copies of mutant samples only
rev_df = mutant_df.copy()
# For the reverse augmentation we invert the sign of ddg
if "ddg" in rev_df.columns:
rev_df["ddg"] = -rev_df["ddg"]
rev_df["is_reverse"] = True # flag for reversed samples
# Original samples stay as-is
self.df["is_reverse"] = False
self.df = pd.concat([self.df, rev_df], ignore_index=True)
print(f" [Dataset] Augmented: added {len(rev_df)} reversed mutant samples (antisymmetry training)")
else:
self.df["is_reverse"] = False
else:
self.df["is_reverse"] = False
# -------------------------------------------------------------------
# ---------- PAIR ID (mutant – WT) ----------------------------------
# Use PDB + cleaned‑mutation string so mutant and its WT share an ID
self.df["pair_id"] = (
self.df["#Pdb"].astype(str) + "_" +
self.df["Mutation(s)_cleaned"].fillna("") # WT rows have empty mutation
)
# -------------------------------------------------------------------
self.featurizer = featurizer
self.embedding_dir = Path(embedding_dir)
self.embedding_dir.mkdir(exist_ok=True, parents=True)
self.normalize = normalize_embeddings
self.samples = []
self._embedding_cache = {} # LRU-style cache for frequently accessed embeddings
self._cache_max_size = 20000 # Cache up to 20k embeddings (~20-40GB RAM)
self._cache_hits = 0
self._cache_misses = 0
# map each PDB to a wildtype row index if it exists
print(f" [Dataset] Building WT index for {len(self.df)} rows...")
self.pdb_to_wt = {}
for i, row in self.df.iterrows():
pdb = row["#Pdb"]
mut_str = row.get("Mutation(s)_PDB","")
is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
if is_wt and pdb not in self.pdb_to_wt:
self.pdb_to_wt[pdb] = i
# Build external WT map if reference DF is provided
self.external_pdb_to_wt = {}
if self.wt_reference_df is not None:
print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...")
for i, row in self.wt_reference_df.iterrows():
# Only index actual WTs
mut_str = row.get("Mutation(s)_PDB","")
is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
if 'is_wt' in row: # Prioritize pre-computed flag
is_wt = is_wt or row['is_wt']
pdb = row["#Pdb"]
if is_wt and pdb not in self.external_pdb_to_wt:
self.external_pdb_to_wt[pdb] = i
print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.")
# Build external WT map if reference DF is provided
self.external_pdb_to_wt = {}
if self.wt_reference_df is not None:
print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...")
for i, row in self.wt_reference_df.iterrows():
# Only index actual WTs
mut_str = row.get("Mutation(s)_PDB","")
is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
# Also check 'is_wt' column if present
if 'is_wt' in row:
is_wt = is_wt or row['is_wt']
pdb = row["#Pdb"]
if is_wt and pdb not in self.external_pdb_to_wt:
self.external_pdb_to_wt[pdb] = i
print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.")
# LAZY LOADING: Only store metadata, NOT embeddings
# Embeddings will be loaded on-demand in __getitem__
print(f" [Dataset] Building sample metadata for {len(self.df)} rows (lazy loading)...")
from tqdm import tqdm
for i, row in tqdm(self.df.iterrows(), total=len(self.df), desc=" Indexing"):
# RESET computed mutations for this row to prevent stale data from previous iterations
if hasattr(self, '_last_computed_mutpos'):
del self._last_computed_mutpos
pdb = row["#Pdb"]
seq1 = row["block1_sequence"]
seq2 = row["block2_sequence"]
# Data source for per-source validation metrics
data_source = row.get("data_source", "unknown")
# Handle missing dG values (e.g., BindingGym has only ddG)
raw_delg = row["del_g"]
delg = float(raw_delg) if pd.notna(raw_delg) and raw_delg != '' else float('nan')
# Get ddG if available (for ddG-only datasets like BindingGym)
raw_ddg = row.get("ddg", None)
ddg = float(raw_ddg) if pd.notna(raw_ddg) and raw_ddg != '' else float('nan')
# Parse mutations (just store the string, parse later)
b1_mutpos_str = row.get("block1_mut_positions","[]")
b2_mutpos_str = row.get("block2_mut_positions","[]")
# DEBUG: Print first few rows to debug disappearing mutations
if i < 5:
print(f"DEBUG ROW {i}: b1='{b1_mutpos_str}' ({type(b1_mutpos_str)}), b2='{b2_mutpos_str}' ({type(b2_mutpos_str)})")
#region agent log
try:
payload = {
"sessionId": "debug-session",
"runId": "pre-fix",
"hypothesisId": "G",
"location": "modules.py:AdvancedSiameseDataset:__init__:row0_4",
"message": "Raw mutpos strings from df row (first few)",
"data": {
"i": int(i),
"b1_mutpos_str": str(b1_mutpos_str),
"b2_mutpos_str": str(b2_mutpos_str),
"mutation_pdb": str(row.get("Mutation(s)_PDB", "")),
},
"timestamp": int(time.time() * 1000),
}
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
f.write(json.dumps(payload, default=str) + "\n")
print(f"[AGENTLOG MUTPOSRAW] i={i} b1={b1_mutpos_str} b2={b2_mutpos_str} mut={row.get('Mutation(s)_PDB','')}")
except Exception:
pass
#endregion
# Get chain info for block assignment during WT inference
b1_chains = str(row.get("block1_chains", "")).upper()
b2_chains = str(row.get("block2_chains", "")).upper()
mut_str = row.get("Mutation(s)_PDB","")
is_wt = (pd.isna(mut_str) or mut_str.strip()=="")
wt_idx = self.pdb_to_wt.get(pdb, None)
# Get WT info if available (Internal > External)
row_wt = None
wt_source = None
if not hasattr(self, '_wt_source_stats'):
self._wt_source_stats = {'internal': 0, 'external': 0}
if wt_idx is not None:
row_wt = self.df.iloc[wt_idx]
wt_source = 'internal'
self._wt_source_stats['internal'] += 1
elif pdb in self.external_pdb_to_wt:
ext_idx = self.external_pdb_to_wt[pdb]
row_wt = self.wt_reference_df.iloc[ext_idx]
wt_source = 'external'
self._wt_source_stats['external'] += 1
if row_wt is not None:
seq1_wt = row_wt["block1_sequence"]
seq2_wt = row_wt["block2_sequence"]
raw_delg_wt = row_wt["del_g"]
delg_wt = float(raw_delg_wt) if pd.notna(raw_delg_wt) and raw_delg_wt != '' else float('nan')
b1_wtpos_str = row_wt.get("block1_mut_positions","[]")
b2_wtpos_str = row_wt.get("block2_mut_positions","[]")
# BUGFIX: If we have WT but NO mutation positions in CSV, we MUST calculate them!
# This fixes the "0% mutation positions" issue when the CSV column is empty/missing
if not is_wt and (b1_mutpos_str in ["[]", "", "nan", "None"] and b2_mutpos_str in ["[]", "", "nan", "None"]):
# Run inference to locate mutations (side-effect: sets _last_computed_mutpos)
# We ignore the inferred WT sequence since we have the real one
# We pass "[]" to force scanning PDB positions
self._infer_wt_sequences(
seq1, seq2, mut_str, "[]", "[]",
b1_chains, b2_chains
)
# Update mutpos_str if we found mutations
if hasattr(self, '_last_computed_mutpos'):
comp_b1, comp_b2 = self._last_computed_mutpos
if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1:
b1_mutpos_str = str(comp_b1)
if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2:
b2_mutpos_str = str(comp_b2)
else:
# No WT row found - try to INFER WT sequence by reversing mutations
# This is crucial for BindingGym data which stores mutant sequences only
seq1_wt, seq2_wt = self._infer_wt_sequences(
seq1, seq2, mut_str, b1_mutpos_str, b2_mutpos_str,
b1_chains, b2_chains # Chain info for block assignment
)
delg_wt = float('nan') # No WT dG available for inferred sequences
b1_wtpos_str, b2_wtpos_str = "[]", "[]" # WT has no mutation positions
# FIX Bug #3: Use computed mutation positions from inference if original empty
if hasattr(self, '_last_computed_mutpos'):
comp_b1, comp_b2 = self._last_computed_mutpos
if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1:
b1_mutpos_str = str(comp_b1)
if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2:
b2_mutpos_str = str(comp_b2)
# Check if this sample has BOTH dG and ddG (for symmetric consistency)
has_dg = not (delg != delg) # False if NaN
has_ddg = not (ddg != ddg) # False if NaN
has_both = has_dg and has_ddg
# NEW: Compute inferred ddG for samples with dG_mut and dG_wt but no explicit ddG
# ddG_inferred = dG_mut - dG_wt (can be used as additional training signal)
has_dg_wt = not (delg_wt != delg_wt) # False if NaN
has_inferred_ddg = has_dg and has_dg_wt and (not has_ddg) # Only if no explicit ddG
if has_inferred_ddg:
ddg_inferred = delg - delg_wt # Computed from dG values
else:
ddg_inferred = float('nan')
# Track WT availability: real (from row), inferred, or none
has_real_wt = (wt_idx is not None)
has_inferred_wt = (wt_idx is None and seq1_wt is not None and seq2_wt is not None)
has_any_wt = has_real_wt or has_inferred_wt
# Store ONLY metadata - no embeddings loaded yet!
is_reverse = row.get("is_reverse", False) # Track reversed samples
# CRITICAL: Swap sequences and dG for reversed samples (antisymmetry augmentation)
if is_reverse:
# Swap sequences: New Mutant = Old WT, New WT = Old Mutant
if seq1_wt is not None and seq2_wt is not None:
seq1, seq1_wt = seq1_wt, seq1
seq2, seq2_wt = seq2_wt, seq2
# Swap dG values
delg, delg_wt = delg_wt, delg
# Negate inferred ddG (dG_new_mut - dG_new_wt = dG_old_wt - dG_old_mut = -(dG_old_mut - dG_old_wt))
if not math.isnan(ddg_inferred):
ddg_inferred = -ddg_inferred
# Note: Explicit 'ddg' is already negated in __init__ augmentation logic
# Note: We do NOT swap mutation positions because the indices of difference
# are the same for A->B vs B->A. We want the 'input' (new mutant) to have
# the indicator flags at the difference sites.
self.samples.append({
"pdb": pdb,
"is_wt": is_wt,
"is_reverse": is_reverse, # True if this is a reversed (augmented) sample
"seq1": seq1, "seq2": seq2, "delg": delg,
"seq1_wt": seq1_wt, "seq2_wt": seq2_wt, "delg_wt": delg_wt,
"ddg": ddg,
"ddg_inferred": ddg_inferred, # NEW: Computed from dG_mut - dG_wt
"has_dg": has_dg,
"has_ddg": has_ddg,
"has_inferred_ddg": has_inferred_ddg, # NEW: True if ddg_inferred is valid
"has_both_dg_ddg": has_both,
"has_real_wt": has_real_wt,
"has_inferred_wt": has_inferred_wt,
"has_any_wt": has_any_wt,
"b1_mutpos_str": b1_mutpos_str,
"b2_mutpos_str": b2_mutpos_str,
"b1_wtpos_str": b1_wtpos_str,
"b2_wtpos_str": b2_wtpos_str,
"data_source": data_source
})
# Log WT inference statistics
n_real_wt = sum(1 for s in self.samples if s["has_real_wt"])
n_inferred_wt = sum(1 for s in self.samples if s["has_inferred_wt"])
n_no_wt = len(self.samples) - n_real_wt - n_inferred_wt
# Detailed stats for Real WTs (Internal vs External)
if hasattr(self, '_wt_source_stats'):
n_internal = self._wt_source_stats.get('internal', 0)
n_external = self._wt_source_stats.get('external', 0)
source_msg = f" (Internal: {n_internal}, External: {n_external})"
else:
source_msg = ""
print(f" [Dataset] Ready! {len(self.samples)} samples indexed (embeddings loaded on-demand)")
print(f" [Dataset] WT stats: {n_real_wt} real WT{source_msg}, {n_inferred_wt} inferred WT, {n_no_wt} no WT")
# Log detailed failure breakdown (for debugging)
if hasattr(self, '_wt_inference_failures') and hasattr(self, '_wt_inference_fail_count'):
print(f" [Dataset] ⚠️ WT inference failed for {self._wt_inference_fail_count} samples:")
fail_dict = self._wt_inference_failures
# Count by category (note: these are capped sample counts, not totals)
n_no_pdb = len(fail_dict.get('no_pdb', []))
n_del_ins = len(fail_dict.get('del_ins_only', []))
n_parse = len(fail_dict.get('parse_fail', []))
if n_no_pdb > 0:
print(f" - ANTIBODY samples (no PDB structure): {self._wt_inference_fail_count} samples")
print(f" (These are antibody design samples without original PDB - only dG usable)")
elif n_del_ins > 0 or n_parse > 0:
print(f" - DEL/INS/stop-codon (can't reverse): counted")
print(f" - Parsing failed (unknown format): counted")
# Show samples for non-ANTIBODY failures
if fail_dict.get('parse_fail') and n_no_pdb == 0:
print(f" Sample parse failures:")
for mut in fail_dict['parse_fail'][:5]:
print(f" '{mut}'")
def _parse_mutpos(self, pos_str) -> List[int]:
"""
pos_str might be '[]' or '[170, 172]' etc.
We'll do a simple parse.
"""
# Handle NaN, None, or non-string values
if pos_str is None or (isinstance(pos_str, float) and str(pos_str) == 'nan'):
return []
if not isinstance(pos_str, str):
pos_str = str(pos_str)
pos_str = pos_str.strip()
if pos_str.startswith("[") and pos_str.endswith("]"):
inside = pos_str[1:-1].strip()
if not inside:
return []
# split by comma
arr = inside.split(",")
out = []
for x in arr:
x_ = x.strip()
if x_:
out.append(int(x_))
return out
return []
def _recover_antibody_wts(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Recover WT information for antibody samples (ANTIBODY_MUTATION)
by finding the closest-to-consensus sequence in each antigen group.
Strategy:
1. Identify samples with 'ANTIBODY_MUTATION'
2. Group by antigen (block2_sequence)
3. Assign unique Pseudo-PDB ID to each group (e.g. ANTIBODY_GRP_xxx)
4. For same-length groups: find sequence closest to consensus as WT
5. For variable-length groups: fallback to best binder (lowest del_g)
6. Mark selected sequence as WT (clear mutation string)
"""
from collections import Counter
# Identify antibody mutation rows
mask = df['Mutation(s)_PDB'].astype(str).str.contains('ANTIBODY_MUTATION', na=False)
if not mask.any():
return df
print(f" [Dataset] Attempting to recover WT for {mask.sum()} antibody samples...")
recovered_count = 0
n_groups = 0
n_consensus = 0
n_median = 0
n_fallback = 0
# We need a copy to avoid SettingWithCopy warnings if df is a slice
df = df.copy()
# Add a temporary column for grouping (hash of antigen sequence)
df['temp_antigen_hash'] = df['block2_sequence'].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest())
# Get hashes for antibody rows
ab_hashes = df.loc[mask, 'temp_antigen_hash'].unique()
for h in ab_hashes:
# Get all antibody rows for this antigen
grp_mask = mask & (df['temp_antigen_hash'] == h)
grp_indices = df.index[grp_mask]
if len(grp_indices) == 0:
continue
n_groups += 1
# 1. Create unique Pseudo-PDB ID
pseudo_pdb = f"ANTIBODY_GRP_{h[:8]}"
df.loc[grp_indices, '#Pdb'] = pseudo_pdb
# 2. Select WT: closest-to-consensus (same-length) or best-binder (variable-length)
seqs = df.loc[grp_indices, 'block1_sequence'].tolist()
seq_lens = set(len(s) for s in seqs)
wt_idx = None
if len(seq_lens) == 1:
# SAME LENGTH: Use closest-to-consensus
seq_len = list(seq_lens)[0]
# Build consensus sequence
consensus = []
for pos in range(seq_len):
residues = [s[pos] for s in seqs]
counts = Counter(residues)
most_common = counts.most_common(1)[0][0]
consensus.append(most_common)
consensus_seq = ''.join(consensus)
# Find sequence with minimum Hamming distance to consensus
min_dist = float('inf')
for idx in grp_indices:
seq = df.at[idx, 'block1_sequence']
dist = sum(c1 != c2 for c1, c2 in zip(seq, consensus_seq))
if dist < min_dist:
min_dist = dist
wt_idx = idx
n_consensus += 1
else:
# VARIABLE LENGTH: Fallback to median binder (more representative than best)
if 'del_g' in df.columns:
delg_vals = pd.to_numeric(df.loc[grp_indices, 'del_g'], errors='coerce').dropna()
if len(delg_vals) > 0:
# Find index of value closest to median
median_val = delg_vals.median()
median_idx = (delg_vals - median_val).abs().idxmin()
wt_idx = median_idx
n_median += 1
# FINAL FALLBACK: Pick first sample if no other method works (e.g., all NaN dG)
if wt_idx is None and len(grp_indices) > 0:
wt_idx = grp_indices[0]
n_fallback += 1
# 3. Mark selected sequence as WT
if wt_idx is not None:
df.at[wt_idx, 'Mutation(s)_PDB'] = ""
recovered_count += len(grp_indices)
# Cleanup
df.drop(columns=['temp_antigen_hash'], inplace=True, errors='ignore')
print(f" [Dataset] Recovered {recovered_count} antibody samples ({n_groups} groups):")
print(f" - {n_consensus} groups via closest-to-consensus")
print(f" - {n_median} groups via median-binder (variable-length)")
if n_fallback > 0:
print(f" - {n_fallback} groups via first-sample fallback (no dG data)")
return df
def _infer_wt_sequences(self, mut_seq1: str, mut_seq2: str, mutation_str: str,
b1_mutpos_str: str, b2_mutpos_str: str,
b1_chains: str = "", b2_chains: str = "") -> Tuple[Optional[str], Optional[str]]:
"""
Infer wildtype sequences by reversing mutations in the mutant sequences.
IMPROVED: Instead of relying on PDB positions (which don't match 0-indexed
sequence positions), this version searches for the mutant residue and
reverses it. Also computes actual mutation positions as byproduct.
Mutations are in formats like:
- BindingGym: "H:P53L" or "H:P53L,H:Y57C" (chain:WTresPOSmutres)
- SKEMPI: "HP53L" or "CA182A" (chainWTresPOSmutres)
Args:
mut_seq1: Mutant sequence for block1
mut_seq2: Mutant sequence for block2
mutation_str: Raw mutation string from data
b1_mutpos_str: Mutation positions for block1 (e.g., "[52, 56]")
b2_mutpos_str: Mutation positions for block2
b1_chains: Chain letters in block1 (e.g., "AB")
b2_chains: Chain letters in block2 (e.g., "HL")
Returns:
Tuple of (wt_seq1, wt_seq2) or (None, None) if inference fails
"""
import re
if pd.isna(mutation_str) or str(mutation_str).strip() == '':
# No mutations = this IS the wildtype
return mut_seq1, mut_seq2
# FALLBACK: Handle ANTIBODY_MUTATION samples that couldn't be recovered
mutation_str_upper = str(mutation_str).strip().upper()
if 'ANTIBODY_MUTATION' in mutation_str_upper or mutation_str_upper == 'ANTIBODY_MUTATION':
if not hasattr(self, '_wt_inference_failures'):
self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'no_pdb': [], 'other': []}
self._wt_inference_fail_count = 0
self._wt_inference_fail_count += 1
if len(self._wt_inference_failures['no_pdb']) < 5:
self._wt_inference_failures['no_pdb'].append(mutation_str[:80])
return None, None
try:
# Parse mutation string to extract (chain, position, original_AA, mutant_AA)
mutations = []
mutation_str = str(mutation_str).strip()
# Split by common delimiters
parts = re.split(r'[,;]', mutation_str)
for part in parts:
part = part.strip().strip('"\'')
if not part:
continue
# Skip deletion/insertion markers - can't reverse these
if 'DEL' in part.upper() or 'INS' in part.upper() or '*' in part:
continue
# BindingGym format: "H:P53L" or "L:K103R"
if ':' in part:
chain_mut = part.split(':')
if len(chain_mut) >= 2:
chain = chain_mut[0].strip().upper()
for mut_part in chain_mut[1:]:
mut_part = mut_part.strip()
if not mut_part:
continue
match = re.match(r'([A-Z])(\d+)([A-Z])', mut_part)
if match:
wt_aa = match.group(1)
pos = int(match.group(2)) # PDB-numbered (1-indexed)
mut_aa = match.group(3)
mutations.append((chain, pos, wt_aa, mut_aa))
else:
# SKEMPI format: "CA182A" = C(WTresidue) + A(chain) + 182(pos) + A(mutant)
# Format: WTresidue + ChainID + Position[insertcode] + MutResidue
# Example: CA182A means Cysteine at chain A position 182 mutated to Alanine
match = re.match(r'([A-Z])([A-Z])(-?\d+[a-z]?)([A-Z])', part)
if match:
wt_aa = match.group(1) # First char is WT residue
chain = match.group(2).upper() # Second char is chain ID
pos_str = match.group(3)
pos = int(re.match(r'-?\d+', pos_str).group())
mut_aa = match.group(4) # Last char is mutant residue
mutations.append((chain, pos, wt_aa, mut_aa))
else:
# Simple format without chain: "F139A" (used by PEPBI)
# Format: WTresidue + Position + MutResidue
match = re.match(r'([A-Z])(\d+)([A-Z])', part)
if match:
wt_aa = match.group(1)
pos = int(match.group(2))
mut_aa = match.group(3)
# No chain info - will try both blocks
mutations.append(('?', pos, wt_aa, mut_aa))
if not mutations:
if not hasattr(self, '_wt_inference_failures'):
self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'other': []}
self._wt_inference_fail_count = 0
self._wt_inference_fail_count += 1
if 'DEL' in mutation_str.upper() or 'INS' in mutation_str.upper() or '*' in mutation_str:
category = 'del_ins_only'
else:
category = 'parse_fail'
if len(self._wt_inference_failures.get(category, [])) < 10:
self._wt_inference_failures.setdefault(category, []).append(mutation_str[:80])
return None, None
# Convert sequences to lists for mutation
wt_seq1_list = list(mut_seq1) if mut_seq1 else []
wt_seq2_list = list(mut_seq2) if mut_seq2 else []
# Build chain sets for block assignment
b1_chain_set = set(b1_chains.upper()) if b1_chains else set()
b2_chain_set = set(b2_chains.upper()) if b2_chains else set()
# Parse PRECOMPUTED mutation positions (these are correct 0-indexed seq positions)
# PDB residue numbers often don't match sequence indices due to numbering offsets
precomputed_b1_positions = self._parse_mutpos(b1_mutpos_str)
precomputed_b2_positions = self._parse_mutpos(b2_mutpos_str)
# Track reversal success
if not hasattr(self, '_wt_inference_stats'):
self._wt_inference_stats = {'reversed': 0, 'not_found': 0, 'total': 0}
# Also track actual mutation positions found
found_positions_b1 = []
found_positions_b2 = []
# STRATEGY 1: Use precomputed positions if available (MOST RELIABLE)
# These were computed during preprocessing with correct PDB-to-sequence mapping
if precomputed_b1_positions or precomputed_b2_positions:
pos_idx = 0
for chain, pdb_pos, wt_aa, mut_aa in mutations:
self._wt_inference_stats['total'] += 1
reversed_this = False
# Determine which block based on chain
if chain in b2_chain_set:
# Use precomputed block2 positions
if pos_idx < len(precomputed_b2_positions):
seq_idx = precomputed_b2_positions[pos_idx]
if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa:
wt_seq2_list[seq_idx] = wt_aa
reversed_this = True
found_positions_b2.append(seq_idx)
elif chain in b1_chain_set:
# Use precomputed block1 positions
if pos_idx < len(precomputed_b1_positions):
seq_idx = precomputed_b1_positions[pos_idx]
if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa:
wt_seq1_list[seq_idx] = wt_aa
reversed_this = True
found_positions_b1.append(seq_idx)
else:
# Chain unknown - try both precomputed positions
if pos_idx < len(precomputed_b1_positions):
seq_idx = precomputed_b1_positions[pos_idx]
if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa:
wt_seq1_list[seq_idx] = wt_aa
reversed_this = True
found_positions_b1.append(seq_idx)
if not reversed_this and pos_idx < len(precomputed_b2_positions):
seq_idx = precomputed_b2_positions[pos_idx]
if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa:
wt_seq2_list[seq_idx] = wt_aa
reversed_this = True
found_positions_b2.append(seq_idx)
if reversed_this:
self._wt_inference_stats['reversed'] += 1
else:
self._wt_inference_stats['not_found'] += 1
pos_idx += 1
self._last_computed_mutpos = (found_positions_b1, found_positions_b2)
return ''.join(wt_seq1_list), ''.join(wt_seq2_list)
# STRATEGY 2: Fall back to PDB position-based search (less reliable)
for chain, pdb_pos, wt_aa, mut_aa in mutations:
self._wt_inference_stats['total'] += 1
reversed_this = False
found_idx = None
# Determine which block(s) to search based on chain
chain_known = chain in b1_chain_set or chain in b2_chain_set
if chain in b1_chain_set:
blocks_to_try = [(wt_seq1_list, True, found_positions_b1)]
elif chain in b2_chain_set:
blocks_to_try = [(wt_seq2_list, False, found_positions_b2)]
else:
# Chain info unavailable - try BOTH blocks
blocks_to_try = [
(wt_seq1_list, True, found_positions_b1),
(wt_seq2_list, False, found_positions_b2)
]
for target_seq, is_block1, pos_list in blocks_to_try:
if reversed_this:
break # Already found in previous block
guess_idx = pdb_pos - 1 # Convert to 0-indexed
# Strategy 1: Try exact position if in bounds
if 0 <= guess_idx < len(target_seq) and target_seq[guess_idx] == mut_aa:
found_idx = guess_idx
else:
# Strategy 2: Search ±50 window around expected position
search_start = max(0, pdb_pos - 50)
search_end = min(len(target_seq), pdb_pos + 50)
for idx in range(search_start, search_end):
if target_seq[idx] == mut_aa:
found_idx = idx
break
# Strategy 3: If position was out of bounds AND chain unknown,
# search the ENTIRE sequence as last resort
if found_idx is None and not chain_known:
if guess_idx >= len(target_seq) or guess_idx < 0:
# Position was out of bounds - search entire sequence
for idx in range(len(target_seq)):
if target_seq[idx] == mut_aa:
found_idx = idx
break
if found_idx is not None:
target_seq[found_idx] = wt_aa # Reverse the mutation!
reversed_this = True
pos_list.append(found_idx)
if reversed_this:
self._wt_inference_stats['reversed'] += 1
else:
self._wt_inference_stats['not_found'] += 1
# Store computed mutation positions for later use (helps with Bug #3)
# These are the ACTUAL 0-indexed positions in the sequence
self._last_computed_mutpos = (found_positions_b1, found_positions_b2)
return ''.join(wt_seq1_list), ''.join(wt_seq2_list)
except Exception as e:
# On any error, return None to indicate inference failed
return None, None
def _get_embedding(self, seq: str, mut_positions: List[int]) -> torch.Tensor:
"""
Basic embedding with mutation position indicator channel.
Args:
seq: The protein sequence
mut_positions: List of positions that are mutated (0-indexed)
"""
# Get base ESM embedding (already ensures min length of 2)
base_emb = self._get_or_create_embedding(seq) # => [L, 1152]
base_emb = base_emb.cpu()
# Get sequence length and embedding dimension
L, D = base_emb.shape
#region agent log
try:
if not hasattr(self, "_agent_log_counter"):
self._agent_log_counter = 0
if self._agent_log_counter < 5:
self._agent_log_counter += 1
last1_stats = None
last2_stats = None
if D >= 1153:
v1 = base_emb[:, -1]
last1_stats = {
"min": float(v1.min().item()),
"max": float(v1.max().item()),
"mean": float(v1.float().mean().item()),
"std": float(v1.float().std().item()),
}
if D >= 1154:
v2 = base_emb[:, -2]
last2_stats = {
"min": float(v2.min().item()),
"max": float(v2.max().item()),
"mean": float(v2.float().mean().item()),
"std": float(v2.float().std().item()),
}
payload = {
"sessionId": "debug-session",
"runId": "pre-fix",
"hypothesisId": "F",
"location": "modules.py:AdvancedSiameseDataset:_get_embedding",
"message": "Base embedding shape + tail-channel stats before appending mutation indicator",
"data": {
"L": int(L),
"D": int(D),
"mut_positions_n": int(len(mut_positions) if mut_positions is not None else -1),
"mut_positions_first5": (mut_positions[:5] if mut_positions else []),
"base_last1": last1_stats,
"base_last2": last2_stats,
},
"timestamp": int(time.time() * 1000),
}
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
f.write(json.dumps(payload, default=str) + "\n")
# Also emit a concise line to stdout/logs (useful on cluster runs)
print(f"[AGENTLOG EMB] D={D} mut_n={len(mut_positions) if mut_positions else 0} last1={last1_stats} last2={last2_stats}")
except Exception:
pass
#endregion
# Create mutation indicator channel (just one channel)
# FIX FOR DOUBLE-INDICATOR BUG: Check if base_emb already has indicator (D=1153)
# If D=1153, the cached embedding already has an old indicator channel - OVERWRITE it
# If D=1152, this is a fresh ESM embedding - APPEND indicator channel
D = base_emb.shape[-1]
L = base_emb.shape[0]
if D == 1153:
# Already has indicator channel (from cache) - overwrite it with correct mutation positions
new_emb = base_emb.clone()
new_emb[:, -1] = 0.0 # Reset old indicator
for pos in mut_positions:
if isinstance(pos, int) and 0 <= pos < L:
new_emb[pos, -1] = 1.0
print(f"[AGENTLOG INDICATOR-FIX] D=1153 OVERWRITING last channel with {len(mut_positions)} positions")
else:
# Fresh ESM embedding (D=1152) - append indicator channel
chan = torch.zeros((L, 1), dtype=base_emb.dtype, device=base_emb.device)
for pos in mut_positions:
if isinstance(pos, int) and 0 <= pos < L:
chan[pos, 0] = 1.0
new_emb = torch.cat([base_emb, chan], dim=-1)
print(f"[AGENTLOG INDICATOR-FIX] D={D} APPENDING indicator channel with {len(mut_positions)} positions")
return new_emb
def _get_or_create_embedding(self, seq: str) -> torch.Tensor:
# Check LRU cache first (limited size to control memory)
if seq in self._embedding_cache:
self._cache_hits += 1
return self._embedding_cache[seq].clone()
seq_hash = hashlib.md5(seq.encode()).hexdigest()
pt_file = self.embedding_dir / f"{seq_hash}.pt"
npy_file = self.embedding_dir / f"{seq_hash}.npy"
emb = None
load_source = None # Track where embedding came from
# Try .npy first (pre-computed), then .pt
if npy_file.is_file():
try:
import numpy as np
emb = torch.from_numpy(np.load(npy_file))
load_source = "npy"
except Exception:
pass
if emb is None and pt_file.is_file():
try:
emb = torch.load(pt_file, map_location="cpu")
load_source = "pt"
except Exception:
pt_file.unlink(missing_ok=True) # Delete corrupted file
if emb is None:
# On-the-fly embedding generation for missing sequences (e.g., inferred WT)
# This is slower but ensures accurate embeddings
try:
emb = self.featurizer.transform(seq) # [L, 1152]
# Save for future use
torch.save(emb, pt_file)
load_source = "generated"
# Track on-the-fly generation stats
if not hasattr(self, '_on_the_fly_count'):
self._on_the_fly_count = 0
self._on_the_fly_count += 1
# Log first few on-the-fly generations
if self._on_the_fly_count <= 5:
print(f"[EMBEDDING] Generated on-the-fly #{self._on_the_fly_count}: len={len(seq)}, saved to {pt_file.name}")
elif self._on_the_fly_count == 6:
print(f"[EMBEDDING] Generated 5+ embeddings on-the-fly (suppressing further logs)")
except Exception as e:
raise RuntimeError(
f"Embedding not found and on-the-fly generation failed for sequence (len={len(seq)}): {e}"
)
#region agent log
try:
if not hasattr(self, "_agent_embload_counter"):
self._agent_embload_counter = 0
if self._agent_embload_counter < 8:
self._agent_embload_counter += 1
shape = tuple(int(x) for x in emb.shape)
D = int(shape[1]) if len(shape) == 2 else None
payload = {
"sessionId": "debug-session",
"runId": "pre-fix",
"hypothesisId": "A",
"location": "modules.py:AdvancedSiameseDataset:_get_or_create_embedding",
"message": "Loaded embedding tensor (source + shape) before any indicator is appended",
"data": {
"load_source": load_source,
"seq_len": int(len(seq)),
"shape": shape,
"D": D,
"looks_like_has_indicator": bool(D is not None and D >= 1153),
"file_pt_exists": bool(pt_file.is_file()),
"file_npy_exists": bool(npy_file.is_file()),
},
"timestamp": int(time.time() * 1000),
}
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
f.write(json.dumps(payload, default=str) + "\n")
print(f"[AGENTLOG EMBLOAD] src={load_source} shape={shape} D={D}")
except Exception:
pass
#endregion
# SAFETY: Ensure embedding has valid shape (at least 5 residues for interpolation)
if emb.shape[0] < 5:
# Pad to minimum length of 5 by repeating
repeats = (5 // emb.shape[0]) + 1
emb = emb.repeat(repeats, 1)[:5] # Ensure exactly 5 rows
# Track cache miss
self._cache_misses += 1
# Add to LRU cache (evict oldest if full)
if len(self._embedding_cache) >= self._cache_max_size:
# Remove oldest entry (first key in dict)
oldest_key = next(iter(self._embedding_cache))
del self._embedding_cache[oldest_key]
self._embedding_cache[seq] = emb
return emb.clone() # Return clone to avoid mutation issues
def get_cache_stats(self):
"""Return cache statistics."""
total = self._cache_hits + self._cache_misses
hit_rate = (self._cache_hits / total * 100) if total > 0 else 0
on_the_fly = getattr(self, '_on_the_fly_count', 0)
wt_missing = getattr(self, '_wt_missing_count', 0)
return {
"hits": self._cache_hits,
"misses": self._cache_misses,
"total": total,
"hit_rate": hit_rate,
"cache_size": len(self._embedding_cache),
"cache_max": self._cache_max_size,
"on_the_fly_generated": on_the_fly,
"wt_embedding_failed": wt_missing
}
def print_cache_stats(self):
"""Print cache statistics."""
stats = self.get_cache_stats()
print(f" [Cache] Hits: {stats['hits']:,} | Misses: {stats['misses']:,} | "
f"Hit Rate: {stats['hit_rate']:.1f}% | Size: {stats['cache_size']:,}/{stats['cache_max']:,}")
if stats['on_the_fly_generated'] > 0:
print(f" [Cache] On-the-fly generated: {stats['on_the_fly_generated']:,} embeddings")
if stats['wt_embedding_failed'] > 0:
print(f" [Cache] ⚠️ WT embedding failures: {stats['wt_embedding_failed']:,} (excluded from ddG training)")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
item = self.samples[idx]
# DEBUG: Track sequence difference statistics
if not hasattr(self, '_seq_diff_stats'):
self._seq_diff_stats = {'same': 0, 'different': 0, 'no_wt': 0}
if not hasattr(self, '_mutpos_stats'):
self._mutpos_stats = {'has_mutpos': 0, 'no_mutpos': 0}
# LAZY LOADING: Load embeddings on-demand
b1_mutpos = self._parse_mutpos(item["b1_mutpos_str"])
b2_mutpos = self._parse_mutpos(item["b2_mutpos_str"])
#region agent log
try:
if not hasattr(self, "_agent_mutpos_getitem_counter"):
self._agent_mutpos_getitem_counter = 0
if self._agent_mutpos_getitem_counter < 20:
self._agent_mutpos_getitem_counter += 1
payload = {
"sessionId": "debug-session",
"runId": "pre-fix",
"hypothesisId": "G",
"location": "modules.py:AdvancedSiameseDataset:__getitem__",
"message": "Parsed mut_positions passed to _get_embedding",
"data": {
"idx": int(idx),
"pdb": str(item.get("pdb")),
"is_wt": bool(item.get("is_wt")),
"b1_mutpos_str": str(item.get("b1_mutpos_str")),
"b2_mutpos_str": str(item.get("b2_mutpos_str")),
"b1_mutpos_n": int(len(b1_mutpos)),
"b2_mutpos_n": int(len(b2_mutpos)),
"b1_mutpos_first5": b1_mutpos[:5],
"b2_mutpos_first5": b2_mutpos[:5],
},
"timestamp": int(time.time() * 1000),
}
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f:
f.write(json.dumps(payload, default=str) + "\n")
print(f"[AGENTLOG MUTPOSGET] idx={idx} b1n={len(b1_mutpos)} b2n={len(b2_mutpos)} b1str={item.get('b1_mutpos_str')} b2str={item.get('b2_mutpos_str')}")
except Exception:
pass
#endregion
# Track mutation position statistics
if len(b1_mutpos) > 0 or len(b2_mutpos) > 0:
self._mutpos_stats['has_mutpos'] += 1
else:
self._mutpos_stats['no_mutpos'] += 1
# Log mutation position stats periodically
total = sum(self._mutpos_stats.values())
if total in [100, 1000, 10000]:
has_mp = self._mutpos_stats['has_mutpos']
no_mp = self._mutpos_stats['no_mutpos']
print(f" [MUTPOS] After {total} samples: {has_mp} have mutation positions ({100*has_mp/total:.1f}%), "
f"{no_mp} have NO mutation positions ({100*no_mp/total:.1f}%)")
c1_emb = self._get_embedding(item["seq1"], b1_mutpos)
c2_emb = self._get_embedding(item["seq2"], b2_mutpos)
if self.normalize:
c1_emb[:, :-1] = torch.nn.functional.normalize(c1_emb[:, :-1], p=2, dim=-1)
c2_emb[:, :-1] = torch.nn.functional.normalize(c2_emb[:, :-1], p=2, dim=-1)
# Load WT embeddings if available
if item["seq1_wt"] is not None:
# DEBUG: Track sequence differences
seq1_same = (item["seq1"] == item["seq1_wt"])
seq2_same = (item["seq2"] == item["seq2_wt"])
if seq1_same and seq2_same:
self._seq_diff_stats['same'] += 1
else:
self._seq_diff_stats['different'] += 1
# Periodic logging
total_samples = sum(self._seq_diff_stats.values())
if total_samples in [100, 1000, 10000, 50000]:
same = self._seq_diff_stats['same']
diff = self._seq_diff_stats['different']
no_wt = self._seq_diff_stats['no_wt']
print(f" [SEQ DIFF] After {total_samples} samples: {same} same seq ({100*same/total_samples:.1f}%), "
f"{diff} different ({100*diff/total_samples:.1f}%), {no_wt} no WT")
b1_wtpos = self._parse_mutpos(item["b1_wtpos_str"])
b2_wtpos = self._parse_mutpos(item["b2_wtpos_str"])
#region agent log
try:
if not hasattr(self, "_agent_embed_call_counter_wt"):
self._agent_embed_call_counter_wt = 0
if self._agent_embed_call_counter_wt < 10:
self._agent_embed_call_counter_wt += 1
print(
f"[AGENTLOG EMBCALL] idx={idx} role=wt "
f"b1_wtpos_n={len(b1_wtpos)} b2_wtpos_n={len(b2_wtpos)} "
f"seq1_wt_len={len(item.get('seq1_wt','') or '')} seq2_wt_len={len(item.get('seq2_wt','') or '')}"
)
except Exception:
pass
#endregion
try:
cw1 = self._get_embedding(item["seq1_wt"], b1_wtpos)
cw2 = self._get_embedding(item["seq2_wt"], b2_wtpos)
except RuntimeError as e:
# WT embedding unavailable - mark as no WT for this sample
# DO NOT use mutant embedding as proxy - this corrupts the mutation signal!
# Instead, set cw1, cw2 to None and let training handle missing WT
cw1, cw2 = None, None
if not hasattr(self, '_wt_missing_count'):
self._wt_missing_count = 0
self._wt_missing_count += 1
if self._wt_missing_count <= 3: # Only log first 3 to avoid spam
print(f" [WARN] WT embedding missing #{self._wt_missing_count}, sample will be WT-less: {e}")
if cw1 is not None and self.normalize:
cw1[:, :-1] = torch.nn.functional.normalize(cw1[:, :-1], p=2, dim=-1)
cw2[:, :-1] = torch.nn.functional.normalize(cw2[:, :-1], p=2, dim=-1)
else:
cw1, cw2 = None, None
self._seq_diff_stats['no_wt'] += 1
data_tuple = (c1_emb, c2_emb, item["delg"],
cw1, cw2, item["delg_wt"])
meta = {
"pdb": item["pdb"],
"is_wt": item["is_wt"],
"has_real_wt": item["has_real_wt"],
"has_dg": item["has_dg"],
"has_ddg": item["has_ddg"], # Whether sample has valid explicit ddG value
"has_inferred_ddg": item["has_inferred_ddg"], # Whether sample has inferred ddG (dG_mut - dG_wt)
"has_both_dg_ddg": item["has_both_dg_ddg"],
"ddg": item["ddg"],
"ddg_inferred": item["ddg_inferred"], # Inferred ddG value (needed for Fix #1)
"has_any_wt": item["has_any_wt"], # Include inferred WT status (CRITICAL!)
"b1_mutpos": b1_mutpos,
"b2_mutpos": b2_mutpos,
"data_source": item["data_source"]
}
return (data_tuple, meta)
#########################################
# AffinityDataModule
#########################################
from sklearn.model_selection import GroupKFold
class AffinityDataModule(pl.LightningDataModule):
"""
Data module for protein binding affinity prediction.
Supports multiple splitting strategies:
1. split_indices_dir: Load pre-computed cluster-based splits (RECOMMENDED)
2. use_cluster_split: Create new cluster-based splits on the fly
3. split column: Use existing 'split' column in CSV (legacy)
4. num_folds > 1: GroupKFold on PDB IDs
"""
def __init__(
self,
data_csv: str,
protein_featurizer: ESM3Featurizer,
embedding_dir: str = "precomputed_esm",
batch_size: int = 32,
num_workers: int = 4,
shuffle: bool = True,
num_folds: int = 1,
fold_index: int = 0,
# New cluster-based splitting options
split_indices_dir: str = None, # Path to pre-computed split indices
benchmark_indices_dir: str = None, # Path to balanced benchmark subset indices (optional override)
use_cluster_split: bool = False, # Create cluster-based splits on the fly
train_ratio: float = 0.70,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
random_state: int = 42
):
super().__init__()
self.data_csv = data_csv
self.featurizer = protein_featurizer
self.embedding_dir = embedding_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.num_folds = num_folds
self.fold_index = fold_index
# Cluster-based splitting options
self.split_indices_dir = split_indices_dir
self.benchmark_indices_dir = benchmark_indices_dir # Optional balanced benchmark override
self.use_cluster_split = use_cluster_split
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
self.random_state = random_state
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
# Dual-split datasets (separate for dG and ddG heads)
self.dg_train_dataset = None # WT-only training set for Stage A
self.ddg_train_dataset = None # Mutation training set for Stage B
self.dg_val_dataset = None
self.dg_test_dataset = None
self.ddg_val_dataset = None
self.ddg_test_dataset = None
self.use_dual_split = False
def prepare_data(self):
if not os.path.exists(self.data_csv):
raise FileNotFoundError(f"Data CSV not found => {self.data_csv}")
def setup(self, stage=None):
data = pd.read_csv(self.data_csv, low_memory=False)
# Check if this is a dual-split directory
dual_split_file = os.path.join(self.split_indices_dir, 'dg_val_indices.csv') if self.split_indices_dir else None
# Strategy 0: Load DUAL splits (separate for dG and ddG heads)
if self.split_indices_dir and dual_split_file and os.path.exists(dual_split_file):
from data_splitting import load_dual_splits
print(f"\n[DataModule] Loading DUAL splits from {self.split_indices_dir}")
splits = load_dual_splits(self.split_indices_dir)
self.use_dual_split = True
# Combined training set (union of dG and ddG train indices)
train_idx = splits['combined_train']
train_df = data.iloc[train_idx].reset_index(drop=True)
# For backward compatibility, use ddG validation as default val set
# (since most validation is on mutation data)
val_idx = splits['ddg']['val']
val_df = data.iloc[val_idx].reset_index(drop=True)
test_idx = splits['ddg']['test']
test_df = data.iloc[test_idx].reset_index(drop=True)
# Create separate datasets for each head
# CRITICAL: Create separate dG (WT-only) and ddG (MT-only) TRAINING sets
# This fixes Stage A WT starvation where WT is diluted to 2.75% in combined_train
dg_train_df = data.iloc[splits['dg']['train']].reset_index(drop=True)
ddg_train_df = data.iloc[splits['ddg']['train']].reset_index(drop=True)
dg_val_df = data.iloc[splits['dg']['val']].reset_index(drop=True)
dg_test_df = data.iloc[splits['dg']['test']].reset_index(drop=True)
ddg_val_df = data.iloc[splits['ddg']['val']].reset_index(drop=True)
ddg_test_df = data.iloc[splits['ddg']['test']].reset_index(drop=True)
print(f"\n[DataModule] Creating dG TRAIN dataset ({len(dg_train_df)} WT rows)...")
self.dg_train_dataset = AdvancedSiameseDataset(dg_train_df, self.featurizer, self.embedding_dir, augment=False) # Baseline: no augment
print(f"[DataModule] Creating ddG TRAIN dataset ({len(ddg_train_df)} MT rows)...")
self.ddg_train_dataset = AdvancedSiameseDataset(ddg_train_df, self.featurizer, self.embedding_dir, augment=False) # Baseline: no augment
# === BALANCED BENCHMARK OVERRIDE ===
# If benchmark_indices_dir is provided, use those for ddG val/test instead
if self.benchmark_indices_dir and os.path.exists(self.benchmark_indices_dir):
print(f"\n[DataModule] Loading BALANCED BENCHMARK indices from {self.benchmark_indices_dir}")
# Load ddG benchmark val indices
ddg_val_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_val_benchmark_indices.csv')
if os.path.exists(ddg_val_bench_file):
bench_val_idx = pd.read_csv(ddg_val_bench_file, header=None).iloc[:, 0].values.tolist()
ddg_val_df = data.iloc[bench_val_idx].reset_index(drop=True)
print(f" ddG val: {len(ddg_val_df)} rows (balanced benchmark)")
# Load ddG benchmark test indices
ddg_test_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_test_benchmark_indices.csv')
if os.path.exists(ddg_test_bench_file):
bench_test_idx = pd.read_csv(ddg_test_bench_file, header=None).iloc[:, 0].values.tolist()
ddg_test_df = data.iloc[bench_test_idx].reset_index(drop=True)
print(f" ddG test: {len(ddg_test_df)} rows (balanced benchmark)")
print(f"\n[DataModule] Creating dG val dataset ({len(dg_val_df)} rows)...")
# NOTE: Do NOT subsample validation - we want accurate metrics on full set
self.dg_val_dataset = AdvancedSiameseDataset(
dg_val_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=data # FIX: Use full data for WT lookup (robust to split boundaries)
)
print(f"\n[DataModule] Creating dG test dataset ({len(dg_test_df)} rows)...")
self.dg_test_dataset = AdvancedSiameseDataset(
dg_test_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=data # FIX: Use full data for WT lookup
)
print(f"\n[DataModule] Creating ddG val dataset ({len(ddg_val_df)} rows)...")
# NOTE: Do NOT subsample validation - we want accurate metrics on full set
# cap_k only applies to training DMS data
self.ddg_val_dataset = AdvancedSiameseDataset(
ddg_val_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=data # FIX: Use full data for WT lookup
)
print(f"\n[DataModule] Creating ddG test dataset ({len(ddg_test_df)} rows)...")
self.ddg_test_dataset = AdvancedSiameseDataset(
ddg_test_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=data # FIX: Use full data for WT lookup
)
print(f"\n[DataModule] Dual split datasets created:")
print(f" dG train: {len(self.dg_train_dataset)} samples (WT-only for Stage A)")
print(f" ddG train: {len(self.ddg_train_dataset)} samples (MT-only)")
print(f" dG val: {len(self.dg_val_dataset)} samples")
print(f" dG test: {len(self.dg_test_dataset)} samples")
print(f" ddG val: {len(self.ddg_val_dataset)} samples")
print(f" ddG test: {len(self.ddg_test_dataset)} samples")
# Strategy 1: Load pre-computed cluster-based splits (single split)
elif self.split_indices_dir and os.path.exists(self.split_indices_dir):
from data_splitting import load_split_indices, verify_no_leakage
train_idx, val_idx, test_idx = load_split_indices(self.split_indices_dir)
train_df = data.iloc[train_idx].reset_index(drop=True)
val_df = data.iloc[val_idx].reset_index(drop=True)
test_df = data.iloc[test_idx].reset_index(drop=True)
# Verify no leakage
verify_no_leakage(data, train_idx, val_idx, test_idx)
# Strategy 2: Create cluster-based splits on the fly
elif self.use_cluster_split:
from data_splitting import create_cluster_splits, verify_no_leakage
# Create splits directory if needed
splits_dir = os.path.join(os.path.dirname(self.data_csv), 'splits')
train_idx, val_idx, test_idx = create_cluster_splits(
data,
train_ratio=self.train_ratio,
val_ratio=self.val_ratio,
test_ratio=self.test_ratio,
random_state=self.random_state,
save_dir=splits_dir
)
train_df = data.iloc[train_idx].reset_index(drop=True)
val_df = data.iloc[val_idx].reset_index(drop=True)
test_df = data.iloc[test_idx].reset_index(drop=True)
# Strategy 3: Legacy - use 'split' column in CSV
else:
# must have block1_sequence, block1_mut_positions, block2_sequence, ...
bench_df = data[data["split"]=="Benchmark test"].copy()
trainval_df = data[data["split"]!="Benchmark test"].copy()
if self.num_folds > 1:
gkf = GroupKFold(n_splits=self.num_folds)
groups = trainval_df["#Pdb"].values
folds = list(gkf.split(trainval_df, groups=groups))
train_idx, val_idx = folds[self.fold_index]
train_df = trainval_df.iloc[train_idx].reset_index(drop=True)
val_df = trainval_df.iloc[val_idx].reset_index(drop=True)
else:
train_df = trainval_df[trainval_df["split"]=="train"].reset_index(drop=True)
val_df = trainval_df[trainval_df["split"]=="val"].reset_index(drop=True)
test_df = bench_df
print(f"\n[DataModule] Creating TRAIN dataset ({len(train_df)} rows)...")
self.train_dataset = AdvancedSiameseDataset(
train_df, self.featurizer, self.embedding_dir, augment=False # Baseline: no augment (enable later for antisymmetry)
)
print(f"\n[DataModule] Creating VAL dataset ({len(val_df)} rows)...")
# Subsampling disabled for v20 ablation to ensure robust Macro-PCC evaluation
# (need full diversity of PDB families for honest reporting)
self.val_dataset = AdvancedSiameseDataset(
val_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=train_df # Pass training set as source for WTs
)
print(f"\n[DataModule] Creating TEST dataset ({len(test_df)} rows)...")
self.test_dataset = AdvancedSiameseDataset(
test_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=train_df # Pass training set as source for WTs (no leakage, WTs are known)
)
# FIX: Create separate dg_test and ddg_test datasets for proper test metric logging
# This is CRITICAL for sweep runs - without this, test metrics are never computed!
if self.dg_test_dataset is None and self.ddg_test_dataset is None:
# Determine WT/MT based on Mutation(s)_cleaned column
def is_wt_row(row):
mut_str = str(row.get('Mutation(s)_cleaned', '')).strip()
return mut_str == '' or mut_str.lower() == 'nan' or mut_str == 'WT'
# Separate test_df into WT (for dG test) and MT (for ddG test)
test_is_wt = test_df.apply(is_wt_row, axis=1)
dg_test_df = test_df[test_is_wt].reset_index(drop=True)
ddg_test_df = test_df[~test_is_wt].reset_index(drop=True)
if len(dg_test_df) > 0:
print(f"\n[DataModule] Creating dG TEST dataset ({len(dg_test_df)} WT rows)...")
self.dg_test_dataset = AdvancedSiameseDataset(
dg_test_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=data # Use full data for WT lookup
)
else:
print(f"[DataModule] WARNING: No WT rows in test set for dG test dataset!")
if len(ddg_test_df) > 0:
print(f"\n[DataModule] Creating ddG TEST dataset ({len(ddg_test_df)} MT rows)...")
self.ddg_test_dataset = AdvancedSiameseDataset(
ddg_test_df, self.featurizer, self.embedding_dir, augment=False,
wt_reference_df=data # Use full data for WT lookup
)
else:
print(f"[DataModule] WARNING: No MT rows in test set for ddG test dataset!")
# Log dataset sizes
print(f"\nDataset sizes:")
print(f" Train: {len(self.train_dataset)} samples")
print(f" Val: {len(self.val_dataset)} samples")
print(f" Test: {len(self.test_dataset)} samples")
if self.dg_test_dataset:
print(f" dG Test: {len(self.dg_test_dataset)} samples (WT)")
if self.ddg_test_dataset:
print(f" ddG Test: {len(self.ddg_test_dataset)} samples (MT)")
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
# Dual-split training dataloaders for separate dG-only (Stage A) and ddG (Stage B) training
def dg_train_dataloader(self):
"""Training dataloader for dG head (WT data only for Stage A pretraining)."""
if self.dg_train_dataset is None:
return None
return DataLoader(
self.dg_train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
def ddg_train_dataloader(self):
"""Training dataloader for ddG head (mutation data for Stage B training)."""
if self.ddg_train_dataset is None:
return None
return DataLoader(
self.ddg_train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
# Dual-split dataloaders for separate dG and ddG validation
def dg_val_dataloader(self):
"""Validation dataloader for dG head (WT data only)."""
if self.dg_val_dataset is None:
return None
return DataLoader(
self.dg_val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
def dg_test_dataloader(self):
"""Test dataloader for dG head (WT data only)."""
if self.dg_test_dataset is None:
return None
return DataLoader(
self.dg_test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
def ddg_val_dataloader(self):
"""Validation dataloader for ddG head (mutation data including DMS)."""
if self.ddg_val_dataset is None:
return None
return DataLoader(
self.ddg_val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)
def ddg_test_dataloader(self):
"""Test dataloader for ddG head (mutation data including DMS)."""
if self.ddg_test_dataset is None:
return None
return DataLoader(
self.ddg_test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=advanced_collate_fn
)