manpreet88 commited on
Commit
08a251c
·
1 Parent(s): 3a0d11d

Create SchNet.py

Browse files
Files changed (1) hide show
  1. PolyFusion/SchNet.py +737 -0
PolyFusion/SchNet.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import shutil
5
+
6
+ import sys
7
+ import csv
8
+
9
+ # Increase max CSV field size limit
10
+ csv.field_size_limit(sys.maxsize)
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import pandas as pd
17
+ from sklearn.model_selection import train_test_split
18
+ from torch.utils.data import Dataset, DataLoader
19
+
20
+ # PyG (SchNet)
21
+ from torch_geometric.nn import SchNet
22
+
23
+ from transformers import TrainingArguments, Trainer
24
+ from transformers.trainer_callback import TrainerCallback
25
+ from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
26
+ from torch_geometric.nn import radius_graph
27
+
28
+ # ---------------------------
29
+ # Configuration / Constants
30
+ # ---------------------------
31
+ P_MASK = 0.15
32
+ # NOTE: do NOT infer max atomic number from the dataset; set it manually as requested.
33
+ # "At" (Astatine) atomic number = 85 — change this value if your actual maximum differs.
34
+ MAX_ATOMIC_Z = 85
35
+
36
+ # Use a dedicated MASK token index (not 0). We'll place it after the max atomic number.
37
+ MASK_ATOM_ID = MAX_ATOMIC_Z + 1
38
+
39
+ COORD_NOISE_SIGMA = 0.5 # Å (start value, can tune)
40
+ USE_LEARNED_WEIGHTING = True
41
+
42
+ # SchNet hyperparams requested by user:
43
+ SCHNET_NUM_GAUSSIANS = 50
44
+ SCHNET_NUM_INTERACTIONS = 6
45
+ SCHNET_CUTOFF = 10.0 # Å
46
+ SCHNET_MAX_NEIGHBORS = 64
47
+
48
+ # Number of anchor atoms to predict distances to (invariant objective)
49
+ K_ANCHORS = 6
50
+
51
+ # Output directory
52
+ OUTPUT_DIR = "./schnet_output_5M"
53
+ BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best")
54
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
55
+
56
+ # ---------------------------
57
+ # 1. Load Data (chunked to avoid OOM)
58
+ # ---------------------------
59
+ csv_path = "Polymer_Foundational_Model/Datasets/polymer_structures_unified_processed.csv"
60
+ # target max rows to read (you previously used nrows=2000000)
61
+ TARGET_ROWS = 5000000
62
+ # choose a chunksize that fits your memory; adjust if needed
63
+ CHUNKSIZE = 50000
64
+
65
+ atomic_lists = []
66
+ coord_lists = []
67
+ rows_read = 0
68
+
69
+ # Read in chunks and parse geometry JSON for each chunk to avoid OOM
70
+ for chunk in pd.read_csv(csv_path, engine="python", chunksize=CHUNKSIZE):
71
+   # parse geometry column (JSON strings) in this chunk
72
+   geoms_chunk = chunk["geometry"].apply(json.loads)
73
+   for geom in geoms_chunk:
74
+   conf = geom["best_conformer"]
75
+   atomic_lists.append(conf["atomic_numbers"])
76
+   coord_lists.append(conf["coordinates"])
77
+
78
+   rows_read += len(chunk)
79
+   if rows_read >= TARGET_ROWS:
80
+   break
81
+
82
+ # Use manual maximum atomic number (do not compute from data)
83
+ max_atomic_z = MAX_ATOMIC_Z
84
+ print(f"Using manual max atomic number: {max_atomic_z} (MASK_ATOM_ID={MASK_ATOM_ID})")
85
+
86
+ # ---------------------------
87
+ # 2. Train/Val Split
88
+ # ---------------------------
89
+ train_idx, val_idx = train_test_split(list(range(len(atomic_lists))), test_size=0.2, random_state=42)
90
+ train_z = [torch.tensor(atomic_lists[i], dtype=torch.long) for i in train_idx]
91
+ train_pos = [torch.tensor(coord_lists[i], dtype=torch.float) for i in train_idx]
92
+ val_z = [torch.tensor(atomic_lists[i], dtype=torch.long) for i in val_idx]
93
+ val_pos = [torch.tensor(coord_lists[i], dtype=torch.float) for i in val_idx]
94
+
95
+ # ---------------------------
96
+ # Compute class weights (for weighted CE to mitigate element imbalance)
97
+ # ---------------------------
98
+ # We create weights for classes [0 .. max_atomic_z, MASK_ATOM_ID] where most labels will be in 1..max_atomic_z.
99
+ num_classes = MASK_ATOM_ID + 1 # (0 unused for typical atomic numbers; mask token at end)
100
+ counts = np.ones((num_classes,), dtype=np.float64) # init with 1 to avoid zero division
101
+
102
+ for z in train_z:
103
+   if z.numel() > 0:
104
+   vals = z.cpu().numpy().astype(int)
105
+   for v in vals:
106
+   if 0 <= v < num_classes:
107
+   counts[v] += 1.0
108
+
109
+ # Inverse frequency (normalized to mean 1.0)
110
+ freq = counts / counts.sum()
111
+ inv_freq = 1.0 / (freq + 1e-12)
112
+ class_weights = inv_freq / inv_freq.mean()
113
+ class_weights = torch.tensor(class_weights, dtype=torch.float)
114
+
115
+ # Set MASK token weight to 1.0 (it is not used as target in labels_z)
116
+ class_weights[MASK_ATOM_ID] = 1.0
117
+
118
+ # ---------------------------
119
+ # 3. Dataset and Collator
120
+ # ---------------------------
121
+ class PolymerDataset(Dataset):
122
+   def __init__(self, zs, pos_list):
123
+   self.zs = zs
124
+   self.pos_list = pos_list
125
+
126
+   def __len__(self):
127
+   return len(self.zs)
128
+
129
+   def __getitem__(self, idx):
130
+   return {"z": self.zs[idx], "pos": self.pos_list[idx]}
131
+
132
+ def collate_batch(batch):
133
+   """
134
+   Masking + create invariant distance targets:
135
+   - Select atoms for masking (P_MASK).
136
+   - For atomic numbers: 80/10/10 BERT-style corruption. Use MASK_ATOM_ID for mask token.
137
+   - For distances: for each masked atom, compute true distances to up to K_ANCHORS visible atoms
138
+   (nearest visible anchors). Produce labels_dists [N, K_ANCHORS] and anchors_exists mask [N, K_ANCHORS].
139
+   - Return labels_z (atomic targets, -100 for unselected) and labels_dists (+ anchors mask).
140
+   """
141
+ �� all_z = []
142
+   all_pos = []
143
+   all_labels_z = []
144
+   all_labels_dists = []
145
+   all_labels_dists_mask = []
146
+   batch_idx = []
147
+
148
+   for i, data in enumerate(batch):
149
+   z = data["z"] # [n_atoms]
150
+   pos = data["pos"] # [n_atoms,3]
151
+   n_atoms = z.size(0)
152
+   if n_atoms == 0:
153
+   continue
154
+
155
+   # 1) choose which atoms are selected for masking (15%)
156
+   is_selected = torch.rand(n_atoms) < P_MASK
157
+
158
+   # ensure not ALL atoms are selected (we need some visible anchors)
159
+   if is_selected.all():
160
+   # set one random atom to unselected
161
+   is_selected[torch.randint(0, n_atoms, (1,))] = False
162
+
163
+   # Prepare labels (only for selected atoms)
164
+   labels_z = torch.full((n_atoms,), -100, dtype=torch.long) # -100 ignored by CE
165
+   # labels_dists: per-atom K distances (0 padded) and mask indicating valid anchors
166
+   labels_dists = torch.zeros((n_atoms, K_ANCHORS), dtype=torch.float)
167
+   labels_dists_mask = torch.zeros((n_atoms, K_ANCHORS), dtype=torch.bool)
168
+
169
+   labels_z[is_selected] = z[is_selected] # true atomic numbers for selecteds
170
+
171
+   # 2) apply BERT-style corruption for atomic numbers
172
+   z_masked = z.clone()
173
+   if is_selected.any():
174
+   sel_idx = torch.nonzero(is_selected).squeeze(-1)
175
+   # sample random atomic numbers from 1..max_atomic_z (avoid 0 which is often unused)
176
+   rand_atomic = torch.randint(1, max_atomic_z + 1, (sel_idx.size(0),), dtype=torch.long)
177
+
178
+   probs = torch.rand(sel_idx.size(0))
179
+   mask_choice = probs < 0.8
180
+   rand_choice = (probs >= 0.8) & (probs < 0.9)
181
+   # keep_choice = probs >= 0.9
182
+
183
+   if mask_choice.any():
184
+   z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID
185
+   if rand_choice.any():
186
+   z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice]
187
+   # 10% keep => do nothing
188
+
189
+   # 3) coordinate corruption for selected atoms (we still corrupt positions for training robust embeddings)
190
+   pos_masked = pos.clone()
191
+   if is_selected.any():
192
+   sel_idx = torch.nonzero(is_selected).squeeze(-1)
193
+   probs_c = torch.rand(sel_idx.size(0))
194
+   noisy_choice = probs_c < 0.8
195
+   randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
196
+
197
+   if noisy_choice.any():
198
+   idx = sel_idx[noisy_choice]
199
+   noise = torch.randn((idx.size(0), 3)) * COORD_NOISE_SIGMA
200
+   pos_masked[idx] = pos_masked[idx] + noise
201
+
202
+   if randpos_choice.any():
203
+   idx = sel_idx[randpos_choice]
204
+   mins = pos.min(dim=0).values
205
+   maxs = pos.max(dim=0).values
206
+   randpos = (torch.rand((idx.size(0), 3)) * (maxs - mins)) + mins
207
+   pos_masked[idx] = randpos
208
+
209
+   # 4) Build invariant distance targets for masked atoms:
210
+   visible_idx = torch.nonzero(~is_selected).squeeze(-1)
211
+   # If for some reason no visible (shouldn't happen due to earlier guard), fall back to all atoms as visible
212
+   if visible_idx.numel() == 0:
213
+   visible_idx = torch.arange(n_atoms, dtype=torch.long)
214
+
215
+   # Precompute pairwise distances
216
+   # pos: [n_atoms,3], visible_pos: [V,3]
217
+   visible_pos = pos[visible_idx] # true positions for anchors
218
+   for a in torch.nonzero(is_selected).squeeze(-1).tolist():
219
+   # distances from atom a to all visible anchors
220
+   dists = torch.sqrt(((pos[a].unsqueeze(0) - visible_pos) ** 2).sum(dim=1) + 1e-12)
221
+   # find nearest anchors (ascending)
222
+   if dists.numel() > 0:
223
+   k = min(K_ANCHORS, dists.numel())
224
+   nearest_vals, nearest_idx = torch.topk(dists, k, largest=False)
225
+   labels_dists[a, :k] = nearest_vals
226
+   labels_dists_mask[a, :k] = True
227
+   # else leave zeros and mask False
228
+
229
+   all_z.append(z_masked)
230
+   all_pos.append(pos_masked)
231
+   all_labels_z.append(labels_z)
232
+   all_labels_dists.append(labels_dists)
233
+   all_labels_dists_mask.append(labels_dists_mask)
234
+   batch_idx.append(torch.full((n_atoms,), i, dtype=torch.long))
235
+
236
+   if len(all_z) == 0:
237
+   return {"z": torch.tensor([], dtype=torch.long),
238
+   "pos": torch.tensor([], dtype=torch.float).reshape(0, 3),
239
+   "batch": torch.tensor([], dtype=torch.long),
240
+   "labels_z": torch.tensor([], dtype=torch.long),
241
+   "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS),
242
+   "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS)}
243
+
244
+   z_batch = torch.cat(all_z, dim=0)
245
+   pos_batch = torch.cat(all_pos, dim=0)
246
+   labels_z_batch = torch.cat(all_labels_z, dim=0)
247
+   labels_dists_batch = torch.cat(all_labels_dists, dim=0)
248
+   labels_dists_mask_batch = torch.cat(all_labels_dists_mask, dim=0)
249
+   batch_batch = torch.cat(batch_idx, dim=0)
250
+
251
+   return {"z": z_batch, "pos": pos_batch, "batch": batch_batch,
252
+   "labels_z": labels_z_batch,
253
+   "labels_dists": labels_dists_batch,
254
+   "labels_dists_mask": labels_dists_mask_batch}
255
+
256
+ train_dataset = PolymerDataset(train_z, train_pos)
257
+ val_dataset = PolymerDataset(val_z, val_pos)
258
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch)
259
+ val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_batch)
260
+
261
+ from torch_geometric.nn import SchNet as BaseSchNet
262
+ from torch_geometric.nn import radius_graph
263
+
264
+ class NodeSchNet(nn.Module):
265
+   """Custom SchNet that returns node embeddings instead of graph-level predictions"""
266
+
267
+   def __init__(self, hidden_channels=128, num_filters=128, num_interactions=6,
268
+   num_gaussians=50, cutoff=10.0, max_num_neighbors=32, readout='add'):
269
+   super().__init__()
270
+
271
+   self.hidden_channels = hidden_channels
272
+   self.cutoff = cutoff
273
+   self.max_num_neighbors = max_num_neighbors
274
+
275
+   # Initialize the base SchNet but we'll only use parts of it
276
+   self.base_schnet = BaseSchNet(
277
+   hidden_channels=hidden_channels,
278
+   num_filters=num_filters,
279
+   num_interactions=num_interactions,
280
+   num_gaussians=num_gaussians,
281
+   cutoff=cutoff,
282
+   max_num_neighbors=max_num_neighbors,
283
+   readout=readout
284
+   )
285
+
286
+   def forward(self, z, pos, batch=None):
287
+   """Return node embeddings, not graph-level predictions"""
288
+   if batch is None:
289
+   batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
290
+
291
+   # Use the embedding and interaction layers from base SchNet
292
+   h = self.base_schnet.embedding(z)
293
+
294
+   # Build edge connectivity
295
+   edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
296
+   max_num_neighbors=self.max_num_neighbors)
297
+
298
+   # Compute edge distances and expand with Gaussians
299
+   row, col = edge_index
300
+   edge_weight = (pos[row] - pos[col]).norm(dim=-1)
301
+   edge_attr = self.base_schnet.distance_expansion(edge_weight)
302
+
303
+   # Apply interaction blocks (message passing)
304
+   for interaction in self.base_schnet.interactions:
305
+   h = h + interaction(h, edge_index, edge_weight, edge_attr)
306
+
307
+   # STOP HERE - return node embeddings, don't do readout/final layers
308
+   return h # Shape: [num_nodes, hidden_channels]
309
+
310
+ # ---------------------------
311
+ # 4. Model Definition (SchNet + two heads + learned weighting)
312
+ # ---------------------------
313
+ class MaskedSchNet(nn.Module):
314
+   def __init__(self,
315
+   hidden_channels=600,
316
+   num_interactions=SCHNET_NUM_INTERACTIONS,
317
+   num_gaussians=SCHNET_NUM_GAUSSIANS,
318
+   cutoff=SCHNET_CUTOFF,
319
+   max_atomic_z=max_atomic_z,
320
+   max_num_neighbors=SCHNET_MAX_NEIGHBORS,
321
+   class_weights=None):
322
+   super().__init__()
323
+   self.hidden_channels = hidden_channels
324
+   self.cutoff = cutoff
325
+   self.max_num_neighbors = max_num_neighbors
326
+   self.max_atomic_z = max_atomic_z
327
+
328
+   # SchNet model from PyG
329
+   self.schnet = NodeSchNet(
330
+   hidden_channels=hidden_channels,
331
+   num_filters=hidden_channels,
332
+   num_interactions=num_interactions,
333
+   num_gaussians=num_gaussians,
334
+   cutoff=cutoff,
335
+   max_num_neighbors=max_num_neighbors
336
+   )
337
+
338
+   # Classification head for atomic number (classes 0..max_atomic_z and MASK token)
339
+   num_classes_local = MASK_ATOM_ID + 1
340
+   self.atom_head = nn.Linear(hidden_channels, num_classes_local)
341
+
342
+   # Distance-prediction head (predict K_ANCHORS scalar distances per node) -> invariant target
343
+   self.coord_head = nn.Linear(hidden_channels, K_ANCHORS)
344
+
345
+   # Learned uncertainty weighting (log-variances) if enabled
346
+   if USE_LEARNED_WEIGHTING:
347
+   self.log_var_z = nn.Parameter(torch.zeros(1))
348
+   self.log_var_pos = nn.Parameter(torch.zeros(1))
349
+   else:
350
+   self.log_var_z = None
351
+   self.log_var_pos = None
352
+
353
+   # Class weights for cross entropy
354
+   if class_weights is not None:
355
+   # register as buffer so it moves with .to(device)
356
+   self.register_buffer("class_weights", class_weights)
357
+   else:
358
+   self.class_weights = None
359
+
360
+   def forward(self, z, pos, batch, labels_z=None, labels_dists=None, labels_dists_mask=None):
361
+   """
362
+   z: [N] long (atomic numbers or MASK_ATOM_ID)
363
+   pos: [N,3] float (possibly corrupted)
364
+   batch: [N] long (graph indices)
365
+   labels_z: [N] long (-100 for unselected)
366
+   labels_dists: [N, K_ANCHORS] float (0 padded)
367
+   labels_dists_mask: [N, K_ANCHORS] bool (True where anchor exists)
368
+   """
369
+   # Let SchNet produce node embeddings. SchNet builds its own neighbor graph internally.
370
+   # SchNet's forward often accepts (z, pos, batch)
371
+   try:
372
+   h = self.schnet(z=z, pos=pos, batch=batch)
373
+   except TypeError:
374
+   # fallback if different signature
375
+   h = self.schnet(z=z, pos=pos)
376
+
377
+   # Node embeddings
378
+   logits = self.atom_head(h) # [N, num_classes]
379
+   dists_pred = self.coord_head(h) # [N, K_ANCHORS]
380
+
381
+   # If labels provided -> compute loss (aggregated only over masked atoms)
382
+   if labels_z is not None and labels_dists is not None and labels_dists_mask is not None:
383
+   mask = labels_z != -100 # which atoms were selected for supervision
384
+   if mask.sum() == 0:
385
+   # Nothing masked in this batch: return zero loss (avoid NaNs)
386
+   return torch.tensor(0.0, device=z.device)
387
+
388
+   logits_masked = logits[mask] # [M, num_classes]
389
+   dists_pred_masked = dists_pred[mask] # [M, K_ANCHORS]
390
+   labels_z_masked = labels_z[mask] # [M]
391
+   labels_dists_masked = labels_dists[mask] # [M, K_ANCHORS]
392
+   labels_dists_mask_mask = labels_dists_mask[mask] # [M, K_ANCHORS] bool
393
+
394
+   # classification loss (weighted cross entropy)
395
+   if self.class_weights is not None:
396
+   loss_z = F.cross_entropy(logits_masked, labels_z_masked, weight=self.class_weights)
397
+   else:
398
+   loss_z = F.cross_entropy(logits_masked, labels_z_masked)
399
+
400
+   # coordinate/distance loss: only over existing anchor distances
401
+   # flatten valid entries
402
+   if labels_dists_mask_mask.any():
403
+   preds = dists_pred_masked[labels_dists_mask_mask]
404
+   trues = labels_dists_masked[labels_dists_mask_mask]
405
+   loss_pos = F.mse_loss(preds, trues, reduction="mean")
406
+   else:
407
+   # no anchor distances present (shouldn't happen), set zero
408
+   loss_pos = torch.tensor(0.0, device=z.device)
409
+
410
+   if USE_LEARNED_WEIGHTING:
411
+   lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z
412
+   lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos
413
+   loss = 0.5 * (lz + lp)
414
+   else:
415
+   alpha = 1.0
416
+   loss = loss_z + alpha * loss_pos
417
+
418
+   return loss
419
+
420
+   # Inference: return logits and predicted distances
421
+   return logits, dists_pred
422
+
423
+ # instantiate model with requested SchNet params and computed class weights
424
+ model = MaskedSchNet(hidden_channels=600,
425
+   num_interactions=SCHNET_NUM_INTERACTIONS,
426
+   num_gaussians=SCHNET_NUM_GAUSSIANS,
427
+   cutoff=SCHNET_CUTOFF,
428
+   max_atomic_z=max_atomic_z,
429
+   max_num_neighbors=SCHNET_MAX_NEIGHBORS,
430
+   class_weights=class_weights)
431
+
432
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
433
+ model.to(device)
434
+
435
+ # ---------------------------
436
+ # 5. Training Setup (Hugging Face Trainer)
437
+ # ---------------------------
438
+ training_args = TrainingArguments(
439
+   output_dir=OUTPUT_DIR,
440
+   overwrite_output_dir=True,
441
+   num_train_epochs=25,
442
+   per_device_train_batch_size=16,
443
+   per_device_eval_batch_size=8,
444
+   gradient_accumulation_steps=4,
445
+   eval_strategy="epoch",
446
+   logging_steps=500,
447
+   learning_rate=1e-4,
448
+   weight_decay=0.01,
449
+   fp16=torch.cuda.is_available(),
450
+   save_strategy="no", # we will let callback save best model
451
+   disable_tqdm=False,
452
+   logging_first_step=True,
453
+   report_to=[],
454
+   dataloader_num_workers=4,
455
+ )
456
+
457
+ class ValLossCallback(TrainerCallback):
458
+   def __init__(self, trainer_ref=None):
459
+   self.best_val_loss = float("inf")
460
+   self.epochs_no_improve = 0
461
+   self.patience = 10
462
+   self.best_epoch = None
463
+   self.trainer_ref = trainer_ref
464
+
465
+   def on_epoch_end(self, args, state, control, **kwargs):
466
+   # Print epoch starting from 1 instead of 0
467
+   epoch_num = int(state.epoch)
468
+   train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None)
469
+   print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===")
470
+   if train_loss is not None:
471
+   print(f"Train Loss: {train_loss:.4f}")
472
+
473
+   def on_evaluate(self, args, state, control, metrics=None, **kwargs):
474
+   """
475
+   When trainer runs evaluation, compute full validation metrics here (accuracy, f1, rmse, mae, perplexity)
476
+   using the provided val_loader and the trainer's model. Save the model when val_loss improves.
477
+   NOTE: Validation loss printed and used for the best-model decision is taken from the `metrics`
478
+   object provided by the Trainer when available, so it matches the Trainer's evaluation output.
479
+   """
480
+   # Compute epoch number for printing (1-based)
481
+   epoch_num = int(state.epoch) + 1
482
+
483
+   # If we don't have a trainer reference or val_loader, fallback to printing whatever metrics provided
484
+   if self.trainer_ref is None:
485
+   print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
486
+   return
487
+
488
+   # If trainer provided an eval_loss in metrics, prefer that value for printing and best-model decision
489
+   metric_val_loss = None
490
+   if metrics is not None:
491
+   metric_val_loss = metrics.get("eval_loss")
492
+
493
+   # Evaluate over val_loader to compute other metrics (accuracy, f1, rmse, mae, perplexity)
494
+   model_eval = self.trainer_ref.model
495
+   model_eval.eval()
496
+
497
+   device_local = next(model_eval.parameters()).device if any(p.numel() > 0 for p in model_eval.parameters()) else torch.device("cpu")
498
+
499
+   preds_z_all = []
500
+   true_z_all = []
501
+   pred_dists_all = []
502
+   true_dists_all = []
503
+   total_loss = 0.0
504
+   n_batches = 0
505
+
506
+   logits_masked_list = []
507
+   labels_masked_list = []
508
+
509
+   with torch.no_grad():
510
+   for batch in val_loader:
511
+   z = batch["z"].to(device_local)
512
+   pos = batch["pos"].to(device_local)
513
+   batch_idx = batch["batch"].to(device_local)
514
+   labels_z = batch["labels_z"].to(device_local)
515
+   labels_dists = batch["labels_dists"].to(device_local)
516
+   labels_dists_mask = batch["labels_dists_mask"].to(device_local)
517
+
518
+   # compute loss using labels (model returns loss when labels provided)
519
+   try:
520
+   loss = model_eval(z, pos, batch_idx, labels_z, labels_dists, labels_dists_mask)
521
+   except Exception as e:
522
+   # If model.forward signature is different, skip loss accumulation but still compute preds
523
+   loss = None
524
+
525
+   if isinstance(loss, torch.Tensor):
526
+   total_loss += loss.item()
527
+   n_batches += 1
528
+
529
+   # inference to get logits and distance preds
530
+   logits, dists_pred = model_eval(z, pos, batch_idx)
531
+
532
+   mask = labels_z != -100
533
+   if mask.sum().item() == 0:
534
+   continue
535
+
536
+   # collect masked logits/labels for perplexity
537
+   logits_masked_list.append(logits[mask])
538
+   labels_masked_list.append(labels_z[mask])
539
+
540
+   pred_z = torch.argmax(logits[mask], dim=-1)
541
+   true_z = labels_z[mask]
542
+
543
+   # flatten valid distances across anchors
544
+   pred_d = dists_pred[mask][labels_dists_mask[mask]]
545
+   true_d = labels_dists[mask][labels_dists_mask[mask]]
546
+
547
+   if pred_d.numel() > 0:
548
+   pred_dists_all.extend(pred_d.cpu().tolist())
549
+   true_dists_all.extend(true_d.cpu().tolist())
550
+
551
+   preds_z_all.extend(pred_z.cpu().tolist())
552
+   true_z_all.extend(true_z.cpu().tolist())
553
+
554
+   # If the trainer provided eval_loss, use it; otherwise fall back to the computed average loss
555
+   avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
556
+
557
+   # Compute metrics (classification + distance regression)
558
+   accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
559
+   f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
560
+   rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
561
+   mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
562
+
563
+   # Compute classification perplexity from masked-token cross-entropy, if available
564
+   if len(logits_masked_list) > 0:
565
+   all_logits_masked = torch.cat(logits_masked_list, dim=0)
566
+   all_labels_masked = torch.cat(labels_masked_list, dim=0)
567
+   # Use model's class_weights if present
568
+   cw = getattr(model_eval, "class_weights", None)
569
+   if cw is not None:
570
+   cw_device = cw.to(device_local)
571
+   try:
572
+   loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw_device)
573
+   except Exception:
574
+   # fallback without weight
575
+   loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
576
+   else:
577
+   loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
578
+   try:
579
+   perplexity = float(torch.exp(loss_z_all).cpu().item())
580
+   except Exception:
581
+   perplexity = float(np.exp(float(loss_z_all.cpu().item())))
582
+   else:
583
+   perplexity = float("nan")
584
+
585
+   print(f"\n--- Evaluation after Epoch {epoch_num} ---")
586
+   # Print validation loss that matches Trainer's evaluation when available
587
+   print(f"Validation Loss: {avg_val_loss:.4f}")
588
+   print(f"Validation Accuracy: {accuracy:.4f}")
589
+   print(f"Validation F1 (weighted): {f1:.4f}")
590
+   print(f"Validation RMSE (distances): {rmse:.4f}")
591
+   print(f"Validation MAE (distances): {mae:.4f}")
592
+   print(f"Validation Perplexity (classification head): {perplexity:.4f}")
593
+
594
+   # Check for improvement (use a small tolerance)
595
+   if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
596
+   self.best_val_loss = avg_val_loss
597
+   self.best_epoch = int(state.epoch) # store 0-based internally
598
+   self.epochs_no_improve = 0
599
+   # Save best model state_dict
600
+   os.makedirs(BEST_MODEL_DIR, exist_ok=True)
601
+   try:
602
+   # Prefer trainer's model (which may be wrapped)
603
+   torch.save(self.trainer_ref.model.state_dict(), os.path.join(BEST_MODEL_DIR, "pytorch_model.bin"))
604
+   print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(BEST_MODEL_DIR, 'pytorch_model.bin')}")
605
+   except Exception as e:
606
+   print(f"Failed to save best model at epoch {epoch_num}: {e}")
607
+   else:
608
+   self.epochs_no_improve += 1
609
+
610
+   if self.epochs_no_improve >= self.patience:
611
+   print(f"Early stopping after {self.patience} epochs with no improvement.")
612
+   control.should_training_stop = True
613
+
614
+ # Create callback and Trainer
615
+ callback = ValLossCallback()
616
+ trainer = Trainer(
617
+   model=model,
618
+   args=training_args,
619
+   train_dataset=train_dataset,
620
+   eval_dataset=val_dataset,
621
+   data_collator=collate_batch,
622
+   callbacks=[callback]
623
+ )
624
+ # attach trainer_ref so callback can save model
625
+ callback.trainer_ref = trainer
626
+
627
+ # ---------------------------
628
+ # 6. Run training
629
+ # ---------------------------
630
+ start_time = time.time()
631
+ trainer.train()
632
+ total_time = time.time() - start_time
633
+
634
+ # ---------------------------
635
+ # 7. Final Evaluation (metrics computed on masked atoms in validation set)
636
+ # -> NOTE: per request, we will evaluate the best-saved model (by least val loss)
637
+ # ---------------------------
638
+ # If a best model was saved by the callback, load it
639
+ best_model_path = os.path.join(BEST_MODEL_DIR, "pytorch_model.bin")
640
+ if os.path.exists(best_model_path):
641
+   try:
642
+   model.load_state_dict(torch.load(best_model_path, map_location=device))
643
+   print(f"\nLoaded best model from {best_model_path}")
644
+   except Exception as e:
645
+   print(f"\nFailed to load best model from {best_model_path}: {e}")
646
+
647
+ model.eval()
648
+ preds_z_all = []
649
+ true_z_all = []
650
+ pred_dists_all = []
651
+ true_dists_all = []
652
+
653
+ # For computing perplexity in final eval
654
+ logits_masked_list_final = []
655
+ labels_masked_list_final = []
656
+
657
+ with torch.no_grad():
658
+   for batch in val_loader:
659
+   z = batch["z"].to(device)
660
+   pos = batch["pos"].to(device)
661
+   batch_idx = batch["batch"].to(device)
662
+   labels_z = batch["labels_z"].to(device)
663
+   labels_dists = batch["labels_dists"].to(device)
664
+   labels_dists_mask = batch["labels_dists_mask"].to(device)
665
+
666
+   logits, dists_pred = model(z, pos, batch_idx) # inference mode returns (logits, dists_pred)
667
+
668
+   mask = labels_z != -100
669
+   if mask.sum().item() == 0:
670
+   continue
671
+
672
+   # collect masked logits/labels for perplexity
673
+   logits_masked_list_final.append(logits[mask])
674
+   labels_masked_list_final.append(labels_z[mask])
675
+
676
+   pred_z = torch.argmax(logits[mask], dim=-1)
677
+   true_z = labels_z[mask]
678
+
679
+   # flatten valid distances across anchors
680
+   pred_d = dists_pred[mask][labels_dists_mask[mask]]
681
+   true_d = labels_dists[mask][labels_dists_mask[mask]]
682
+
683
+   if pred_d.numel() > 0:
684
+   pred_dists_all.extend(pred_d.cpu().tolist())
685
+   true_dists_all.extend(true_d.cpu().tolist())
686
+
687
+   preds_z_all.extend(pred_z.cpu().tolist())
688
+   true_z_all.extend(true_z.cpu().tolist())
689
+
690
+ # Compute metrics (classification + distance regression)
691
+ accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
692
+ f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
693
+ rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
694
+ mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
695
+
696
+ # Compute perplexity from collected masked logits/labels
697
+ if len(logits_masked_list_final) > 0:
698
+   all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0)
699
+   all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0)
700
+   cw_final = getattr(model, "class_weights", None)
701
+   if cw_final is not None:
702
+   try:
703
+   loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device))
704
+   except Exception:
705
+   loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
706
+   else:
707
+   loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
708
+   try:
709
+   perplexity_final = float(torch.exp(loss_z_final).cpu().item())
710
+   except Exception:
711
+   perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
712
+ else:
713
+   perplexity_final = float("nan")
714
+
715
+ best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
716
+ best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
717
+
718
+ print(f"\n=== Final Results (evaluated on best saved model) ===")
719
+ print(f"Total Training Time (s): {total_time:.2f}")
720
+ if best_epoch_num is not None:
721
+   print(f"Best Epoch (1-based): {best_epoch_num}")
722
+ else:
723
+   print("Best Epoch: (none saved)")
724
+
725
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
726
+ print(f"Validation Accuracy: {accuracy:.4f}")
727
+ print(f"Validation F1 (weighted): {f1:.4f}")
728
+ print(f"Validation RMSE (distances): {rmse:.4f}")
729
+ print(f"Validation MAE (distances): {mae:.4f}")
730
+ print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
731
+
732
+ total_params = sum(p.numel() for p in model.parameters())
733
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
734
+ non_trainable_params = total_params - trainable_params
735
+ print(f"Total Parameters: {total_params}")
736
+ print(f"Trainable Parameters: {trainable_params}")
737
+ print(f"Non-trainable Parameters: {non_trainable_params}")