manpreet88 commited on
Commit
1a6c805
·
0 Parent(s):

Create Data_Modalities.py

Browse files
Files changed (1) hide show
  1. Data_Modalities.py +384 -0
Data_Modalities.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem, Descriptors, rdMolDescriptors, rdDepictor
5
+ from rdkit.Chem.Draw import rdMolDraw2D
6
+ from rdkit.Chem import Crippen, Descriptors3D
7
+ from rdkit.Chem.Scaffolds import MurckoScaffold
8
+ from rdkit.Chem import rdFingerprintGenerator
9
+ import networkx as nx
10
+ import requests
11
+ import time
12
+ import json
13
+ from typing import List, Dict, Tuple, Optional
14
+ import warnings
15
+ warnings.filterwarnings('ignore')
16
+ from rdkit import RDLogger
17
+ RDLogger.DisableLog('rdApp.*')
18
+ import os
19
+ import multiprocessing as mp
20
+
21
+ # ----------------------------------------------------------------------
22
+ # -------------- STAR ATOM HANDLING (robust original style) ------------
23
+ # ----------------------------------------------------------------------
24
+
25
+ def process_star_atoms(mol):
26
+ """
27
+ Replace all wildcard atoms (‘*’ or atomicNum == 0) with **Astatine (At)**.
28
+ Astatine has atomic number 85.
29
+ """
30
+ ATOMIC_NUM_AT = 85 # Astatine
31
+
32
+ for atom in mol.GetAtoms():
33
+ if atom.GetAtomicNum() == 0:
34
+ atom.SetAtomicNum(ATOMIC_NUM_AT)
35
+ for atom in mol.GetAtoms():
36
+ if atom.GetSymbol() == "*":
37
+ atom.SetAtomicNum(ATOMIC_NUM_AT)
38
+
39
+ return mol
40
+
41
+ # ----------------------------------------------------------------------
42
+ # -------------- SINGLE POLYMER PROCESSING -----------------------------
43
+ # ----------------------------------------------------------------------
44
+
45
+ def process_single_polymer(args):
46
+ idx, row_dict, extractor = args
47
+ polymer_data = None
48
+ failed_info = None
49
+ try:
50
+ smiles = row_dict['psmiles']
51
+ source = row_dict['source']
52
+
53
+ if pd.isna(smiles) or not isinstance(smiles, str) or len(smiles.strip()) == 0:
54
+ failed_info = {'index': idx, 'smiles': str(smiles), 'error': 'Empty or invalid SMILES'}
55
+ return polymer_data, failed_info
56
+
57
+ canonical_smiles = extractor.validate_and_standardize_smiles(smiles)
58
+ if canonical_smiles is None:
59
+ failed_info = {'index': idx, 'smiles': smiles, 'error': 'Invalid SMILES or contains wildcards'}
60
+ return polymer_data, failed_info
61
+
62
+ polymer_data = {
63
+ 'original_index': idx,
64
+ 'psmiles': canonical_smiles,
65
+ 'source': source,
66
+ 'smiles': canonical_smiles
67
+ }
68
+
69
+ try:
70
+ graph_data = extractor.generate_molecular_graph(canonical_smiles)
71
+ polymer_data['graph'] = graph_data
72
+ except Exception:
73
+ polymer_data['graph'] = {}
74
+
75
+ try:
76
+ geometry_data = extractor.optimize_3d_geometry(canonical_smiles)
77
+ polymer_data['geometry'] = geometry_data
78
+ except Exception:
79
+ polymer_data['geometry'] = {}
80
+
81
+ try:
82
+ fingerprint_data = extractor.calculate_morgan_fingerprints(canonical_smiles)
83
+ polymer_data['fingerprints'] = fingerprint_data
84
+ except Exception:
85
+ polymer_data['fingerprints'] = {}
86
+
87
+ return polymer_data, failed_info
88
+
89
+ except Exception as e:
90
+ failed_info = {'index': idx, 'smiles': row_dict.get('psmiles', ''), 'error': str(e)}
91
+ return polymer_data, failed_info
92
+
93
+ # ----------------------------------------------------------------------
94
+ # -------------- MAIN MULTIMODAL EXTRACTOR -----------------------------
95
+ # ----------------------------------------------------------------------
96
+
97
+ class AdvancedPolymerMultimodalExtractor:
98
+ def __init__(self, csv_file: str):
99
+ self.csv_file = csv_file
100
+
101
+ # ---------- SMILES VALIDATION & STANDARDIZATION ----------
102
+ def validate_and_standardize_smiles(self, smiles: str) -> Optional[str]:
103
+ try:
104
+ if not smiles or pd.isna(smiles):
105
+ return None
106
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
107
+ mol = process_star_atoms(mol) # CONVERT * to Astatine (At)
108
+ Chem.SanitizeMol(mol)
109
+ mol = process_star_atoms(mol) # SECOND PASS (robust)
110
+ canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
111
+ if len(canonical_smiles) == 0:
112
+ return None
113
+ return canonical_smiles
114
+ except Exception:
115
+ return None
116
+
117
+ # ---------- POLYMER VALIDITY CHECKS ----------
118
+ def _has_invalid_polymer_features(self, mol) -> bool:
119
+ try:
120
+ if mol.GetNumAtoms() > 200:
121
+ return True
122
+ for atom in mol.GetAtoms():
123
+ if atom.GetFormalCharge() > 5 or atom.GetFormalCharge() < -5:
124
+ return True
125
+ return False
126
+ except:
127
+ return True
128
+
129
+ def _is_valid_polymer(self, mol) -> bool:
130
+ num_atoms = mol.GetNumAtoms()
131
+ num_rings = rdMolDescriptors.CalcNumRings(mol)
132
+ return num_atoms > 10 or num_rings > 1
133
+
134
+ # ---------- MOLECULAR GRAPH GENERATION ----------
135
+ def generate_molecular_graph(self, smiles: str) -> Dict:
136
+ mol = Chem.MolFromSmiles(smiles)
137
+ mol = process_star_atoms(mol) # Ensure no stars left
138
+ if mol is None:
139
+ return {}
140
+
141
+ mol = Chem.AddHs(mol) # Explicit hydrogens (unchanged)
142
+
143
+ node_features = []
144
+ for atom in mol.GetAtoms():
145
+ node_features.append({
146
+ 'atomic_num': atom.GetAtomicNum(),
147
+ 'degree': atom.GetDegree(),
148
+ 'formal_charge': atom.GetFormalCharge(),
149
+ 'hybridization': int(atom.GetHybridization()),
150
+ 'is_aromatic': atom.GetIsAromatic(),
151
+ 'is_in_ring': atom.IsInRing(),
152
+ 'chirality': int(atom.GetChiralTag()),
153
+ 'mass': atom.GetMass(),
154
+ 'valence': atom.GetTotalValence(),
155
+ 'num_radical_electrons': atom.GetNumRadicalElectrons()
156
+ })
157
+ edge_features = []
158
+ edge_indices = []
159
+ for bond in mol.GetBonds():
160
+ start_atom = bond.GetBeginAtomIdx()
161
+ end_atom = bond.GetEndAtomIdx()
162
+ edge_features.append({
163
+ 'bond_type': int(bond.GetBondType()),
164
+ 'is_aromatic': bond.GetIsAromatic(),
165
+ 'is_in_ring': bond.IsInRing(),
166
+ 'stereo': int(bond.GetStereo()),
167
+ 'is_conjugated': bond.GetIsConjugated()
168
+ })
169
+ edge_indices.extend([[start_atom, end_atom], [end_atom, start_atom]])
170
+ graph_features = {
171
+ 'num_atoms': mol.GetNumAtoms(),
172
+ 'num_bonds': mol.GetNumBonds(),
173
+ 'num_rings': rdMolDescriptors.CalcNumRings(mol),
174
+ 'molecular_weight': Descriptors.MolWt(mol),
175
+ 'logp': Crippen.MolLogP(mol),
176
+ 'tpsa': Descriptors.TPSA(mol),
177
+ 'num_rotatable_bonds': Descriptors.NumRotatableBonds(mol),
178
+ 'num_h_acceptors': rdMolDescriptors.CalcNumHBA(mol),
179
+ 'num_h_donors': rdMolDescriptors.CalcNumHBD(mol)
180
+ }
181
+ adj = Chem.GetAdjacencyMatrix(mol).tolist()
182
+ return {
183
+ 'node_features': node_features,
184
+ 'edge_features': edge_features,
185
+ 'edge_indices': edge_indices,
186
+ 'graph_features': graph_features,
187
+ 'adjacency_matrix': adj
188
+ }
189
+
190
+ # ---------- 3-D GEOMETRY ----------
191
+ def optimize_3d_geometry(self, smiles: str, num_conformers: int = 10) -> Dict:
192
+ mol = Chem.MolFromSmiles(smiles)
193
+ if mol is None or mol.GetNumAtoms() > 200:
194
+ return {}
195
+ mol = process_star_atoms(mol)
196
+ mol_h = Chem.AddHs(mol) # explicit hydrogens
197
+
198
+ # Collect atomic numbers (matches the order in coordinates)
199
+ atomic_numbers = [atom.GetAtomicNum() for atom in mol_h.GetAtoms()]
200
+
201
+ try:
202
+ params = AllChem.ETKDGv3()
203
+ params.randomSeed = 42
204
+ conformer_ids = AllChem.EmbedMultipleConfs(mol_h, numConfs=num_conformers, params=params)
205
+ except Exception:
206
+ conformer_ids = []
207
+
208
+ best_conformer = None
209
+ best_energy = float('inf')
210
+
211
+ for conf_id in conformer_ids:
212
+ try:
213
+ mmff_ok = AllChem.MMFFHasAllMoleculeParams(mol_h)
214
+ if mmff_ok:
215
+ AllChem.MMFFOptimizeMolecule(mol_h, confId=conf_id)
216
+ props = AllChem.MMFFGetMoleculeProperties(mol_h)
217
+ ff = AllChem.MMFFGetMoleculeForceField(mol_h, props, confId=conf_id)
218
+ energy = ff.CalcEnergy() if ff is not None else None
219
+ else:
220
+ AllChem.UFFOptimizeMolecule(mol_h, confId=conf_id)
221
+ ff = AllChem.UFFGetMoleculeForceField(mol_h, confId=conf_id)
222
+ energy = ff.CalcEnergy() if ff is not None else None
223
+
224
+ if energy is not None and energy < best_energy:
225
+ conf = mol_h.GetConformer(conf_id)
226
+ coords = [
227
+ [conf.GetAtomPosition(i).x,
228
+ conf.GetAtomPosition(i).y,
229
+ conf.GetAtomPosition(i).z]
230
+ for i in range(mol_h.GetNumAtoms())
231
+ ]
232
+ descriptors_3d = {}
233
+ try:
234
+ descriptors_3d = {
235
+ 'asphericity': Descriptors3D.Asphericity(mol_h, confId=conf_id),
236
+ 'eccentricity': Descriptors3D.Eccentricity(mol_h, confId=conf_id),
237
+ 'inertial_shape_factor': Descriptors3D.InertialShapeFactor(mol_h, confId=conf_id),
238
+ 'radius_of_gyration': Descriptors3D.RadiusOfGyration(mol_h, confId=conf_id),
239
+ 'spherocity_index': Descriptors3D.SpherocityIndex(mol_h, confId=conf_id)
240
+ }
241
+ except Exception:
242
+ pass
243
+
244
+ best_conformer = {
245
+ 'conformer_id': conf_id,
246
+ 'coordinates': coords,
247
+ 'atomic_numbers': atomic_numbers,
248
+ 'energy': energy,
249
+ 'descriptors_3d': descriptors_3d
250
+ }
251
+ best_energy = energy
252
+
253
+ except Exception:
254
+ continue
255
+
256
+ if best_conformer is not None:
257
+ return {
258
+ 'best_conformer': best_conformer,
259
+ 'num_conformers_generated': len(conformer_ids),
260
+ 'converted_smiles': Chem.MolToSmiles(mol)
261
+ }
262
+
263
+ # Fallback 2-D coordinates
264
+ try:
265
+ rdDepictor.Compute2DCoords(mol)
266
+ coords_2d = mol.GetConformer().GetPositions().tolist()
267
+ # match the atomic_numbers to 2D duplicate (should have same order)
268
+ atomic_numbers_2d = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
269
+ return {
270
+ 'best_conformer': {
271
+ 'conformer_id': -1,
272
+ 'coordinates': coords_2d,
273
+ 'atomic_numbers': atomic_numbers_2d,
274
+ 'energy': None,
275
+ 'descriptors_3d': {},
276
+ },
277
+ 'num_conformers_generated': 0,
278
+ 'converted_smiles': Chem.MolToSmiles(mol)
279
+ }
280
+ except Exception:
281
+ return {}
282
+
283
+ # ---------- MORGAN FINGERPRINTS ----------
284
+ def calculate_morgan_fingerprints(self, smiles: str, radius: int = 3, n_bits: int = 2048) -> Dict:
285
+ mol = Chem.MolFromSmiles(smiles)
286
+ mol = process_star_atoms(mol)
287
+ if mol is None:
288
+ return {}
289
+ generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits)
290
+ fp_bitvect = generator.GetFingerprint(mol)
291
+ fingerprints = {
292
+ f'morgan_r{radius}_bits': list(fp_bitvect.ToBitString()),
293
+ f'morgan_r{radius}_counts': dict(AllChem.GetMorganFingerprint(mol, radius).GetNonzeroElements()),
294
+ }
295
+ # Extended multi-radius support
296
+ for r in range(1, radius):
297
+ gen = rdFingerprintGenerator.GetMorganGenerator(radius=r, fpSize=n_bits)
298
+ bitvect = gen.GetFingerprint(mol)
299
+ fingerprints[f'morgan_r{r}_bits'] = list(bitvect.ToBitString())
300
+ counts = AllChem.GetMorganFingerprint(mol, r).GetNonzeroElements()
301
+ fingerprints[f'morgan_r{r}_counts'] = dict(counts)
302
+ return fingerprints
303
+
304
+ # ---------- PARALLEL PROCESSING ----------
305
+ def process_all_polymers_parallel(self, chunk_size: int = 100, num_workers: int = 40):
306
+ chunk_iterator = pd.read_csv(self.csv_file, chunksize=chunk_size, engine='python')
307
+
308
+ for chunk in chunk_iterator:
309
+ for col in ['graph', 'geometry', 'fingerprints']:
310
+ if col not in chunk.columns:
311
+ chunk[col] = None
312
+ chunk[col] = chunk[col].astype(object)
313
+
314
+ chunk_to_process = chunk[
315
+ chunk[['graph', 'geometry', 'fingerprints']].isnull().any(axis=1)
316
+ ].copy()
317
+ if len(chunk_to_process) == 0:
318
+ self.save_chunk_to_csv(chunk)
319
+ continue
320
+
321
+ rows = list(chunk_to_process.iterrows())
322
+ argslist = [(i, row.to_dict(), self) for i, row in rows]
323
+ with mp.Pool(num_workers) as pool:
324
+ results = pool.map(process_single_polymer, argslist)
325
+
326
+ failed_list = []
327
+ for n, (output, fail) in enumerate(results):
328
+ idx = rows[n][0]
329
+ if output:
330
+ chunk.at[idx, 'graph'] = json.dumps(output['graph'])
331
+ chunk.at[idx, 'geometry'] = json.dumps(output['geometry'])
332
+ chunk.at[idx, 'fingerprints'] = json.dumps(output['fingerprints'])
333
+ if fail:
334
+ failed_list.append(fail)
335
+
336
+ self.save_chunk_to_csv(chunk)
337
+ self.save_failed_to_json(failed_list)
338
+
339
+ return "Processing Done"
340
+
341
+ # ---------- SAVE HELPERS ----------
342
+ def save_chunk_to_csv(self, chunk):
343
+ out_csv = self.csv_file.replace('.csv', '_processed.csv')
344
+ if not os.path.exists(out_csv):
345
+ chunk.to_csv(out_csv, index=False, mode='w')
346
+ else:
347
+ chunk.to_csv(out_csv, index=False, mode='a', header=False)
348
+
349
+ def save_failed_to_json(self, failed_list):
350
+ if not failed_list:
351
+ return
352
+ fail_json = self.csv_file.replace('.csv', '_failures.jsonl')
353
+ with open(fail_json, 'a') as f:
354
+ for fail in failed_list:
355
+ json.dump(fail, f)
356
+ f.write('\n')
357
+
358
+ # ---------- OPTIONAL RESULT SAVER (stub) ----------
359
+ def save_results(self, output_file: str = 'polymer_multimodal_data.json'):
360
+ pass
361
+
362
+ # ---------- OPTIONAL SUMMARY (stub) ----------
363
+ def generate_summary_statistics(self) -> Dict:
364
+ return {}
365
+
366
+ # ----------------------------------------------------------------------
367
+ # -------------- SCRIPT ENTRY POINT ------------------------------------
368
+ # ----------------------------------------------------------------------
369
+
370
+ def main():
371
+ csv_file = "Polymer_Foundational_Model/polymer_structures_unified.csv"
372
+ extractor = AdvancedPolymerMultimodalExtractor(csv_file)
373
+ try:
374
+ extractor.process_all_polymers_parallel(chunk_size=1000, num_workers=24)
375
+ except KeyboardInterrupt:
376
+ return extractor, None
377
+ except Exception as e:
378
+ print(f"CRASH! Error: {e}")
379
+ return extractor, None
380
+ print("\n=== Processing Complete ===")
381
+ return extractor, None
382
+
383
+ if __name__ == "__main__":
384
+ extractor, results = main()