manpreet88 commited on
Commit
4b734a1
·
1 Parent(s): 9ae5583

Update Data_Modalities.py

Browse files
Files changed (1) hide show
  1. Data_Modalities.py +293 -185
Data_Modalities.py CHANGED
@@ -2,202 +2,253 @@ import pandas as pd
2
  import numpy as np
3
  from rdkit import Chem
4
  from rdkit.Chem import AllChem, Descriptors, rdMolDescriptors, rdDepictor
5
- from rdkit.Chem.Draw import rdMolDraw2D
6
  from rdkit.Chem import Crippen, Descriptors3D
7
- from rdkit.Chem.Scaffolds import MurckoScaffold
8
  from rdkit.Chem import rdFingerprintGenerator
9
- import networkx as nx
10
- import requests
11
- from pathlib import Path
12
- import argparse
13
- import time
14
- import json
15
- from typing import List, Dict, Tuple, Optional
16
  import warnings
17
- warnings.filterwarnings('ignore')
18
  from rdkit import RDLogger
19
- RDLogger.DisableLog('rdApp.*')
 
20
  import os
 
 
21
  import multiprocessing as mp
 
 
22
 
23
  # ----------------------------------------------------------------------
24
- # -------------- STAR ATOM HANDLING (robust original style) ------------
25
  # ----------------------------------------------------------------------
26
 
27
- def process_star_atoms(mol):
 
 
 
 
 
 
 
 
 
 
28
  """
29
- Replace all wildcard atoms (* or atomicNum == 0) with **Astatine (At)**.
30
- Astatine has atomic number 85.
 
 
 
 
 
31
  """
32
- ATOMIC_NUM_AT = 85 # Astatine
 
33
 
34
  for atom in mol.GetAtoms():
35
- if atom.GetAtomicNum() == 0:
36
  atom.SetAtomicNum(ATOMIC_NUM_AT)
37
- for atom in mol.GetAtoms():
38
- if atom.GetSymbol() == "*":
39
- atom.SetAtomicNum(ATOMIC_NUM_AT)
40
-
41
  return mol
42
 
 
43
  # ----------------------------------------------------------------------
44
- # -------------- SINGLE POLYMER PROCESSING -----------------------------
45
  # ----------------------------------------------------------------------
46
 
47
- def process_single_polymer(args):
 
 
 
 
 
 
 
48
  idx, row_dict, extractor = args
49
  polymer_data = None
50
  failed_info = None
 
51
  try:
52
- smiles = row_dict['psmiles']
53
- source = row_dict['source']
54
 
55
  if pd.isna(smiles) or not isinstance(smiles, str) or len(smiles.strip()) == 0:
56
- failed_info = {'index': idx, 'smiles': str(smiles), 'error': 'Empty or invalid SMILES'}
57
  return polymer_data, failed_info
58
 
59
  canonical_smiles = extractor.validate_and_standardize_smiles(smiles)
60
  if canonical_smiles is None:
61
- failed_info = {'index': idx, 'smiles': smiles, 'error': 'Invalid SMILES or contains wildcards'}
62
  return polymer_data, failed_info
63
 
64
  polymer_data = {
65
- 'original_index': idx,
66
- 'psmiles': canonical_smiles,
67
- 'source': source,
68
- 'smiles': canonical_smiles
69
  }
70
 
 
71
  try:
72
- graph_data = extractor.generate_molecular_graph(canonical_smiles)
73
- polymer_data['graph'] = graph_data
74
  except Exception:
75
- polymer_data['graph'] = {}
76
 
 
77
  try:
78
- geometry_data = extractor.optimize_3d_geometry(canonical_smiles)
79
- polymer_data['geometry'] = geometry_data
80
  except Exception:
81
- polymer_data['geometry'] = {}
82
 
 
83
  try:
84
- fingerprint_data = extractor.calculate_morgan_fingerprints(canonical_smiles)
85
- polymer_data['fingerprints'] = fingerprint_data
86
  except Exception:
87
- polymer_data['fingerprints'] = {}
88
 
89
  return polymer_data, failed_info
90
 
91
  except Exception as e:
92
- failed_info = {'index': idx, 'smiles': row_dict.get('psmiles', ''), 'error': str(e)}
93
  return polymer_data, failed_info
94
 
 
95
  # ----------------------------------------------------------------------
96
- # -------------- MAIN MULTIMODAL EXTRACTOR -----------------------------
97
  # ----------------------------------------------------------------------
98
 
99
  class AdvancedPolymerMultimodalExtractor:
 
 
 
 
 
 
 
 
 
 
 
100
  def __init__(self, csv_file: str):
101
- self.csv_file = csv_file
102
 
103
- # ---------- SMILES VALIDATION & STANDARDIZATION ----------
 
 
104
  def validate_and_standardize_smiles(self, smiles: str) -> Optional[str]:
 
 
 
 
105
  try:
106
  if not smiles or pd.isna(smiles):
107
  return None
 
108
  mol = Chem.MolFromSmiles(smiles, sanitize=False)
109
- mol = process_star_atoms(mol) # CONVERT * to Astatine (At)
 
 
 
110
  Chem.SanitizeMol(mol)
111
- mol = process_star_atoms(mol) # SECOND PASS (robust)
112
  canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
113
- if len(canonical_smiles) == 0:
 
114
  return None
115
  return canonical_smiles
 
116
  except Exception:
117
  return None
118
 
119
- # ---------- POLYMER VALIDITY CHECKS ----------
120
- def _has_invalid_polymer_features(self, mol) -> bool:
121
- try:
122
- if mol.GetNumAtoms() > 200:
123
- return True
124
- for atom in mol.GetAtoms():
125
- if atom.GetFormalCharge() > 5 or atom.GetFormalCharge() < -5:
126
- return True
127
- return False
128
- except:
129
- return True
130
-
131
- def _is_valid_polymer(self, mol) -> bool:
132
- num_atoms = mol.GetNumAtoms()
133
- num_rings = rdMolDescriptors.CalcNumRings(mol)
134
- return num_atoms > 10 or num_rings > 1
135
-
136
- # ---------- MOLECULAR GRAPH GENERATION ----------
137
  def generate_molecular_graph(self, smiles: str) -> Dict:
 
 
 
 
138
  mol = Chem.MolFromSmiles(smiles)
139
- mol = process_star_atoms(mol) # Ensure no stars left
140
  if mol is None:
141
  return {}
142
 
143
- mol = Chem.AddHs(mol) # Explicit hydrogens (unchanged)
 
144
 
145
  node_features = []
146
  for atom in mol.GetAtoms():
147
- node_features.append({
148
- 'atomic_num': atom.GetAtomicNum(),
149
- 'degree': atom.GetDegree(),
150
- 'formal_charge': atom.GetFormalCharge(),
151
- 'hybridization': int(atom.GetHybridization()),
152
- 'is_aromatic': atom.GetIsAromatic(),
153
- 'is_in_ring': atom.IsInRing(),
154
- 'chirality': int(atom.GetChiralTag()),
155
- 'mass': atom.GetMass(),
156
- 'valence': atom.GetTotalValence(),
157
- 'num_radical_electrons': atom.GetNumRadicalElectrons()
158
- })
 
 
 
159
  edge_features = []
160
  edge_indices = []
161
  for bond in mol.GetBonds():
162
- start_atom = bond.GetBeginAtomIdx()
163
- end_atom = bond.GetEndAtomIdx()
164
- edge_features.append({
165
- 'bond_type': int(bond.GetBondType()),
166
- 'is_aromatic': bond.GetIsAromatic(),
167
- 'is_in_ring': bond.IsInRing(),
168
- 'stereo': int(bond.GetStereo()),
169
- 'is_conjugated': bond.GetIsConjugated()
170
- })
171
- edge_indices.extend([[start_atom, end_atom], [end_atom, start_atom]])
 
 
 
 
 
 
172
  graph_features = {
173
- 'num_atoms': mol.GetNumAtoms(),
174
- 'num_bonds': mol.GetNumBonds(),
175
- 'num_rings': rdMolDescriptors.CalcNumRings(mol),
176
- 'molecular_weight': Descriptors.MolWt(mol),
177
- 'logp': Crippen.MolLogP(mol),
178
- 'tpsa': Descriptors.TPSA(mol),
179
- 'num_rotatable_bonds': Descriptors.NumRotatableBonds(mol),
180
- 'num_h_acceptors': rdMolDescriptors.CalcNumHBA(mol),
181
- 'num_h_donors': rdMolDescriptors.CalcNumHBD(mol)
182
  }
 
183
  adj = Chem.GetAdjacencyMatrix(mol).tolist()
 
184
  return {
185
- 'node_features': node_features,
186
- 'edge_features': edge_features,
187
- 'edge_indices': edge_indices,
188
- 'graph_features': graph_features,
189
- 'adjacency_matrix': adj
190
  }
191
 
192
- # ---------- 3-D GEOMETRY ----------
 
 
193
  def optimize_3d_geometry(self, smiles: str, num_conformers: int = 10) -> Dict:
 
 
 
 
 
 
194
  mol = Chem.MolFromSmiles(smiles)
195
  if mol is None or mol.GetNumAtoms() > 200:
196
  return {}
 
197
  mol = process_star_atoms(mol)
198
- mol_h = Chem.AddHs(mol) # explicit hydrogens
199
 
200
- # Collect atomic numbers (matches the order in coordinates)
201
  atomic_numbers = [atom.GetAtomicNum() for atom in mol_h.GetAtoms()]
202
 
203
  try:
@@ -208,120 +259,142 @@ class AdvancedPolymerMultimodalExtractor:
208
  conformer_ids = []
209
 
210
  best_conformer = None
211
- best_energy = float('inf')
212
 
213
  for conf_id in conformer_ids:
214
  try:
215
  mmff_ok = AllChem.MMFFHasAllMoleculeParams(mol_h)
 
216
  if mmff_ok:
217
  AllChem.MMFFOptimizeMolecule(mol_h, confId=conf_id)
218
  props = AllChem.MMFFGetMoleculeProperties(mol_h)
219
  ff = AllChem.MMFFGetMoleculeForceField(mol_h, props, confId=conf_id)
220
- energy = ff.CalcEnergy() if ff is not None else None
221
  else:
222
  AllChem.UFFOptimizeMolecule(mol_h, confId=conf_id)
223
  ff = AllChem.UFFGetMoleculeForceField(mol_h, confId=conf_id)
224
- energy = ff.CalcEnergy() if ff is not None else None
225
-
226
- if energy is not None and energy < best_energy:
227
- conf = mol_h.GetConformer(conf_id)
228
- coords = [
229
- [conf.GetAtomPosition(i).x,
230
- conf.GetAtomPosition(i).y,
231
- conf.GetAtomPosition(i).z]
232
- for i in range(mol_h.GetNumAtoms())
233
- ]
234
- descriptors_3d = {}
235
- try:
236
- descriptors_3d = {
237
- 'asphericity': Descriptors3D.Asphericity(mol_h, confId=conf_id),
238
- 'eccentricity': Descriptors3D.Eccentricity(mol_h, confId=conf_id),
239
- 'inertial_shape_factor': Descriptors3D.InertialShapeFactor(mol_h, confId=conf_id),
240
- 'radius_of_gyration': Descriptors3D.RadiusOfGyration(mol_h, confId=conf_id),
241
- 'spherocity_index': Descriptors3D.SpherocityIndex(mol_h, confId=conf_id)
242
- }
243
- except Exception:
244
- pass
245
-
246
- best_conformer = {
247
- 'conformer_id': conf_id,
248
- 'coordinates': coords,
249
- 'atomic_numbers': atomic_numbers,
250
- 'energy': energy,
251
- 'descriptors_3d': descriptors_3d
252
  }
253
- best_energy = energy
 
 
 
 
 
 
 
 
 
 
254
 
255
  except Exception:
256
  continue
257
 
258
  if best_conformer is not None:
259
  return {
260
- 'best_conformer': best_conformer,
261
- 'num_conformers_generated': len(conformer_ids),
262
- 'converted_smiles': Chem.MolToSmiles(mol)
263
  }
264
 
265
- # Fallback 2-D coordinates
266
  try:
267
  rdDepictor.Compute2DCoords(mol)
268
  coords_2d = mol.GetConformer().GetPositions().tolist()
269
- # match the atomic_numbers to 2D duplicate (should have same order)
270
  atomic_numbers_2d = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
271
  return {
272
- 'best_conformer': {
273
- 'conformer_id': -1,
274
- 'coordinates': coords_2d,
275
- 'atomic_numbers': atomic_numbers_2d,
276
- 'energy': None,
277
- 'descriptors_3d': {},
278
  },
279
- 'num_conformers_generated': 0,
280
- 'converted_smiles': Chem.MolToSmiles(mol)
281
  }
282
  except Exception:
283
  return {}
284
 
285
- # ---------- MORGAN FINGERPRINTS ----------
 
 
286
  def calculate_morgan_fingerprints(self, smiles: str, radius: int = 3, n_bits: int = 2048) -> Dict:
 
 
 
 
 
 
287
  mol = Chem.MolFromSmiles(smiles)
288
  mol = process_star_atoms(mol)
289
  if mol is None:
290
  return {}
 
 
 
 
291
  generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits)
292
  fp_bitvect = generator.GetFingerprint(mol)
293
- fingerprints = {
294
- f'morgan_r{radius}_bits': list(fp_bitvect.ToBitString()),
295
- f'morgan_r{radius}_counts': dict(AllChem.GetMorganFingerprint(mol, radius).GetNonzeroElements()),
296
- }
297
- # Extended multi-radius support
298
  for r in range(1, radius):
299
  gen = rdFingerprintGenerator.GetMorganGenerator(radius=r, fpSize=n_bits)
300
  bitvect = gen.GetFingerprint(mol)
301
- fingerprints[f'morgan_r{r}_bits'] = list(bitvect.ToBitString())
302
- counts = AllChem.GetMorganFingerprint(mol, r).GetNonzeroElements()
303
- fingerprints[f'morgan_r{r}_counts'] = dict(counts)
304
  return fingerprints
305
 
306
- # ---------- PARALLEL PROCESSING ----------
307
- def process_all_polymers_parallel(self, chunk_size: int = 100, num_workers: int = 40):
308
- chunk_iterator = pd.read_csv(self.csv_file, chunksize=chunk_size, engine='python')
 
 
 
 
 
 
 
 
309
 
310
  for chunk in chunk_iterator:
311
- for col in ['graph', 'geometry', 'fingerprints']:
 
312
  if col not in chunk.columns:
313
  chunk[col] = None
314
  chunk[col] = chunk[col].astype(object)
315
 
316
- chunk_to_process = chunk[
317
- chunk[['graph', 'geometry', 'fingerprints']].isnull().any(axis=1)
318
- ].copy()
 
319
  if len(chunk_to_process) == 0:
320
  self.save_chunk_to_csv(chunk)
321
  continue
322
 
323
  rows = list(chunk_to_process.iterrows())
324
  argslist = [(i, row.to_dict(), self) for i, row in rows]
 
325
  with mp.Pool(num_workers) as pool:
326
  results = pool.map(process_single_polymer, argslist)
327
 
@@ -329,9 +402,9 @@ class AdvancedPolymerMultimodalExtractor:
329
  for n, (output, fail) in enumerate(results):
330
  idx = rows[n][0]
331
  if output:
332
- chunk.at[idx, 'graph'] = json.dumps(output['graph'])
333
- chunk.at[idx, 'geometry'] = json.dumps(output['geometry'])
334
- chunk.at[idx, 'fingerprints'] = json.dumps(output['fingerprints'])
335
  if fail:
336
  failed_list.append(fail)
337
 
@@ -340,47 +413,82 @@ class AdvancedPolymerMultimodalExtractor:
340
 
341
  return "Processing Done"
342
 
343
- # ---------- SAVE HELPERS ----------
344
- def save_chunk_to_csv(self, chunk):
345
- out_csv = self.csv_file.replace('.csv', '_processed.csv')
 
 
 
 
 
346
  if not os.path.exists(out_csv):
347
- chunk.to_csv(out_csv, index=False, mode='w')
348
  else:
349
- chunk.to_csv(out_csv, index=False, mode='a', header=False)
350
 
351
- def save_failed_to_json(self, failed_list):
 
 
 
352
  if not failed_list:
353
  return
354
- fail_json = self.csv_file.replace('.csv', '_failures.jsonl')
355
- with open(fail_json, 'a') as f:
356
  for fail in failed_list:
357
  json.dump(fail, f)
358
- f.write('\n')
359
 
360
- # ---------- OPTIONAL RESULT SAVER (stub) ----------
361
- def save_results(self, output_file: str = 'polymer_multimodal_data.json'):
362
  pass
363
 
364
- # ---------- OPTIONAL SUMMARY (stub) ----------
365
  def generate_summary_statistics(self) -> Dict:
366
  return {}
367
 
 
368
  # ----------------------------------------------------------------------
369
- # -------------- SCRIPT ENTRY POINT ------------------------------------
370
  # ----------------------------------------------------------------------
371
 
372
- def main():
373
- csv_file = "Polymer_Foundational_Model/polymer_structures_unified.csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  extractor = AdvancedPolymerMultimodalExtractor(csv_file)
375
  try:
376
- extractor.process_all_polymers_parallel(chunk_size=1000, num_workers=24)
377
  except KeyboardInterrupt:
378
  return extractor, None
379
  except Exception as e:
380
  print(f"CRASH! Error: {e}")
381
  return extractor, None
 
382
  print("\n=== Processing Complete ===")
383
  return extractor, None
384
 
 
385
  if __name__ == "__main__":
386
  extractor, results = main()
 
2
  import numpy as np
3
  from rdkit import Chem
4
  from rdkit.Chem import AllChem, Descriptors, rdMolDescriptors, rdDepictor
 
5
  from rdkit.Chem import Crippen, Descriptors3D
 
6
  from rdkit.Chem import rdFingerprintGenerator
 
 
 
 
 
 
 
7
  import warnings
8
+ warnings.filterwarnings("ignore")
9
  from rdkit import RDLogger
10
+ RDLogger.DisableLog("rdApp.*")
11
+
12
  import os
13
+ import json
14
+ import argparse
15
  import multiprocessing as mp
16
+ from pathlib import Path
17
+ from typing import Dict, Optional, Tuple
18
 
19
  # ----------------------------------------------------------------------
20
+ # Logging / RDKit hygiene
21
  # ----------------------------------------------------------------------
22
 
23
+ # RDKit can be chatty; we silence logs above via RDLogger.DisableLog.
24
+ # We also suppress Python warnings (set above).
25
+
26
+ # ----------------------------------------------------------------------
27
+ # Wildcard ("*") handling utilities
28
+ # ----------------------------------------------------------------------
29
+
30
+ ATOMIC_NUM_AT = 85 # Astatine (At) used as a placeholder for wildcard atoms
31
+
32
+
33
+ def process_star_atoms(mol: Chem.Mol) -> Chem.Mol:
34
  """
35
+ Replace all wildcard atoms ("*" or atomicNum == 0) with Astatine (At, Z=85).
36
+
37
+ Rationale:
38
+ - Polymer SMILES often contain '*' to indicate attachment points.
39
+ - Many RDKit operations fail or sanitize differently with atomicNum == 0.
40
+ - Mapping '*' -> At allows sanitization and downstream featurization while
41
+ keeping a consistent placeholder identity.
42
  """
43
+ if mol is None:
44
+ return mol
45
 
46
  for atom in mol.GetAtoms():
47
+ if atom.GetAtomicNum() == 0 or atom.GetSymbol() == "*":
48
  atom.SetAtomicNum(ATOMIC_NUM_AT)
 
 
 
 
49
  return mol
50
 
51
+
52
  # ----------------------------------------------------------------------
53
+ # Per-polymer worker function
54
  # ----------------------------------------------------------------------
55
 
56
+ def process_single_polymer(args) -> Tuple[Optional[Dict], Optional[Dict]]:
57
+ """
58
+ Worker that processes one row (one polymer) and returns:
59
+ (polymer_data, failed_info)
60
+
61
+ polymer_data is a dict containing serialized multimodal outputs.
62
+ failed_info is a dict with index/smiles/error if anything fails.
63
+ """
64
  idx, row_dict, extractor = args
65
  polymer_data = None
66
  failed_info = None
67
+
68
  try:
69
+ smiles = row_dict.get("psmiles", None)
70
+ source = row_dict.get("source", None)
71
 
72
  if pd.isna(smiles) or not isinstance(smiles, str) or len(smiles.strip()) == 0:
73
+ failed_info = {"index": idx, "smiles": str(smiles), "error": "Empty or invalid SMILES"}
74
  return polymer_data, failed_info
75
 
76
  canonical_smiles = extractor.validate_and_standardize_smiles(smiles)
77
  if canonical_smiles is None:
78
+ failed_info = {"index": idx, "smiles": smiles, "error": "Invalid SMILES or cannot be standardized"}
79
  return polymer_data, failed_info
80
 
81
  polymer_data = {
82
+ "original_index": idx,
83
+ "psmiles": canonical_smiles,
84
+ "source": source,
85
+ "smiles": canonical_smiles,
86
  }
87
 
88
+ # Graph
89
  try:
90
+ polymer_data["graph"] = extractor.generate_molecular_graph(canonical_smiles)
 
91
  except Exception:
92
+ polymer_data["graph"] = {}
93
 
94
+ # Geometry
95
  try:
96
+ polymer_data["geometry"] = extractor.optimize_3d_geometry(canonical_smiles)
 
97
  except Exception:
98
+ polymer_data["geometry"] = {}
99
 
100
+ # Fingerprints
101
  try:
102
+ polymer_data["fingerprints"] = extractor.calculate_morgan_fingerprints(canonical_smiles)
 
103
  except Exception:
104
+ polymer_data["fingerprints"] = {}
105
 
106
  return polymer_data, failed_info
107
 
108
  except Exception as e:
109
+ failed_info = {"index": idx, "smiles": row_dict.get("psmiles", ""), "error": str(e)}
110
  return polymer_data, failed_info
111
 
112
+
113
  # ----------------------------------------------------------------------
114
+ # Main extractor class
115
  # ----------------------------------------------------------------------
116
 
117
  class AdvancedPolymerMultimodalExtractor:
118
+ """
119
+ Multimodal extractor that reads a CSV of polymers and adds:
120
+ - graph: node/edge features + adjacency + summary graph features
121
+ - geometry: best 3D conformer (or fallback 2D coords) + 3D descriptors
122
+ - fingerprints: Morgan fingerprints (bitstrings + counts) for multiple radii
123
+
124
+ Output:
125
+ - <input>_processed.csv (appended chunk-by-chunk)
126
+ - <input>_failures.jsonl (one JSON per failure)
127
+ """
128
+
129
  def __init__(self, csv_file: str):
130
+ self.csv_file = str(csv_file)
131
 
132
+ # ------------------------------
133
+ # SMILES validation/standardization
134
+ # ------------------------------
135
  def validate_and_standardize_smiles(self, smiles: str) -> Optional[str]:
136
+ """
137
+ Parse, sanitize, replace '*' with At, and return canonical SMILES.
138
+ Returns None if parsing/sanitization fails.
139
+ """
140
  try:
141
  if not smiles or pd.isna(smiles):
142
  return None
143
+
144
  mol = Chem.MolFromSmiles(smiles, sanitize=False)
145
+ if mol is None:
146
+ return None
147
+
148
+ mol = process_star_atoms(mol) # pass 1
149
  Chem.SanitizeMol(mol)
150
+ mol = process_star_atoms(mol) # pass 2 (robust)
151
  canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
152
+
153
+ if not canonical_smiles:
154
  return None
155
  return canonical_smiles
156
+
157
  except Exception:
158
  return None
159
 
160
+ # ------------------------------
161
+ # Molecular graph (RDKit -> JSONable dict)
162
+ # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def generate_molecular_graph(self, smiles: str) -> Dict:
164
+ """
165
+ Build a molecular graph representation with atom/bond features and
166
+ global graph descriptors.
167
+ """
168
  mol = Chem.MolFromSmiles(smiles)
169
+ mol = process_star_atoms(mol)
170
  if mol is None:
171
  return {}
172
 
173
+ # Explicit hydrogens for atom-level features
174
+ mol = Chem.AddHs(mol)
175
 
176
  node_features = []
177
  for atom in mol.GetAtoms():
178
+ node_features.append(
179
+ {
180
+ "atomic_num": atom.GetAtomicNum(),
181
+ "degree": atom.GetDegree(),
182
+ "formal_charge": atom.GetFormalCharge(),
183
+ "hybridization": int(atom.GetHybridization()),
184
+ "is_aromatic": atom.GetIsAromatic(),
185
+ "is_in_ring": atom.IsInRing(),
186
+ "chirality": int(atom.GetChiralTag()),
187
+ "mass": atom.GetMass(),
188
+ "valence": atom.GetTotalValence(),
189
+ "num_radical_electrons": atom.GetNumRadicalElectrons(),
190
+ }
191
+ )
192
+
193
  edge_features = []
194
  edge_indices = []
195
  for bond in mol.GetBonds():
196
+ i = bond.GetBeginAtomIdx()
197
+ j = bond.GetEndAtomIdx()
198
+
199
+ edge_features.append(
200
+ {
201
+ "bond_type": int(bond.GetBondType()),
202
+ "is_aromatic": bond.GetIsAromatic(),
203
+ "is_in_ring": bond.IsInRing(),
204
+ "stereo": int(bond.GetStereo()),
205
+ "is_conjugated": bond.GetIsConjugated(),
206
+ }
207
+ )
208
+
209
+ # Undirected -> store both directions for GNN-style edge lists
210
+ edge_indices.extend([[i, j], [j, i]])
211
+
212
  graph_features = {
213
+ "num_atoms": mol.GetNumAtoms(),
214
+ "num_bonds": mol.GetNumBonds(),
215
+ "num_rings": rdMolDescriptors.CalcNumRings(mol),
216
+ "molecular_weight": Descriptors.MolWt(mol),
217
+ "logp": Crippen.MolLogP(mol),
218
+ "tpsa": Descriptors.TPSA(mol),
219
+ "num_rotatable_bonds": Descriptors.NumRotatableBonds(mol),
220
+ "num_h_acceptors": rdMolDescriptors.CalcNumHBA(mol),
221
+ "num_h_donors": rdMolDescriptors.CalcNumHBD(mol),
222
  }
223
+
224
  adj = Chem.GetAdjacencyMatrix(mol).tolist()
225
+
226
  return {
227
+ "node_features": node_features,
228
+ "edge_features": edge_features,
229
+ "edge_indices": edge_indices,
230
+ "graph_features": graph_features,
231
+ "adjacency_matrix": adj,
232
  }
233
 
234
+ # ------------------------------
235
+ # 3D geometry (ETKDG + MMFF/UFF) with fallback 2D coords
236
+ # ------------------------------
237
  def optimize_3d_geometry(self, smiles: str, num_conformers: int = 10) -> Dict:
238
+ """
239
+ Generate multiple conformers, optimize (MMFF if available else UFF),
240
+ and return the lowest-energy conformer coordinates + 3D descriptors.
241
+
242
+ If no conformer is generated/optimized, fall back to 2D coordinates.
243
+ """
244
  mol = Chem.MolFromSmiles(smiles)
245
  if mol is None or mol.GetNumAtoms() > 200:
246
  return {}
247
+
248
  mol = process_star_atoms(mol)
249
+ mol_h = Chem.AddHs(mol)
250
 
251
+ # Atomic numbers aligned to coordinate ordering (mol_h atoms)
252
  atomic_numbers = [atom.GetAtomicNum() for atom in mol_h.GetAtoms()]
253
 
254
  try:
 
259
  conformer_ids = []
260
 
261
  best_conformer = None
262
+ best_energy = float("inf")
263
 
264
  for conf_id in conformer_ids:
265
  try:
266
  mmff_ok = AllChem.MMFFHasAllMoleculeParams(mol_h)
267
+
268
  if mmff_ok:
269
  AllChem.MMFFOptimizeMolecule(mol_h, confId=conf_id)
270
  props = AllChem.MMFFGetMoleculeProperties(mol_h)
271
  ff = AllChem.MMFFGetMoleculeForceField(mol_h, props, confId=conf_id)
 
272
  else:
273
  AllChem.UFFOptimizeMolecule(mol_h, confId=conf_id)
274
  ff = AllChem.UFFGetMoleculeForceField(mol_h, confId=conf_id)
275
+
276
+ energy = ff.CalcEnergy() if ff is not None else None
277
+ if energy is None or energy >= best_energy:
278
+ continue
279
+
280
+ conf = mol_h.GetConformer(conf_id)
281
+ coords = [
282
+ [conf.GetAtomPosition(i).x, conf.GetAtomPosition(i).y, conf.GetAtomPosition(i).z]
283
+ for i in range(mol_h.GetNumAtoms())
284
+ ]
285
+
286
+ descriptors_3d = {}
287
+ try:
288
+ descriptors_3d = {
289
+ "asphericity": Descriptors3D.Asphericity(mol_h, confId=conf_id),
290
+ "eccentricity": Descriptors3D.Eccentricity(mol_h, confId=conf_id),
291
+ "inertial_shape_factor": Descriptors3D.InertialShapeFactor(mol_h, confId=conf_id),
292
+ "radius_of_gyration": Descriptors3D.RadiusOfGyration(mol_h, confId=conf_id),
293
+ "spherocity_index": Descriptors3D.SpherocityIndex(mol_h, confId=conf_id),
 
 
 
 
 
 
 
 
 
294
  }
295
+ except Exception:
296
+ pass
297
+
298
+ best_conformer = {
299
+ "conformer_id": int(conf_id),
300
+ "coordinates": coords,
301
+ "atomic_numbers": atomic_numbers,
302
+ "energy": float(energy),
303
+ "descriptors_3d": descriptors_3d,
304
+ }
305
+ best_energy = energy
306
 
307
  except Exception:
308
  continue
309
 
310
  if best_conformer is not None:
311
  return {
312
+ "best_conformer": best_conformer,
313
+ "num_conformers_generated": int(len(conformer_ids)),
314
+ "converted_smiles": Chem.MolToSmiles(mol),
315
  }
316
 
317
+ # Fallback: 2D coordinates
318
  try:
319
  rdDepictor.Compute2DCoords(mol)
320
  coords_2d = mol.GetConformer().GetPositions().tolist()
 
321
  atomic_numbers_2d = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
322
  return {
323
+ "best_conformer": {
324
+ "conformer_id": -1,
325
+ "coordinates": coords_2d,
326
+ "atomic_numbers": atomic_numbers_2d,
327
+ "energy": None,
328
+ "descriptors_3d": {},
329
  },
330
+ "num_conformers_generated": 0,
331
+ "converted_smiles": Chem.MolToSmiles(mol),
332
  }
333
  except Exception:
334
  return {}
335
 
336
+ # ------------------------------
337
+ # Morgan fingerprints (multi-radius)
338
+ # ------------------------------
339
  def calculate_morgan_fingerprints(self, smiles: str, radius: int = 3, n_bits: int = 2048) -> Dict:
340
+ """
341
+ Compute Morgan fingerprints:
342
+ - bitstring (as list of '0'/'1' chars) at radius=radius
343
+ - counts (as dict) at radius=radius
344
+ Also includes all radii r in [1, radius-1].
345
+ """
346
  mol = Chem.MolFromSmiles(smiles)
347
  mol = process_star_atoms(mol)
348
  if mol is None:
349
  return {}
350
+
351
+ fingerprints = {}
352
+
353
+ # Main radius
354
  generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits)
355
  fp_bitvect = generator.GetFingerprint(mol)
356
+ fingerprints[f"morgan_r{radius}_bits"] = list(fp_bitvect.ToBitString())
357
+ fingerprints[f"morgan_r{radius}_counts"] = dict(AllChem.GetMorganFingerprint(mol, radius).GetNonzeroElements())
358
+
359
+ # Additional radii
 
360
  for r in range(1, radius):
361
  gen = rdFingerprintGenerator.GetMorganGenerator(radius=r, fpSize=n_bits)
362
  bitvect = gen.GetFingerprint(mol)
363
+ fingerprints[f"morgan_r{r}_bits"] = list(bitvect.ToBitString())
364
+ fingerprints[f"morgan_r{r}_counts"] = dict(AllChem.GetMorganFingerprint(mol, r).GetNonzeroElements())
365
+
366
  return fingerprints
367
 
368
+ # ------------------------------
369
+ # Chunked parallel processing over CSV
370
+ # ------------------------------
371
+ def process_all_polymers_parallel(self, chunk_size: int = 100, num_workers: int = 40) -> str:
372
+ """
373
+ Read the input CSV in chunks, fill missing multimodal columns, and process
374
+ only rows that are missing any of: graph/geometry/fingerprints.
375
+
376
+ Appends processed chunks to <input>_processed.csv and failures to <input>_failures.jsonl.
377
+ """
378
+ chunk_iterator = pd.read_csv(self.csv_file, chunksize=chunk_size, engine="python")
379
 
380
  for chunk in chunk_iterator:
381
+ # Ensure expected output columns exist and are object dtype (for JSON strings)
382
+ for col in ["graph", "geometry", "fingerprints"]:
383
  if col not in chunk.columns:
384
  chunk[col] = None
385
  chunk[col] = chunk[col].astype(object)
386
 
387
+ # Only process rows missing any modality
388
+ chunk_to_process = chunk[chunk[["graph", "geometry", "fingerprints"]].isnull().any(axis=1)].copy()
389
+
390
+ # If all rows already done, just persist chunk and continue
391
  if len(chunk_to_process) == 0:
392
  self.save_chunk_to_csv(chunk)
393
  continue
394
 
395
  rows = list(chunk_to_process.iterrows())
396
  argslist = [(i, row.to_dict(), self) for i, row in rows]
397
+
398
  with mp.Pool(num_workers) as pool:
399
  results = pool.map(process_single_polymer, argslist)
400
 
 
402
  for n, (output, fail) in enumerate(results):
403
  idx = rows[n][0]
404
  if output:
405
+ chunk.at[idx, "graph"] = json.dumps(output["graph"])
406
+ chunk.at[idx, "geometry"] = json.dumps(output["geometry"])
407
+ chunk.at[idx, "fingerprints"] = json.dumps(output["fingerprints"])
408
  if fail:
409
  failed_list.append(fail)
410
 
 
413
 
414
  return "Processing Done"
415
 
416
+ # ------------------------------
417
+ # Output helpers
418
+ # ------------------------------
419
+ def save_chunk_to_csv(self, chunk: pd.DataFrame) -> None:
420
+ """
421
+ Append processed chunk to <input>_processed.csv.
422
+ """
423
+ out_csv = self.csv_file.replace(".csv", "_processed.csv")
424
  if not os.path.exists(out_csv):
425
+ chunk.to_csv(out_csv, index=False, mode="w")
426
  else:
427
+ chunk.to_csv(out_csv, index=False, mode="a", header=False)
428
 
429
+ def save_failed_to_json(self, failed_list) -> None:
430
+ """
431
+ Append failures to <input>_failures.jsonl (JSON lines).
432
+ """
433
  if not failed_list:
434
  return
435
+ fail_json = self.csv_file.replace(".csv", "_failures.jsonl")
436
+ with open(fail_json, "a", encoding="utf-8") as f:
437
  for fail in failed_list:
438
  json.dump(fail, f)
439
+ f.write("\n")
440
 
441
+ # Optional stubs preserved (no functional change)
442
+ def save_results(self, output_file: str = "polymer_multimodal_data.json"):
443
  pass
444
 
 
445
  def generate_summary_statistics(self) -> Dict:
446
  return {}
447
 
448
+
449
  # ----------------------------------------------------------------------
450
+ # CLI / entry-point helpers
451
  # ----------------------------------------------------------------------
452
 
453
+ def parse_args() -> argparse.Namespace:
454
+ """
455
+ Command-line arguments:
456
+ --csv_file: path to input CSV (required)
457
+ --chunk_size: rows per chunk
458
+ --num_workers: multiprocessing workers
459
+ """
460
+ parser = argparse.ArgumentParser(description="Polymer multimodal feature extraction (RDKit).")
461
+ parser.add_argument(
462
+ "--csv_file",
463
+ type=str,
464
+ default="/path/to/polymer_structures_unified.csv",
465
+ help="Path to the input CSV file containing at least a 'psmiles' column.",
466
+ )
467
+ parser.add_argument("--chunk_size", type=int, default=1000, help="Rows per chunk for streaming CSV processing.")
468
+ parser.add_argument("--num_workers", type=int, default=24, help="Number of parallel worker processes.")
469
+ return parser.parse_args()
470
+
471
+
472
+ def main() -> Tuple[AdvancedPolymerMultimodalExtractor, Optional[object]]:
473
+ """
474
+ Script entry point.
475
+ Reads arguments, constructs the extractor, and runs chunked parallel processing.
476
+ """
477
+ args = parse_args()
478
+ csv_file = args.csv_file
479
+
480
  extractor = AdvancedPolymerMultimodalExtractor(csv_file)
481
  try:
482
+ extractor.process_all_polymers_parallel(chunk_size=args.chunk_size, num_workers=args.num_workers)
483
  except KeyboardInterrupt:
484
  return extractor, None
485
  except Exception as e:
486
  print(f"CRASH! Error: {e}")
487
  return extractor, None
488
+
489
  print("\n=== Processing Complete ===")
490
  return extractor, None
491
 
492
+
493
  if __name__ == "__main__":
494
  extractor, results = main()