|
|
from typing import Tuple, List, Dict, Optional |
|
|
from pathlib import Path |
|
|
import os |
|
|
import math |
|
|
import pandas as pd |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
import numpy as np |
|
|
import hashlib |
|
|
import json |
|
|
import time |
|
|
from esm3bedding import ESM3Featurizer |
|
|
from utils import get_logger |
|
|
|
|
|
logg = get_logger() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SOURCE_TYPE_MAP = { |
|
|
|
|
|
'PDBbind': 'protein_complex', |
|
|
'PPIKB': 'protein_complex', |
|
|
'asd_biomap': 'protein_complex', |
|
|
'asd_aae': 'protein_complex', |
|
|
'asd_aatp': 'protein_complex', |
|
|
'asd_osh': 'protein_complex', |
|
|
|
|
|
'SKEMPI': 'mutation', |
|
|
'BindingGym': 'mutation', |
|
|
'asd_flab_koenig2017': 'mutation', |
|
|
'asd_flab_warszawski2019': 'mutation', |
|
|
'asd_flab_rosace2023': 'mutation', |
|
|
'PEPBI': 'mutation', |
|
|
|
|
|
'asd_abbd': 'antibody_cdr', |
|
|
'abdesign': 'antibody_cdr', |
|
|
'asd_flab_hie2022': 'antibody_cdr', |
|
|
'asd_flab_shanehsazzadeh2023': 'antibody_cdr', |
|
|
} |
|
|
SOURCE_TYPE_TO_ID = {'protein_complex': 0, 'mutation': 1, 'antibody_cdr': 2} |
|
|
DEFAULT_SOURCE_TYPE = 'mutation' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def advanced_collate_fn(batch): |
|
|
mut_c1_list, mut_c2_list, mut_y_list = [], [], [] |
|
|
wt_c1_list, wt_c2_list, wt_y_list = [], [], [] |
|
|
has_valid_wt_list = [] |
|
|
meta_list = [] |
|
|
|
|
|
for data, meta in batch: |
|
|
(c1, c2, y, cw1, cw2, yw) = data |
|
|
|
|
|
mut_c1_list.append(c1) |
|
|
mut_c2_list.append(c2) |
|
|
mut_y_list.append(torch.tensor([y], dtype=torch.float32)) |
|
|
|
|
|
if cw1 is not None and cw2 is not None and yw is not None: |
|
|
wt_c1_list.append(cw1) |
|
|
wt_c2_list.append(cw2) |
|
|
wt_y_list.append(torch.tensor([yw], dtype=torch.float32)) |
|
|
has_valid_wt_list.append(True) |
|
|
else: |
|
|
|
|
|
wt_c1_list.append(torch.zeros((1, c1.shape[1]))) |
|
|
wt_c2_list.append(torch.zeros((1, c2.shape[1]))) |
|
|
wt_y_list.append(torch.tensor([0.0], dtype=torch.float32)) |
|
|
has_valid_wt_list.append(False) |
|
|
|
|
|
meta_list.append(meta) |
|
|
|
|
|
|
|
|
c1_padded = pad_sequence(mut_c1_list, batch_first=True) |
|
|
c2_padded = pad_sequence(mut_c2_list, batch_first=True) |
|
|
|
|
|
B = c1_padded.shape[0] |
|
|
N1 = c1_padded.shape[1] |
|
|
N2 = c2_padded.shape[1] |
|
|
c1_mask_list, c2_mask_list = [], [] |
|
|
for i in range(B): |
|
|
l1 = mut_c1_list[i].shape[0] |
|
|
l2 = mut_c2_list[i].shape[0] |
|
|
m1 = [True]*l1 + [False]*(N1-l1) |
|
|
m2 = [True]*l2 + [False]*(N2-l2) |
|
|
c1_mask_list.append(torch.tensor(m1, dtype=torch.bool)) |
|
|
c2_mask_list.append(torch.tensor(m2, dtype=torch.bool)) |
|
|
c1_mask = torch.stack(c1_mask_list, dim=0) |
|
|
c2_mask = torch.stack(c2_mask_list, dim=0) |
|
|
y_mut = torch.cat(mut_y_list, dim=0) |
|
|
|
|
|
|
|
|
w1_padded = pad_sequence(wt_c1_list, batch_first=True) |
|
|
w2_padded = pad_sequence(wt_c2_list, batch_first=True) |
|
|
N1w = w1_padded.shape[1] |
|
|
N2w = w2_padded.shape[1] |
|
|
w1_mask_list, w2_mask_list = [], [] |
|
|
for i in range(B): |
|
|
l1 = wt_c1_list[i].shape[0] |
|
|
l2 = wt_c2_list[i].shape[0] |
|
|
m1 = [True]*l1 + [False]*(N1w-l1) |
|
|
m2 = [True]*l2 + [False]*(N2w-l2) |
|
|
w1_mask_list.append(torch.tensor(m1, dtype=torch.bool)) |
|
|
w2_mask_list.append(torch.tensor(m2, dtype=torch.bool)) |
|
|
w1_mask = torch.stack(w1_mask_list, dim=0) |
|
|
w2_mask = torch.stack(w2_mask_list, dim=0) |
|
|
y_wt = torch.cat(wt_y_list, dim=0) |
|
|
|
|
|
has_wt_list = [] |
|
|
is_wt_list = [] |
|
|
has_dg_list = [] |
|
|
has_ddg_list = [] |
|
|
has_inferred_ddg_list = [] |
|
|
has_both_list = [] |
|
|
ddg_list = [] |
|
|
ddg_inferred_list = [] |
|
|
|
|
|
|
|
|
n_has_ddg_true = 0 |
|
|
n_ddg_zero = 0 |
|
|
n_ddg_nan = 0 |
|
|
|
|
|
for i in range(B): |
|
|
|
|
|
has_wt_list.append(meta_list[i].get("has_any_wt", meta_list[i].get("has_real_wt", False))) |
|
|
is_wt_list.append(meta_list[i].get("is_wt", False)) |
|
|
has_dg_list.append(meta_list[i].get("has_dg", False)) |
|
|
|
|
|
has_explicit_ddg = meta_list[i].get("has_ddg", False) |
|
|
has_inferred_ddg_flag = meta_list[i].get("has_inferred_ddg", False) |
|
|
|
|
|
has_ddg_flag = has_explicit_ddg or has_inferred_ddg_flag |
|
|
has_ddg_list.append(has_ddg_flag) |
|
|
has_inferred_ddg_list.append(has_inferred_ddg_flag) |
|
|
has_both_list.append(meta_list[i].get("has_both_dg_ddg", False)) |
|
|
|
|
|
|
|
|
ddg_val = meta_list[i].get("ddg", float('nan')) |
|
|
ddg_inf_val = meta_list[i].get("ddg_inferred", float('nan')) |
|
|
is_explicit_nan = ddg_val != ddg_val |
|
|
is_inferred_nan = ddg_inf_val != ddg_inf_val |
|
|
|
|
|
|
|
|
if has_explicit_ddg: |
|
|
n_has_ddg_true += 1 |
|
|
if is_explicit_nan: |
|
|
n_ddg_nan += 1 |
|
|
elif abs(ddg_val) < 1e-8: |
|
|
n_ddg_zero += 1 |
|
|
|
|
|
|
|
|
if not is_explicit_nan: |
|
|
ddg_list.append(ddg_val) |
|
|
elif not is_inferred_nan: |
|
|
ddg_list.append(ddg_inf_val) |
|
|
else: |
|
|
ddg_list.append(0.0) |
|
|
|
|
|
ddg_inferred_list.append(ddg_inf_val if not is_inferred_nan else 0.0) |
|
|
|
|
|
|
|
|
if n_has_ddg_true > 0 and (n_ddg_nan > 0 or n_ddg_zero > B // 2): |
|
|
print(f"[COLLATE DEBUG] Batch has_ddg stats: {n_has_ddg_true}/{B} have has_ddg=True, " |
|
|
f"{n_ddg_nan} have NaN ddg (BUG!), {n_ddg_zero} have ddg≈0") |
|
|
|
|
|
has_wt = torch.tensor(has_wt_list, dtype=torch.bool) |
|
|
has_valid_wt = torch.tensor(has_valid_wt_list, dtype=torch.bool) |
|
|
is_wt = torch.tensor(is_wt_list, dtype=torch.bool) |
|
|
has_dg = torch.tensor(has_dg_list, dtype=torch.bool) |
|
|
has_ddg = torch.tensor(has_ddg_list, dtype=torch.bool) |
|
|
has_inferred_ddg = torch.tensor(has_inferred_ddg_list, dtype=torch.bool) |
|
|
has_both_dg_ddg = torch.tensor(has_both_list, dtype=torch.bool) |
|
|
ddg_labels = torch.tensor(ddg_list, dtype=torch.float32) |
|
|
ddg_inferred_labels = torch.tensor(ddg_inferred_list, dtype=torch.float32) |
|
|
|
|
|
|
|
|
n_valid_wt = has_valid_wt.sum().item() |
|
|
n_has_wt = has_wt.sum().item() |
|
|
if n_has_wt > 0 and n_valid_wt < n_has_wt: |
|
|
print(f"[COLLATE DEBUG] WT validity: {n_valid_wt}/{n_has_wt} have valid WT embeddings " |
|
|
f"({n_has_wt - n_valid_wt} samples have zero-fallback and will be EXCLUDED from ddG training)") |
|
|
|
|
|
|
|
|
data_source_list = [meta_list[i].get("data_source", "unknown") for i in range(B)] |
|
|
|
|
|
|
|
|
source_type_id_list = [] |
|
|
for i in range(B): |
|
|
data_src = meta_list[i].get("data_source", "unknown") |
|
|
source_type = SOURCE_TYPE_MAP.get(data_src, DEFAULT_SOURCE_TYPE) |
|
|
source_type_id = SOURCE_TYPE_TO_ID[source_type] |
|
|
source_type_id_list.append(source_type_id) |
|
|
source_type_ids = torch.tensor(source_type_id_list, dtype=torch.long) |
|
|
|
|
|
out = { |
|
|
"mutant": (c1_padded, c1_mask, c2_padded, c2_mask, y_mut), |
|
|
"wildtype": (w1_padded, w1_mask, w2_padded, w2_mask, y_wt), |
|
|
"has_wt": has_wt, |
|
|
"has_valid_wt": has_valid_wt, |
|
|
"is_wt": is_wt, |
|
|
"has_dg": has_dg, |
|
|
"has_ddg": has_ddg, |
|
|
"has_inferred_ddg": has_inferred_ddg, |
|
|
"has_both_dg_ddg": has_both_dg_ddg, |
|
|
"ddg_labels": ddg_labels, |
|
|
"ddg_inferred_labels": ddg_inferred_labels, |
|
|
"data_source": data_source_list, |
|
|
"source_type_ids": source_type_ids, |
|
|
"metadata": meta_list |
|
|
} |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdvancedSiameseDataset(Dataset): |
|
|
""" |
|
|
Dataset that handles mutation positions with a simple indicator channel. |
|
|
|
|
|
Reads columns: |
|
|
#Pdb, block1_sequence, block1_mut_positions, block1_mutations, |
|
|
block2_sequence, block2_mut_positions, block2_mutations, del_g, ... |
|
|
""" |
|
|
def __init__(self, df: pd.DataFrame, featurizer: ESM3Featurizer, embedding_dir: str, |
|
|
normalize_embeddings=True, augment=False, max_len=1022, |
|
|
wt_reference_df: pd.DataFrame = None): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.wt_reference_df = wt_reference_df if wt_reference_df is not None else None |
|
|
initial_len = len(df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
long_mask = (df["block1_sequence"].astype(str).str.len() > max_len) | \ |
|
|
(df["block2_sequence"].astype(str).str.len() > max_len) |
|
|
n_long = long_mask.sum() |
|
|
|
|
|
if n_long > 0: |
|
|
print(f" [Dataset] Truncating {n_long} samples with length > {max_len} to maintain index alignment (CRITICAL FIX).") |
|
|
|
|
|
|
|
|
df = df.copy() |
|
|
df.loc[long_mask, "block1_sequence"] = df.loc[long_mask, "block1_sequence"].astype(str).str.slice(0, max_len) |
|
|
df.loc[long_mask, "block2_sequence"] = df.loc[long_mask, "block2_sequence"].astype(str).str.slice(0, max_len) |
|
|
|
|
|
|
|
|
self.df = df.reset_index(drop=True) |
|
|
|
|
|
|
|
|
try: |
|
|
cols = set(self.df.columns.tolist()) |
|
|
need = {"block1_mut_positions", "block2_mut_positions", "Mutation(s)_PDB"} |
|
|
missing = sorted(list(need - cols)) |
|
|
payload = { |
|
|
"sessionId": "debug-session", |
|
|
"runId": "pre-fix", |
|
|
"hypothesisId": "G", |
|
|
"location": "modules.py:AdvancedSiameseDataset:__init__", |
|
|
"message": "Dataset columns presence check for mutation positions", |
|
|
"data": { |
|
|
"n_rows": int(len(self.df)), |
|
|
"has_block1_mut_positions": "block1_mut_positions" in cols, |
|
|
"has_block2_mut_positions": "block2_mut_positions" in cols, |
|
|
"has_mutation_pdb": "Mutation(s)_PDB" in cols, |
|
|
"missing": missing, |
|
|
}, |
|
|
"timestamp": int(time.time() * 1000), |
|
|
} |
|
|
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: |
|
|
f.write(json.dumps(payload, default=str) + "\n") |
|
|
print(f"[AGENTLOG MUTPOSCOLS] missing={missing}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if not hasattr(self, "_agent_embed_call_counter"): |
|
|
self._agent_embed_call_counter = 0 |
|
|
if self._agent_embed_call_counter < 10: |
|
|
self._agent_embed_call_counter += 1 |
|
|
print( |
|
|
f"[AGENTLOG EMBCALL] idx={idx} role=mut " |
|
|
f"b1_mutpos_n={len(b1_mutpos)} b2_mutpos_n={len(b2_mutpos)} " |
|
|
f"seq1_len={len(item.get('seq1',''))} seq2_len={len(item.get('seq2',''))}" |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
self.df = self._recover_antibody_wts(self.df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if augment: |
|
|
|
|
|
mut_mask = self.df["Mutation(s)_PDB"].notna() & (self.df["Mutation(s)_PDB"].str.strip() != "") |
|
|
mutant_df = self.df[mut_mask].copy() |
|
|
|
|
|
if len(mutant_df) > 0: |
|
|
|
|
|
rev_df = mutant_df.copy() |
|
|
|
|
|
if "ddg" in rev_df.columns: |
|
|
rev_df["ddg"] = -rev_df["ddg"] |
|
|
rev_df["is_reverse"] = True |
|
|
|
|
|
|
|
|
self.df["is_reverse"] = False |
|
|
self.df = pd.concat([self.df, rev_df], ignore_index=True) |
|
|
print(f" [Dataset] Augmented: added {len(rev_df)} reversed mutant samples (antisymmetry training)") |
|
|
else: |
|
|
self.df["is_reverse"] = False |
|
|
else: |
|
|
self.df["is_reverse"] = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.df["pair_id"] = ( |
|
|
self.df["#Pdb"].astype(str) + "_" + |
|
|
self.df["Mutation(s)_cleaned"].fillna("") |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.featurizer = featurizer |
|
|
self.embedding_dir = Path(embedding_dir) |
|
|
self.embedding_dir.mkdir(exist_ok=True, parents=True) |
|
|
self.normalize = normalize_embeddings |
|
|
|
|
|
self.samples = [] |
|
|
self._embedding_cache = {} |
|
|
self._cache_max_size = 20000 |
|
|
self._cache_hits = 0 |
|
|
self._cache_misses = 0 |
|
|
|
|
|
|
|
|
print(f" [Dataset] Building WT index for {len(self.df)} rows...") |
|
|
self.pdb_to_wt = {} |
|
|
for i, row in self.df.iterrows(): |
|
|
pdb = row["#Pdb"] |
|
|
mut_str = row.get("Mutation(s)_PDB","") |
|
|
is_wt = (pd.isna(mut_str) or mut_str.strip()=="") |
|
|
if is_wt and pdb not in self.pdb_to_wt: |
|
|
self.pdb_to_wt[pdb] = i |
|
|
|
|
|
|
|
|
self.external_pdb_to_wt = {} |
|
|
if self.wt_reference_df is not None: |
|
|
print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...") |
|
|
for i, row in self.wt_reference_df.iterrows(): |
|
|
|
|
|
mut_str = row.get("Mutation(s)_PDB","") |
|
|
is_wt = (pd.isna(mut_str) or mut_str.strip()=="") |
|
|
if 'is_wt' in row: |
|
|
is_wt = is_wt or row['is_wt'] |
|
|
|
|
|
pdb = row["#Pdb"] |
|
|
if is_wt and pdb not in self.external_pdb_to_wt: |
|
|
self.external_pdb_to_wt[pdb] = i |
|
|
print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.") |
|
|
|
|
|
|
|
|
self.external_pdb_to_wt = {} |
|
|
if self.wt_reference_df is not None: |
|
|
print(f" [Dataset] Building external WT index from {len(self.wt_reference_df)} reference rows...") |
|
|
for i, row in self.wt_reference_df.iterrows(): |
|
|
|
|
|
mut_str = row.get("Mutation(s)_PDB","") |
|
|
is_wt = (pd.isna(mut_str) or mut_str.strip()=="") |
|
|
|
|
|
if 'is_wt' in row: |
|
|
is_wt = is_wt or row['is_wt'] |
|
|
|
|
|
pdb = row["#Pdb"] |
|
|
if is_wt and pdb not in self.external_pdb_to_wt: |
|
|
self.external_pdb_to_wt[pdb] = i |
|
|
print(f" [Dataset] Indexed {len(self.external_pdb_to_wt)} external WTs.") |
|
|
|
|
|
|
|
|
|
|
|
print(f" [Dataset] Building sample metadata for {len(self.df)} rows (lazy loading)...") |
|
|
from tqdm import tqdm |
|
|
for i, row in tqdm(self.df.iterrows(), total=len(self.df), desc=" Indexing"): |
|
|
|
|
|
if hasattr(self, '_last_computed_mutpos'): |
|
|
del self._last_computed_mutpos |
|
|
|
|
|
pdb = row["#Pdb"] |
|
|
seq1 = row["block1_sequence"] |
|
|
seq2 = row["block2_sequence"] |
|
|
|
|
|
|
|
|
data_source = row.get("data_source", "unknown") |
|
|
|
|
|
|
|
|
raw_delg = row["del_g"] |
|
|
delg = float(raw_delg) if pd.notna(raw_delg) and raw_delg != '' else float('nan') |
|
|
|
|
|
|
|
|
raw_ddg = row.get("ddg", None) |
|
|
ddg = float(raw_ddg) if pd.notna(raw_ddg) and raw_ddg != '' else float('nan') |
|
|
|
|
|
|
|
|
b1_mutpos_str = row.get("block1_mut_positions","[]") |
|
|
b2_mutpos_str = row.get("block2_mut_positions","[]") |
|
|
|
|
|
|
|
|
if i < 5: |
|
|
print(f"DEBUG ROW {i}: b1='{b1_mutpos_str}' ({type(b1_mutpos_str)}), b2='{b2_mutpos_str}' ({type(b2_mutpos_str)})") |
|
|
|
|
|
try: |
|
|
payload = { |
|
|
"sessionId": "debug-session", |
|
|
"runId": "pre-fix", |
|
|
"hypothesisId": "G", |
|
|
"location": "modules.py:AdvancedSiameseDataset:__init__:row0_4", |
|
|
"message": "Raw mutpos strings from df row (first few)", |
|
|
"data": { |
|
|
"i": int(i), |
|
|
"b1_mutpos_str": str(b1_mutpos_str), |
|
|
"b2_mutpos_str": str(b2_mutpos_str), |
|
|
"mutation_pdb": str(row.get("Mutation(s)_PDB", "")), |
|
|
}, |
|
|
"timestamp": int(time.time() * 1000), |
|
|
} |
|
|
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: |
|
|
f.write(json.dumps(payload, default=str) + "\n") |
|
|
print(f"[AGENTLOG MUTPOSRAW] i={i} b1={b1_mutpos_str} b2={b2_mutpos_str} mut={row.get('Mutation(s)_PDB','')}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
b1_chains = str(row.get("block1_chains", "")).upper() |
|
|
b2_chains = str(row.get("block2_chains", "")).upper() |
|
|
|
|
|
mut_str = row.get("Mutation(s)_PDB","") |
|
|
is_wt = (pd.isna(mut_str) or mut_str.strip()=="") |
|
|
wt_idx = self.pdb_to_wt.get(pdb, None) |
|
|
|
|
|
|
|
|
row_wt = None |
|
|
wt_source = None |
|
|
|
|
|
if not hasattr(self, '_wt_source_stats'): |
|
|
self._wt_source_stats = {'internal': 0, 'external': 0} |
|
|
|
|
|
if wt_idx is not None: |
|
|
row_wt = self.df.iloc[wt_idx] |
|
|
wt_source = 'internal' |
|
|
self._wt_source_stats['internal'] += 1 |
|
|
elif pdb in self.external_pdb_to_wt: |
|
|
ext_idx = self.external_pdb_to_wt[pdb] |
|
|
row_wt = self.wt_reference_df.iloc[ext_idx] |
|
|
wt_source = 'external' |
|
|
self._wt_source_stats['external'] += 1 |
|
|
|
|
|
if row_wt is not None: |
|
|
seq1_wt = row_wt["block1_sequence"] |
|
|
seq2_wt = row_wt["block2_sequence"] |
|
|
raw_delg_wt = row_wt["del_g"] |
|
|
delg_wt = float(raw_delg_wt) if pd.notna(raw_delg_wt) and raw_delg_wt != '' else float('nan') |
|
|
b1_wtpos_str = row_wt.get("block1_mut_positions","[]") |
|
|
b2_wtpos_str = row_wt.get("block2_mut_positions","[]") |
|
|
|
|
|
|
|
|
|
|
|
if not is_wt and (b1_mutpos_str in ["[]", "", "nan", "None"] and b2_mutpos_str in ["[]", "", "nan", "None"]): |
|
|
|
|
|
|
|
|
|
|
|
self._infer_wt_sequences( |
|
|
seq1, seq2, mut_str, "[]", "[]", |
|
|
b1_chains, b2_chains |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(self, '_last_computed_mutpos'): |
|
|
comp_b1, comp_b2 = self._last_computed_mutpos |
|
|
if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1: |
|
|
b1_mutpos_str = str(comp_b1) |
|
|
if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2: |
|
|
b2_mutpos_str = str(comp_b2) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
seq1_wt, seq2_wt = self._infer_wt_sequences( |
|
|
seq1, seq2, mut_str, b1_mutpos_str, b2_mutpos_str, |
|
|
b1_chains, b2_chains |
|
|
) |
|
|
delg_wt = float('nan') |
|
|
b1_wtpos_str, b2_wtpos_str = "[]", "[]" |
|
|
|
|
|
|
|
|
if hasattr(self, '_last_computed_mutpos'): |
|
|
comp_b1, comp_b2 = self._last_computed_mutpos |
|
|
if b1_mutpos_str in ["[]", "", "nan", "None"] and comp_b1: |
|
|
b1_mutpos_str = str(comp_b1) |
|
|
if b2_mutpos_str in ["[]", "", "nan", "None"] and comp_b2: |
|
|
b2_mutpos_str = str(comp_b2) |
|
|
|
|
|
|
|
|
has_dg = not (delg != delg) |
|
|
has_ddg = not (ddg != ddg) |
|
|
has_both = has_dg and has_ddg |
|
|
|
|
|
|
|
|
|
|
|
has_dg_wt = not (delg_wt != delg_wt) |
|
|
has_inferred_ddg = has_dg and has_dg_wt and (not has_ddg) |
|
|
if has_inferred_ddg: |
|
|
ddg_inferred = delg - delg_wt |
|
|
else: |
|
|
ddg_inferred = float('nan') |
|
|
|
|
|
|
|
|
has_real_wt = (wt_idx is not None) |
|
|
has_inferred_wt = (wt_idx is None and seq1_wt is not None and seq2_wt is not None) |
|
|
has_any_wt = has_real_wt or has_inferred_wt |
|
|
|
|
|
|
|
|
is_reverse = row.get("is_reverse", False) |
|
|
|
|
|
|
|
|
if is_reverse: |
|
|
|
|
|
if seq1_wt is not None and seq2_wt is not None: |
|
|
seq1, seq1_wt = seq1_wt, seq1 |
|
|
seq2, seq2_wt = seq2_wt, seq2 |
|
|
|
|
|
delg, delg_wt = delg_wt, delg |
|
|
|
|
|
if not math.isnan(ddg_inferred): |
|
|
ddg_inferred = -ddg_inferred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.samples.append({ |
|
|
"pdb": pdb, |
|
|
"is_wt": is_wt, |
|
|
"is_reverse": is_reverse, |
|
|
"seq1": seq1, "seq2": seq2, "delg": delg, |
|
|
"seq1_wt": seq1_wt, "seq2_wt": seq2_wt, "delg_wt": delg_wt, |
|
|
"ddg": ddg, |
|
|
"ddg_inferred": ddg_inferred, |
|
|
"has_dg": has_dg, |
|
|
"has_ddg": has_ddg, |
|
|
"has_inferred_ddg": has_inferred_ddg, |
|
|
"has_both_dg_ddg": has_both, |
|
|
"has_real_wt": has_real_wt, |
|
|
"has_inferred_wt": has_inferred_wt, |
|
|
"has_any_wt": has_any_wt, |
|
|
"b1_mutpos_str": b1_mutpos_str, |
|
|
"b2_mutpos_str": b2_mutpos_str, |
|
|
"b1_wtpos_str": b1_wtpos_str, |
|
|
"b2_wtpos_str": b2_wtpos_str, |
|
|
"data_source": data_source |
|
|
}) |
|
|
|
|
|
|
|
|
n_real_wt = sum(1 for s in self.samples if s["has_real_wt"]) |
|
|
n_inferred_wt = sum(1 for s in self.samples if s["has_inferred_wt"]) |
|
|
n_no_wt = len(self.samples) - n_real_wt - n_inferred_wt |
|
|
|
|
|
|
|
|
if hasattr(self, '_wt_source_stats'): |
|
|
n_internal = self._wt_source_stats.get('internal', 0) |
|
|
n_external = self._wt_source_stats.get('external', 0) |
|
|
source_msg = f" (Internal: {n_internal}, External: {n_external})" |
|
|
else: |
|
|
source_msg = "" |
|
|
|
|
|
print(f" [Dataset] Ready! {len(self.samples)} samples indexed (embeddings loaded on-demand)") |
|
|
print(f" [Dataset] WT stats: {n_real_wt} real WT{source_msg}, {n_inferred_wt} inferred WT, {n_no_wt} no WT") |
|
|
|
|
|
|
|
|
if hasattr(self, '_wt_inference_failures') and hasattr(self, '_wt_inference_fail_count'): |
|
|
print(f" [Dataset] ⚠️ WT inference failed for {self._wt_inference_fail_count} samples:") |
|
|
fail_dict = self._wt_inference_failures |
|
|
|
|
|
|
|
|
n_no_pdb = len(fail_dict.get('no_pdb', [])) |
|
|
n_del_ins = len(fail_dict.get('del_ins_only', [])) |
|
|
n_parse = len(fail_dict.get('parse_fail', [])) |
|
|
|
|
|
if n_no_pdb > 0: |
|
|
print(f" - ANTIBODY samples (no PDB structure): {self._wt_inference_fail_count} samples") |
|
|
print(f" (These are antibody design samples without original PDB - only dG usable)") |
|
|
elif n_del_ins > 0 or n_parse > 0: |
|
|
print(f" - DEL/INS/stop-codon (can't reverse): counted") |
|
|
print(f" - Parsing failed (unknown format): counted") |
|
|
|
|
|
|
|
|
if fail_dict.get('parse_fail') and n_no_pdb == 0: |
|
|
print(f" Sample parse failures:") |
|
|
for mut in fail_dict['parse_fail'][:5]: |
|
|
print(f" '{mut}'") |
|
|
|
|
|
def _parse_mutpos(self, pos_str) -> List[int]: |
|
|
""" |
|
|
pos_str might be '[]' or '[170, 172]' etc. |
|
|
We'll do a simple parse. |
|
|
""" |
|
|
|
|
|
if pos_str is None or (isinstance(pos_str, float) and str(pos_str) == 'nan'): |
|
|
return [] |
|
|
if not isinstance(pos_str, str): |
|
|
pos_str = str(pos_str) |
|
|
pos_str = pos_str.strip() |
|
|
if pos_str.startswith("[") and pos_str.endswith("]"): |
|
|
inside = pos_str[1:-1].strip() |
|
|
if not inside: |
|
|
return [] |
|
|
|
|
|
arr = inside.split(",") |
|
|
out = [] |
|
|
for x in arr: |
|
|
x_ = x.strip() |
|
|
if x_: |
|
|
out.append(int(x_)) |
|
|
return out |
|
|
return [] |
|
|
|
|
|
def _recover_antibody_wts(self, df: pd.DataFrame) -> pd.DataFrame: |
|
|
""" |
|
|
Recover WT information for antibody samples (ANTIBODY_MUTATION) |
|
|
by finding the closest-to-consensus sequence in each antigen group. |
|
|
|
|
|
Strategy: |
|
|
1. Identify samples with 'ANTIBODY_MUTATION' |
|
|
2. Group by antigen (block2_sequence) |
|
|
3. Assign unique Pseudo-PDB ID to each group (e.g. ANTIBODY_GRP_xxx) |
|
|
4. For same-length groups: find sequence closest to consensus as WT |
|
|
5. For variable-length groups: fallback to best binder (lowest del_g) |
|
|
6. Mark selected sequence as WT (clear mutation string) |
|
|
""" |
|
|
from collections import Counter |
|
|
|
|
|
|
|
|
mask = df['Mutation(s)_PDB'].astype(str).str.contains('ANTIBODY_MUTATION', na=False) |
|
|
|
|
|
if not mask.any(): |
|
|
return df |
|
|
|
|
|
print(f" [Dataset] Attempting to recover WT for {mask.sum()} antibody samples...") |
|
|
|
|
|
recovered_count = 0 |
|
|
n_groups = 0 |
|
|
n_consensus = 0 |
|
|
n_median = 0 |
|
|
n_fallback = 0 |
|
|
|
|
|
|
|
|
df = df.copy() |
|
|
|
|
|
|
|
|
df['temp_antigen_hash'] = df['block2_sequence'].apply(lambda x: hashlib.md5(str(x).encode()).hexdigest()) |
|
|
|
|
|
|
|
|
ab_hashes = df.loc[mask, 'temp_antigen_hash'].unique() |
|
|
|
|
|
for h in ab_hashes: |
|
|
|
|
|
grp_mask = mask & (df['temp_antigen_hash'] == h) |
|
|
grp_indices = df.index[grp_mask] |
|
|
|
|
|
if len(grp_indices) == 0: |
|
|
continue |
|
|
|
|
|
n_groups += 1 |
|
|
|
|
|
|
|
|
pseudo_pdb = f"ANTIBODY_GRP_{h[:8]}" |
|
|
df.loc[grp_indices, '#Pdb'] = pseudo_pdb |
|
|
|
|
|
|
|
|
seqs = df.loc[grp_indices, 'block1_sequence'].tolist() |
|
|
seq_lens = set(len(s) for s in seqs) |
|
|
|
|
|
wt_idx = None |
|
|
|
|
|
if len(seq_lens) == 1: |
|
|
|
|
|
seq_len = list(seq_lens)[0] |
|
|
|
|
|
|
|
|
consensus = [] |
|
|
for pos in range(seq_len): |
|
|
residues = [s[pos] for s in seqs] |
|
|
counts = Counter(residues) |
|
|
most_common = counts.most_common(1)[0][0] |
|
|
consensus.append(most_common) |
|
|
consensus_seq = ''.join(consensus) |
|
|
|
|
|
|
|
|
min_dist = float('inf') |
|
|
for idx in grp_indices: |
|
|
seq = df.at[idx, 'block1_sequence'] |
|
|
dist = sum(c1 != c2 for c1, c2 in zip(seq, consensus_seq)) |
|
|
if dist < min_dist: |
|
|
min_dist = dist |
|
|
wt_idx = idx |
|
|
|
|
|
n_consensus += 1 |
|
|
else: |
|
|
|
|
|
if 'del_g' in df.columns: |
|
|
delg_vals = pd.to_numeric(df.loc[grp_indices, 'del_g'], errors='coerce').dropna() |
|
|
if len(delg_vals) > 0: |
|
|
|
|
|
median_val = delg_vals.median() |
|
|
median_idx = (delg_vals - median_val).abs().idxmin() |
|
|
wt_idx = median_idx |
|
|
n_median += 1 |
|
|
|
|
|
|
|
|
if wt_idx is None and len(grp_indices) > 0: |
|
|
wt_idx = grp_indices[0] |
|
|
n_fallback += 1 |
|
|
|
|
|
|
|
|
if wt_idx is not None: |
|
|
df.at[wt_idx, 'Mutation(s)_PDB'] = "" |
|
|
recovered_count += len(grp_indices) |
|
|
|
|
|
|
|
|
df.drop(columns=['temp_antigen_hash'], inplace=True, errors='ignore') |
|
|
|
|
|
print(f" [Dataset] Recovered {recovered_count} antibody samples ({n_groups} groups):") |
|
|
print(f" - {n_consensus} groups via closest-to-consensus") |
|
|
print(f" - {n_median} groups via median-binder (variable-length)") |
|
|
if n_fallback > 0: |
|
|
print(f" - {n_fallback} groups via first-sample fallback (no dG data)") |
|
|
return df |
|
|
|
|
|
def _infer_wt_sequences(self, mut_seq1: str, mut_seq2: str, mutation_str: str, |
|
|
b1_mutpos_str: str, b2_mutpos_str: str, |
|
|
b1_chains: str = "", b2_chains: str = "") -> Tuple[Optional[str], Optional[str]]: |
|
|
""" |
|
|
Infer wildtype sequences by reversing mutations in the mutant sequences. |
|
|
|
|
|
IMPROVED: Instead of relying on PDB positions (which don't match 0-indexed |
|
|
sequence positions), this version searches for the mutant residue and |
|
|
reverses it. Also computes actual mutation positions as byproduct. |
|
|
|
|
|
Mutations are in formats like: |
|
|
- BindingGym: "H:P53L" or "H:P53L,H:Y57C" (chain:WTresPOSmutres) |
|
|
- SKEMPI: "HP53L" or "CA182A" (chainWTresPOSmutres) |
|
|
|
|
|
Args: |
|
|
mut_seq1: Mutant sequence for block1 |
|
|
mut_seq2: Mutant sequence for block2 |
|
|
mutation_str: Raw mutation string from data |
|
|
b1_mutpos_str: Mutation positions for block1 (e.g., "[52, 56]") |
|
|
b2_mutpos_str: Mutation positions for block2 |
|
|
b1_chains: Chain letters in block1 (e.g., "AB") |
|
|
b2_chains: Chain letters in block2 (e.g., "HL") |
|
|
|
|
|
Returns: |
|
|
Tuple of (wt_seq1, wt_seq2) or (None, None) if inference fails |
|
|
""" |
|
|
import re |
|
|
|
|
|
if pd.isna(mutation_str) or str(mutation_str).strip() == '': |
|
|
|
|
|
return mut_seq1, mut_seq2 |
|
|
|
|
|
|
|
|
mutation_str_upper = str(mutation_str).strip().upper() |
|
|
if 'ANTIBODY_MUTATION' in mutation_str_upper or mutation_str_upper == 'ANTIBODY_MUTATION': |
|
|
if not hasattr(self, '_wt_inference_failures'): |
|
|
self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'no_pdb': [], 'other': []} |
|
|
self._wt_inference_fail_count = 0 |
|
|
self._wt_inference_fail_count += 1 |
|
|
if len(self._wt_inference_failures['no_pdb']) < 5: |
|
|
self._wt_inference_failures['no_pdb'].append(mutation_str[:80]) |
|
|
return None, None |
|
|
|
|
|
try: |
|
|
|
|
|
mutations = [] |
|
|
mutation_str = str(mutation_str).strip() |
|
|
|
|
|
|
|
|
parts = re.split(r'[,;]', mutation_str) |
|
|
|
|
|
for part in parts: |
|
|
part = part.strip().strip('"\'') |
|
|
if not part: |
|
|
continue |
|
|
|
|
|
|
|
|
if 'DEL' in part.upper() or 'INS' in part.upper() or '*' in part: |
|
|
continue |
|
|
|
|
|
|
|
|
if ':' in part: |
|
|
chain_mut = part.split(':') |
|
|
if len(chain_mut) >= 2: |
|
|
chain = chain_mut[0].strip().upper() |
|
|
for mut_part in chain_mut[1:]: |
|
|
mut_part = mut_part.strip() |
|
|
if not mut_part: |
|
|
continue |
|
|
match = re.match(r'([A-Z])(\d+)([A-Z])', mut_part) |
|
|
if match: |
|
|
wt_aa = match.group(1) |
|
|
pos = int(match.group(2)) |
|
|
mut_aa = match.group(3) |
|
|
mutations.append((chain, pos, wt_aa, mut_aa)) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
match = re.match(r'([A-Z])([A-Z])(-?\d+[a-z]?)([A-Z])', part) |
|
|
if match: |
|
|
wt_aa = match.group(1) |
|
|
chain = match.group(2).upper() |
|
|
pos_str = match.group(3) |
|
|
pos = int(re.match(r'-?\d+', pos_str).group()) |
|
|
mut_aa = match.group(4) |
|
|
mutations.append((chain, pos, wt_aa, mut_aa)) |
|
|
else: |
|
|
|
|
|
|
|
|
match = re.match(r'([A-Z])(\d+)([A-Z])', part) |
|
|
if match: |
|
|
wt_aa = match.group(1) |
|
|
pos = int(match.group(2)) |
|
|
mut_aa = match.group(3) |
|
|
|
|
|
mutations.append(('?', pos, wt_aa, mut_aa)) |
|
|
|
|
|
if not mutations: |
|
|
if not hasattr(self, '_wt_inference_failures'): |
|
|
self._wt_inference_failures = {'parse_fail': [], 'del_ins_only': [], 'other': []} |
|
|
self._wt_inference_fail_count = 0 |
|
|
self._wt_inference_fail_count += 1 |
|
|
|
|
|
if 'DEL' in mutation_str.upper() or 'INS' in mutation_str.upper() or '*' in mutation_str: |
|
|
category = 'del_ins_only' |
|
|
else: |
|
|
category = 'parse_fail' |
|
|
|
|
|
if len(self._wt_inference_failures.get(category, [])) < 10: |
|
|
self._wt_inference_failures.setdefault(category, []).append(mutation_str[:80]) |
|
|
|
|
|
return None, None |
|
|
|
|
|
|
|
|
wt_seq1_list = list(mut_seq1) if mut_seq1 else [] |
|
|
wt_seq2_list = list(mut_seq2) if mut_seq2 else [] |
|
|
|
|
|
|
|
|
b1_chain_set = set(b1_chains.upper()) if b1_chains else set() |
|
|
b2_chain_set = set(b2_chains.upper()) if b2_chains else set() |
|
|
|
|
|
|
|
|
|
|
|
precomputed_b1_positions = self._parse_mutpos(b1_mutpos_str) |
|
|
precomputed_b2_positions = self._parse_mutpos(b2_mutpos_str) |
|
|
|
|
|
|
|
|
if not hasattr(self, '_wt_inference_stats'): |
|
|
self._wt_inference_stats = {'reversed': 0, 'not_found': 0, 'total': 0} |
|
|
|
|
|
|
|
|
found_positions_b1 = [] |
|
|
found_positions_b2 = [] |
|
|
|
|
|
|
|
|
|
|
|
if precomputed_b1_positions or precomputed_b2_positions: |
|
|
pos_idx = 0 |
|
|
for chain, pdb_pos, wt_aa, mut_aa in mutations: |
|
|
self._wt_inference_stats['total'] += 1 |
|
|
reversed_this = False |
|
|
|
|
|
|
|
|
if chain in b2_chain_set: |
|
|
|
|
|
if pos_idx < len(precomputed_b2_positions): |
|
|
seq_idx = precomputed_b2_positions[pos_idx] |
|
|
if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa: |
|
|
wt_seq2_list[seq_idx] = wt_aa |
|
|
reversed_this = True |
|
|
found_positions_b2.append(seq_idx) |
|
|
elif chain in b1_chain_set: |
|
|
|
|
|
if pos_idx < len(precomputed_b1_positions): |
|
|
seq_idx = precomputed_b1_positions[pos_idx] |
|
|
if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa: |
|
|
wt_seq1_list[seq_idx] = wt_aa |
|
|
reversed_this = True |
|
|
found_positions_b1.append(seq_idx) |
|
|
else: |
|
|
|
|
|
if pos_idx < len(precomputed_b1_positions): |
|
|
seq_idx = precomputed_b1_positions[pos_idx] |
|
|
if 0 <= seq_idx < len(wt_seq1_list) and wt_seq1_list[seq_idx] == mut_aa: |
|
|
wt_seq1_list[seq_idx] = wt_aa |
|
|
reversed_this = True |
|
|
found_positions_b1.append(seq_idx) |
|
|
if not reversed_this and pos_idx < len(precomputed_b2_positions): |
|
|
seq_idx = precomputed_b2_positions[pos_idx] |
|
|
if 0 <= seq_idx < len(wt_seq2_list) and wt_seq2_list[seq_idx] == mut_aa: |
|
|
wt_seq2_list[seq_idx] = wt_aa |
|
|
reversed_this = True |
|
|
found_positions_b2.append(seq_idx) |
|
|
|
|
|
if reversed_this: |
|
|
self._wt_inference_stats['reversed'] += 1 |
|
|
else: |
|
|
self._wt_inference_stats['not_found'] += 1 |
|
|
pos_idx += 1 |
|
|
|
|
|
self._last_computed_mutpos = (found_positions_b1, found_positions_b2) |
|
|
return ''.join(wt_seq1_list), ''.join(wt_seq2_list) |
|
|
|
|
|
|
|
|
for chain, pdb_pos, wt_aa, mut_aa in mutations: |
|
|
self._wt_inference_stats['total'] += 1 |
|
|
reversed_this = False |
|
|
found_idx = None |
|
|
|
|
|
|
|
|
chain_known = chain in b1_chain_set or chain in b2_chain_set |
|
|
|
|
|
if chain in b1_chain_set: |
|
|
blocks_to_try = [(wt_seq1_list, True, found_positions_b1)] |
|
|
elif chain in b2_chain_set: |
|
|
blocks_to_try = [(wt_seq2_list, False, found_positions_b2)] |
|
|
else: |
|
|
|
|
|
blocks_to_try = [ |
|
|
(wt_seq1_list, True, found_positions_b1), |
|
|
(wt_seq2_list, False, found_positions_b2) |
|
|
] |
|
|
|
|
|
for target_seq, is_block1, pos_list in blocks_to_try: |
|
|
if reversed_this: |
|
|
break |
|
|
|
|
|
guess_idx = pdb_pos - 1 |
|
|
|
|
|
|
|
|
if 0 <= guess_idx < len(target_seq) and target_seq[guess_idx] == mut_aa: |
|
|
found_idx = guess_idx |
|
|
else: |
|
|
|
|
|
search_start = max(0, pdb_pos - 50) |
|
|
search_end = min(len(target_seq), pdb_pos + 50) |
|
|
for idx in range(search_start, search_end): |
|
|
if target_seq[idx] == mut_aa: |
|
|
found_idx = idx |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if found_idx is None and not chain_known: |
|
|
if guess_idx >= len(target_seq) or guess_idx < 0: |
|
|
|
|
|
for idx in range(len(target_seq)): |
|
|
if target_seq[idx] == mut_aa: |
|
|
found_idx = idx |
|
|
break |
|
|
|
|
|
if found_idx is not None: |
|
|
target_seq[found_idx] = wt_aa |
|
|
reversed_this = True |
|
|
pos_list.append(found_idx) |
|
|
|
|
|
if reversed_this: |
|
|
self._wt_inference_stats['reversed'] += 1 |
|
|
else: |
|
|
self._wt_inference_stats['not_found'] += 1 |
|
|
|
|
|
|
|
|
|
|
|
self._last_computed_mutpos = (found_positions_b1, found_positions_b2) |
|
|
|
|
|
return ''.join(wt_seq1_list), ''.join(wt_seq2_list) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return None, None |
|
|
|
|
|
def _get_embedding(self, seq: str, mut_positions: List[int]) -> torch.Tensor: |
|
|
""" |
|
|
Basic embedding with mutation position indicator channel. |
|
|
|
|
|
Args: |
|
|
seq: The protein sequence |
|
|
mut_positions: List of positions that are mutated (0-indexed) |
|
|
""" |
|
|
|
|
|
base_emb = self._get_or_create_embedding(seq) |
|
|
base_emb = base_emb.cpu() |
|
|
|
|
|
|
|
|
L, D = base_emb.shape |
|
|
|
|
|
|
|
|
try: |
|
|
if not hasattr(self, "_agent_log_counter"): |
|
|
self._agent_log_counter = 0 |
|
|
if self._agent_log_counter < 5: |
|
|
self._agent_log_counter += 1 |
|
|
last1_stats = None |
|
|
last2_stats = None |
|
|
if D >= 1153: |
|
|
v1 = base_emb[:, -1] |
|
|
last1_stats = { |
|
|
"min": float(v1.min().item()), |
|
|
"max": float(v1.max().item()), |
|
|
"mean": float(v1.float().mean().item()), |
|
|
"std": float(v1.float().std().item()), |
|
|
} |
|
|
if D >= 1154: |
|
|
v2 = base_emb[:, -2] |
|
|
last2_stats = { |
|
|
"min": float(v2.min().item()), |
|
|
"max": float(v2.max().item()), |
|
|
"mean": float(v2.float().mean().item()), |
|
|
"std": float(v2.float().std().item()), |
|
|
} |
|
|
payload = { |
|
|
"sessionId": "debug-session", |
|
|
"runId": "pre-fix", |
|
|
"hypothesisId": "F", |
|
|
"location": "modules.py:AdvancedSiameseDataset:_get_embedding", |
|
|
"message": "Base embedding shape + tail-channel stats before appending mutation indicator", |
|
|
"data": { |
|
|
"L": int(L), |
|
|
"D": int(D), |
|
|
"mut_positions_n": int(len(mut_positions) if mut_positions is not None else -1), |
|
|
"mut_positions_first5": (mut_positions[:5] if mut_positions else []), |
|
|
"base_last1": last1_stats, |
|
|
"base_last2": last2_stats, |
|
|
}, |
|
|
"timestamp": int(time.time() * 1000), |
|
|
} |
|
|
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: |
|
|
f.write(json.dumps(payload, default=str) + "\n") |
|
|
|
|
|
print(f"[AGENTLOG EMB] D={D} mut_n={len(mut_positions) if mut_positions else 0} last1={last1_stats} last2={last2_stats}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
D = base_emb.shape[-1] |
|
|
L = base_emb.shape[0] |
|
|
|
|
|
if D == 1153: |
|
|
|
|
|
new_emb = base_emb.clone() |
|
|
new_emb[:, -1] = 0.0 |
|
|
for pos in mut_positions: |
|
|
if isinstance(pos, int) and 0 <= pos < L: |
|
|
new_emb[pos, -1] = 1.0 |
|
|
print(f"[AGENTLOG INDICATOR-FIX] D=1153 OVERWRITING last channel with {len(mut_positions)} positions") |
|
|
else: |
|
|
|
|
|
chan = torch.zeros((L, 1), dtype=base_emb.dtype, device=base_emb.device) |
|
|
for pos in mut_positions: |
|
|
if isinstance(pos, int) and 0 <= pos < L: |
|
|
chan[pos, 0] = 1.0 |
|
|
new_emb = torch.cat([base_emb, chan], dim=-1) |
|
|
print(f"[AGENTLOG INDICATOR-FIX] D={D} APPENDING indicator channel with {len(mut_positions)} positions") |
|
|
|
|
|
return new_emb |
|
|
|
|
|
def _get_or_create_embedding(self, seq: str) -> torch.Tensor: |
|
|
|
|
|
if seq in self._embedding_cache: |
|
|
self._cache_hits += 1 |
|
|
return self._embedding_cache[seq].clone() |
|
|
|
|
|
seq_hash = hashlib.md5(seq.encode()).hexdigest() |
|
|
pt_file = self.embedding_dir / f"{seq_hash}.pt" |
|
|
npy_file = self.embedding_dir / f"{seq_hash}.npy" |
|
|
|
|
|
emb = None |
|
|
load_source = None |
|
|
|
|
|
|
|
|
if npy_file.is_file(): |
|
|
try: |
|
|
import numpy as np |
|
|
emb = torch.from_numpy(np.load(npy_file)) |
|
|
load_source = "npy" |
|
|
except Exception: |
|
|
pass |
|
|
if emb is None and pt_file.is_file(): |
|
|
try: |
|
|
emb = torch.load(pt_file, map_location="cpu") |
|
|
load_source = "pt" |
|
|
except Exception: |
|
|
pt_file.unlink(missing_ok=True) |
|
|
if emb is None: |
|
|
|
|
|
|
|
|
try: |
|
|
emb = self.featurizer.transform(seq) |
|
|
|
|
|
torch.save(emb, pt_file) |
|
|
load_source = "generated" |
|
|
|
|
|
|
|
|
if not hasattr(self, '_on_the_fly_count'): |
|
|
self._on_the_fly_count = 0 |
|
|
self._on_the_fly_count += 1 |
|
|
|
|
|
|
|
|
if self._on_the_fly_count <= 5: |
|
|
print(f"[EMBEDDING] Generated on-the-fly #{self._on_the_fly_count}: len={len(seq)}, saved to {pt_file.name}") |
|
|
elif self._on_the_fly_count == 6: |
|
|
print(f"[EMBEDDING] Generated 5+ embeddings on-the-fly (suppressing further logs)") |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Embedding not found and on-the-fly generation failed for sequence (len={len(seq)}): {e}" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
if not hasattr(self, "_agent_embload_counter"): |
|
|
self._agent_embload_counter = 0 |
|
|
if self._agent_embload_counter < 8: |
|
|
self._agent_embload_counter += 1 |
|
|
shape = tuple(int(x) for x in emb.shape) |
|
|
D = int(shape[1]) if len(shape) == 2 else None |
|
|
payload = { |
|
|
"sessionId": "debug-session", |
|
|
"runId": "pre-fix", |
|
|
"hypothesisId": "A", |
|
|
"location": "modules.py:AdvancedSiameseDataset:_get_or_create_embedding", |
|
|
"message": "Loaded embedding tensor (source + shape) before any indicator is appended", |
|
|
"data": { |
|
|
"load_source": load_source, |
|
|
"seq_len": int(len(seq)), |
|
|
"shape": shape, |
|
|
"D": D, |
|
|
"looks_like_has_indicator": bool(D is not None and D >= 1153), |
|
|
"file_pt_exists": bool(pt_file.is_file()), |
|
|
"file_npy_exists": bool(npy_file.is_file()), |
|
|
}, |
|
|
"timestamp": int(time.time() * 1000), |
|
|
} |
|
|
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: |
|
|
f.write(json.dumps(payload, default=str) + "\n") |
|
|
print(f"[AGENTLOG EMBLOAD] src={load_source} shape={shape} D={D}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if emb.shape[0] < 5: |
|
|
|
|
|
repeats = (5 // emb.shape[0]) + 1 |
|
|
emb = emb.repeat(repeats, 1)[:5] |
|
|
|
|
|
|
|
|
self._cache_misses += 1 |
|
|
|
|
|
|
|
|
if len(self._embedding_cache) >= self._cache_max_size: |
|
|
|
|
|
oldest_key = next(iter(self._embedding_cache)) |
|
|
del self._embedding_cache[oldest_key] |
|
|
self._embedding_cache[seq] = emb |
|
|
|
|
|
return emb.clone() |
|
|
|
|
|
def get_cache_stats(self): |
|
|
"""Return cache statistics.""" |
|
|
total = self._cache_hits + self._cache_misses |
|
|
hit_rate = (self._cache_hits / total * 100) if total > 0 else 0 |
|
|
on_the_fly = getattr(self, '_on_the_fly_count', 0) |
|
|
wt_missing = getattr(self, '_wt_missing_count', 0) |
|
|
return { |
|
|
"hits": self._cache_hits, |
|
|
"misses": self._cache_misses, |
|
|
"total": total, |
|
|
"hit_rate": hit_rate, |
|
|
"cache_size": len(self._embedding_cache), |
|
|
"cache_max": self._cache_max_size, |
|
|
"on_the_fly_generated": on_the_fly, |
|
|
"wt_embedding_failed": wt_missing |
|
|
} |
|
|
|
|
|
def print_cache_stats(self): |
|
|
"""Print cache statistics.""" |
|
|
stats = self.get_cache_stats() |
|
|
print(f" [Cache] Hits: {stats['hits']:,} | Misses: {stats['misses']:,} | " |
|
|
f"Hit Rate: {stats['hit_rate']:.1f}% | Size: {stats['cache_size']:,}/{stats['cache_max']:,}") |
|
|
if stats['on_the_fly_generated'] > 0: |
|
|
print(f" [Cache] On-the-fly generated: {stats['on_the_fly_generated']:,} embeddings") |
|
|
if stats['wt_embedding_failed'] > 0: |
|
|
print(f" [Cache] ⚠️ WT embedding failures: {stats['wt_embedding_failed']:,} (excluded from ddG training)") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.samples[idx] |
|
|
|
|
|
|
|
|
if not hasattr(self, '_seq_diff_stats'): |
|
|
self._seq_diff_stats = {'same': 0, 'different': 0, 'no_wt': 0} |
|
|
if not hasattr(self, '_mutpos_stats'): |
|
|
self._mutpos_stats = {'has_mutpos': 0, 'no_mutpos': 0} |
|
|
|
|
|
|
|
|
b1_mutpos = self._parse_mutpos(item["b1_mutpos_str"]) |
|
|
b2_mutpos = self._parse_mutpos(item["b2_mutpos_str"]) |
|
|
|
|
|
|
|
|
try: |
|
|
if not hasattr(self, "_agent_mutpos_getitem_counter"): |
|
|
self._agent_mutpos_getitem_counter = 0 |
|
|
if self._agent_mutpos_getitem_counter < 20: |
|
|
self._agent_mutpos_getitem_counter += 1 |
|
|
payload = { |
|
|
"sessionId": "debug-session", |
|
|
"runId": "pre-fix", |
|
|
"hypothesisId": "G", |
|
|
"location": "modules.py:AdvancedSiameseDataset:__getitem__", |
|
|
"message": "Parsed mut_positions passed to _get_embedding", |
|
|
"data": { |
|
|
"idx": int(idx), |
|
|
"pdb": str(item.get("pdb")), |
|
|
"is_wt": bool(item.get("is_wt")), |
|
|
"b1_mutpos_str": str(item.get("b1_mutpos_str")), |
|
|
"b2_mutpos_str": str(item.get("b2_mutpos_str")), |
|
|
"b1_mutpos_n": int(len(b1_mutpos)), |
|
|
"b2_mutpos_n": int(len(b2_mutpos)), |
|
|
"b1_mutpos_first5": b1_mutpos[:5], |
|
|
"b2_mutpos_first5": b2_mutpos[:5], |
|
|
}, |
|
|
"timestamp": int(time.time() * 1000), |
|
|
} |
|
|
with open("/Users/supantha/Documents/code_v2/protein/.cursor/debug.log", "a") as f: |
|
|
f.write(json.dumps(payload, default=str) + "\n") |
|
|
print(f"[AGENTLOG MUTPOSGET] idx={idx} b1n={len(b1_mutpos)} b2n={len(b2_mutpos)} b1str={item.get('b1_mutpos_str')} b2str={item.get('b2_mutpos_str')}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if len(b1_mutpos) > 0 or len(b2_mutpos) > 0: |
|
|
self._mutpos_stats['has_mutpos'] += 1 |
|
|
else: |
|
|
self._mutpos_stats['no_mutpos'] += 1 |
|
|
|
|
|
|
|
|
total = sum(self._mutpos_stats.values()) |
|
|
if total in [100, 1000, 10000]: |
|
|
has_mp = self._mutpos_stats['has_mutpos'] |
|
|
no_mp = self._mutpos_stats['no_mutpos'] |
|
|
print(f" [MUTPOS] After {total} samples: {has_mp} have mutation positions ({100*has_mp/total:.1f}%), " |
|
|
f"{no_mp} have NO mutation positions ({100*no_mp/total:.1f}%)") |
|
|
|
|
|
c1_emb = self._get_embedding(item["seq1"], b1_mutpos) |
|
|
c2_emb = self._get_embedding(item["seq2"], b2_mutpos) |
|
|
|
|
|
if self.normalize: |
|
|
c1_emb[:, :-1] = torch.nn.functional.normalize(c1_emb[:, :-1], p=2, dim=-1) |
|
|
c2_emb[:, :-1] = torch.nn.functional.normalize(c2_emb[:, :-1], p=2, dim=-1) |
|
|
|
|
|
|
|
|
if item["seq1_wt"] is not None: |
|
|
|
|
|
seq1_same = (item["seq1"] == item["seq1_wt"]) |
|
|
seq2_same = (item["seq2"] == item["seq2_wt"]) |
|
|
if seq1_same and seq2_same: |
|
|
self._seq_diff_stats['same'] += 1 |
|
|
else: |
|
|
self._seq_diff_stats['different'] += 1 |
|
|
|
|
|
|
|
|
total_samples = sum(self._seq_diff_stats.values()) |
|
|
if total_samples in [100, 1000, 10000, 50000]: |
|
|
same = self._seq_diff_stats['same'] |
|
|
diff = self._seq_diff_stats['different'] |
|
|
no_wt = self._seq_diff_stats['no_wt'] |
|
|
print(f" [SEQ DIFF] After {total_samples} samples: {same} same seq ({100*same/total_samples:.1f}%), " |
|
|
f"{diff} different ({100*diff/total_samples:.1f}%), {no_wt} no WT") |
|
|
|
|
|
b1_wtpos = self._parse_mutpos(item["b1_wtpos_str"]) |
|
|
b2_wtpos = self._parse_mutpos(item["b2_wtpos_str"]) |
|
|
|
|
|
|
|
|
try: |
|
|
if not hasattr(self, "_agent_embed_call_counter_wt"): |
|
|
self._agent_embed_call_counter_wt = 0 |
|
|
if self._agent_embed_call_counter_wt < 10: |
|
|
self._agent_embed_call_counter_wt += 1 |
|
|
print( |
|
|
f"[AGENTLOG EMBCALL] idx={idx} role=wt " |
|
|
f"b1_wtpos_n={len(b1_wtpos)} b2_wtpos_n={len(b2_wtpos)} " |
|
|
f"seq1_wt_len={len(item.get('seq1_wt','') or '')} seq2_wt_len={len(item.get('seq2_wt','') or '')}" |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
cw1 = self._get_embedding(item["seq1_wt"], b1_wtpos) |
|
|
cw2 = self._get_embedding(item["seq2_wt"], b2_wtpos) |
|
|
except RuntimeError as e: |
|
|
|
|
|
|
|
|
|
|
|
cw1, cw2 = None, None |
|
|
if not hasattr(self, '_wt_missing_count'): |
|
|
self._wt_missing_count = 0 |
|
|
self._wt_missing_count += 1 |
|
|
if self._wt_missing_count <= 3: |
|
|
print(f" [WARN] WT embedding missing #{self._wt_missing_count}, sample will be WT-less: {e}") |
|
|
|
|
|
if cw1 is not None and self.normalize: |
|
|
cw1[:, :-1] = torch.nn.functional.normalize(cw1[:, :-1], p=2, dim=-1) |
|
|
cw2[:, :-1] = torch.nn.functional.normalize(cw2[:, :-1], p=2, dim=-1) |
|
|
else: |
|
|
cw1, cw2 = None, None |
|
|
self._seq_diff_stats['no_wt'] += 1 |
|
|
|
|
|
data_tuple = (c1_emb, c2_emb, item["delg"], |
|
|
cw1, cw2, item["delg_wt"]) |
|
|
meta = { |
|
|
"pdb": item["pdb"], |
|
|
"is_wt": item["is_wt"], |
|
|
"has_real_wt": item["has_real_wt"], |
|
|
"has_dg": item["has_dg"], |
|
|
"has_ddg": item["has_ddg"], |
|
|
"has_inferred_ddg": item["has_inferred_ddg"], |
|
|
"has_both_dg_ddg": item["has_both_dg_ddg"], |
|
|
"ddg": item["ddg"], |
|
|
"ddg_inferred": item["ddg_inferred"], |
|
|
"has_any_wt": item["has_any_wt"], |
|
|
"b1_mutpos": b1_mutpos, |
|
|
"b2_mutpos": b2_mutpos, |
|
|
"data_source": item["data_source"] |
|
|
} |
|
|
return (data_tuple, meta) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.model_selection import GroupKFold |
|
|
|
|
|
class AffinityDataModule(pl.LightningDataModule): |
|
|
""" |
|
|
Data module for protein binding affinity prediction. |
|
|
|
|
|
Supports multiple splitting strategies: |
|
|
1. split_indices_dir: Load pre-computed cluster-based splits (RECOMMENDED) |
|
|
2. use_cluster_split: Create new cluster-based splits on the fly |
|
|
3. split column: Use existing 'split' column in CSV (legacy) |
|
|
4. num_folds > 1: GroupKFold on PDB IDs |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
data_csv: str, |
|
|
protein_featurizer: ESM3Featurizer, |
|
|
embedding_dir: str = "precomputed_esm", |
|
|
batch_size: int = 32, |
|
|
num_workers: int = 4, |
|
|
shuffle: bool = True, |
|
|
num_folds: int = 1, |
|
|
fold_index: int = 0, |
|
|
|
|
|
split_indices_dir: str = None, |
|
|
benchmark_indices_dir: str = None, |
|
|
use_cluster_split: bool = False, |
|
|
train_ratio: float = 0.70, |
|
|
val_ratio: float = 0.15, |
|
|
test_ratio: float = 0.15, |
|
|
random_state: int = 42 |
|
|
): |
|
|
super().__init__() |
|
|
self.data_csv = data_csv |
|
|
self.featurizer = protein_featurizer |
|
|
self.embedding_dir = embedding_dir |
|
|
self.batch_size = batch_size |
|
|
self.num_workers = num_workers |
|
|
self.shuffle = shuffle |
|
|
self.num_folds = num_folds |
|
|
self.fold_index = fold_index |
|
|
|
|
|
|
|
|
self.split_indices_dir = split_indices_dir |
|
|
self.benchmark_indices_dir = benchmark_indices_dir |
|
|
self.use_cluster_split = use_cluster_split |
|
|
self.train_ratio = train_ratio |
|
|
self.val_ratio = val_ratio |
|
|
self.test_ratio = test_ratio |
|
|
self.random_state = random_state |
|
|
|
|
|
self.train_dataset = None |
|
|
self.val_dataset = None |
|
|
self.test_dataset = None |
|
|
|
|
|
|
|
|
self.dg_train_dataset = None |
|
|
self.ddg_train_dataset = None |
|
|
self.dg_val_dataset = None |
|
|
self.dg_test_dataset = None |
|
|
self.ddg_val_dataset = None |
|
|
self.ddg_test_dataset = None |
|
|
self.use_dual_split = False |
|
|
|
|
|
def prepare_data(self): |
|
|
if not os.path.exists(self.data_csv): |
|
|
raise FileNotFoundError(f"Data CSV not found => {self.data_csv}") |
|
|
|
|
|
def setup(self, stage=None): |
|
|
data = pd.read_csv(self.data_csv, low_memory=False) |
|
|
|
|
|
|
|
|
dual_split_file = os.path.join(self.split_indices_dir, 'dg_val_indices.csv') if self.split_indices_dir else None |
|
|
|
|
|
|
|
|
if self.split_indices_dir and dual_split_file and os.path.exists(dual_split_file): |
|
|
from data_splitting import load_dual_splits |
|
|
print(f"\n[DataModule] Loading DUAL splits from {self.split_indices_dir}") |
|
|
|
|
|
splits = load_dual_splits(self.split_indices_dir) |
|
|
self.use_dual_split = True |
|
|
|
|
|
|
|
|
train_idx = splits['combined_train'] |
|
|
train_df = data.iloc[train_idx].reset_index(drop=True) |
|
|
|
|
|
|
|
|
|
|
|
val_idx = splits['ddg']['val'] |
|
|
val_df = data.iloc[val_idx].reset_index(drop=True) |
|
|
test_idx = splits['ddg']['test'] |
|
|
test_df = data.iloc[test_idx].reset_index(drop=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dg_train_df = data.iloc[splits['dg']['train']].reset_index(drop=True) |
|
|
ddg_train_df = data.iloc[splits['ddg']['train']].reset_index(drop=True) |
|
|
|
|
|
dg_val_df = data.iloc[splits['dg']['val']].reset_index(drop=True) |
|
|
dg_test_df = data.iloc[splits['dg']['test']].reset_index(drop=True) |
|
|
ddg_val_df = data.iloc[splits['ddg']['val']].reset_index(drop=True) |
|
|
ddg_test_df = data.iloc[splits['ddg']['test']].reset_index(drop=True) |
|
|
|
|
|
print(f"\n[DataModule] Creating dG TRAIN dataset ({len(dg_train_df)} WT rows)...") |
|
|
self.dg_train_dataset = AdvancedSiameseDataset(dg_train_df, self.featurizer, self.embedding_dir, augment=False) |
|
|
|
|
|
print(f"[DataModule] Creating ddG TRAIN dataset ({len(ddg_train_df)} MT rows)...") |
|
|
self.ddg_train_dataset = AdvancedSiameseDataset(ddg_train_df, self.featurizer, self.embedding_dir, augment=False) |
|
|
|
|
|
|
|
|
|
|
|
if self.benchmark_indices_dir and os.path.exists(self.benchmark_indices_dir): |
|
|
print(f"\n[DataModule] Loading BALANCED BENCHMARK indices from {self.benchmark_indices_dir}") |
|
|
|
|
|
|
|
|
ddg_val_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_val_benchmark_indices.csv') |
|
|
if os.path.exists(ddg_val_bench_file): |
|
|
bench_val_idx = pd.read_csv(ddg_val_bench_file, header=None).iloc[:, 0].values.tolist() |
|
|
ddg_val_df = data.iloc[bench_val_idx].reset_index(drop=True) |
|
|
print(f" ddG val: {len(ddg_val_df)} rows (balanced benchmark)") |
|
|
|
|
|
|
|
|
ddg_test_bench_file = os.path.join(self.benchmark_indices_dir, 'ddg_test_benchmark_indices.csv') |
|
|
if os.path.exists(ddg_test_bench_file): |
|
|
bench_test_idx = pd.read_csv(ddg_test_bench_file, header=None).iloc[:, 0].values.tolist() |
|
|
ddg_test_df = data.iloc[bench_test_idx].reset_index(drop=True) |
|
|
print(f" ddG test: {len(ddg_test_df)} rows (balanced benchmark)") |
|
|
|
|
|
print(f"\n[DataModule] Creating dG val dataset ({len(dg_val_df)} rows)...") |
|
|
|
|
|
self.dg_val_dataset = AdvancedSiameseDataset( |
|
|
dg_val_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=data |
|
|
) |
|
|
|
|
|
print(f"\n[DataModule] Creating dG test dataset ({len(dg_test_df)} rows)...") |
|
|
self.dg_test_dataset = AdvancedSiameseDataset( |
|
|
dg_test_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=data |
|
|
) |
|
|
|
|
|
print(f"\n[DataModule] Creating ddG val dataset ({len(ddg_val_df)} rows)...") |
|
|
|
|
|
|
|
|
self.ddg_val_dataset = AdvancedSiameseDataset( |
|
|
ddg_val_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=data |
|
|
) |
|
|
|
|
|
print(f"\n[DataModule] Creating ddG test dataset ({len(ddg_test_df)} rows)...") |
|
|
self.ddg_test_dataset = AdvancedSiameseDataset( |
|
|
ddg_test_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=data |
|
|
) |
|
|
|
|
|
print(f"\n[DataModule] Dual split datasets created:") |
|
|
print(f" dG train: {len(self.dg_train_dataset)} samples (WT-only for Stage A)") |
|
|
print(f" ddG train: {len(self.ddg_train_dataset)} samples (MT-only)") |
|
|
print(f" dG val: {len(self.dg_val_dataset)} samples") |
|
|
print(f" dG test: {len(self.dg_test_dataset)} samples") |
|
|
print(f" ddG val: {len(self.ddg_val_dataset)} samples") |
|
|
print(f" ddG test: {len(self.ddg_test_dataset)} samples") |
|
|
|
|
|
|
|
|
elif self.split_indices_dir and os.path.exists(self.split_indices_dir): |
|
|
from data_splitting import load_split_indices, verify_no_leakage |
|
|
train_idx, val_idx, test_idx = load_split_indices(self.split_indices_dir) |
|
|
|
|
|
train_df = data.iloc[train_idx].reset_index(drop=True) |
|
|
val_df = data.iloc[val_idx].reset_index(drop=True) |
|
|
test_df = data.iloc[test_idx].reset_index(drop=True) |
|
|
|
|
|
|
|
|
verify_no_leakage(data, train_idx, val_idx, test_idx) |
|
|
|
|
|
|
|
|
elif self.use_cluster_split: |
|
|
from data_splitting import create_cluster_splits, verify_no_leakage |
|
|
|
|
|
|
|
|
splits_dir = os.path.join(os.path.dirname(self.data_csv), 'splits') |
|
|
|
|
|
train_idx, val_idx, test_idx = create_cluster_splits( |
|
|
data, |
|
|
train_ratio=self.train_ratio, |
|
|
val_ratio=self.val_ratio, |
|
|
test_ratio=self.test_ratio, |
|
|
random_state=self.random_state, |
|
|
save_dir=splits_dir |
|
|
) |
|
|
|
|
|
train_df = data.iloc[train_idx].reset_index(drop=True) |
|
|
val_df = data.iloc[val_idx].reset_index(drop=True) |
|
|
test_df = data.iloc[test_idx].reset_index(drop=True) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
bench_df = data[data["split"]=="Benchmark test"].copy() |
|
|
trainval_df = data[data["split"]!="Benchmark test"].copy() |
|
|
|
|
|
if self.num_folds > 1: |
|
|
gkf = GroupKFold(n_splits=self.num_folds) |
|
|
groups = trainval_df["#Pdb"].values |
|
|
folds = list(gkf.split(trainval_df, groups=groups)) |
|
|
train_idx, val_idx = folds[self.fold_index] |
|
|
train_df = trainval_df.iloc[train_idx].reset_index(drop=True) |
|
|
val_df = trainval_df.iloc[val_idx].reset_index(drop=True) |
|
|
else: |
|
|
train_df = trainval_df[trainval_df["split"]=="train"].reset_index(drop=True) |
|
|
val_df = trainval_df[trainval_df["split"]=="val"].reset_index(drop=True) |
|
|
|
|
|
test_df = bench_df |
|
|
|
|
|
print(f"\n[DataModule] Creating TRAIN dataset ({len(train_df)} rows)...") |
|
|
self.train_dataset = AdvancedSiameseDataset( |
|
|
train_df, self.featurizer, self.embedding_dir, augment=False |
|
|
) |
|
|
print(f"\n[DataModule] Creating VAL dataset ({len(val_df)} rows)...") |
|
|
|
|
|
|
|
|
self.val_dataset = AdvancedSiameseDataset( |
|
|
val_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=train_df |
|
|
) |
|
|
print(f"\n[DataModule] Creating TEST dataset ({len(test_df)} rows)...") |
|
|
self.test_dataset = AdvancedSiameseDataset( |
|
|
test_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=train_df |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.dg_test_dataset is None and self.ddg_test_dataset is None: |
|
|
|
|
|
def is_wt_row(row): |
|
|
mut_str = str(row.get('Mutation(s)_cleaned', '')).strip() |
|
|
return mut_str == '' or mut_str.lower() == 'nan' or mut_str == 'WT' |
|
|
|
|
|
|
|
|
test_is_wt = test_df.apply(is_wt_row, axis=1) |
|
|
dg_test_df = test_df[test_is_wt].reset_index(drop=True) |
|
|
ddg_test_df = test_df[~test_is_wt].reset_index(drop=True) |
|
|
|
|
|
if len(dg_test_df) > 0: |
|
|
print(f"\n[DataModule] Creating dG TEST dataset ({len(dg_test_df)} WT rows)...") |
|
|
self.dg_test_dataset = AdvancedSiameseDataset( |
|
|
dg_test_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=data |
|
|
) |
|
|
else: |
|
|
print(f"[DataModule] WARNING: No WT rows in test set for dG test dataset!") |
|
|
|
|
|
if len(ddg_test_df) > 0: |
|
|
print(f"\n[DataModule] Creating ddG TEST dataset ({len(ddg_test_df)} MT rows)...") |
|
|
self.ddg_test_dataset = AdvancedSiameseDataset( |
|
|
ddg_test_df, self.featurizer, self.embedding_dir, augment=False, |
|
|
wt_reference_df=data |
|
|
) |
|
|
else: |
|
|
print(f"[DataModule] WARNING: No MT rows in test set for ddG test dataset!") |
|
|
|
|
|
|
|
|
print(f"\nDataset sizes:") |
|
|
print(f" Train: {len(self.train_dataset)} samples") |
|
|
print(f" Val: {len(self.val_dataset)} samples") |
|
|
print(f" Test: {len(self.test_dataset)} samples") |
|
|
if self.dg_test_dataset: |
|
|
print(f" dG Test: {len(self.dg_test_dataset)} samples (WT)") |
|
|
if self.ddg_test_dataset: |
|
|
print(f" ddG Test: {len(self.ddg_test_dataset)} samples (MT)") |
|
|
|
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader( |
|
|
self.train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=self.shuffle, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader( |
|
|
self.val_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
def test_dataloader(self): |
|
|
return DataLoader( |
|
|
self.test_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
|
|
|
def dg_train_dataloader(self): |
|
|
"""Training dataloader for dG head (WT data only for Stage A pretraining).""" |
|
|
if self.dg_train_dataset is None: |
|
|
return None |
|
|
return DataLoader( |
|
|
self.dg_train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=self.shuffle, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
def ddg_train_dataloader(self): |
|
|
"""Training dataloader for ddG head (mutation data for Stage B training).""" |
|
|
if self.ddg_train_dataset is None: |
|
|
return None |
|
|
return DataLoader( |
|
|
self.ddg_train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=self.shuffle, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
|
|
|
def dg_val_dataloader(self): |
|
|
"""Validation dataloader for dG head (WT data only).""" |
|
|
if self.dg_val_dataset is None: |
|
|
return None |
|
|
return DataLoader( |
|
|
self.dg_val_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
def dg_test_dataloader(self): |
|
|
"""Test dataloader for dG head (WT data only).""" |
|
|
if self.dg_test_dataset is None: |
|
|
return None |
|
|
return DataLoader( |
|
|
self.dg_test_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
def ddg_val_dataloader(self): |
|
|
"""Validation dataloader for ddG head (mutation data including DMS).""" |
|
|
if self.ddg_val_dataset is None: |
|
|
return None |
|
|
return DataLoader( |
|
|
self.ddg_val_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |
|
|
|
|
|
def ddg_test_dataloader(self): |
|
|
"""Test dataloader for ddG head (mutation data including DMS).""" |
|
|
if self.ddg_test_dataset is None: |
|
|
return None |
|
|
return DataLoader( |
|
|
self.ddg_test_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=advanced_collate_fn |
|
|
) |