manpreet88 commited on
Commit
1de49d3
·
1 Parent(s): 0111d29

Update SchNet.py

Browse files
Files changed (1) hide show
  1. PolyFusion/SchNet.py +569 -684
PolyFusion/SchNet.py CHANGED
@@ -1,10 +1,14 @@
 
 
 
 
1
  import os
2
  import json
3
  import time
4
- import shutil
5
-
6
  import sys
7
  import csv
 
 
8
 
9
  # Increase max CSV field size limit
10
  csv.field_size_limit(sys.maxsize)
@@ -17,721 +21,602 @@ import pandas as pd
17
  from sklearn.model_selection import train_test_split
18
  from torch.utils.data import Dataset, DataLoader
19
 
20
- # PyG (SchNet)
21
- from torch_geometric.nn import SchNet
22
 
23
  from transformers import TrainingArguments, Trainer
24
  from transformers.trainer_callback import TrainerCallback
25
  from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
26
- from torch_geometric.nn import radius_graph
27
 
28
  # ---------------------------
29
  # Configuration / Constants
30
  # ---------------------------
31
  P_MASK = 0.15
32
- # NOTE: do NOT infer max atomic number from the dataset; set it manually as requested.
33
- # "At" (Astatine) atomic number = 85 — change this value if your actual maximum differs.
34
  MAX_ATOMIC_Z = 85
35
-
36
- # Use a dedicated MASK token index (not 0). We'll place it after the max atomic number.
37
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
38
 
39
- COORD_NOISE_SIGMA = 0.5 # Å (start value, can tune)
40
  USE_LEARNED_WEIGHTING = True
41
 
42
- # SchNet hyperparams requested by user:
43
  SCHNET_NUM_GAUSSIANS = 50
44
  SCHNET_NUM_INTERACTIONS = 6
45
- SCHNET_CUTOFF = 10.0 # Å
46
  SCHNET_MAX_NEIGHBORS = 64
47
 
48
- # Number of anchor atoms to predict distances to (invariant objective)
49
  K_ANCHORS = 6
50
 
51
- # Output directory
52
- OUTPUT_DIR = "./schnet_output_5M"
53
- BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best")
54
- os.makedirs(OUTPUT_DIR, exist_ok=True)
55
 
56
- # ---------------------------
57
- # 1. Load Data (chunked to avoid OOM)
58
- # ---------------------------
59
- csv_path = "Polymer_Foundational_Model/Datasets/polymer_structures_unified_processed.csv"
60
- # target max rows to read (you previously used nrows=2000000)
61
- TARGET_ROWS = 5000000
62
- # choose a chunksize that fits your memory; adjust if needed
63
- CHUNKSIZE = 50000
64
-
65
- atomic_lists = []
66
- coord_lists = []
67
- rows_read = 0
68
-
69
- # Read in chunks and parse geometry JSON for each chunk to avoid OOM
70
- for chunk in pd.read_csv(csv_path, engine="python", chunksize=CHUNKSIZE):
71
-   # parse geometry column (JSON strings) in this chunk
72
-   geoms_chunk = chunk["geometry"].apply(json.loads)
73
-   for geom in geoms_chunk:
74
-   conf = geom["best_conformer"]
75
-   atomic_lists.append(conf["atomic_numbers"])
76
-   coord_lists.append(conf["coordinates"])
77
-
78
-   rows_read += len(chunk)
79
-   if rows_read >= TARGET_ROWS:
80
-   break
81
-
82
- # Use manual maximum atomic number (do not compute from data)
83
- max_atomic_z = MAX_ATOMIC_Z
84
- print(f"Using manual max atomic number: {max_atomic_z} (MASK_ATOM_ID={MASK_ATOM_ID})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # ---------------------------
87
- # 2. Train/Val Split
88
- # ---------------------------
89
- train_idx, val_idx = train_test_split(list(range(len(atomic_lists))), test_size=0.2, random_state=42)
90
- train_z = [torch.tensor(atomic_lists[i], dtype=torch.long) for i in train_idx]
91
- train_pos = [torch.tensor(coord_lists[i], dtype=torch.float) for i in train_idx]
92
- val_z = [torch.tensor(atomic_lists[i], dtype=torch.long) for i in val_idx]
93
- val_pos = [torch.tensor(coord_lists[i], dtype=torch.float) for i in val_idx]
94
 
95
- # ---------------------------
96
- # Compute class weights (for weighted CE to mitigate element imbalance)
97
- # ---------------------------
98
- # We create weights for classes [0 .. max_atomic_z, MASK_ATOM_ID] where most labels will be in 1..max_atomic_z.
99
- num_classes = MASK_ATOM_ID + 1 # (0 unused for typical atomic numbers; mask token at end)
100
- counts = np.ones((num_classes,), dtype=np.float64) # init with 1 to avoid zero division
101
-
102
- for z in train_z:
103
-   if z.numel() > 0:
104
-   vals = z.cpu().numpy().astype(int)
105
-   for v in vals:
106
-   if 0 <= v < num_classes:
107
-   counts[v] += 1.0
108
-
109
- # Inverse frequency (normalized to mean 1.0)
110
- freq = counts / counts.sum()
111
- inv_freq = 1.0 / (freq + 1e-12)
112
- class_weights = inv_freq / inv_freq.mean()
113
- class_weights = torch.tensor(class_weights, dtype=torch.float)
114
-
115
- # Set MASK token weight to 1.0 (it is not used as target in labels_z)
116
- class_weights[MASK_ATOM_ID] = 1.0
117
-
118
- # ---------------------------
119
- # 3. Dataset and Collator
120
- # ---------------------------
121
  class PolymerDataset(Dataset):
122
-   def __init__(self, zs, pos_list):
123
-   self.zs = zs
124
-   self.pos_list = pos_list
 
125
 
126
-   def __len__(self):
127
-   return len(self.zs)
 
 
 
128
 
129
-   def __getitem__(self, idx):
130
-   return {"z": self.zs[idx], "pos": self.pos_list[idx]}
131
 
132
  def collate_batch(batch):
133
-   """
134
-   Masking + create invariant distance targets:
135
-   - Select atoms for masking (P_MASK).
136
-   - For atomic numbers: 80/10/10 BERT-style corruption. Use MASK_ATOM_ID for mask token.
137
-   - For distances: for each masked atom, compute true distances to up to K_ANCHORS visible atoms
138
-   (nearest visible anchors). Produce labels_dists [N, K_ANCHORS] and anchors_exists mask [N, K_ANCHORS].
139
-   - Return labels_z (atomic targets, -100 for unselected) and labels_dists (+ anchors mask).
140
-   """
141
-   all_z = []
142
-   all_pos = []
143
-   all_labels_z = []
144
-   all_labels_dists = []
145
-   all_labels_dists_mask = []
146
-   batch_idx = []
147
-
148
-   for i, data in enumerate(batch):
149
-   z = data["z"] # [n_atoms]
150
-   pos = data["pos"] # [n_atoms,3]
151
-   n_atoms = z.size(0)
152
-   if n_atoms == 0:
153
-   continue
154
-
155
-   # 1) choose which atoms are selected for masking (15%)
156
-   is_selected = torch.rand(n_atoms) < P_MASK
157
-
158
-   # ensure not ALL atoms are selected (we need some visible anchors)
159
-   if is_selected.all():
160
-   # set one random atom to unselected
161
-   is_selected[torch.randint(0, n_atoms, (1,))] = False
162
-
163
-   # Prepare labels (only for selected atoms)
164
-   labels_z = torch.full((n_atoms,), -100, dtype=torch.long) # -100 ignored by CE
165
-   # labels_dists: per-atom K distances (0 padded) and mask indicating valid anchors
166
-   labels_dists = torch.zeros((n_atoms, K_ANCHORS), dtype=torch.float)
167
-   labels_dists_mask = torch.zeros((n_atoms, K_ANCHORS), dtype=torch.bool)
168
-
169
-   labels_z[is_selected] = z[is_selected] # true atomic numbers for selecteds
170
-
171
-   # 2) apply BERT-style corruption for atomic numbers
172
-   z_masked = z.clone()
173
-   if is_selected.any():
174
-   sel_idx = torch.nonzero(is_selected).squeeze(-1)
175
-   # sample random atomic numbers from 1..max_atomic_z (avoid 0 which is often unused)
176
-   rand_atomic = torch.randint(1, max_atomic_z + 1, (sel_idx.size(0),), dtype=torch.long)
177
-
178
-   probs = torch.rand(sel_idx.size(0))
179
-   mask_choice = probs < 0.8
180
-   rand_choice = (probs >= 0.8) & (probs < 0.9)
181
-   # keep_choice = probs >= 0.9
182
-
183
-   if mask_choice.any():
184
-   z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID
185
-   if rand_choice.any():
186
-   z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice]
187
-   # 10% keep => do nothing
188
-
189
-   # 3) coordinate corruption for selected atoms (we still corrupt positions for training robust embeddings)
190
-   pos_masked = pos.clone()
191
-   if is_selected.any():
192
-   sel_idx = torch.nonzero(is_selected).squeeze(-1)
193
-   probs_c = torch.rand(sel_idx.size(0))
194
-   noisy_choice = probs_c < 0.8
195
-   randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
196
-
197
-   if noisy_choice.any():
198
-   idx = sel_idx[noisy_choice]
199
-   noise = torch.randn((idx.size(0), 3)) * COORD_NOISE_SIGMA
200
-   pos_masked[idx] = pos_masked[idx] + noise
201
-
202
-   if randpos_choice.any():
203
-   idx = sel_idx[randpos_choice]
204
-   mins = pos.min(dim=0).values
205
-   maxs = pos.max(dim=0).values
206
-   randpos = (torch.rand((idx.size(0), 3)) * (maxs - mins)) + mins
207
-   pos_masked[idx] = randpos
208
-
209
-   # 4) Build invariant distance targets for masked atoms:
210
-   visible_idx = torch.nonzero(~is_selected).squeeze(-1)
211
-   # If for some reason no visible (shouldn't happen due to earlier guard), fall back to all atoms as visible
212
-   if visible_idx.numel() == 0:
213
-   visible_idx = torch.arange(n_atoms, dtype=torch.long)
214
-
215
-   # Precompute pairwise distances
216
-   # pos: [n_atoms,3], visible_pos: [V,3]
217
-   visible_pos = pos[visible_idx] # true positions for anchors
218
-   for a in torch.nonzero(is_selected).squeeze(-1).tolist():
219
-   # distances from atom a to all visible anchors
220
-   dists = torch.sqrt(((pos[a].unsqueeze(0) - visible_pos) ** 2).sum(dim=1) + 1e-12)
221
-   # find nearest anchors (ascending)
222
-   if dists.numel() > 0:
223
-   k = min(K_ANCHORS, dists.numel())
224
-   nearest_vals, nearest_idx = torch.topk(dists, k, largest=False)
225
-   labels_dists[a, :k] = nearest_vals
226
-   labels_dists_mask[a, :k] = True
227
-   # else leave zeros and mask False
228
-
229
-   all_z.append(z_masked)
230
-   all_pos.append(pos_masked)
231
-   all_labels_z.append(labels_z)
232
-   all_labels_dists.append(labels_dists)
233
-   all_labels_dists_mask.append(labels_dists_mask)
234
-   batch_idx.append(torch.full((n_atoms,), i, dtype=torch.long))
235
-
236
-   if len(all_z) == 0:
237
-   return {"z": torch.tensor([], dtype=torch.long),
238
-   "pos": torch.tensor([], dtype=torch.float).reshape(0, 3),
239
-   "batch": torch.tensor([], dtype=torch.long),
240
-   "labels_z": torch.tensor([], dtype=torch.long),
241
-   "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS),
242
-   "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS)}
243
-
244
-   z_batch = torch.cat(all_z, dim=0)
245
-   pos_batch = torch.cat(all_pos, dim=0)
246
-   labels_z_batch = torch.cat(all_labels_z, dim=0)
247
-   labels_dists_batch = torch.cat(all_labels_dists, dim=0)
248
-   labels_dists_mask_batch = torch.cat(all_labels_dists_mask, dim=0)
249
-   batch_batch = torch.cat(batch_idx, dim=0)
250
-
251
-   return {"z": z_batch, "pos": pos_batch, "batch": batch_batch,
252
-   "labels_z": labels_z_batch,
253
-   "labels_dists": labels_dists_batch,
254
-   "labels_dists_mask": labels_dists_mask_batch}
255
-
256
- train_dataset = PolymerDataset(train_z, train_pos)
257
- val_dataset = PolymerDataset(val_z, val_pos)
258
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch)
259
- val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_batch)
260
 
261
- from torch_geometric.nn import SchNet as BaseSchNet
262
- from torch_geometric.nn import radius_graph
263
 
264
  class NodeSchNet(nn.Module):
265
-   """Custom SchNet that returns node embeddings instead of graph-level predictions"""
266
-
267
-   def __init__(self, hidden_channels=128, num_filters=128, num_interactions=6,
268
-   num_gaussians=50, cutoff=10.0, max_num_neighbors=32, readout='add'):
269
-   super().__init__()
270
-
271
-   self.hidden_channels = hidden_channels
272
-   self.cutoff = cutoff
273
-   self.max_num_neighbors = max_num_neighbors
274
-
275
-   # Initialize the base SchNet but we'll only use parts of it
276
-   self.base_schnet = BaseSchNet(
277
-   hidden_channels=hidden_channels,
278
-   num_filters=num_filters,
279
-   num_interactions=num_interactions,
280
-   num_gaussians=num_gaussians,
281
-   cutoff=cutoff,
282
-   max_num_neighbors=max_num_neighbors,
283
-   readout=readout
284
-   )
285
-
286
-   def forward(self, z, pos, batch=None):
287
-   """Return node embeddings, not graph-level predictions"""
288
-   if batch is None:
289
-   batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
290
-
291
-   # Use the embedding and interaction layers from base SchNet
292
-   h = self.base_schnet.embedding(z)
293
-
294
-   # Build edge connectivity
295
-   edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
296
-   max_num_neighbors=self.max_num_neighbors)
297
-
298
-   # Compute edge distances and expand with Gaussians
299
-   row, col = edge_index
300
-   edge_weight = (pos[row] - pos[col]).norm(dim=-1)
301
-   edge_attr = self.base_schnet.distance_expansion(edge_weight)
302
-
303
-   # Apply interaction blocks (message passing)
304
-   for interaction in self.base_schnet.interactions:
305
-   h = h + interaction(h, edge_index, edge_weight, edge_attr)
306
-
307
-   # STOP HERE - return node embeddings, don't do readout/final layers
308
-   return h # Shape: [num_nodes, hidden_channels]
309
 
310
- # ---------------------------
311
- # 4. Model Definition (SchNet + two heads + learned weighting)
312
- # ---------------------------
313
  class MaskedSchNet(nn.Module):
314
-   def __init__(self,
315
-   hidden_channels=600,
316
-   num_interactions=SCHNET_NUM_INTERACTIONS,
317
-   num_gaussians=SCHNET_NUM_GAUSSIANS,
318
-   cutoff=SCHNET_CUTOFF,
319
-   max_atomic_z=max_atomic_z,
320
-   max_num_neighbors=SCHNET_MAX_NEIGHBORS,
321
-   class_weights=None):
322
-   super().__init__()
323
-   self.hidden_channels = hidden_channels
324
-   self.cutoff = cutoff
325
-   self.max_num_neighbors = max_num_neighbors
326
-   self.max_atomic_z = max_atomic_z
327
-
328
-   # SchNet model from PyG
329
-   self.schnet = NodeSchNet(
330
-   hidden_channels=hidden_channels,
331
-   num_filters=hidden_channels,
332
-   num_interactions=num_interactions,
333
-   num_gaussians=num_gaussians,
334
-   cutoff=cutoff,
335
-   max_num_neighbors=max_num_neighbors
336
-   )
337
-
338
-   # Classification head for atomic number (classes 0..max_atomic_z and MASK token)
339
-   num_classes_local = MASK_ATOM_ID + 1
340
-   self.atom_head = nn.Linear(hidden_channels, num_classes_local)
341
-
342
-   # Distance-prediction head (predict K_ANCHORS scalar distances per node) -> invariant target
343
-   self.coord_head = nn.Linear(hidden_channels, K_ANCHORS)
344
-
345
-   # Learned uncertainty weighting (log-variances) if enabled
346
-   if USE_LEARNED_WEIGHTING:
347
-   self.log_var_z = nn.Parameter(torch.zeros(1))
348
-   self.log_var_pos = nn.Parameter(torch.zeros(1))
349
-   else:
350
-   self.log_var_z = None
351
-   self.log_var_pos = None
352
-
353
-   # Class weights for cross entropy
354
-   if class_weights is not None:
355
-   # register as buffer so it moves with .to(device)
356
-   self.register_buffer("class_weights", class_weights)
357
-   else:
358
-   self.class_weights = None
359
-
360
-   def forward(self, z, pos, batch, labels_z=None, labels_dists=None, labels_dists_mask=None):
361
-   """
362
-   z: [N] long (atomic numbers or MASK_ATOM_ID)
363
-   pos: [N,3] float (possibly corrupted)
364
-   batch: [N] long (graph indices)
365
-   labels_z: [N] long (-100 for unselected)
366
-   labels_dists: [N, K_ANCHORS] float (0 padded)
367
-   labels_dists_mask: [N, K_ANCHORS] bool (True where anchor exists)
368
-   """
369
-   # Let SchNet produce node embeddings. SchNet builds its own neighbor graph internally.
370
-   # SchNet's forward often accepts (z, pos, batch)
371
-   try:
372
-   h = self.schnet(z=z, pos=pos, batch=batch)
373
-   except TypeError:
374
-   # fallback if different signature
375
-   h = self.schnet(z=z, pos=pos)
376
-
377
-   # Node embeddings
378
-   logits = self.atom_head(h) # [N, num_classes]
379
-   dists_pred = self.coord_head(h) # [N, K_ANCHORS]
380
-
381
-   # If labels provided -> compute loss (aggregated only over masked atoms)
382
-   if labels_z is not None and labels_dists is not None and labels_dists_mask is not None:
383
-   mask = labels_z != -100 # which atoms were selected for supervision
384
-   if mask.sum() == 0:
385
-   # Nothing masked in this batch: return zero loss (avoid NaNs)
386
-   return torch.tensor(0.0, device=z.device)
387
-
388
-   logits_masked = logits[mask] # [M, num_classes]
389
-   dists_pred_masked = dists_pred[mask] # [M, K_ANCHORS]
390
-   labels_z_masked = labels_z[mask] # [M]
391
-   labels_dists_masked = labels_dists[mask] # [M, K_ANCHORS]
392
-   labels_dists_mask_mask = labels_dists_mask[mask] # [M, K_ANCHORS] bool
393
-
394
-   # classification loss (weighted cross entropy)
395
-   if self.class_weights is not None:
396
-   loss_z = F.cross_entropy(logits_masked, labels_z_masked, weight=self.class_weights)
397
-   else:
398
-   loss_z = F.cross_entropy(logits_masked, labels_z_masked)
399
-
400
-   # coordinate/distance loss: only over existing anchor distances
401
-   # flatten valid entries
402
-   if labels_dists_mask_mask.any():
403
-   preds = dists_pred_masked[labels_dists_mask_mask]
404
-   trues = labels_dists_masked[labels_dists_mask_mask]
405
-   loss_pos = F.mse_loss(preds, trues, reduction="mean")
406
-   else:
407
-   # no anchor distances present (shouldn't happen), set zero
408
-   loss_pos = torch.tensor(0.0, device=z.device)
409
-
410
-   if USE_LEARNED_WEIGHTING:
411
-   lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z
412
-   lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos
413
-   loss = 0.5 * (lz + lp)
414
-   else:
415
-   alpha = 1.0
416
-   loss = loss_z + alpha * loss_pos
417
-
418
-   return loss
419
-
420
-   # Inference: return logits and predicted distances
421
-   return logits, dists_pred
422
-
423
- # instantiate model with requested SchNet params and computed class weights
424
- model = MaskedSchNet(hidden_channels=600,
425
-   num_interactions=SCHNET_NUM_INTERACTIONS,
426
-   num_gaussians=SCHNET_NUM_GAUSSIANS,
427
-   cutoff=SCHNET_CUTOFF,
428
-   max_atomic_z=max_atomic_z,
429
-   max_num_neighbors=SCHNET_MAX_NEIGHBORS,
430
-   class_weights=class_weights)
431
-
432
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
433
- model.to(device)
434
 
435
- # ---------------------------
436
- # 5. Training Setup (Hugging Face Trainer)
437
- # ---------------------------
438
- training_args = TrainingArguments(
439
-   output_dir=OUTPUT_DIR,
440
-   overwrite_output_dir=True,
441
-   num_train_epochs=25,
442
-   per_device_train_batch_size=16,
443
-   per_device_eval_batch_size=8,
444
-   gradient_accumulation_steps=4,
445
-   eval_strategy="epoch",
446
-   logging_steps=500,
447
-   learning_rate=1e-4,
448
-   weight_decay=0.01,
449
-   fp16=torch.cuda.is_available(),
450
-   save_strategy="no", # we will let callback save best model
451
-   disable_tqdm=False,
452
-   logging_first_step=True,
453
-   report_to=[],
454
-   dataloader_num_workers=4,
455
- )
456
 
457
  class ValLossCallback(TrainerCallback):
458
-   def __init__(self, trainer_ref=None):
459
-   self.best_val_loss = float("inf")
460
-   self.epochs_no_improve = 0
461
-   self.patience = 10
462
-   self.best_epoch = None
463
-   self.trainer_ref = trainer_ref
464
-
465
-   def on_epoch_end(self, args, state, control, **kwargs):
466
-   # Print epoch starting from 1 instead of 0
467
-   epoch_num = int(state.epoch)
468
-   train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None)
469
-   print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===")
470
-   if train_loss is not None:
471
-   print(f"Train Loss: {train_loss:.4f}")
472
-
473
-   def on_evaluate(self, args, state, control, metrics=None, **kwargs):
474
-   """
475
-   When trainer runs evaluation, compute full validation metrics here (accuracy, f1, rmse, mae, perplexity)
476
-   using the provided val_loader and the trainer's model. Save the model when val_loss improves.
477
-   NOTE: Validation loss printed and used for the best-model decision is taken from the `metrics`
478
-   object provided by the Trainer when available, so it matches the Trainer's evaluation output.
479
-   """
480
-   # Compute epoch number for printing (1-based)
481
-   epoch_num = int(state.epoch) + 1
482
-
483
-   # If we don't have a trainer reference or val_loader, fallback to printing whatever metrics provided
484
-   if self.trainer_ref is None:
485
-   print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
486
-   return
487
-
488
-   # If trainer provided an eval_loss in metrics, prefer that value for printing and best-model decision
489
-   metric_val_loss = None
490
-   if metrics is not None:
491
-   metric_val_loss = metrics.get("eval_loss")
492
-
493
-   # Evaluate over val_loader to compute other metrics (accuracy, f1, rmse, mae, perplexity)
494
-   model_eval = self.trainer_ref.model
495
-   model_eval.eval()
496
-
497
-   device_local = next(model_eval.parameters()).device if any(p.numel() > 0 for p in model_eval.parameters()) else torch.device("cpu")
498
-
499
-   preds_z_all = []
500
-   true_z_all = []
501
-   pred_dists_all = []
502
-   true_dists_all = []
503
-   total_loss = 0.0
504
-   n_batches = 0
505
-
506
-   logits_masked_list = []
507
-   labels_masked_list = []
508
-
509
-   with torch.no_grad():
510
-   for batch in val_loader:
511
-   z = batch["z"].to(device_local)
512
-   pos = batch["pos"].to(device_local)
513
-   batch_idx = batch["batch"].to(device_local)
514
-   labels_z = batch["labels_z"].to(device_local)
515
-   labels_dists = batch["labels_dists"].to(device_local)
516
-   labels_dists_mask = batch["labels_dists_mask"].to(device_local)
517
-
518
-   # compute loss using labels (model returns loss when labels provided)
519
-   try:
520
-   loss = model_eval(z, pos, batch_idx, labels_z, labels_dists, labels_dists_mask)
521
-   except Exception as e:
522
-   # If model.forward signature is different, skip loss accumulation but still compute preds
523
-   loss = None
524
-
525
-   if isinstance(loss, torch.Tensor):
526
-   total_loss += loss.item()
527
-   n_batches += 1
528
-
529
-   # inference to get logits and distance preds
530
-   logits, dists_pred = model_eval(z, pos, batch_idx)
531
-
532
-   mask = labels_z != -100
533
-   if mask.sum().item() == 0:
534
-   continue
535
-
536
-   # collect masked logits/labels for perplexity
537
-   logits_masked_list.append(logits[mask])
538
-   labels_masked_list.append(labels_z[mask])
539
-
540
-   pred_z = torch.argmax(logits[mask], dim=-1)
541
-   true_z = labels_z[mask]
542
-
543
-   # flatten valid distances across anchors
544
-   pred_d = dists_pred[mask][labels_dists_mask[mask]]
545
-   true_d = labels_dists[mask][labels_dists_mask[mask]]
546
-
547
-   if pred_d.numel() > 0:
548
-   pred_dists_all.extend(pred_d.cpu().tolist())
549
-   true_dists_all.extend(true_d.cpu().tolist())
550
-
551
-   preds_z_all.extend(pred_z.cpu().tolist())
552
-   true_z_all.extend(true_z.cpu().tolist())
553
-
554
-   # If the trainer provided eval_loss, use it; otherwise fall back to the computed average loss
555
-   avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
556
-
557
-   # Compute metrics (classification + distance regression)
558
-   accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
559
-   f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
560
-   rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
561
-   mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
562
-
563
-   # Compute classification perplexity from masked-token cross-entropy, if available
564
-   if len(logits_masked_list) > 0:
565
-   all_logits_masked = torch.cat(logits_masked_list, dim=0)
566
-   all_labels_masked = torch.cat(labels_masked_list, dim=0)
567
-   # Use model's class_weights if present
568
-   cw = getattr(model_eval, "class_weights", None)
569
-   if cw is not None:
570
-   cw_device = cw.to(device_local)
571
-   try:
572
-   loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw_device)
573
-   except Exception:
574
-   # fallback without weight
575
-   loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
576
-   else:
577
-   loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
578
-   try:
579
-   perplexity = float(torch.exp(loss_z_all).cpu().item())
580
-   except Exception:
581
-   perplexity = float(np.exp(float(loss_z_all.cpu().item())))
582
-   else:
583
-   perplexity = float("nan")
584
-
585
-   print(f"\n--- Evaluation after Epoch {epoch_num} ---")
586
-   # Print validation loss that matches Trainer's evaluation when available
587
-   print(f"Validation Loss: {avg_val_loss:.4f}")
588
-   print(f"Validation Accuracy: {accuracy:.4f}")
589
-   print(f"Validation F1 (weighted): {f1:.4f}")
590
-   print(f"Validation RMSE (distances): {rmse:.4f}")
591
-   print(f"Validation MAE (distances): {mae:.4f}")
592
-   print(f"Validation Perplexity (classification head): {perplexity:.4f}")
593
-
594
-   # Check for improvement (use a small tolerance)
595
-   if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
596
-   self.best_val_loss = avg_val_loss
597
-   self.best_epoch = int(state.epoch) # store 0-based internally
598
-   self.epochs_no_improve = 0
599
-   # Save best model state_dict
600
-   os.makedirs(BEST_MODEL_DIR, exist_ok=True)
601
-   try:
602
-   # Prefer trainer's model (which may be wrapped)
603
-   torch.save(self.trainer_ref.model.state_dict(), os.path.join(BEST_MODEL_DIR, "pytorch_model.bin"))
604
-   print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(BEST_MODEL_DIR, 'pytorch_model.bin')}")
605
-   except Exception as e:
606
-   print(f"Failed to save best model at epoch {epoch_num}: {e}")
607
-   else:
608
-   self.epochs_no_improve += 1
609
-
610
-   if self.epochs_no_improve >= self.patience:
611
-   print(f"Early stopping after {self.patience} epochs with no improvement.")
612
-   control.should_training_stop = True
613
-
614
- # Create callback and Trainer
615
- callback = ValLossCallback()
616
- trainer = Trainer(
617
-   model=model,
618
-   args=training_args,
619
-   train_dataset=train_dataset,
620
-   eval_dataset=val_dataset,
621
-   data_collator=collate_batch,
622
-   callbacks=[callback]
623
- )
624
- # attach trainer_ref so callback can save model
625
- callback.trainer_ref = trainer
626
-
627
- # ---------------------------
628
- # 6. Run training
629
- # ---------------------------
630
- start_time = time.time()
631
- trainer.train()
632
- total_time = time.time() - start_time
633
-
634
- # ---------------------------
635
- # 7. Final Evaluation (metrics computed on masked atoms in validation set)
636
- # -> NOTE: per request, we will evaluate the best-saved model (by least val loss)
637
- # ---------------------------
638
- # If a best model was saved by the callback, load it
639
- best_model_path = os.path.join(BEST_MODEL_DIR, "pytorch_model.bin")
640
- if os.path.exists(best_model_path):
641
-   try:
642
-   model.load_state_dict(torch.load(best_model_path, map_location=device))
643
-   print(f"\nLoaded best model from {best_model_path}")
644
-   except Exception as e:
645
-   print(f"\nFailed to load best model from {best_model_path}: {e}")
646
-
647
- model.eval()
648
- preds_z_all = []
649
- true_z_all = []
650
- pred_dists_all = []
651
- true_dists_all = []
652
-
653
- # For computing perplexity in final eval
654
- logits_masked_list_final = []
655
- labels_masked_list_final = []
656
-
657
- with torch.no_grad():
658
-   for batch in val_loader:
659
-   z = batch["z"].to(device)
660
-   pos = batch["pos"].to(device)
661
-   batch_idx = batch["batch"].to(device)
662
-   labels_z = batch["labels_z"].to(device)
663
-   labels_dists = batch["labels_dists"].to(device)
664
-   labels_dists_mask = batch["labels_dists_mask"].to(device)
665
-
666
-   logits, dists_pred = model(z, pos, batch_idx) # inference mode returns (logits, dists_pred)
667
-
668
-   mask = labels_z != -100
669
-   if mask.sum().item() == 0:
670
-   continue
671
-
672
-   # collect masked logits/labels for perplexity
673
-   logits_masked_list_final.append(logits[mask])
674
-   labels_masked_list_final.append(labels_z[mask])
675
-
676
-   pred_z = torch.argmax(logits[mask], dim=-1)
677
-   true_z = labels_z[mask]
678
-
679
-   # flatten valid distances across anchors
680
-   pred_d = dists_pred[mask][labels_dists_mask[mask]]
681
-   true_d = labels_dists[mask][labels_dists_mask[mask]]
682
-
683
-   if pred_d.numel() > 0:
684
-   pred_dists_all.extend(pred_d.cpu().tolist())
685
-   true_dists_all.extend(true_d.cpu().tolist())
686
-
687
-   preds_z_all.extend(pred_z.cpu().tolist())
688
-   true_z_all.extend(true_z.cpu().tolist())
689
-
690
- # Compute metrics (classification + distance regression)
691
- accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
692
- f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
693
- rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
694
- mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
695
-
696
- # Compute perplexity from collected masked logits/labels
697
- if len(logits_masked_list_final) > 0:
698
-   all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0)
699
-   all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0)
700
-   cw_final = getattr(model, "class_weights", None)
701
-   if cw_final is not None:
702
-   try:
703
-   loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device))
704
-   except Exception:
705
-   loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
706
-   else:
707
-   loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
708
-   try:
709
-   perplexity_final = float(torch.exp(loss_z_final).cpu().item())
710
-   except Exception:
711
-   perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
712
- else:
713
-   perplexity_final = float("nan")
714
-
715
- best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
716
- best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
717
-
718
- print(f"\n=== Final Results (evaluated on best saved model) ===")
719
- print(f"Total Training Time (s): {total_time:.2f}")
720
- if best_epoch_num is not None:
721
-   print(f"Best Epoch (1-based): {best_epoch_num}")
722
- else:
723
-   print("Best Epoch: (none saved)")
724
-
725
- print(f"Best Validation Loss: {best_val_loss:.4f}")
726
- print(f"Validation Accuracy: {accuracy:.4f}")
727
- print(f"Validation F1 (weighted): {f1:.4f}")
728
- print(f"Validation RMSE (distances): {rmse:.4f}")
729
- print(f"Validation MAE (distances): {mae:.4f}")
730
- print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
731
-
732
- total_params = sum(p.numel() for p in model.parameters())
733
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
734
- non_trainable_params = total_params - trainable_params
735
- print(f"Total Parameters: {total_params}")
736
- print(f"Trainable Parameters: {trainable_params}")
737
- print(f"Non-trainable Parameters: {non_trainable_params}")
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SchNet-based masked pretraining on polymer conformer geometry.
3
+ """
4
+
5
  import os
6
  import json
7
  import time
 
 
8
  import sys
9
  import csv
10
+ import argparse
11
+ from typing import List
12
 
13
  # Increase max CSV field size limit
14
  csv.field_size_limit(sys.maxsize)
 
21
  from sklearn.model_selection import train_test_split
22
  from torch.utils.data import Dataset, DataLoader
23
 
24
+ from torch_geometric.nn import SchNet as BaseSchNet
25
+ from torch_geometric.nn import radius_graph
26
 
27
  from transformers import TrainingArguments, Trainer
28
  from transformers.trainer_callback import TrainerCallback
29
  from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
 
30
 
31
  # ---------------------------
32
  # Configuration / Constants
33
  # ---------------------------
34
  P_MASK = 0.15
 
 
35
  MAX_ATOMIC_Z = 85
 
 
36
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
37
 
38
+ COORD_NOISE_SIGMA = 0.5
39
  USE_LEARNED_WEIGHTING = True
40
 
 
41
  SCHNET_NUM_GAUSSIANS = 50
42
  SCHNET_NUM_INTERACTIONS = 6
43
+ SCHNET_CUTOFF = 10.0
44
  SCHNET_MAX_NEIGHBORS = 64
45
 
 
46
  K_ANCHORS = 6
47
 
 
 
 
 
48
 
49
+ def parse_args() -> argparse.Namespace:
50
+ parser = argparse.ArgumentParser(description="SchNet masked pretraining (geometry).")
51
+ parser.add_argument(
52
+ "--csv_path",
53
+ type=str,
54
+ default="/path/to/polymer_structures_unified_processed.csv",
55
+ help="Processed CSV containing a JSON 'geometry' column.",
56
+ )
57
+ parser.add_argument("--target_rows", type=int, default=5_000_000, help="Max rows to read.")
58
+ parser.add_argument("--chunksize", type=int, default=50_000, help="CSV chunksize.")
59
+ parser.add_argument("--output_dir", type=str, default="/path/to/schnet_output_5M", help="Training output directory.")
60
+ parser.add_argument("--num_workers", type=int, default=4, help="PyTorch DataLoader num workers.")
61
+ return parser.parse_args()
62
+
63
+
64
+ def load_geometry_from_csv(csv_path: str, target_rows: int, chunksize: int):
65
+ """
66
+ Stream the processed CSV and extract:
67
+ - atomic_numbers
68
+ - coordinates
69
+ from geometry['best_conformer'] for each row.
70
+ """
71
+ atomic_lists = []
72
+ coord_lists = []
73
+ rows_read = 0
74
+
75
+ for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize):
76
+ # parse geometry JSON strings in the chunk
77
+ geoms_chunk = chunk["geometry"].apply(json.loads)
78
+ for geom in geoms_chunk:
79
+ conf = geom["best_conformer"]
80
+ atomic_lists.append(conf["atomic_numbers"])
81
+ coord_lists.append(conf["coordinates"])
82
+
83
+ rows_read += len(chunk)
84
+ if rows_read >= target_rows:
85
+ break
86
+
87
+ print(f"Using manual max atomic number: {MAX_ATOMIC_Z} (MASK_ATOM_ID={MASK_ATOM_ID})")
88
+ return atomic_lists, coord_lists
89
+
90
+
91
+ def compute_class_weights(train_z: List[torch.Tensor]) -> torch.Tensor:
92
+ """Inverse-frequency class weights for atomic number classification."""
93
+ num_classes = MASK_ATOM_ID + 1
94
+ counts = np.ones((num_classes,), dtype=np.float64)
95
+
96
+ for z in train_z:
97
+ if z.numel() > 0:
98
+ vals = z.cpu().numpy().astype(int)
99
+ for v in vals:
100
+ if 0 <= v < num_classes:
101
+ counts[v] += 1.0
102
+
103
+ freq = counts / counts.sum()
104
+ inv_freq = 1.0 / (freq + 1e-12)
105
+ class_weights = inv_freq / inv_freq.mean()
106
+ class_weights = torch.tensor(class_weights, dtype=torch.float)
107
+ class_weights[MASK_ATOM_ID] = 1.0
108
+ return class_weights
109
 
 
 
 
 
 
 
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class PolymerDataset(Dataset):
112
+ """Pairs of (z, pos) per polymer conformer."""
113
+ def __init__(self, zs: List[torch.Tensor], pos_list: List[torch.Tensor]):
114
+ self.zs = zs
115
+ self.pos_list = pos_list
116
 
117
+ def __len__(self):
118
+ return len(self.zs)
119
+
120
+ def __getitem__(self, idx):
121
+ return {"z": self.zs[idx], "pos": self.pos_list[idx]}
122
 
 
 
123
 
124
  def collate_batch(batch):
125
+ """
126
+ Collate conformers into a concatenated node set with a 'batch' vector, while applying:
127
+ - atomic number masking (MLM-style)
128
+ - coordinate corruption for masked atoms
129
+ - invariant distance targets to nearest visible anchors (K_ANCHORS)
130
+ """
131
+ all_z, all_pos = [], []
132
+ all_labels_z, all_labels_dists, all_labels_dists_mask = [], [], []
133
+ batch_idx = []
134
+
135
+ for i, data in enumerate(batch):
136
+ z = data["z"]
137
+ pos = data["pos"]
138
+ n_atoms = z.size(0)
139
+ if n_atoms == 0:
140
+ continue
141
+
142
+ is_selected = torch.rand(n_atoms) < P_MASK
143
+ if is_selected.all():
144
+ is_selected[torch.randint(0, n_atoms, (1,))] = False
145
+
146
+ labels_z = torch.full((n_atoms,), -100, dtype=torch.long)
147
+ labels_dists = torch.zeros((n_atoms, K_ANCHORS), dtype=torch.float)
148
+ labels_dists_mask = torch.zeros((n_atoms, K_ANCHORS), dtype=torch.bool)
149
+ labels_z[is_selected] = z[is_selected]
150
+
151
+ # Atomic number corruption
152
+ z_masked = z.clone()
153
+ if is_selected.any():
154
+ sel_idx = torch.nonzero(is_selected).squeeze(-1)
155
+ rand_atomic = torch.randint(1, MAX_ATOMIC_Z + 1, (sel_idx.size(0),), dtype=torch.long)
156
+
157
+ probs = torch.rand(sel_idx.size(0))
158
+ mask_choice = probs < 0.8
159
+ rand_choice = (probs >= 0.8) & (probs < 0.9)
160
+
161
+ if mask_choice.any():
162
+ z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID
163
+ if rand_choice.any():
164
+ z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice]
165
+
166
+ # Coordinate corruption (noise/random position)
167
+ pos_masked = pos.clone()
168
+ if is_selected.any():
169
+ sel_idx = torch.nonzero(is_selected).squeeze(-1)
170
+ probs_c = torch.rand(sel_idx.size(0))
171
+ noisy_choice = probs_c < 0.8
172
+ randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
173
+
174
+ if noisy_choice.any():
175
+ idx = sel_idx[noisy_choice]
176
+ noise = torch.randn((idx.size(0), 3)) * COORD_NOISE_SIGMA
177
+ pos_masked[idx] = pos_masked[idx] + noise
178
+
179
+ if randpos_choice.any():
180
+ idx = sel_idx[randpos_choice]
181
+ mins = pos.min(dim=0).values
182
+ maxs = pos.max(dim=0).values
183
+ randpos = (torch.rand((idx.size(0), 3)) * (maxs - mins)) + mins
184
+ pos_masked[idx] = randpos
185
+
186
+ # Anchor-distance targets for masked atoms (using true positions as reference)
187
+ visible_idx = torch.nonzero(~is_selected).squeeze(-1)
188
+ if visible_idx.numel() == 0:
189
+ visible_idx = torch.arange(n_atoms, dtype=torch.long)
190
+
191
+ visible_pos = pos[visible_idx]
192
+ for a in torch.nonzero(is_selected).squeeze(-1).tolist():
193
+ dists = torch.sqrt(((pos[a].unsqueeze(0) - visible_pos) ** 2).sum(dim=1) + 1e-12)
194
+ if dists.numel() > 0:
195
+ k = min(K_ANCHORS, dists.numel())
196
+ nearest_vals, _ = torch.topk(dists, k, largest=False)
197
+ labels_dists[a, :k] = nearest_vals
198
+ labels_dists_mask[a, :k] = True
199
+
200
+ all_z.append(z_masked)
201
+ all_pos.append(pos_masked)
202
+ all_labels_z.append(labels_z)
203
+ all_labels_dists.append(labels_dists)
204
+ all_labels_dists_mask.append(labels_dists_mask)
205
+ batch_idx.append(torch.full((n_atoms,), i, dtype=torch.long))
206
+
207
+ if len(all_z) == 0:
208
+ return {
209
+ "z": torch.tensor([], dtype=torch.long),
210
+ "pos": torch.tensor([], dtype=torch.float).reshape(0, 3),
211
+ "batch": torch.tensor([], dtype=torch.long),
212
+ "labels_z": torch.tensor([], dtype=torch.long),
213
+ "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS),
214
+ "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS),
215
+ }
216
+
217
+ return {
218
+ "z": torch.cat(all_z, dim=0),
219
+ "pos": torch.cat(all_pos, dim=0),
220
+ "batch": torch.cat(batch_idx, dim=0),
221
+ "labels_z": torch.cat(all_labels_z, dim=0),
222
+ "labels_dists": torch.cat(all_labels_dists, dim=0),
223
+ "labels_dists_mask": torch.cat(all_labels_dists_mask, dim=0),
224
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
 
 
226
 
227
  class NodeSchNet(nn.Module):
228
+ """SchNet variant that returns node embeddings (no readout)."""
229
+ def __init__(self, hidden_channels=128, num_filters=128, num_interactions=6,
230
+ num_gaussians=50, cutoff=10.0, max_num_neighbors=32, readout="add"):
231
+ super().__init__()
232
+ self.hidden_channels = hidden_channels
233
+ self.cutoff = cutoff
234
+ self.max_num_neighbors = max_num_neighbors
235
+ self.base_schnet = BaseSchNet(
236
+ hidden_channels=hidden_channels,
237
+ num_filters=num_filters,
238
+ num_interactions=num_interactions,
239
+ num_gaussians=num_gaussians,
240
+ cutoff=cutoff,
241
+ max_num_neighbors=max_num_neighbors,
242
+ readout=readout,
243
+ )
244
+
245
+ def forward(self, z, pos, batch=None):
246
+ if batch is None:
247
+ batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
248
+
249
+ h = self.base_schnet.embedding(z)
250
+
251
+ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors)
252
+ row, col = edge_index
253
+ edge_weight = (pos[row] - pos[col]).norm(dim=-1)
254
+ edge_attr = self.base_schnet.distance_expansion(edge_weight)
255
+
256
+ for interaction in self.base_schnet.interactions:
257
+ h = h + interaction(h, edge_index, edge_weight, edge_attr)
258
+
259
+ return h
260
+
 
 
 
 
 
 
 
 
 
 
 
261
 
 
 
 
262
  class MaskedSchNet(nn.Module):
263
+ """Masked objectives on top of node embeddings from SchNet."""
264
+ def __init__(self, hidden_channels=600, num_interactions=SCHNET_NUM_INTERACTIONS, num_gaussians=SCHNET_NUM_GAUSSIANS,
265
+ cutoff=SCHNET_CUTOFF, max_atomic_z=MAX_ATOMIC_Z, max_num_neighbors=SCHNET_MAX_NEIGHBORS, class_weights=None):
266
+ super().__init__()
267
+ self.hidden_channels = hidden_channels
268
+ self.cutoff = cutoff
269
+ self.max_num_neighbors = max_num_neighbors
270
+ self.max_atomic_z = max_atomic_z
271
+
272
+ self.schnet = NodeSchNet(
273
+ hidden_channels=hidden_channels,
274
+ num_filters=hidden_channels,
275
+ num_interactions=num_interactions,
276
+ num_gaussians=num_gaussians,
277
+ cutoff=cutoff,
278
+ max_num_neighbors=max_num_neighbors,
279
+ )
280
+
281
+ self.atom_head = nn.Linear(hidden_channels, MASK_ATOM_ID + 1)
282
+ self.coord_head = nn.Linear(hidden_channels, K_ANCHORS)
283
+
284
+ if USE_LEARNED_WEIGHTING:
285
+ self.log_var_z = nn.Parameter(torch.zeros(1))
286
+ self.log_var_pos = nn.Parameter(torch.zeros(1))
287
+ else:
288
+ self.log_var_z = None
289
+ self.log_var_pos = None
290
+
291
+ if class_weights is not None:
292
+ self.register_buffer("class_weights", class_weights)
293
+ else:
294
+ self.class_weights = None
295
+
296
+ def forward(self, z, pos, batch, labels_z=None, labels_dists=None, labels_dists_mask=None):
297
+ h = self.schnet(z=z, pos=pos, batch=batch)
298
+ logits = self.atom_head(h)
299
+ dists_pred = self.coord_head(h)
300
+
301
+ if labels_z is not None and labels_dists is not None and labels_dists_mask is not None:
302
+ mask = labels_z != -100
303
+ if mask.sum() == 0:
304
+ return torch.tensor(0.0, device=z.device)
305
+
306
+ logits_masked = logits[mask]
307
+ dists_pred_masked = dists_pred[mask]
308
+ labels_z_masked = labels_z[mask]
309
+ labels_dists_masked = labels_dists[mask]
310
+ labels_dists_mask_mask = labels_dists_mask[mask]
311
+
312
+ if self.class_weights is not None:
313
+ loss_z = F.cross_entropy(logits_masked, labels_z_masked, weight=self.class_weights)
314
+ else:
315
+ loss_z = F.cross_entropy(logits_masked, labels_z_masked)
316
+
317
+ if labels_dists_mask_mask.any():
318
+ preds = dists_pred_masked[labels_dists_mask_mask]
319
+ trues = labels_dists_masked[labels_dists_mask_mask]
320
+ loss_pos = F.mse_loss(preds, trues, reduction="mean")
321
+ else:
322
+ loss_pos = torch.tensor(0.0, device=z.device)
323
+
324
+ if USE_LEARNED_WEIGHTING:
325
+ lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z
326
+ lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos
327
+ return 0.5 * (lz + lp)
328
+
329
+ return loss_z + loss_pos
330
+
331
+ return logits, dists_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  class ValLossCallback(TrainerCallback):
335
+ """Evaluation callback: computes metrics on val_loader, saves best, early-stops on val loss."""
336
+ def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None):
337
+ self.best_val_loss = float("inf")
338
+ self.epochs_no_improve = 0
339
+ self.patience = patience
340
+ self.best_epoch = None
341
+ self.trainer_ref = trainer_ref
342
+ self.best_model_dir = best_model_dir
343
+ self.val_loader = val_loader
344
+
345
+ def on_epoch_end(self, args, state, control, **kwargs):
346
+ epoch_num = int(state.epoch)
347
+ train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None)
348
+ print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===")
349
+ if train_loss is not None:
350
+ print(f"Train Loss: {train_loss:.4f}")
351
+
352
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
353
+ epoch_num = int(state.epoch) + 1
354
+ if self.trainer_ref is None:
355
+ print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
356
+ return
357
+
358
+ metric_val_loss = metrics.get("eval_loss") if metrics is not None else None
359
+ model_eval = self.trainer_ref.model
360
+ model_eval.eval()
361
+ device_local = next(model_eval.parameters()).device
362
+
363
+ preds_z_all, true_z_all = [], []
364
+ pred_dists_all, true_dists_all = [], []
365
+ total_loss, n_batches = 0.0, 0
366
+ logits_masked_list, labels_masked_list = [], []
367
+
368
+ with torch.no_grad():
369
+ for batch in self.val_loader:
370
+ z = batch["z"].to(device_local)
371
+ pos = batch["pos"].to(device_local)
372
+ batch_idx = batch["batch"].to(device_local)
373
+ labels_z = batch["labels_z"].to(device_local)
374
+ labels_dists = batch["labels_dists"].to(device_local)
375
+ labels_dists_mask = batch["labels_dists_mask"].to(device_local)
376
+
377
+ try:
378
+ loss = model_eval(z, pos, batch_idx, labels_z, labels_dists, labels_dists_mask)
379
+ except Exception:
380
+ loss = None
381
+
382
+ if isinstance(loss, torch.Tensor):
383
+ total_loss += loss.item()
384
+ n_batches += 1
385
+
386
+ logits, dists_pred = model_eval(z, pos, batch_idx)
387
+
388
+ mask = labels_z != -100
389
+ if mask.sum().item() == 0:
390
+ continue
391
+
392
+ logits_masked_list.append(logits[mask])
393
+ labels_masked_list.append(labels_z[mask])
394
+
395
+ pred_z = torch.argmax(logits[mask], dim=-1)
396
+ true_z = labels_z[mask]
397
+
398
+ pred_d = dists_pred[mask][labels_dists_mask[mask]]
399
+ true_d = labels_dists[mask][labels_dists_mask[mask]]
400
+
401
+ if pred_d.numel() > 0:
402
+ pred_dists_all.extend(pred_d.cpu().tolist())
403
+ true_dists_all.extend(true_d.cpu().tolist())
404
+
405
+ preds_z_all.extend(pred_z.cpu().tolist())
406
+ true_z_all.extend(true_z.cpu().tolist())
407
+
408
+ avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
409
+
410
+ accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
411
+ f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
412
+ rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
413
+ mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
414
+
415
+ if len(logits_masked_list) > 0:
416
+ all_logits_masked = torch.cat(logits_masked_list, dim=0)
417
+ all_labels_masked = torch.cat(labels_masked_list, dim=0)
418
+ cw = getattr(model_eval, "class_weights", None)
419
+ if cw is not None:
420
+ try:
421
+ loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw.to(device_local))
422
+ except Exception:
423
+ loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
424
+ else:
425
+ loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
426
+ try:
427
+ perplexity = float(torch.exp(loss_z_all).cpu().item())
428
+ except Exception:
429
+ perplexity = float(np.exp(float(loss_z_all.cpu().item())))
430
+ else:
431
+ perplexity = float("nan")
432
+
433
+ print(f"\n--- Evaluation after Epoch {epoch_num} ---")
434
+ print(f"Validation Loss: {avg_val_loss:.4f}")
435
+ print(f"Validation Accuracy: {accuracy:.4f}")
436
+ print(f"Validation F1 (weighted): {f1:.4f}")
437
+ print(f"Validation RMSE (distances): {rmse:.4f}")
438
+ print(f"Validation MAE (distances): {mae:.4f}")
439
+ print(f"Validation Perplexity (classification head): {perplexity:.4f}")
440
+
441
+ if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
442
+ self.best_val_loss = avg_val_loss
443
+ self.best_epoch = int(state.epoch)
444
+ self.epochs_no_improve = 0
445
+ os.makedirs(self.best_model_dir, exist_ok=True)
446
+ try:
447
+ torch.save(self.trainer_ref.model.state_dict(), os.path.join(self.best_model_dir, "pytorch_model.bin"))
448
+ print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(self.best_model_dir, 'pytorch_model.bin')}")
449
+ except Exception as e:
450
+ print(f"Failed to save best model at epoch {epoch_num}: {e}")
451
+ else:
452
+ self.epochs_no_improve += 1
453
+
454
+ if self.epochs_no_improve >= self.patience:
455
+ print(f"Early stopping after {self.patience} epochs with no improvement.")
456
+ control.should_training_stop = True
457
+
458
+
459
+ def train_and_eval(args: argparse.Namespace) -> None:
460
+ output_dir = args.output_dir
461
+ best_model_dir = os.path.join(output_dir, "best")
462
+ os.makedirs(output_dir, exist_ok=True)
463
+
464
+ atomic_lists, coord_lists = load_geometry_from_csv(args.csv_path, args.target_rows, args.chunksize)
465
+
466
+ train_idx, val_idx = train_test_split(list(range(len(atomic_lists))), test_size=0.2, random_state=42)
467
+ train_z = [torch.tensor(atomic_lists[i], dtype=torch.long) for i in train_idx]
468
+ train_pos = [torch.tensor(coord_lists[i], dtype=torch.float) for i in train_idx]
469
+ val_z = [torch.tensor(atomic_lists[i], dtype=torch.long) for i in val_idx]
470
+ val_pos = [torch.tensor(coord_lists[i], dtype=torch.float) for i in val_idx]
471
+
472
+ class_weights = compute_class_weights(train_z)
473
+
474
+ train_dataset = PolymerDataset(train_z, train_pos)
475
+ val_dataset = PolymerDataset(val_z, val_pos)
476
+
477
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch, num_workers=args.num_workers)
478
+ val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_batch, num_workers=args.num_workers)
479
+
480
+ model = MaskedSchNet(
481
+ hidden_channels=600,
482
+ num_interactions=SCHNET_NUM_INTERACTIONS,
483
+ num_gaussians=SCHNET_NUM_GAUSSIANS,
484
+ cutoff=SCHNET_CUTOFF,
485
+ max_atomic_z=MAX_ATOMIC_Z,
486
+ max_num_neighbors=SCHNET_MAX_NEIGHBORS,
487
+ class_weights=class_weights,
488
+ )
489
+
490
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
491
+ model.to(device)
492
+
493
+ training_args = TrainingArguments(
494
+ output_dir=output_dir,
495
+ overwrite_output_dir=True,
496
+ num_train_epochs=25,
497
+ per_device_train_batch_size=16,
498
+ per_device_eval_batch_size=8,
499
+ gradient_accumulation_steps=4,
500
+ eval_strategy="epoch",
501
+ logging_steps=500,
502
+ learning_rate=1e-4,
503
+ weight_decay=0.01,
504
+ fp16=torch.cuda.is_available(),
505
+ save_strategy="no",
506
+ disable_tqdm=False,
507
+ logging_first_step=True,
508
+ report_to=[],
509
+ dataloader_num_workers=args.num_workers,
510
+ )
511
+
512
+ callback = ValLossCallback(best_model_dir=best_model_dir, val_loader=val_loader, patience=10)
513
+ trainer = Trainer(
514
+ model=model,
515
+ args=training_args,
516
+ train_dataset=train_dataset,
517
+ eval_dataset=val_dataset,
518
+ data_collator=collate_batch,
519
+ callbacks=[callback],
520
+ )
521
+ callback.trainer_ref = trainer
522
+
523
+ start_time = time.time()
524
+ trainer.train()
525
+ total_time = time.time() - start_time
526
+
527
+ best_model_path = os.path.join(best_model_dir, "pytorch_model.bin")
528
+ if os.path.exists(best_model_path):
529
+ try:
530
+ model.load_state_dict(torch.load(best_model_path, map_location=device))
531
+ print(f"\nLoaded best model from {best_model_path}")
532
+ except Exception as e:
533
+ print(f"\nFailed to load best model from {best_model_path}: {e}")
534
+
535
+ # Final evaluation
536
+ model.eval()
537
+ preds_z_all, true_z_all = [], []
538
+ pred_dists_all, true_dists_all = [], []
539
+ logits_masked_list_final, labels_masked_list_final = [], []
540
+
541
+ with torch.no_grad():
542
+ for batch in val_loader:
543
+ z = batch["z"].to(device)
544
+ pos = batch["pos"].to(device)
545
+ batch_idx = batch["batch"].to(device)
546
+ labels_z = batch["labels_z"].to(device)
547
+ labels_dists = batch["labels_dists"].to(device)
548
+ labels_dists_mask = batch["labels_dists_mask"].to(device)
549
+
550
+ logits, dists_pred = model(z, pos, batch_idx)
551
+
552
+ mask = labels_z != -100
553
+ if mask.sum().item() == 0:
554
+ continue
555
+
556
+ logits_masked_list_final.append(logits[mask])
557
+ labels_masked_list_final.append(labels_z[mask])
558
+
559
+ pred_z = torch.argmax(logits[mask], dim=-1)
560
+ true_z = labels_z[mask]
561
+
562
+ pred_d = dists_pred[mask][labels_dists_mask[mask]]
563
+ true_d = labels_dists[mask][labels_dists_mask[mask]]
564
+
565
+ if pred_d.numel() > 0:
566
+ pred_dists_all.extend(pred_d.cpu().tolist())
567
+ true_dists_all.extend(true_d.cpu().tolist())
568
+
569
+ preds_z_all.extend(pred_z.cpu().tolist())
570
+ true_z_all.extend(true_z.cpu().tolist())
571
+
572
+ accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
573
+ f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
574
+ rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
575
+ mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
576
+
577
+ if len(logits_masked_list_final) > 0:
578
+ all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0)
579
+ all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0)
580
+ cw_final = getattr(model, "class_weights", None)
581
+ if cw_final is not None:
582
+ try:
583
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device))
584
+ except Exception:
585
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
586
+ else:
587
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
588
+ try:
589
+ perplexity_final = float(torch.exp(loss_z_final).cpu().item())
590
+ except Exception:
591
+ perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
592
+ else:
593
+ perplexity_final = float("nan")
594
+
595
+ best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
596
+ best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
597
+
598
+ print(f"\n=== Final Results (evaluated on best saved model) ===")
599
+ print(f"Total Training Time (s): {total_time:.2f}")
600
+ print(f"Best Epoch (1-based): {best_epoch_num}" if best_epoch_num is not None else "Best Epoch: (none saved)")
601
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
602
+ print(f"Validation Accuracy: {accuracy:.4f}")
603
+ print(f"Validation F1 (weighted): {f1:.4f}")
604
+ print(f"Validation RMSE (distances): {rmse:.4f}")
605
+ print(f"Validation MAE (distances): {mae:.4f}")
606
+ print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
607
+
608
+ total_params = sum(p.numel() for p in model.parameters())
609
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
610
+ non_trainable_params = total_params - trainable_params
611
+ print(f"Total Parameters: {total_params}")
612
+ print(f"Trainable Parameters: {trainable_params}")
613
+ print(f"Non-trainable Parameters: {non_trainable_params}")
614
+
615
+
616
+ def main():
617
+ args = parse_args()
618
+ train_and_eval(args)
619
+
620
+
621
+ if __name__ == "__main__":
622
+ main()