manpreet88 commited on
Commit
0111d29
·
1 Parent(s): 5dee3bf

Update GINE.py

Browse files
Files changed (1) hide show
  1. PolyFusion/GINE.py +413 -514
PolyFusion/GINE.py CHANGED
@@ -1,11 +1,14 @@
 
 
 
1
 
2
  import os
3
  import json
4
  import time
5
- import shutil
6
-
7
  import sys
8
  import csv
 
 
9
 
10
  # Increase max CSV field size limit
11
  csv.field_size_limit(sys.maxsize)
@@ -22,35 +25,38 @@ from transformers import TrainingArguments, Trainer
22
  from transformers.trainer_callback import TrainerCallback
23
  from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
24
 
25
- # PyG
26
  from torch_geometric.nn import GINEConv
27
 
28
  # ---------------------------
29
  # Configuration / Constants
30
  # ---------------------------
31
  P_MASK = 0.15
32
- # Manual max atomic number (user requested)
33
  MAX_ATOMIC_Z = 85
34
- # Mask token id
35
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
36
 
37
  USE_LEARNED_WEIGHTING = True
38
 
39
- # GINE / embedding hyperparams requested
40
- NODE_EMB_DIM = 300 # node embedding dimension
41
- EDGE_EMB_DIM = 300 # edge embedding dimension
42
  NUM_GNN_LAYERS = 5
43
 
44
- # Other hyperparams
45
  K_ANCHORS = 6
46
- OUTPUT_DIR = "./gin_output_5M"
47
- BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best")
48
- os.makedirs(OUTPUT_DIR, exist_ok=True)
49
 
50
- # Data reading settings
51
- csv_path = "polymer_structures_unified_processed.csv"
52
- TARGET_ROWS = 5000000
53
- CHUNKSIZE = 50000
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # ---------------------------
56
  # Helper functions
@@ -59,73 +65,68 @@ CHUNKSIZE = 50000
59
  def safe_get(d: dict, key: str, default=None):
60
  return d[key] if (isinstance(d, dict) and key in d) else default
61
 
62
- def build_adj_list(edge_index, num_nodes):
 
 
63
  adj = [[] for _ in range(num_nodes)]
64
  if edge_index is None or edge_index.numel() == 0:
65
  return adj
66
- # edge_index shape [2, E]
67
  src = edge_index[0].tolist()
68
  dst = edge_index[1].tolist()
69
  for u, v in zip(src, dst):
70
- # ensure indices are within range
71
  if 0 <= u < num_nodes and 0 <= v < num_nodes:
72
  adj[u].append(v)
73
  return adj
74
 
75
- def shortest_path_lengths_hops(edge_index, num_nodes):
 
76
  """
77
- Compute all-pairs shortest path lengths in hops (BFS per node).
78
- Returns an (num_nodes, num_nodes) numpy array with int distances; unreachable -> large number (e.g., num_nodes+1)
79
  """
80
  adj = build_adj_list(edge_index, num_nodes)
81
  INF = num_nodes + 1
82
  dist_mat = np.full((num_nodes, num_nodes), INF, dtype=np.int32)
83
  for s in range(num_nodes):
84
- # BFS
85
  q = [s]
86
  dist_mat[s, s] = 0
87
  head = 0
88
  while head < len(q):
89
- u = q[head]; head += 1
 
90
  for v in adj[u]:
91
  if dist_mat[s, v] == INF:
92
  dist_mat[s, v] = dist_mat[s, u] + 1
93
  q.append(v)
94
  return dist_mat
95
 
96
- def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor, target_dim: int = 3):
 
97
  """
98
- Ensure edge_attr has shape [E_index, D]. Handles common mismatches:
99
- - If edge_attr is empty/None -> returns zeros of shape [E_index, target_dim].
100
- - If edge_attr.size(0) == edge_index.size(1) -> return as-is.
101
- - If edge_attr.size(0) * 2 == edge_index.size(1) -> duplicate (common when features only for undirected edges).
102
- - Otherwise repeat/truncate edge_attr to match E_index (safe fallback).
103
  """
104
  E_idx = edge_index.size(1) if (edge_index is not None and edge_index.numel() > 0) else 0
105
  if E_idx == 0:
106
  return torch.zeros((0, target_dim), dtype=torch.float)
107
  if edge_attr is None or edge_attr.numel() == 0:
108
  return torch.zeros((E_idx, target_dim), dtype=torch.float)
 
109
  E_attr = edge_attr.size(0)
110
  if E_attr == E_idx:
111
- # already matches
112
  if edge_attr.size(1) != target_dim:
113
- # pad/truncate feature dimension to target_dim
114
  D = edge_attr.size(1)
115
  if D < target_dim:
116
  pad = torch.zeros((E_attr, target_dim - D), dtype=torch.float, device=edge_attr.device)
117
  return torch.cat([edge_attr, pad], dim=1)
118
- else:
119
- return edge_attr[:, :target_dim]
120
  return edge_attr
121
- # common case: features provided for undirected edges while edge_index contains both directions
122
  if E_attr * 2 == E_idx:
123
  try:
124
  return torch.cat([edge_attr, edge_attr], dim=0)
125
  except Exception:
126
- # fallback to repeat below
127
  pass
128
- # fallback: repeat/truncate edge_attr to fit E_idx
129
  reps = (E_idx + E_attr - 1) // E_attr
130
  edge_rep = edge_attr.repeat(reps, 1)[:E_idx]
131
  if edge_rep.size(1) != target_dim:
@@ -137,201 +138,148 @@ def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor,
137
  edge_rep = edge_rep[:, :target_dim]
138
  return edge_rep
139
 
140
- # ---------------------------
141
- # 1. Load Data from `graph` column (chunked)
142
- # ---------------------------
143
- node_atomic_lists = []
144
- node_chirality_lists = []
145
- node_charge_lists = []
146
- edge_index_lists = []
147
- edge_attr_lists = []
148
- num_nodes_list = []
149
- rows_read = 0
150
-
151
- for chunk in pd.read_csv(csv_path, engine="python", chunksize=CHUNKSIZE):
152
- for idx, row in chunk.iterrows():
153
- # Prefer 'graph' column JSON (string) per user request
154
- graph_field = None
155
- if "graph" in row and not pd.isna(row["graph"]):
156
- try:
157
- graph_field = json.loads(row["graph"])
158
- except Exception:
159
- # If already parsed or other format
160
  try:
161
- graph_field = row["graph"]
162
  except Exception:
163
  graph_field = None
164
- else:
165
- # If no graph column, skip (user requested to use graph column)
166
- continue
167
-
168
- if graph_field is None:
169
- continue
170
-
171
- # NODE FEATURES
172
- node_features = safe_get(graph_field, "node_features", None)
173
- if not node_features:
174
- # skip graphs without node_features
175
- continue
176
-
177
- atomic_nums = []
178
- chirality_vals = []
179
- formal_charges = []
180
-
181
- for nf in node_features:
182
- # atomic number
183
- an = safe_get(nf, "atomic_num", None)
184
- if an is None:
185
- # try alternate keys
186
- an = safe_get(nf, "atomic_number", 0)
187
- # chirality (use 0 default)
188
- ch = safe_get(nf, "chirality", 0)
189
- # formal charge (use 0 default)
190
- fc = safe_get(nf, "formal_charge", 0)
191
- atomic_nums.append(int(an))
192
- chirality_vals.append(float(ch))
193
- formal_charges.append(float(fc))
194
-
195
- n_nodes = len(atomic_nums)
196
-
197
- # EDGE INDICES & FEATURES
198
- edge_indices_raw = safe_get(graph_field, "edge_indices", None)
199
- edge_features_raw = safe_get(graph_field, "edge_features", None)
200
-
201
- if edge_indices_raw is None:
202
- # try adjacency_matrix to infer edges
203
- adj_mat = safe_get(graph_field, "adjacency_matrix", None)
204
- if adj_mat:
205
- # adjacency_matrix is list of lists
206
- srcs = []
207
- dsts = []
208
- for i, row_adj in enumerate(adj_mat):
209
- for j, val in enumerate(row_adj):
210
- if val:
211
- srcs.append(i)
212
- dsts.append(j)
213
- edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
214
- # no edge features available -> create zeros matching edges
215
- E = edge_index.size(1)
216
- edge_attr = torch.zeros((E, 3), dtype=torch.float)
217
  else:
218
- # no edges found -> skip this graph (GINE requires edges)
219
  continue
220
- else:
221
- # edge_indices_raw expected like [[u,v], [u2,v2], ...] or [[u1,u2,...],[v1,v2,...]]
222
- if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
223
- # Could be list of pairs or list of lists
224
- if all(len(pair) == 2 and isinstance(pair[0], int) for pair in edge_indices_raw):
225
- # list of pairs
226
- srcs = [int(p[0]) for p in edge_indices_raw]
227
- dsts = [int(p[1]) for p in edge_indices_raw]
228
- elif isinstance(edge_indices_raw[0][0], int):
229
- # Possibly already in [[srcs],[dsts]] format
230
- try:
231
- srcs = [int(x) for x in edge_indices_raw[0]]
232
- dsts = [int(x) for x in edge_indices_raw[1]]
233
- except Exception:
234
- # fallback
235
- srcs = []
236
- dsts = []
237
- else:
238
- srcs = []
239
- dsts = []
240
- else:
241
- srcs = []
242
- dsts = []
243
 
244
- if len(srcs) == 0:
245
- # fallback: skip graph
246
  continue
247
 
248
- edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
249
-
250
- # Build edge_attr matrix with 3 features: bond_type, stereo, is_conjugated (as float)
251
- if edge_features_raw and isinstance(edge_features_raw, list):
252
- bond_types = []
253
- stereos = []
254
- is_conjs = []
255
- for ef in edge_features_raw:
256
- bt = safe_get(ef, "bond_type", 0)
257
- st = safe_get(ef, "stereo", 0)
258
- ic = safe_get(ef, "is_conjugated", False)
259
- bond_types.append(float(bt))
260
- stereos.append(float(st))
261
- is_conjs.append(float(1.0 if ic else 0.0))
262
- edge_attr = torch.tensor(np.stack([bond_types, stereos, is_conjs], axis=1), dtype=torch.float)
263
- else:
264
- # no edge features -> zeros
265
- E = edge_index.size(1)
266
- edge_attr = torch.zeros((E, 3), dtype=torch.float)
267
-
268
- # Ensure edge_attr length matches edge_index (fix common mismatches)
269
- edge_attr = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3)
270
 
271
- # NOTE: we explicitly DO NOT parse or use coordinates (geometry) anywhere.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # Save lists
274
- node_atomic_lists.append(torch.tensor(atomic_nums, dtype=torch.long))
275
- node_chirality_lists.append(torch.tensor(chirality_vals, dtype=torch.float))
276
- node_charge_lists.append(torch.tensor(formal_charges, dtype=torch.float))
277
- edge_index_lists.append(edge_index)
278
- edge_attr_lists.append(edge_attr)
279
- num_nodes_list.append(n_nodes)
280
 
281
- rows_read += 1
282
- if rows_read >= TARGET_ROWS:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  break
284
- if rows_read >= TARGET_ROWS:
285
- break
286
-
287
- if len(node_atomic_lists) == 0:
288
- raise RuntimeError("No graphs were parsed from the CSV 'graph' column. Check input file and format.")
289
 
290
- print(f"Parsed {len(node_atomic_lists)} graphs (using 'graph' column). Using manual max atomic Z = {MAX_ATOMIC_Z}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- # ---------------------------
293
- # 2. Train/Val Split
294
- # ---------------------------
295
- indices = list(range(len(node_atomic_lists)))
296
- train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
297
-
298
- def subset(l, idxs):
299
- return [l[i] for i in idxs]
300
-
301
- train_atomic = subset(node_atomic_lists, train_idx)
302
- train_chirality = subset(node_chirality_lists, train_idx)
303
- train_charge = subset(node_charge_lists, train_idx)
304
- train_edge_index = subset(edge_index_lists, train_idx)
305
- train_edge_attr = subset(edge_attr_lists, train_idx)
306
- train_num_nodes = subset(num_nodes_list, train_idx)
307
-
308
- val_atomic = subset(node_atomic_lists, val_idx)
309
- val_chirality = subset(node_chirality_lists, val_idx)
310
- val_charge = subset(node_charge_lists, val_idx)
311
- val_edge_index = subset(edge_index_lists, val_idx)
312
- val_edge_attr = subset(edge_attr_lists, val_idx)
313
- val_num_nodes = subset(num_nodes_list, val_idx)
314
-
315
- # ---------------------------
316
- # Compute class weights (for weighted CE)
317
- # ---------------------------
318
- num_classes = MASK_ATOM_ID + 1
319
- counts = np.ones((num_classes,), dtype=np.float64)
320
- for z in train_atomic:
321
- vals = z.cpu().numpy().astype(int)
322
- for v in vals:
323
- if 0 <= v < num_classes:
324
- counts[v] += 1.0
325
- freq = counts / counts.sum()
326
- inv_freq = 1.0 / (freq + 1e-12)
327
- class_weights = inv_freq / inv_freq.mean()
328
- class_weights = torch.tensor(class_weights, dtype=torch.float)
329
- class_weights[MASK_ATOM_ID] = 1.0
330
 
331
- # ---------------------------
332
- # 3. Dataset and Collator (build MLM masks + invariant distance targets using hop counts only)
333
- # ---------------------------
334
  class PolymerDataset(Dataset):
 
335
  def __init__(self, atomic_list, chirality_list, charge_list, edge_index_list, edge_attr_list, num_nodes_list):
336
  self.atomic_list = atomic_list
337
  self.chirality_list = chirality_list
@@ -345,45 +293,31 @@ class PolymerDataset(Dataset):
345
 
346
  def __getitem__(self, idx):
347
  return {
348
- "z": self.atomic_list[idx], # [n_nodes]
349
- "chirality": self.chirality_list[idx], # [n_nodes] float
350
- "formal_charge": self.charge_list[idx], # [n_nodes]
351
- "edge_index": self.edge_index_list[idx], # [2, E]
352
- "edge_attr": self.edge_attr_list[idx], # [E, 3]
353
- "num_nodes": int(self.num_nodes_list[idx]) # int
354
  }
355
 
 
356
  def collate_batch(batch):
357
  """
358
- Builds a batched structure from a list of graph dicts.
359
- Returns:
360
- - z: [N_total] long (atomic numbers possibly masked)
361
- - chirality: [N_total] float
362
- - formal_charge: [N_total] float
363
- - edge_index: [2, E_total] long (node indices offset per graph)
364
- - edge_attr: [E_total, 3] float
365
- - batch: [N_total] long mapping node->graph idx
366
- - labels_z: [N_total] long (-100 for unselected)
367
- - labels_dists: [N_total, K_ANCHORS] float (hop counts)
368
- - labels_dists_mask: [N_total, K_ANCHORS] bool
369
- Distance targets:
370
- - Shortest-path hop distances computed from edge_index for every graph.
371
  """
372
- all_z = []
373
- all_ch = []
374
- all_fc = []
375
- all_labels_z = []
376
- all_labels_dists = []
377
- all_labels_dists_mask = []
378
  batch_idx = []
 
379
  edge_index_list_batched = []
380
  edge_attr_list_batched = []
381
  node_offset = 0
382
- total_nodes = 0
383
- total_edges = 0
384
 
385
  for i, g in enumerate(batch):
386
- z = g["z"] # tensor [n]
387
  n = z.size(0)
388
  if n == 0:
389
  continue
@@ -393,7 +327,6 @@ def collate_batch(batch):
393
  edge_index = g["edge_index"]
394
  edge_attr = g["edge_attr"]
395
 
396
- # Mask selection
397
  is_selected = torch.rand(n) < P_MASK
398
  if is_selected.all():
399
  is_selected[torch.randint(0, n, (1,))] = False
@@ -415,37 +348,27 @@ def collate_batch(batch):
415
  z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID
416
  if rand_choice.any():
417
  z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice]
418
- # keep_choice -> do nothing
419
 
420
- # Build invariant distance targets using hop distances only
421
  visible_idx = torch.nonzero(~is_selected).squeeze(-1)
422
  if visible_idx.numel() == 0:
423
  visible_idx = torch.arange(n, dtype=torch.long)
424
 
425
- # compute hop distances via BFS using edge_index
426
- ei = edge_index.clone()
427
- num_nodes_local = n
428
- dist_mat = shortest_path_lengths_hops(ei, num_nodes_local) # numpy int matrix
429
  for a in torch.nonzero(is_selected).squeeze(-1).tolist():
430
- # distances to visible nodes
431
  vis = visible_idx.numpy()
432
  if vis.size == 0:
433
  continue
434
  dists = dist_mat[a, vis].astype(np.float32)
435
- # filter unreachable (INF = n+1)
436
- valid_mask = dists <= num_nodes_local
437
  if not valid_mask.any():
438
  continue
439
  dists_valid = dists[valid_mask]
440
- vis_valid = vis[valid_mask]
441
- # choose smallest hop distances
442
  k = min(K_ANCHORS, dists_valid.size)
443
  idx_sorted = np.argsort(dists_valid)[:k]
444
- selected_vals = dists_valid[idx_sorted]
445
- labels_dists[a, :k] = torch.tensor(selected_vals, dtype=torch.float)
446
  labels_dists_mask[a, :k] = True
447
 
448
- # Append node-level tensors to batched lists
449
  all_z.append(z_masked)
450
  all_ch.append(chir)
451
  all_fc.append(fc)
@@ -454,20 +377,15 @@ def collate_batch(batch):
454
  all_labels_dists_mask.append(labels_dists_mask)
455
  batch_idx.append(torch.full((n,), i, dtype=torch.long))
456
 
457
- # Offset edge indices and append
458
  if edge_index is not None and edge_index.numel() > 0:
459
  ei_offset = edge_index + node_offset
460
  edge_index_list_batched.append(ei_offset)
461
- # edge_attr already matched earlier; still ensure shapes here for safety
462
  edge_attr_matched = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3)
463
  edge_attr_list_batched.append(edge_attr_matched)
464
- total_edges += edge_index.size(1)
465
 
466
  node_offset += n
467
- total_nodes += n
468
 
469
  if len(all_z) == 0:
470
- # Return empty structured batch
471
  return {
472
  "z": torch.tensor([], dtype=torch.long),
473
  "chirality": torch.tensor([], dtype=torch.float),
@@ -477,7 +395,7 @@ def collate_batch(batch):
477
  "batch": torch.tensor([], dtype=torch.long),
478
  "labels_z": torch.tensor([], dtype=torch.long),
479
  "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS),
480
- "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS)
481
  }
482
 
483
  z_batch = torch.cat(all_z, dim=0)
@@ -504,83 +422,53 @@ def collate_batch(batch):
504
  "batch": batch_batch,
505
  "labels_z": labels_z_batch,
506
  "labels_dists": labels_dists_batch,
507
- "labels_dists_mask": labels_dists_mask_batch
508
  }
509
 
510
- train_dataset = PolymerDataset(train_atomic, train_chirality, train_charge, train_edge_index, train_edge_attr, train_num_nodes)
511
- val_dataset = PolymerDataset(val_atomic, val_chirality, val_charge, val_edge_index, val_edge_attr, val_num_nodes)
512
-
513
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch)
514
- val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_batch)
515
 
516
- # ---------------------------
517
- # 4. Model Definition (GINE-based masked model)
518
- # ---------------------------
519
  class GineBlock(nn.Module):
520
- def __init__(self, node_dim):
 
521
  super().__init__()
522
- # MLP used by GINEConv: map (node_dim) -> node_dim
523
- self.mlp = nn.Sequential(
524
- nn.Linear(node_dim, node_dim),
525
- nn.ReLU(),
526
- nn.Linear(node_dim, node_dim)
527
- )
528
  self.conv = GINEConv(self.mlp)
529
  self.bn = nn.BatchNorm1d(node_dim)
530
  self.act = nn.ReLU()
531
 
532
  def forward(self, x, edge_index, edge_attr):
533
- # GINEConv accepts edge_attr; edge_attr should be same dim as x (or handled in MLP inside)
534
  x = self.conv(x, edge_index, edge_attr)
535
  x = self.bn(x)
536
  x = self.act(x)
537
  return x
538
 
 
539
  class MaskedGINE(nn.Module):
540
- def __init__(self,
541
- node_emb_dim=NODE_EMB_DIM,
542
- edge_emb_dim=EDGE_EMB_DIM,
543
- num_layers=NUM_GNN_LAYERS,
544
- max_atomic_z=MAX_ATOMIC_Z,
545
- class_weights=None):
 
 
546
  super().__init__()
547
  self.node_emb_dim = node_emb_dim
548
  self.edge_emb_dim = edge_emb_dim
549
  self.max_atomic_z = max_atomic_z
550
 
551
- # Embedding for atomic numbers (including MASK token)
552
  num_embeddings = MASK_ATOM_ID + 1
553
  self.atom_emb = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=node_emb_dim, padding_idx=None)
554
 
555
- # Small MLP to map numeric node attributes (chirality, formal_charge) -> node_emb_dim
556
- self.node_attr_proj = nn.Sequential(
557
- nn.Linear(2, node_emb_dim),
558
- nn.ReLU(),
559
- nn.Linear(node_emb_dim, node_emb_dim)
560
- )
561
-
562
- # Edge encoder: maps 3-dim raw edge features -> edge_emb_dim
563
- self.edge_encoder = nn.Sequential(
564
- nn.Linear(3, edge_emb_dim),
565
- nn.ReLU(),
566
- nn.Linear(edge_emb_dim, edge_emb_dim)
567
- )
568
-
569
- # Project edge_emb -> node_emb_dim if needed (registered in __init__ to avoid dynamic creation)
570
- if edge_emb_dim != node_emb_dim:
571
- self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim)
572
- else:
573
- self._edge_to_node_proj = None
574
 
575
- # GINE layers
576
  self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
577
 
578
- # Heads
579
- num_classes_local = MASK_ATOM_ID + 1
580
- self.atom_head = nn.Linear(node_emb_dim, num_classes_local)
581
  self.coord_head = nn.Linear(node_emb_dim, K_ANCHORS)
582
 
583
- # Learned uncertainty weighting
584
  if USE_LEARNED_WEIGHTING:
585
  self.log_var_z = nn.Parameter(torch.zeros(1))
586
  self.log_var_pos = nn.Parameter(torch.zeros(1))
@@ -593,50 +481,30 @@ class MaskedGINE(nn.Module):
593
  else:
594
  self.class_weights = None
595
 
596
- def forward(self, z, chirality, formal_charge, edge_index, edge_attr,
597
- batch=None, labels_z=None, labels_dists=None, labels_dists_mask=None):
598
- """
599
- z: [N] long (atomic numbers or MASK_ATOM_ID)
600
- chirality: [N] float
601
- formal_charge: [N] float
602
- edge_index: [2, E] long (global batched indices)
603
- edge_attr: [E, 3] float
604
- batch: [N] long mapping nodes->graph idx
605
- labels_*: optional supervision targets as in collate_batch
606
- """
607
  if batch is None:
608
  batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
609
 
610
- # Node embedding
611
- atom_embedding = self.atom_emb(z) # [N, node_emb_dim]
612
- node_attr = torch.stack([chirality, formal_charge], dim=1) # [N,2]
613
- node_attr_emb = self.node_attr_proj(node_attr.to(atom_embedding.device)) # [N, node_emb_dim]
614
-
615
- x = atom_embedding + node_attr_emb # combine categorical and numeric node features
616
 
617
- # Edge embedding
618
  if edge_attr is None or edge_attr.numel() == 0:
619
- E = 0 if edge_attr is None else edge_attr.size(0)
620
- edge_emb = torch.zeros((E, self.edge_emb_dim), dtype=torch.float, device=x.device)
621
  else:
622
- edge_emb = self.edge_encoder(edge_attr.to(x.device)) # [E, edge_emb_dim]
623
 
624
- # For GINEConv, edge_attr should match node feature dim used inside GINE (GINE uses the provided nn to process x_j + edge_attr)
625
- # Project edge_emb -> node_emb_dim if dims differ (registered in __init__)
626
- if self._edge_to_node_proj is not None:
627
- edge_for_conv = self._edge_to_node_proj(edge_emb)
628
- else:
629
- edge_for_conv = edge_emb
630
 
631
- # Run GNN layers
632
  h = x
633
  for layer in self.gnn_layers:
634
  h = layer(h, edge_index.to(h.device), edge_for_conv)
635
 
636
- logits = self.atom_head(h) # [N, num_classes]
637
- dists_pred = self.coord_head(h) # [N, K_ANCHORS]
638
 
639
- # Compute loss if labels provided
640
  if labels_z is not None and labels_dists is not None and labels_dists_mask is not None:
641
  mask = labels_z != -100
642
  if mask.sum() == 0:
@@ -648,13 +516,13 @@ class MaskedGINE(nn.Module):
648
  labels_dists_masked = labels_dists[mask]
649
  labels_dists_mask_mask = labels_dists_mask[mask]
650
 
651
- # classification loss
652
  if self.class_weights is not None:
653
- loss_z = F.cross_entropy(logits_masked, labels_z_masked.to(logits_masked.device), weight=self.class_weights.to(logits_masked.device))
 
 
654
  else:
655
  loss_z = F.cross_entropy(logits_masked, labels_z_masked.to(logits_masked.device))
656
 
657
- # distance loss (only where mask true)
658
  if labels_dists_mask_mask.any():
659
  preds = dists_pred_masked[labels_dists_mask_mask]
660
  trues = labels_dists_masked[labels_dists_mask_mask].to(preds.device)
@@ -665,54 +533,23 @@ class MaskedGINE(nn.Module):
665
  if USE_LEARNED_WEIGHTING:
666
  lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z
667
  lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos
668
- loss = 0.5 * (lz + lp)
669
- else:
670
- alpha = 1.0
671
- loss = loss_z + alpha * loss_pos
672
 
673
- return loss
674
 
675
  return logits, dists_pred
676
 
677
- # Instantiate model
678
- model = MaskedGINE(node_emb_dim=NODE_EMB_DIM,
679
- edge_emb_dim=EDGE_EMB_DIM,
680
- num_layers=NUM_GNN_LAYERS,
681
- max_atomic_z=MAX_ATOMIC_Z,
682
- class_weights=class_weights)
683
-
684
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
685
- model.to(device)
686
-
687
- # ---------------------------
688
- # 5. Training Setup (Hugging Face Trainer)
689
- # ---------------------------
690
- training_args = TrainingArguments(
691
- output_dir=OUTPUT_DIR,
692
- overwrite_output_dir=True,
693
- num_train_epochs=25,
694
- per_device_train_batch_size=16,
695
- per_device_eval_batch_size=8,
696
- gradient_accumulation_steps=4,
697
- eval_strategy="epoch",
698
- logging_steps=500,
699
- learning_rate=1e-4,
700
- weight_decay=0.01,
701
- fp16=torch.cuda.is_available(),
702
- save_strategy="no",
703
- disable_tqdm=False,
704
- logging_first_step=True,
705
- report_to=[],
706
- dataloader_num_workers=4,
707
- )
708
 
709
  class ValLossCallback(TrainerCallback):
710
- def __init__(self, trainer_ref=None):
 
711
  self.best_val_loss = float("inf")
712
  self.epochs_no_improve = 0
713
- self.patience = 10
714
  self.best_epoch = None
715
  self.trainer_ref = trainer_ref
 
 
716
 
717
  def on_epoch_end(self, args, state, control, **kwargs):
718
  epoch_num = int(state.epoch)
@@ -723,30 +560,25 @@ class ValLossCallback(TrainerCallback):
723
 
724
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
725
  epoch_num = int(state.epoch) + 1
 
726
  if self.trainer_ref is None:
727
  print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
728
  return
729
 
730
- metric_val_loss = None
731
- if metrics is not None:
732
- metric_val_loss = metrics.get("eval_loss")
733
 
734
  model_eval = self.trainer_ref.model
735
  model_eval.eval()
736
- device_local = next(model_eval.parameters()).device if any(p.numel() > 0 for p in model_eval.parameters()) else torch.device("cpu")
737
 
738
- preds_z_all = []
739
- true_z_all = []
740
- pred_dists_all = []
741
- true_dists_all = []
742
- total_loss = 0.0
743
- n_batches = 0
744
 
745
- logits_masked_list = []
746
- labels_masked_list = []
747
 
748
  with torch.no_grad():
749
- for batch in val_loader:
750
  z = batch["z"].to(device_local)
751
  chir = batch["chirality"].to(device_local)
752
  fc = batch["formal_charge"].to(device_local)
@@ -759,7 +591,7 @@ class ValLossCallback(TrainerCallback):
759
 
760
  try:
761
  loss = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx, labels_z, labels_dists, labels_dists_mask)
762
- except Exception as e:
763
  loss = None
764
 
765
  if isinstance(loss, torch.Tensor):
@@ -767,7 +599,6 @@ class ValLossCallback(TrainerCallback):
767
  n_batches += 1
768
 
769
  logits, dists_pred = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx)
770
-
771
  mask = labels_z != -100
772
  if mask.sum().item() == 0:
773
  continue
@@ -778,7 +609,6 @@ class ValLossCallback(TrainerCallback):
778
  pred_z = torch.argmax(logits[mask], dim=-1)
779
  true_z = labels_z[mask]
780
 
781
- # flatten valid distances
782
  pred_d = dists_pred[mask][labels_dists_mask[mask]]
783
  true_d = labels_dists[mask][labels_dists_mask[mask]]
784
 
@@ -801,9 +631,8 @@ class ValLossCallback(TrainerCallback):
801
  all_labels_masked = torch.cat(labels_masked_list, dim=0)
802
  cw = getattr(model_eval, "class_weights", None)
803
  if cw is not None:
804
- cw_device = cw.to(device_local)
805
  try:
806
- loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw_device)
807
  except Exception:
808
  loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
809
  else:
@@ -823,15 +652,14 @@ class ValLossCallback(TrainerCallback):
823
  print(f"Validation MAE (distances): {mae:.4f}")
824
  print(f"Validation Perplexity (classification head): {perplexity:.4f}")
825
 
826
- # Save best model by val loss
827
  if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
828
  self.best_val_loss = avg_val_loss
829
  self.best_epoch = int(state.epoch)
830
  self.epochs_no_improve = 0
831
- os.makedirs(BEST_MODEL_DIR, exist_ok=True)
832
  try:
833
- torch.save(self.trainer_ref.model.state_dict(), os.path.join(BEST_MODEL_DIR, "pytorch_model.bin"))
834
- print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(BEST_MODEL_DIR, 'pytorch_model.bin')}")
835
  except Exception as e:
836
  print(f"Failed to save best model at epoch {epoch_num}: {e}")
837
  else:
@@ -841,122 +669,193 @@ class ValLossCallback(TrainerCallback):
841
  print(f"Early stopping after {self.patience} epochs with no improvement.")
842
  control.should_training_stop = True
843
 
844
- # Create callback and Trainer
845
- callback = ValLossCallback()
846
- trainer = Trainer(
847
- model=model,
848
- args=training_args,
849
- train_dataset=train_dataset,
850
- eval_dataset=val_dataset,
851
- data_collator=collate_batch,
852
- callbacks=[callback]
853
- )
854
- callback.trainer_ref = trainer
855
 
856
- # ---------------------------
857
- # 6. Run training
858
- # ---------------------------
859
- start_time = time.time()
860
- trainer.train()
861
- total_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
 
863
- # ---------------------------
864
- # 7. Final Evaluation (on best saved model)
865
- # ---------------------------
866
- best_model_path = os.path.join(BEST_MODEL_DIR, "pytorch_model.bin")
867
- if os.path.exists(best_model_path):
868
- try:
869
- model.load_state_dict(torch.load(best_model_path, map_location=device))
870
- print(f"\nLoaded best model from {best_model_path}")
871
- except Exception as e:
872
- print(f"\nFailed to load best model from {best_model_path}: {e}")
873
-
874
- model.eval()
875
- preds_z_all = []
876
- true_z_all = []
877
- pred_dists_all = []
878
- true_dists_all = []
879
-
880
- logits_masked_list_final = []
881
- labels_masked_list_final = []
882
-
883
- with torch.no_grad():
884
- for batch in val_loader:
885
- z = batch["z"].to(device)
886
- chir = batch["chirality"].to(device)
887
- fc = batch["formal_charge"].to(device)
888
- edge_index = batch["edge_index"].to(device)
889
- edge_attr = batch["edge_attr"].to(device)
890
- batch_idx = batch["batch"].to(device)
891
- labels_z = batch["labels_z"].to(device)
892
- labels_dists = batch["labels_dists"].to(device)
893
- labels_dists_mask = batch["labels_dists_mask"].to(device)
894
-
895
- logits, dists_pred = model(z, chir, fc, edge_index, edge_attr, batch_idx)
896
-
897
- mask = labels_z != -100
898
- if mask.sum().item() == 0:
899
- continue
900
 
901
- logits_masked_list_final.append(logits[mask])
902
- labels_masked_list_final.append(labels_z[mask])
903
 
904
- pred_z = torch.argmax(logits[mask], dim=-1)
905
- true_z = labels_z[mask]
906
 
907
- pred_d = dists_pred[mask][labels_dists_mask[mask]]
908
- true_d = labels_dists[mask][labels_dists_mask[mask]]
909
 
910
- if pred_d.numel() > 0:
911
- pred_dists_all.extend(pred_d.cpu().tolist())
912
- true_dists_all.extend(true_d.cpu().tolist())
913
 
914
- preds_z_all.extend(pred_z.cpu().tolist())
915
- true_z_all.extend(true_z.cpu().tolist())
916
 
917
- accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
918
- f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
919
- rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
920
- mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
921
 
922
- if len(logits_masked_list_final) > 0:
923
- all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0)
924
- all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0)
925
- cw_final = getattr(model, "class_weights", None)
926
- if cw_final is not None:
 
 
 
 
 
 
927
  try:
928
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device))
929
  except Exception:
930
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
931
  else:
932
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
933
- try:
934
- perplexity_final = float(torch.exp(loss_z_final).cpu().item())
935
- except Exception:
936
- perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
937
- else:
938
- perplexity_final = float("nan")
939
-
940
- best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
941
- best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
942
-
943
- print(f"\n=== Final Results (evaluated on best saved model) ===")
944
- print(f"Total Training Time (s): {total_time:.2f}")
945
- if best_epoch_num is not None:
946
- print(f"Best Epoch (1-based): {best_epoch_num}")
947
- else:
948
- print("Best Epoch: (none saved)")
949
-
950
- print(f"Best Validation Loss: {best_val_loss:.4f}")
951
- print(f"Validation Accuracy: {accuracy:.4f}")
952
- print(f"Validation F1 (weighted): {f1:.4f}")
953
- print(f"Validation RMSE (distances): {rmse:.4f}")
954
- print(f"Validation MAE (distances): {mae:.4f}")
955
- print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
956
-
957
- total_params = sum(p.numel() for p in model.parameters())
958
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
959
- non_trainable_params = total_params - trainable_params
960
- print(f"Total Parameters: {total_params}")
961
- print(f"Trainable Parameters: {trainable_params}")
962
- print(f"Non-trainable Parameters: {non_trainable_params}")
 
1
+ """
2
+ GINE-based masked pretraining on polymer graphs.
3
+ """
4
 
5
  import os
6
  import json
7
  import time
 
 
8
  import sys
9
  import csv
10
+ import argparse
11
+ from typing import Any, Dict, List, Optional, Tuple
12
 
13
  # Increase max CSV field size limit
14
  csv.field_size_limit(sys.maxsize)
 
25
  from transformers.trainer_callback import TrainerCallback
26
  from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
27
 
 
28
  from torch_geometric.nn import GINEConv
29
 
30
  # ---------------------------
31
  # Configuration / Constants
32
  # ---------------------------
33
  P_MASK = 0.15
 
34
  MAX_ATOMIC_Z = 85
 
35
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
36
 
37
  USE_LEARNED_WEIGHTING = True
38
 
39
+ NODE_EMB_DIM = 300
40
+ EDGE_EMB_DIM = 300
 
41
  NUM_GNN_LAYERS = 5
42
 
 
43
  K_ANCHORS = 6
 
 
 
44
 
45
+
46
+ def parse_args() -> argparse.Namespace:
47
+ parser = argparse.ArgumentParser(description="GINE masked pretraining (graphs).")
48
+ parser.add_argument(
49
+ "--csv_path",
50
+ type=str,
51
+ default="/path/to/polymer_structures_unified_processed.csv",
52
+ help="Processed CSV containing a JSON 'graph' column.",
53
+ )
54
+ parser.add_argument("--target_rows", type=int, default=5_000_000, help="Max rows to parse.")
55
+ parser.add_argument("--chunksize", type=int, default=50_000, help="CSV chunksize.")
56
+ parser.add_argument("--output_dir", type=str, default="/path/to/gin_output_5M", help="Training output directory.")
57
+ parser.add_argument("--num_workers", type=int, default=4, help="PyTorch DataLoader num workers.")
58
+ return parser.parse_args()
59
+
60
 
61
  # ---------------------------
62
  # Helper functions
 
65
  def safe_get(d: dict, key: str, default=None):
66
  return d[key] if (isinstance(d, dict) and key in d) else default
67
 
68
+
69
+ def build_adj_list(edge_index: torch.Tensor, num_nodes: int) -> List[List[int]]:
70
+ """Adjacency list for BFS shortest paths."""
71
  adj = [[] for _ in range(num_nodes)]
72
  if edge_index is None or edge_index.numel() == 0:
73
  return adj
 
74
  src = edge_index[0].tolist()
75
  dst = edge_index[1].tolist()
76
  for u, v in zip(src, dst):
 
77
  if 0 <= u < num_nodes and 0 <= v < num_nodes:
78
  adj[u].append(v)
79
  return adj
80
 
81
+
82
+ def shortest_path_lengths_hops(edge_index: torch.Tensor, num_nodes: int) -> np.ndarray:
83
  """
84
+ All-pairs shortest path lengths in hops using BFS per node.
85
+ Unreachable pairs get distance INF=num_nodes+1.
86
  """
87
  adj = build_adj_list(edge_index, num_nodes)
88
  INF = num_nodes + 1
89
  dist_mat = np.full((num_nodes, num_nodes), INF, dtype=np.int32)
90
  for s in range(num_nodes):
 
91
  q = [s]
92
  dist_mat[s, s] = 0
93
  head = 0
94
  while head < len(q):
95
+ u = q[head]
96
+ head += 1
97
  for v in adj[u]:
98
  if dist_mat[s, v] == INF:
99
  dist_mat[s, v] = dist_mat[s, u] + 1
100
  q.append(v)
101
  return dist_mat
102
 
103
+
104
+ def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor, target_dim: int = 3) -> torch.Tensor:
105
  """
106
+ Ensure edge_attr has shape [E_index, target_dim], handling common mismatches.
 
 
 
 
107
  """
108
  E_idx = edge_index.size(1) if (edge_index is not None and edge_index.numel() > 0) else 0
109
  if E_idx == 0:
110
  return torch.zeros((0, target_dim), dtype=torch.float)
111
  if edge_attr is None or edge_attr.numel() == 0:
112
  return torch.zeros((E_idx, target_dim), dtype=torch.float)
113
+
114
  E_attr = edge_attr.size(0)
115
  if E_attr == E_idx:
 
116
  if edge_attr.size(1) != target_dim:
 
117
  D = edge_attr.size(1)
118
  if D < target_dim:
119
  pad = torch.zeros((E_attr, target_dim - D), dtype=torch.float, device=edge_attr.device)
120
  return torch.cat([edge_attr, pad], dim=1)
121
+ return edge_attr[:, :target_dim]
 
122
  return edge_attr
123
+
124
  if E_attr * 2 == E_idx:
125
  try:
126
  return torch.cat([edge_attr, edge_attr], dim=0)
127
  except Exception:
 
128
  pass
129
+
130
  reps = (E_idx + E_attr - 1) // E_attr
131
  edge_rep = edge_attr.repeat(reps, 1)[:E_idx]
132
  if edge_rep.size(1) != target_dim:
 
138
  edge_rep = edge_rep[:, :target_dim]
139
  return edge_rep
140
 
141
+
142
+ def parse_graphs_from_csv(csv_path: str, target_rows: int, chunksize: int):
143
+ """
144
+ Stream CSV and parse the JSON 'graph' field into graph tensors needed by the model.
145
+ Returns lists of per-graph tensors.
146
+ """
147
+ node_atomic_lists = []
148
+ node_chirality_lists = []
149
+ node_charge_lists = []
150
+ edge_index_lists = []
151
+ edge_attr_lists = []
152
+ num_nodes_list = []
153
+
154
+ rows_read = 0
155
+
156
+ for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize):
157
+ for _, row in chunk.iterrows():
158
+ graph_field = None
159
+ if "graph" in row and not pd.isna(row["graph"]):
 
160
  try:
161
+ graph_field = json.loads(row["graph"]) if isinstance(row["graph"], str) else row["graph"]
162
  except Exception:
163
  graph_field = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  else:
 
165
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ if graph_field is None:
 
168
  continue
169
 
170
+ node_features = safe_get(graph_field, "node_features", None)
171
+ if not node_features:
172
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ atomic_nums = []
175
+ chirality_vals = []
176
+ formal_charges = []
177
+ for nf in node_features:
178
+ an = safe_get(nf, "atomic_num", safe_get(nf, "atomic_number", 0))
179
+ ch = safe_get(nf, "chirality", 0)
180
+ fc = safe_get(nf, "formal_charge", 0)
181
+ atomic_nums.append(int(an))
182
+ chirality_vals.append(float(ch))
183
+ formal_charges.append(float(fc))
184
+
185
+ n_nodes = len(atomic_nums)
186
+
187
+ edge_indices_raw = safe_get(graph_field, "edge_indices", None)
188
+ edge_features_raw = safe_get(graph_field, "edge_features", None)
189
+
190
+ if edge_indices_raw is None:
191
+ adj_mat = safe_get(graph_field, "adjacency_matrix", None)
192
+ if adj_mat:
193
+ srcs, dsts = [], []
194
+ for i, row_adj in enumerate(adj_mat):
195
+ for j, val in enumerate(row_adj):
196
+ if val:
197
+ srcs.append(i)
198
+ dsts.append(j)
199
+ edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
200
+ E = edge_index.size(1)
201
+ edge_attr = torch.zeros((E, 3), dtype=torch.float)
202
+ else:
203
+ continue
204
+ else:
205
+ srcs, dsts = [], []
206
+ if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
207
+ if all(len(pair) == 2 and isinstance(pair[0], int) for pair in edge_indices_raw):
208
+ srcs = [int(p[0]) for p in edge_indices_raw]
209
+ dsts = [int(p[1]) for p in edge_indices_raw]
210
+ elif isinstance(edge_indices_raw[0][0], int):
211
+ try:
212
+ srcs = [int(x) for x in edge_indices_raw[0]]
213
+ dsts = [int(x) for x in edge_indices_raw[1]]
214
+ except Exception:
215
+ srcs, dsts = [], []
216
+ if len(srcs) == 0:
217
+ continue
218
 
219
+ edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
 
 
 
 
 
 
220
 
221
+ if edge_features_raw and isinstance(edge_features_raw, list):
222
+ bond_types, stereos, is_conjs = [], [], []
223
+ for ef in edge_features_raw:
224
+ bt = safe_get(ef, "bond_type", 0)
225
+ st = safe_get(ef, "stereo", 0)
226
+ ic = safe_get(ef, "is_conjugated", False)
227
+ bond_types.append(float(bt))
228
+ stereos.append(float(st))
229
+ is_conjs.append(float(1.0 if ic else 0.0))
230
+ edge_attr = torch.tensor(np.stack([bond_types, stereos, is_conjs], axis=1), dtype=torch.float)
231
+ else:
232
+ E = edge_index.size(1)
233
+ edge_attr = torch.zeros((E, 3), dtype=torch.float)
234
+
235
+ edge_attr = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3)
236
+
237
+ node_atomic_lists.append(torch.tensor(atomic_nums, dtype=torch.long))
238
+ node_chirality_lists.append(torch.tensor(chirality_vals, dtype=torch.float))
239
+ node_charge_lists.append(torch.tensor(formal_charges, dtype=torch.float))
240
+ edge_index_lists.append(edge_index)
241
+ edge_attr_lists.append(edge_attr)
242
+ num_nodes_list.append(n_nodes)
243
+
244
+ rows_read += 1
245
+ if rows_read >= target_rows:
246
+ break
247
+ if rows_read >= target_rows:
248
  break
 
 
 
 
 
249
 
250
+ if len(node_atomic_lists) == 0:
251
+ raise RuntimeError("No graphs were parsed from the CSV 'graph' column. Check input file and format.")
252
+
253
+ print(f"Parsed {len(node_atomic_lists)} graphs (using 'graph' column). Using manual max atomic Z = {MAX_ATOMIC_Z}")
254
+ return (
255
+ node_atomic_lists,
256
+ node_chirality_lists,
257
+ node_charge_lists,
258
+ edge_index_lists,
259
+ edge_attr_lists,
260
+ num_nodes_list,
261
+ )
262
+
263
+
264
+ def compute_class_weights(train_atomic: List[torch.Tensor]) -> torch.Tensor:
265
+ """Compute inverse-frequency class weights for atomic number prediction."""
266
+ num_classes = MASK_ATOM_ID + 1
267
+ counts = np.ones((num_classes,), dtype=np.float64)
268
+ for z in train_atomic:
269
+ vals = z.cpu().numpy().astype(int)
270
+ for v in vals:
271
+ if 0 <= v < num_classes:
272
+ counts[v] += 1.0
273
+ freq = counts / counts.sum()
274
+ inv_freq = 1.0 / (freq + 1e-12)
275
+ class_weights = inv_freq / inv_freq.mean()
276
+ class_weights = torch.tensor(class_weights, dtype=torch.float)
277
+ class_weights[MASK_ATOM_ID] = 1.0
278
+ return class_weights
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
 
281
  class PolymerDataset(Dataset):
282
+ """Holds per-graph tensors; collation builds a single batched graph with masking targets."""
283
  def __init__(self, atomic_list, chirality_list, charge_list, edge_index_list, edge_attr_list, num_nodes_list):
284
  self.atomic_list = atomic_list
285
  self.chirality_list = chirality_list
 
293
 
294
  def __getitem__(self, idx):
295
  return {
296
+ "z": self.atomic_list[idx],
297
+ "chirality": self.chirality_list[idx],
298
+ "formal_charge": self.charge_list[idx],
299
+ "edge_index": self.edge_index_list[idx],
300
+ "edge_attr": self.edge_attr_list[idx],
301
+ "num_nodes": int(self.num_nodes_list[idx]),
302
  }
303
 
304
+
305
  def collate_batch(batch):
306
  """
307
+ Build a single batched graph (node-concatenation with edge index offsets) and create:
308
+ - masked node labels (labels_z)
309
+ - hop-distance anchor targets (labels_dists) for masked nodes
 
 
 
 
 
 
 
 
 
 
310
  """
311
+ all_z, all_ch, all_fc = [], [], []
312
+ all_labels_z, all_labels_dists, all_labels_dists_mask = [], [], []
 
 
 
 
313
  batch_idx = []
314
+
315
  edge_index_list_batched = []
316
  edge_attr_list_batched = []
317
  node_offset = 0
 
 
318
 
319
  for i, g in enumerate(batch):
320
+ z = g["z"]
321
  n = z.size(0)
322
  if n == 0:
323
  continue
 
327
  edge_index = g["edge_index"]
328
  edge_attr = g["edge_attr"]
329
 
 
330
  is_selected = torch.rand(n) < P_MASK
331
  if is_selected.all():
332
  is_selected[torch.randint(0, n, (1,))] = False
 
348
  z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID
349
  if rand_choice.any():
350
  z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice]
 
351
 
352
+ # Hop-distance targets for masked atoms (anchors = nearest visible nodes in hop distance)
353
  visible_idx = torch.nonzero(~is_selected).squeeze(-1)
354
  if visible_idx.numel() == 0:
355
  visible_idx = torch.arange(n, dtype=torch.long)
356
 
357
+ dist_mat = shortest_path_lengths_hops(edge_index.clone(), n)
 
 
 
358
  for a in torch.nonzero(is_selected).squeeze(-1).tolist():
 
359
  vis = visible_idx.numpy()
360
  if vis.size == 0:
361
  continue
362
  dists = dist_mat[a, vis].astype(np.float32)
363
+ valid_mask = dists <= n
 
364
  if not valid_mask.any():
365
  continue
366
  dists_valid = dists[valid_mask]
 
 
367
  k = min(K_ANCHORS, dists_valid.size)
368
  idx_sorted = np.argsort(dists_valid)[:k]
369
+ labels_dists[a, :k] = torch.tensor(dists_valid[idx_sorted], dtype=torch.float)
 
370
  labels_dists_mask[a, :k] = True
371
 
 
372
  all_z.append(z_masked)
373
  all_ch.append(chir)
374
  all_fc.append(fc)
 
377
  all_labels_dists_mask.append(labels_dists_mask)
378
  batch_idx.append(torch.full((n,), i, dtype=torch.long))
379
 
 
380
  if edge_index is not None and edge_index.numel() > 0:
381
  ei_offset = edge_index + node_offset
382
  edge_index_list_batched.append(ei_offset)
 
383
  edge_attr_matched = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3)
384
  edge_attr_list_batched.append(edge_attr_matched)
 
385
 
386
  node_offset += n
 
387
 
388
  if len(all_z) == 0:
 
389
  return {
390
  "z": torch.tensor([], dtype=torch.long),
391
  "chirality": torch.tensor([], dtype=torch.float),
 
395
  "batch": torch.tensor([], dtype=torch.long),
396
  "labels_z": torch.tensor([], dtype=torch.long),
397
  "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS),
398
+ "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS),
399
  }
400
 
401
  z_batch = torch.cat(all_z, dim=0)
 
422
  "batch": batch_batch,
423
  "labels_z": labels_z_batch,
424
  "labels_dists": labels_dists_batch,
425
+ "labels_dists_mask": labels_dists_mask_batch,
426
  }
427
 
 
 
 
 
 
428
 
 
 
 
429
  class GineBlock(nn.Module):
430
+ """One GINEConv block (MLP + BN + ReLU)."""
431
+ def __init__(self, node_dim: int):
432
  super().__init__()
433
+ self.mlp = nn.Sequential(nn.Linear(node_dim, node_dim), nn.ReLU(), nn.Linear(node_dim, node_dim))
 
 
 
 
 
434
  self.conv = GINEConv(self.mlp)
435
  self.bn = nn.BatchNorm1d(node_dim)
436
  self.act = nn.ReLU()
437
 
438
  def forward(self, x, edge_index, edge_attr):
 
439
  x = self.conv(x, edge_index, edge_attr)
440
  x = self.bn(x)
441
  x = self.act(x)
442
  return x
443
 
444
+
445
  class MaskedGINE(nn.Module):
446
+ """
447
+ Masked GNN objective:
448
+ - predict masked atomic numbers (classification head)
449
+ - predict hop-distance anchors for masked nodes (regression head)
450
+ - optionally learned uncertainty weighting across the two losses
451
+ """
452
+ def __init__(self, node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS,
453
+ max_atomic_z=MAX_ATOMIC_Z, class_weights=None):
454
  super().__init__()
455
  self.node_emb_dim = node_emb_dim
456
  self.edge_emb_dim = edge_emb_dim
457
  self.max_atomic_z = max_atomic_z
458
 
 
459
  num_embeddings = MASK_ATOM_ID + 1
460
  self.atom_emb = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=node_emb_dim, padding_idx=None)
461
 
462
+ self.node_attr_proj = nn.Sequential(nn.Linear(2, node_emb_dim), nn.ReLU(), nn.Linear(node_emb_dim, node_emb_dim))
463
+ self.edge_encoder = nn.Sequential(nn.Linear(3, edge_emb_dim), nn.ReLU(), nn.Linear(edge_emb_dim, edge_emb_dim))
464
+
465
+ self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim) if edge_emb_dim != node_emb_dim else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
 
467
  self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
468
 
469
+ self.atom_head = nn.Linear(node_emb_dim, MASK_ATOM_ID + 1)
 
 
470
  self.coord_head = nn.Linear(node_emb_dim, K_ANCHORS)
471
 
 
472
  if USE_LEARNED_WEIGHTING:
473
  self.log_var_z = nn.Parameter(torch.zeros(1))
474
  self.log_var_pos = nn.Parameter(torch.zeros(1))
 
481
  else:
482
  self.class_weights = None
483
 
484
+ def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None,
485
+ labels_z=None, labels_dists=None, labels_dists_mask=None):
 
 
 
 
 
 
 
 
 
486
  if batch is None:
487
  batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
488
 
489
+ atom_embedding = self.atom_emb(z)
490
+ node_attr = torch.stack([chirality, formal_charge], dim=1)
491
+ node_attr_emb = self.node_attr_proj(node_attr.to(atom_embedding.device))
492
+ x = atom_embedding + node_attr_emb
 
 
493
 
 
494
  if edge_attr is None or edge_attr.numel() == 0:
495
+ edge_emb = torch.zeros((0, self.edge_emb_dim), dtype=torch.float, device=x.device)
 
496
  else:
497
+ edge_emb = self.edge_encoder(edge_attr.to(x.device))
498
 
499
+ edge_for_conv = self._edge_to_node_proj(edge_emb) if self._edge_to_node_proj is not None else edge_emb
 
 
 
 
 
500
 
 
501
  h = x
502
  for layer in self.gnn_layers:
503
  h = layer(h, edge_index.to(h.device), edge_for_conv)
504
 
505
+ logits = self.atom_head(h)
506
+ dists_pred = self.coord_head(h)
507
 
 
508
  if labels_z is not None and labels_dists is not None and labels_dists_mask is not None:
509
  mask = labels_z != -100
510
  if mask.sum() == 0:
 
516
  labels_dists_masked = labels_dists[mask]
517
  labels_dists_mask_mask = labels_dists_mask[mask]
518
 
 
519
  if self.class_weights is not None:
520
+ loss_z = F.cross_entropy(
521
+ logits_masked, labels_z_masked.to(logits_masked.device), weight=self.class_weights.to(logits_masked.device)
522
+ )
523
  else:
524
  loss_z = F.cross_entropy(logits_masked, labels_z_masked.to(logits_masked.device))
525
 
 
526
  if labels_dists_mask_mask.any():
527
  preds = dists_pred_masked[labels_dists_mask_mask]
528
  trues = labels_dists_masked[labels_dists_mask_mask].to(preds.device)
 
533
  if USE_LEARNED_WEIGHTING:
534
  lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z
535
  lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos
536
+ return 0.5 * (lz + lp)
 
 
 
537
 
538
+ return loss_z + loss_pos
539
 
540
  return logits, dists_pred
541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
  class ValLossCallback(TrainerCallback):
544
+ """Evaluation callback: prints metrics, saves best model, and early-stops on val loss."""
545
+ def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None):
546
  self.best_val_loss = float("inf")
547
  self.epochs_no_improve = 0
548
+ self.patience = patience
549
  self.best_epoch = None
550
  self.trainer_ref = trainer_ref
551
+ self.best_model_dir = best_model_dir
552
+ self.val_loader = val_loader
553
 
554
  def on_epoch_end(self, args, state, control, **kwargs):
555
  epoch_num = int(state.epoch)
 
560
 
561
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
562
  epoch_num = int(state.epoch) + 1
563
+
564
  if self.trainer_ref is None:
565
  print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
566
  return
567
 
568
+ metric_val_loss = metrics.get("eval_loss") if metrics is not None else None
 
 
569
 
570
  model_eval = self.trainer_ref.model
571
  model_eval.eval()
572
+ device_local = next(model_eval.parameters()).device
573
 
574
+ preds_z_all, true_z_all = [], []
575
+ pred_dists_all, true_dists_all = [], []
576
+ total_loss, n_batches = 0.0, 0
 
 
 
577
 
578
+ logits_masked_list, labels_masked_list = [], []
 
579
 
580
  with torch.no_grad():
581
+ for batch in self.val_loader:
582
  z = batch["z"].to(device_local)
583
  chir = batch["chirality"].to(device_local)
584
  fc = batch["formal_charge"].to(device_local)
 
591
 
592
  try:
593
  loss = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx, labels_z, labels_dists, labels_dists_mask)
594
+ except Exception:
595
  loss = None
596
 
597
  if isinstance(loss, torch.Tensor):
 
599
  n_batches += 1
600
 
601
  logits, dists_pred = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx)
 
602
  mask = labels_z != -100
603
  if mask.sum().item() == 0:
604
  continue
 
609
  pred_z = torch.argmax(logits[mask], dim=-1)
610
  true_z = labels_z[mask]
611
 
 
612
  pred_d = dists_pred[mask][labels_dists_mask[mask]]
613
  true_d = labels_dists[mask][labels_dists_mask[mask]]
614
 
 
631
  all_labels_masked = torch.cat(labels_masked_list, dim=0)
632
  cw = getattr(model_eval, "class_weights", None)
633
  if cw is not None:
 
634
  try:
635
+ loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw.to(device_local))
636
  except Exception:
637
  loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
638
  else:
 
652
  print(f"Validation MAE (distances): {mae:.4f}")
653
  print(f"Validation Perplexity (classification head): {perplexity:.4f}")
654
 
 
655
  if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
656
  self.best_val_loss = avg_val_loss
657
  self.best_epoch = int(state.epoch)
658
  self.epochs_no_improve = 0
659
+ os.makedirs(self.best_model_dir, exist_ok=True)
660
  try:
661
+ torch.save(self.trainer_ref.model.state_dict(), os.path.join(self.best_model_dir, "pytorch_model.bin"))
662
+ print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(self.best_model_dir, 'pytorch_model.bin')}")
663
  except Exception as e:
664
  print(f"Failed to save best model at epoch {epoch_num}: {e}")
665
  else:
 
669
  print(f"Early stopping after {self.patience} epochs with no improvement.")
670
  control.should_training_stop = True
671
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
+ def build_datasets_and_loaders(parsed, batch_train: int = 16, batch_val: int = 8, num_workers: int = 4):
674
+ """Split indices into train/val and construct Dataset/DataLoader."""
675
+ (node_atomic_lists, node_chirality_lists, node_charge_lists, edge_index_lists, edge_attr_lists, num_nodes_list) = parsed
676
+
677
+ indices = list(range(len(node_atomic_lists)))
678
+ train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
679
+
680
+ def subset(l, idxs):
681
+ return [l[i] for i in idxs]
682
+
683
+ train_atomic = subset(node_atomic_lists, train_idx)
684
+ train_chirality = subset(node_chirality_lists, train_idx)
685
+ train_charge = subset(node_charge_lists, train_idx)
686
+ train_edge_index = subset(edge_index_lists, train_idx)
687
+ train_edge_attr = subset(edge_attr_lists, train_idx)
688
+ train_num_nodes = subset(num_nodes_list, train_idx)
689
+
690
+ val_atomic = subset(node_atomic_lists, val_idx)
691
+ val_chirality = subset(node_chirality_lists, val_idx)
692
+ val_charge = subset(node_charge_lists, val_idx)
693
+ val_edge_index = subset(edge_index_lists, val_idx)
694
+ val_edge_attr = subset(edge_attr_lists, val_idx)
695
+ val_num_nodes = subset(num_nodes_list, val_idx)
696
+
697
+ train_dataset = PolymerDataset(train_atomic, train_chirality, train_charge, train_edge_index, train_edge_attr, train_num_nodes)
698
+ val_dataset = PolymerDataset(val_atomic, val_chirality, val_charge, val_edge_index, val_edge_attr, val_num_nodes)
699
+
700
+ train_loader = DataLoader(train_dataset, batch_size=batch_train, shuffle=True, collate_fn=collate_batch, num_workers=num_workers)
701
+ val_loader = DataLoader(val_dataset, batch_size=batch_val, shuffle=False, collate_fn=collate_batch, num_workers=num_workers)
702
+ return train_dataset, val_dataset, train_loader, val_loader, train_atomic
703
+
704
+
705
+ def train_and_evaluate(args: argparse.Namespace) -> None:
706
+ """Main run: parse data, build model, train, reload best, final eval printout."""
707
+ output_dir = args.output_dir
708
+ best_model_dir = os.path.join(output_dir, "best")
709
+ os.makedirs(output_dir, exist_ok=True)
710
+
711
+ parsed = parse_graphs_from_csv(args.csv_path, args.target_rows, args.chunksize)
712
+ train_dataset, val_dataset, train_loader, val_loader, train_atomic = build_datasets_and_loaders(
713
+ parsed, batch_train=16, batch_val=8, num_workers=args.num_workers
714
+ )
715
+
716
+ class_weights = compute_class_weights(train_atomic)
717
+
718
+ model = MaskedGINE(
719
+ node_emb_dim=NODE_EMB_DIM,
720
+ edge_emb_dim=EDGE_EMB_DIM,
721
+ num_layers=NUM_GNN_LAYERS,
722
+ max_atomic_z=MAX_ATOMIC_Z,
723
+ class_weights=class_weights,
724
+ )
725
+
726
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
727
+ model.to(device)
728
+
729
+ training_args = TrainingArguments(
730
+ output_dir=output_dir,
731
+ overwrite_output_dir=True,
732
+ num_train_epochs=25,
733
+ per_device_train_batch_size=16,
734
+ per_device_eval_batch_size=8,
735
+ gradient_accumulation_steps=4,
736
+ eval_strategy="epoch",
737
+ logging_steps=500,
738
+ learning_rate=1e-4,
739
+ weight_decay=0.01,
740
+ fp16=torch.cuda.is_available(),
741
+ save_strategy="no",
742
+ disable_tqdm=False,
743
+ logging_first_step=True,
744
+ report_to=[],
745
+ dataloader_num_workers=args.num_workers,
746
+ )
747
+
748
+ callback = ValLossCallback(best_model_dir=best_model_dir, val_loader=val_loader, patience=10)
749
+ trainer = Trainer(
750
+ model=model,
751
+ args=training_args,
752
+ train_dataset=train_dataset,
753
+ eval_dataset=val_dataset,
754
+ data_collator=collate_batch,
755
+ callbacks=[callback],
756
+ )
757
+ callback.trainer_ref = trainer
758
+
759
+ start_time = time.time()
760
+ trainer.train()
761
+ total_time = time.time() - start_time
762
+
763
+ best_model_path = os.path.join(best_model_dir, "pytorch_model.bin")
764
+ if os.path.exists(best_model_path):
765
+ try:
766
+ model.load_state_dict(torch.load(best_model_path, map_location=device))
767
+ print(f"\nLoaded best model from {best_model_path}")
768
+ except Exception as e:
769
+ print(f"\nFailed to load best model from {best_model_path}: {e}")
770
+
771
+ # Final evaluation block preserved (same as original intent, using val_loader)
772
+ model.eval()
773
+ preds_z_all, true_z_all = [], []
774
+ pred_dists_all, true_dists_all = [], []
775
+ logits_masked_list_final, labels_masked_list_final = [], []
776
+
777
+ with torch.no_grad():
778
+ for batch in val_loader:
779
+ z = batch["z"].to(device)
780
+ chir = batch["chirality"].to(device)
781
+ fc = batch["formal_charge"].to(device)
782
+ edge_index = batch["edge_index"].to(device)
783
+ edge_attr = batch["edge_attr"].to(device)
784
+ batch_idx = batch["batch"].to(device)
785
+ labels_z = batch["labels_z"].to(device)
786
+ labels_dists = batch["labels_dists"].to(device)
787
+ labels_dists_mask = batch["labels_dists_mask"].to(device)
788
+
789
+ logits, dists_pred = model(z, chir, fc, edge_index, edge_attr, batch_idx)
790
 
791
+ mask = labels_z != -100
792
+ if mask.sum().item() == 0:
793
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
 
795
+ logits_masked_list_final.append(logits[mask])
796
+ labels_masked_list_final.append(labels_z[mask])
797
 
798
+ pred_z = torch.argmax(logits[mask], dim=-1)
799
+ true_z = labels_z[mask]
800
 
801
+ pred_d = dists_pred[mask][labels_dists_mask[mask]]
802
+ true_d = labels_dists[mask][labels_dists_mask[mask]]
803
 
804
+ if pred_d.numel() > 0:
805
+ pred_dists_all.extend(pred_d.cpu().tolist())
806
+ true_dists_all.extend(true_d.cpu().tolist())
807
 
808
+ preds_z_all.extend(pred_z.cpu().tolist())
809
+ true_z_all.extend(true_z.cpu().tolist())
810
 
811
+ accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
812
+ f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
813
+ rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
814
+ mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
815
 
816
+ if len(logits_masked_list_final) > 0:
817
+ all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0)
818
+ all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0)
819
+ cw_final = getattr(model, "class_weights", None)
820
+ if cw_final is not None:
821
+ try:
822
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device))
823
+ except Exception:
824
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
825
+ else:
826
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
827
  try:
828
+ perplexity_final = float(torch.exp(loss_z_final).cpu().item())
829
  except Exception:
830
+ perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
831
  else:
832
+ perplexity_final = float("nan")
833
+
834
+ best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
835
+ best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
836
+
837
+ print(f"\n=== Final Results (evaluated on best saved model) ===")
838
+ print(f"Total Training Time (s): {total_time:.2f}")
839
+ print(f"Best Epoch (1-based): {best_epoch_num}" if best_epoch_num is not None else "Best Epoch: (none saved)")
840
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
841
+ print(f"Validation Accuracy: {accuracy:.4f}")
842
+ print(f"Validation F1 (weighted): {f1:.4f}")
843
+ print(f"Validation RMSE (distances): {rmse:.4f}")
844
+ print(f"Validation MAE (distances): {mae:.4f}")
845
+ print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
846
+
847
+ total_params = sum(p.numel() for p in model.parameters())
848
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
849
+ non_trainable_params = total_params - trainable_params
850
+ print(f"Total Parameters: {total_params}")
851
+ print(f"Trainable Parameters: {trainable_params}")
852
+ print(f"Non-trainable Parameters: {non_trainable_params}")
853
+
854
+
855
+ def main():
856
+ args = parse_args()
857
+ train_and_evaluate(args)
858
+
859
+
860
+ if __name__ == "__main__":
861
+ main()