ym59 commited on
Commit
453d2cd
Β·
verified Β·
1 Parent(s): daf4247

Upload src/data/loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/data/loader.py +243 -0
src/data/loader.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/data/loader.py
2
+ #
3
+ # Loads LP-PDBBind and CASF-2016 into clean DataFrames.
4
+ # Output columns: pdb_id, seq, smiles, label
5
+
6
+ import pandas as pd
7
+ from pathlib import Path
8
+
9
+
10
+ def load_lppdb(csv_path: Path,
11
+ exclude_ids: set = None) -> pd.DataFrame:
12
+ """
13
+ Load LP-PDBBind flat CSV.
14
+
15
+ Relevant columns:
16
+ pdb_id β€” PDB identifier
17
+ seq β€” protein sequence
18
+ smiles β€” ligand SMILES
19
+ value β€” pAffinity (already normalized from Kd/Ki/IC50)
20
+
21
+ Args:
22
+ csv_path: path to LP_PDBBind.csv
23
+ exclude_ids: set of lowercase PDB IDs to remove before training
24
+ (pass your CASF IDs here to prevent leakage)
25
+
26
+ Drops rows with missing seq, smiles, or label.
27
+ Strips whitespace from sequences and SMILES.
28
+ """
29
+ df = pd.read_csv(csv_path)
30
+
31
+ df = df[['pdb_id', 'seq', 'smiles', 'value']].copy()
32
+ df.columns = ['pdb_id', 'seq', 'smiles', 'label']
33
+
34
+ before = len(df)
35
+ df = df.dropna(subset=['seq', 'smiles', 'label'])
36
+ df['seq'] = df['seq'].str.strip().str.upper()
37
+ df['smiles'] = df['smiles'].str.strip()
38
+ df['pdb_id'] = df['pdb_id'].str.lower().str.strip()
39
+ df = df[df['seq'].str.len() > 0]
40
+ df = df[df['smiles'].str.len() > 0]
41
+
42
+ after_clean = len(df)
43
+
44
+ # Remove CASF complexes to prevent data leakage
45
+ if exclude_ids:
46
+ before_excl = len(df)
47
+ df = df[~df['pdb_id'].isin(exclude_ids)]
48
+ n_removed = before_excl - len(df)
49
+ print(f" Removed {n_removed} CASF complexes from training (leakage prevention)")
50
+
51
+ df = df.reset_index(drop=True)
52
+ print(f"LP-PDBBind: {before} β†’ {after_clean} (after cleaning) "
53
+ f"β†’ {len(df)} (after CASF removal)")
54
+ return df
55
+
56
+
57
+ def load_casf(casf_dir: Path) -> pd.DataFrame:
58
+ """
59
+ Load CASF-2016 CoreSet.
60
+
61
+ Reads CoreSet.dat for pdb_ids and labels.
62
+ Reads protein sequences from <pdb_id>/<pdb_id>_protein.pdb SEQRES records.
63
+ Reads ligand SMILES from <pdb_id>/<pdb_id>_ligand.mol2 via RDKit.
64
+
65
+ Returns DataFrame with same columns as load_lppdb.
66
+ """
67
+ from rdkit import Chem
68
+ from rdkit import RDLogger
69
+ RDLogger.DisableLog('rdApp.*')
70
+
71
+ coreset_dat = casf_dir / "power_scoring" / "CoreSet.dat"
72
+ coreset_dir = casf_dir / "coreset"
73
+
74
+ # Parse CoreSet.dat β€” tab/space separated, first col = pdb_id, last = -logKd
75
+ records = []
76
+ with open(coreset_dat) as f:
77
+ for line in f:
78
+ line = line.strip()
79
+ if not line or line.startswith('#'):
80
+ continue
81
+ parts = line.split()
82
+ pdb_id = parts[0].lower()
83
+ label = float(parts[-3])
84
+ records.append({'pdb_id': pdb_id, 'label': label})
85
+
86
+ dat_df = pd.DataFrame(records)
87
+ print(f"CASF CoreSet.dat: {len(dat_df)} entries")
88
+
89
+ rows = []
90
+ dropped = []
91
+
92
+ for _, row in dat_df.iterrows():
93
+ pid = row['pdb_id']
94
+ label = row['label']
95
+ folder = coreset_dir / pid
96
+
97
+ # Protein sequence from SEQRES
98
+ seq = _parse_seqres(folder / f"{pid}_protein.pdb")
99
+
100
+ # Ligand SMILES β€” try mol2 first, then sdf
101
+ smiles = _parse_ligand_smiles(folder, pid)
102
+
103
+ if seq is None or smiles is None:
104
+ dropped.append((pid, "seq missing" if seq is None else "smiles missing"))
105
+ continue
106
+
107
+ rows.append({'pdb_id': pid, 'seq': seq, 'smiles': smiles, 'label': label})
108
+
109
+ df = pd.DataFrame(rows)
110
+ print(f"CASF parsed: {len(df)} complexes | dropped: {len(dropped)}")
111
+ for pid, reason in dropped:
112
+ print(f" dropped {pid}: {reason}")
113
+
114
+ return df, dropped
115
+
116
+
117
+ def load_casf2013(casf13_dir: Path) -> pd.DataFrame:
118
+ """
119
+ Load CASF-2013 CoreSet. Identical structure to CASF-2016:
120
+ power_scoring/CoreSet.dat β€” labels
121
+ coreset/<pid>/ β€” PDB + mol2/sdf files
122
+ Returns same (df, dropped) as load_casf.
123
+ """
124
+ from rdkit import Chem
125
+ from rdkit import RDLogger
126
+ RDLogger.DisableLog('rdApp.*')
127
+
128
+ coreset_dat = casf13_dir / "power_scoring" / "CoreSet.dat"
129
+ coreset_dir = casf13_dir / "coreset"
130
+
131
+ records = []
132
+ with open(coreset_dat) as f:
133
+ for line in f:
134
+ line = line.strip()
135
+ if not line or line.startswith('#'):
136
+ continue
137
+ parts = line.split()
138
+ pdb_id = parts[0].lower()
139
+ label = float(parts[-3])
140
+ records.append({'pdb_id': pdb_id, 'label': label})
141
+
142
+ dat_df = pd.DataFrame(records)
143
+ print(f"CASF-2013 CoreSet.dat: {len(dat_df)} entries")
144
+
145
+ rows, dropped = [], []
146
+ for _, row in dat_df.iterrows():
147
+ pid = row['pdb_id']
148
+ label = row['label']
149
+ folder = coreset_dir / pid
150
+
151
+ seq = _parse_seqres(folder / f"{pid}_protein.pdb")
152
+ smiles = _parse_ligand_smiles(folder, pid)
153
+
154
+ if seq is None or smiles is None:
155
+ dropped.append((pid, "seq missing" if seq is None else "smiles missing"))
156
+ continue
157
+
158
+ rows.append({'pdb_id': pid, 'seq': seq, 'smiles': smiles, 'label': label})
159
+
160
+ df = pd.DataFrame(rows)
161
+ print(f"CASF-2013 parsed: {len(df)} complexes | dropped: {len(dropped)}")
162
+ for pid, reason in dropped:
163
+ print(f" dropped {pid}: {reason}")
164
+
165
+ return df, dropped
166
+
167
+
168
+ # ── Private helpers ───────────────────────────────────────────────────
169
+
170
+ _AA3TO1 = {
171
+ 'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C',
172
+ 'GLN':'Q','GLU':'E','GLY':'G','HIS':'H','ILE':'I',
173
+ 'LEU':'L','LYS':'K','MET':'M','PHE':'F','PRO':'P',
174
+ 'SER':'S','THR':'T','TRP':'W','TYR':'Y','VAL':'V',
175
+ # common non-standard β†’ closest standard
176
+ 'MSE':'M','SEP':'S','TPO':'T','PTR':'Y','HYP':'P',
177
+ }
178
+
179
+
180
+ def _parse_seqres(pdb_path: Path) -> str | None:
181
+ if not pdb_path.exists():
182
+ return None
183
+
184
+ # Try SEQRES records first (canonical, includes all residues)
185
+ seq_by_chain = {}
186
+ with open(pdb_path) as f:
187
+ for line in f:
188
+ if line.startswith('SEQRES'):
189
+ chain = line[11]
190
+ residues = line[19:].split()
191
+ seq_by_chain.setdefault(chain, []).extend(residues)
192
+
193
+ if seq_by_chain:
194
+ chain = max(seq_by_chain, key=lambda c: len(seq_by_chain[c]))
195
+ residues = seq_by_chain[chain]
196
+ seq = ''.join(_AA3TO1.get(r, 'X') for r in residues)
197
+ seq = seq.replace('X', '')
198
+ if seq:
199
+ return seq
200
+
201
+ # Fallback: parse ATOM records (some PDB files lack SEQRES)
202
+ # Collects unique residues in order of appearance
203
+ atom_by_chain = {}
204
+ with open(pdb_path) as f:
205
+ for line in f:
206
+ if not line.startswith('ATOM'):
207
+ continue
208
+ chain = line[21]
209
+ res_name = line[17:20].strip()
210
+ res_seq = line[22:26].strip() # residue sequence number
211
+ atom_by_chain.setdefault(chain, {})[res_seq] = res_name
212
+
213
+ if not atom_by_chain:
214
+ return None
215
+
216
+ chain = max(atom_by_chain, key=lambda c: len(atom_by_chain[c]))
217
+ residues = [atom_by_chain[chain][k]
218
+ for k in sorted(atom_by_chain[chain],
219
+ key=lambda x: int(x) if x.lstrip('-').isdigit() else 0)]
220
+ seq = ''.join(_AA3TO1.get(r, 'X') for r in residues)
221
+ seq = seq.replace('X', '')
222
+ return seq if seq else None
223
+
224
+
225
+ def _parse_ligand_smiles(folder: Path, pid: str) -> str | None:
226
+ from rdkit import Chem
227
+
228
+ # Try mol2
229
+ mol2_path = folder / f"{pid}_ligand.mol2"
230
+ if mol2_path.exists():
231
+ mol = Chem.MolFromMol2File(str(mol2_path), removeHs=True)
232
+ if mol:
233
+ return Chem.MolToSmiles(mol)
234
+
235
+ # Try sdf
236
+ sdf_path = folder / f"{pid}_ligand.sdf"
237
+ if sdf_path.exists():
238
+ suppl = Chem.SDMolSupplier(str(sdf_path), removeHs=True)
239
+ for mol in suppl:
240
+ if mol:
241
+ return Chem.MolToSmiles(mol)
242
+
243
+ return None