manpreet88 commited on
Commit
3a0d11d
·
1 Parent(s): d80a1c4

Delete gine.py

Browse files
Files changed (1) hide show
  1. gine.py +0 -961
gine.py DELETED
@@ -1,961 +0,0 @@
1
- import os
2
- import json
3
- import time
4
- import shutil
5
-
6
- import sys
7
- import csv
8
-
9
- # Increase max CSV field size limit
10
- csv.field_size_limit(sys.maxsize)
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- import numpy as np
16
- import pandas as pd
17
- from sklearn.model_selection import train_test_split
18
- from torch.utils.data import Dataset, DataLoader
19
-
20
- from transformers import TrainingArguments, Trainer
21
- from transformers.trainer_callback import TrainerCallback
22
- from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
23
-
24
- # PyG
25
- from torch_geometric.nn import GINEConv
26
-
27
- # ---------------------------
28
- # Configuration / Constants
29
- # ---------------------------
30
- P_MASK = 0.15
31
- # Manual max atomic number (user requested)
32
- MAX_ATOMIC_Z = 85
33
- # Mask token id
34
- MASK_ATOM_ID = MAX_ATOMIC_Z + 1
35
-
36
- USE_LEARNED_WEIGHTING = True
37
-
38
- # GINE / embedding hyperparams requested
39
- NODE_EMB_DIM = 300 # node embedding dimension
40
- EDGE_EMB_DIM = 300 # edge embedding dimension
41
- NUM_GNN_LAYERS = 5
42
-
43
- # Other hyperparams
44
- K_ANCHORS = 6
45
- OUTPUT_DIR = "./gin_output_5M"
46
- BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best")
47
- os.makedirs(OUTPUT_DIR, exist_ok=True)
48
-
49
- # Data reading settings
50
- csv_path = "polymer_structures_unified_processed.csv"
51
- TARGET_ROWS = 5000000
52
- CHUNKSIZE = 50000
53
-
54
- # ---------------------------
55
- # Helper functions
56
- # ---------------------------
57
-
58
- def safe_get(d: dict, key: str, default=None):
59
- return d[key] if (isinstance(d, dict) and key in d) else default
60
-
61
- def build_adj_list(edge_index, num_nodes):
62
- adj = [[] for _ in range(num_nodes)]
63
- if edge_index is None or edge_index.numel() == 0:
64
- return adj
65
- # edge_index shape [2, E]
66
- src = edge_index[0].tolist()
67
- dst = edge_index[1].tolist()
68
- for u, v in zip(src, dst):
69
- # ensure indices are within range
70
- if 0 <= u < num_nodes and 0 <= v < num_nodes:
71
- adj[u].append(v)
72
- return adj
73
-
74
- def shortest_path_lengths_hops(edge_index, num_nodes):
75
- """
76
- Compute all-pairs shortest path lengths in hops (BFS per node).
77
- Returns an (num_nodes, num_nodes) numpy array with int distances; unreachable -> large number (e.g., num_nodes+1)
78
- """
79
- adj = build_adj_list(edge_index, num_nodes)
80
- INF = num_nodes + 1
81
- dist_mat = np.full((num_nodes, num_nodes), INF, dtype=np.int32)
82
- for s in range(num_nodes):
83
- # BFS
84
- q = [s]
85
- dist_mat[s, s] = 0
86
- head = 0
87
- while head < len(q):
88
- u = q[head]; head += 1
89
- for v in adj[u]:
90
- if dist_mat[s, v] == INF:
91
- dist_mat[s, v] = dist_mat[s, u] + 1
92
- q.append(v)
93
- return dist_mat
94
-
95
- def match_edge_attr_to_index(edge_index: torch.Tensor, edge_attr: torch.Tensor, target_dim: int = 3):
96
- """
97
- Ensure edge_attr has shape [E_index, D]. Handles common mismatches:
98
- - If edge_attr is empty/None -> returns zeros of shape [E_index, target_dim].
99
- - If edge_attr.size(0) == edge_index.size(1) -> return as-is.
100
- - If edge_attr.size(0) * 2 == edge_index.size(1) -> duplicate (common when features only for undirected edges).
101
- - Otherwise repeat/truncate edge_attr to match E_index (safe fallback).
102
- """
103
- E_idx = edge_index.size(1) if (edge_index is not None and edge_index.numel() > 0) else 0
104
- if E_idx == 0:
105
- return torch.zeros((0, target_dim), dtype=torch.float)
106
- if edge_attr is None or edge_attr.numel() == 0:
107
- return torch.zeros((E_idx, target_dim), dtype=torch.float)
108
- E_attr = edge_attr.size(0)
109
- if E_attr == E_idx:
110
- # already matches
111
- if edge_attr.size(1) != target_dim:
112
- # pad/truncate feature dimension to target_dim
113
- D = edge_attr.size(1)
114
- if D < target_dim:
115
- pad = torch.zeros((E_attr, target_dim - D), dtype=torch.float, device=edge_attr.device)
116
- return torch.cat([edge_attr, pad], dim=1)
117
- else:
118
- return edge_attr[:, :target_dim]
119
- return edge_attr
120
- # common case: features provided for undirected edges while edge_index contains both directions
121
- if E_attr * 2 == E_idx:
122
- try:
123
- return torch.cat([edge_attr, edge_attr], dim=0)
124
- except Exception:
125
- # fallback to repeat below
126
- pass
127
- # fallback: repeat/truncate edge_attr to fit E_idx
128
- reps = (E_idx + E_attr - 1) // E_attr
129
- edge_rep = edge_attr.repeat(reps, 1)[:E_idx]
130
- if edge_rep.size(1) != target_dim:
131
- D = edge_rep.size(1)
132
- if D < target_dim:
133
- pad = torch.zeros((E_idx, target_dim - D), dtype=torch.float, device=edge_rep.device)
134
- edge_rep = torch.cat([edge_rep, pad], dim=1)
135
- else:
136
- edge_rep = edge_rep[:, :target_dim]
137
- return edge_rep
138
-
139
- # ---------------------------
140
- # 1. Load Data from `graph` column (chunked)
141
- # ---------------------------
142
- node_atomic_lists = []
143
- node_chirality_lists = []
144
- node_charge_lists = []
145
- edge_index_lists = []
146
- edge_attr_lists = []
147
- num_nodes_list = []
148
- rows_read = 0
149
-
150
- for chunk in pd.read_csv(csv_path, engine="python", chunksize=CHUNKSIZE):
151
- for idx, row in chunk.iterrows():
152
- # Prefer 'graph' column JSON (string) per user request
153
- graph_field = None
154
- if "graph" in row and not pd.isna(row["graph"]):
155
- try:
156
- graph_field = json.loads(row["graph"])
157
- except Exception:
158
- # If already parsed or other format
159
- try:
160
- graph_field = row["graph"]
161
- except Exception:
162
- graph_field = None
163
- else:
164
- # If no graph column, skip (user requested to use graph column)
165
- continue
166
-
167
- if graph_field is None:
168
- continue
169
-
170
- # NODE FEATURES
171
- node_features = safe_get(graph_field, "node_features", None)
172
- if not node_features:
173
- # skip graphs without node_features
174
- continue
175
-
176
- atomic_nums = []
177
- chirality_vals = []
178
- formal_charges = []
179
-
180
- for nf in node_features:
181
- # atomic number
182
- an = safe_get(nf, "atomic_num", None)
183
- if an is None:
184
- # try alternate keys
185
- an = safe_get(nf, "atomic_number", 0)
186
- # chirality (use 0 default)
187
- ch = safe_get(nf, "chirality", 0)
188
- # formal charge (use 0 default)
189
- fc = safe_get(nf, "formal_charge", 0)
190
- atomic_nums.append(int(an))
191
- chirality_vals.append(float(ch))
192
- formal_charges.append(float(fc))
193
-
194
- n_nodes = len(atomic_nums)
195
-
196
- # EDGE INDICES & FEATURES
197
- edge_indices_raw = safe_get(graph_field, "edge_indices", None)
198
- edge_features_raw = safe_get(graph_field, "edge_features", None)
199
-
200
- if edge_indices_raw is None:
201
- # try adjacency_matrix to infer edges
202
- adj_mat = safe_get(graph_field, "adjacency_matrix", None)
203
- if adj_mat:
204
- # adjacency_matrix is list of lists
205
- srcs = []
206
- dsts = []
207
- for i, row_adj in enumerate(adj_mat):
208
- for j, val in enumerate(row_adj):
209
- if val:
210
- srcs.append(i)
211
- dsts.append(j)
212
- edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
213
- # no edge features available -> create zeros matching edges
214
- E = edge_index.size(1)
215
- edge_attr = torch.zeros((E, 3), dtype=torch.float)
216
- else:
217
- # no edges found -> skip this graph (GINE requires edges)
218
- continue
219
- else:
220
- # edge_indices_raw expected like [[u,v], [u2,v2], ...] or [[u1,u2,...],[v1,v2,...]]
221
- if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
222
- # Could be list of pairs or list of lists
223
- if all(len(pair) == 2 and isinstance(pair[0], int) for pair in edge_indices_raw):
224
- # list of pairs
225
- srcs = [int(p[0]) for p in edge_indices_raw]
226
- dsts = [int(p[1]) for p in edge_indices_raw]
227
- elif isinstance(edge_indices_raw[0][0], int):
228
- # Possibly already in [[srcs],[dsts]] format
229
- try:
230
- srcs = [int(x) for x in edge_indices_raw[0]]
231
- dsts = [int(x) for x in edge_indices_raw[1]]
232
- except Exception:
233
- # fallback
234
- srcs = []
235
- dsts = []
236
- else:
237
- srcs = []
238
- dsts = []
239
- else:
240
- srcs = []
241
- dsts = []
242
-
243
- if len(srcs) == 0:
244
- # fallback: skip graph
245
- continue
246
-
247
- edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
248
-
249
- # Build edge_attr matrix with 3 features: bond_type, stereo, is_conjugated (as float)
250
- if edge_features_raw and isinstance(edge_features_raw, list):
251
- bond_types = []
252
- stereos = []
253
- is_conjs = []
254
- for ef in edge_features_raw:
255
- bt = safe_get(ef, "bond_type", 0)
256
- st = safe_get(ef, "stereo", 0)
257
- ic = safe_get(ef, "is_conjugated", False)
258
- bond_types.append(float(bt))
259
- stereos.append(float(st))
260
- is_conjs.append(float(1.0 if ic else 0.0))
261
- edge_attr = torch.tensor(np.stack([bond_types, stereos, is_conjs], axis=1), dtype=torch.float)
262
- else:
263
- # no edge features -> zeros
264
- E = edge_index.size(1)
265
- edge_attr = torch.zeros((E, 3), dtype=torch.float)
266
-
267
- # Ensure edge_attr length matches edge_index (fix common mismatches)
268
- edge_attr = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3)
269
-
270
- # NOTE: we explicitly DO NOT parse or use coordinates (geometry) anywhere.
271
-
272
- # Save lists
273
- node_atomic_lists.append(torch.tensor(atomic_nums, dtype=torch.long))
274
- node_chirality_lists.append(torch.tensor(chirality_vals, dtype=torch.float))
275
- node_charge_lists.append(torch.tensor(formal_charges, dtype=torch.float))
276
- edge_index_lists.append(edge_index)
277
- edge_attr_lists.append(edge_attr)
278
- num_nodes_list.append(n_nodes)
279
-
280
- rows_read += 1
281
- if rows_read >= TARGET_ROWS:
282
- break
283
- if rows_read >= TARGET_ROWS:
284
- break
285
-
286
- if len(node_atomic_lists) == 0:
287
- raise RuntimeError("No graphs were parsed from the CSV 'graph' column. Check input file and format.")
288
-
289
- print(f"Parsed {len(node_atomic_lists)} graphs (using 'graph' column). Using manual max atomic Z = {MAX_ATOMIC_Z}")
290
-
291
- # ---------------------------
292
- # 2. Train/Val Split
293
- # ---------------------------
294
- indices = list(range(len(node_atomic_lists)))
295
- train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
296
-
297
- def subset(l, idxs):
298
- return [l[i] for i in idxs]
299
-
300
- train_atomic = subset(node_atomic_lists, train_idx)
301
- train_chirality = subset(node_chirality_lists, train_idx)
302
- train_charge = subset(node_charge_lists, train_idx)
303
- train_edge_index = subset(edge_index_lists, train_idx)
304
- train_edge_attr = subset(edge_attr_lists, train_idx)
305
- train_num_nodes = subset(num_nodes_list, train_idx)
306
-
307
- val_atomic = subset(node_atomic_lists, val_idx)
308
- val_chirality = subset(node_chirality_lists, val_idx)
309
- val_charge = subset(node_charge_lists, val_idx)
310
- val_edge_index = subset(edge_index_lists, val_idx)
311
- val_edge_attr = subset(edge_attr_lists, val_idx)
312
- val_num_nodes = subset(num_nodes_list, val_idx)
313
-
314
- # ---------------------------
315
- # Compute class weights (for weighted CE)
316
- # ---------------------------
317
- num_classes = MASK_ATOM_ID + 1
318
- counts = np.ones((num_classes,), dtype=np.float64)
319
- for z in train_atomic:
320
- vals = z.cpu().numpy().astype(int)
321
- for v in vals:
322
- if 0 <= v < num_classes:
323
- counts[v] += 1.0
324
- freq = counts / counts.sum()
325
- inv_freq = 1.0 / (freq + 1e-12)
326
- class_weights = inv_freq / inv_freq.mean()
327
- class_weights = torch.tensor(class_weights, dtype=torch.float)
328
- class_weights[MASK_ATOM_ID] = 1.0
329
-
330
- # ---------------------------
331
- # 3. Dataset and Collator (build MLM masks + invariant distance targets using hop counts only)
332
- # ---------------------------
333
- class PolymerDataset(Dataset):
334
- def __init__(self, atomic_list, chirality_list, charge_list, edge_index_list, edge_attr_list, num_nodes_list):
335
- self.atomic_list = atomic_list
336
- self.chirality_list = chirality_list
337
- self.charge_list = charge_list
338
- self.edge_index_list = edge_index_list
339
- self.edge_attr_list = edge_attr_list
340
- self.num_nodes_list = num_nodes_list
341
-
342
- def __len__(self):
343
- return len(self.atomic_list)
344
-
345
- def __getitem__(self, idx):
346
- return {
347
- "z": self.atomic_list[idx], # [n_nodes]
348
- "chirality": self.chirality_list[idx], # [n_nodes] float
349
- "formal_charge": self.charge_list[idx], # [n_nodes]
350
- "edge_index": self.edge_index_list[idx], # [2, E]
351
- "edge_attr": self.edge_attr_list[idx], # [E, 3]
352
- "num_nodes": int(self.num_nodes_list[idx]) # int
353
- }
354
-
355
- def collate_batch(batch):
356
- """
357
- Builds a batched structure from a list of graph dicts.
358
- Returns:
359
- - z: [N_total] long (atomic numbers possibly masked)
360
- - chirality: [N_total] float
361
- - formal_charge: [N_total] float
362
- - edge_index: [2, E_total] long (node indices offset per graph)
363
- - edge_attr: [E_total, 3] float
364
- - batch: [N_total] long mapping node->graph idx
365
- - labels_z: [N_total] long (-100 for unselected)
366
- - labels_dists: [N_total, K_ANCHORS] float (hop counts)
367
- - labels_dists_mask: [N_total, K_ANCHORS] bool
368
- Distance targets:
369
- - Shortest-path hop distances computed from edge_index for every graph.
370
- """
371
- all_z = []
372
- all_ch = []
373
- all_fc = []
374
- all_labels_z = []
375
- all_labels_dists = []
376
- all_labels_dists_mask = []
377
- batch_idx = []
378
- edge_index_list_batched = []
379
- edge_attr_list_batched = []
380
- node_offset = 0
381
- total_nodes = 0
382
- total_edges = 0
383
-
384
- for i, g in enumerate(batch):
385
- z = g["z"] # tensor [n]
386
- n = z.size(0)
387
- if n == 0:
388
- continue
389
-
390
- chir = g["chirality"]
391
- fc = g["formal_charge"]
392
- edge_index = g["edge_index"]
393
- edge_attr = g["edge_attr"]
394
-
395
- # Mask selection
396
- is_selected = torch.rand(n) < P_MASK
397
- if is_selected.all():
398
- is_selected[torch.randint(0, n, (1,))] = False
399
-
400
- labels_z = torch.full((n,), -100, dtype=torch.long)
401
- labels_dists = torch.zeros((n, K_ANCHORS), dtype=torch.float)
402
- labels_dists_mask = torch.zeros((n, K_ANCHORS), dtype=torch.bool)
403
- labels_z[is_selected] = z[is_selected]
404
-
405
- # BERT-style corruption on atomic numbers
406
- z_masked = z.clone()
407
- if is_selected.any():
408
- sel_idx = torch.nonzero(is_selected).squeeze(-1)
409
- rand_atomic = torch.randint(1, MAX_ATOMIC_Z + 1, (sel_idx.size(0),), dtype=torch.long)
410
- probs = torch.rand(sel_idx.size(0))
411
- mask_choice = probs < 0.8
412
- rand_choice = (probs >= 0.8) & (probs < 0.9)
413
- if mask_choice.any():
414
- z_masked[sel_idx[mask_choice]] = MASK_ATOM_ID
415
- if rand_choice.any():
416
- z_masked[sel_idx[rand_choice]] = rand_atomic[rand_choice]
417
- # keep_choice -> do nothing
418
-
419
- # Build invariant distance targets using hop distances only
420
- visible_idx = torch.nonzero(~is_selected).squeeze(-1)
421
- if visible_idx.numel() == 0:
422
- visible_idx = torch.arange(n, dtype=torch.long)
423
-
424
- # compute hop distances via BFS using edge_index
425
- ei = edge_index.clone()
426
- num_nodes_local = n
427
- dist_mat = shortest_path_lengths_hops(ei, num_nodes_local) # numpy int matrix
428
- for a in torch.nonzero(is_selected).squeeze(-1).tolist():
429
- # distances to visible nodes
430
- vis = visible_idx.numpy()
431
- if vis.size == 0:
432
- continue
433
- dists = dist_mat[a, vis].astype(np.float32)
434
- # filter unreachable (INF = n+1)
435
- valid_mask = dists <= num_nodes_local
436
- if not valid_mask.any():
437
- continue
438
- dists_valid = dists[valid_mask]
439
- vis_valid = vis[valid_mask]
440
- # choose smallest hop distances
441
- k = min(K_ANCHORS, dists_valid.size)
442
- idx_sorted = np.argsort(dists_valid)[:k]
443
- selected_vals = dists_valid[idx_sorted]
444
- labels_dists[a, :k] = torch.tensor(selected_vals, dtype=torch.float)
445
- labels_dists_mask[a, :k] = True
446
-
447
- # Append node-level tensors to batched lists
448
- all_z.append(z_masked)
449
- all_ch.append(chir)
450
- all_fc.append(fc)
451
- all_labels_z.append(labels_z)
452
- all_labels_dists.append(labels_dists)
453
- all_labels_dists_mask.append(labels_dists_mask)
454
- batch_idx.append(torch.full((n,), i, dtype=torch.long))
455
-
456
- # Offset edge indices and append
457
- if edge_index is not None and edge_index.numel() > 0:
458
- ei_offset = edge_index + node_offset
459
- edge_index_list_batched.append(ei_offset)
460
- # edge_attr already matched earlier; still ensure shapes here for safety
461
- edge_attr_matched = match_edge_attr_to_index(edge_index, edge_attr, target_dim=3)
462
- edge_attr_list_batched.append(edge_attr_matched)
463
- total_edges += edge_index.size(1)
464
-
465
- node_offset += n
466
- total_nodes += n
467
-
468
- if len(all_z) == 0:
469
- # Return empty structured batch
470
- return {
471
- "z": torch.tensor([], dtype=torch.long),
472
- "chirality": torch.tensor([], dtype=torch.float),
473
- "formal_charge": torch.tensor([], dtype=torch.float),
474
- "edge_index": torch.tensor([[], []], dtype=torch.long),
475
- "edge_attr": torch.tensor([], dtype=torch.float).reshape(0, 3),
476
- "batch": torch.tensor([], dtype=torch.long),
477
- "labels_z": torch.tensor([], dtype=torch.long),
478
- "labels_dists": torch.tensor([], dtype=torch.float).reshape(0, K_ANCHORS),
479
- "labels_dists_mask": torch.tensor([], dtype=torch.bool).reshape(0, K_ANCHORS)
480
- }
481
-
482
- z_batch = torch.cat(all_z, dim=0)
483
- chir_batch = torch.cat(all_ch, dim=0)
484
- fc_batch = torch.cat(all_fc, dim=0)
485
- labels_z_batch = torch.cat(all_labels_z, dim=0)
486
- labels_dists_batch = torch.cat(all_labels_dists, dim=0)
487
- labels_dists_mask_batch = torch.cat(all_labels_dists_mask, dim=0)
488
- batch_batch = torch.cat(batch_idx, dim=0)
489
-
490
- if len(edge_index_list_batched) > 0:
491
- edge_index_batched = torch.cat(edge_index_list_batched, dim=1)
492
- edge_attr_batched = torch.cat(edge_attr_list_batched, dim=0)
493
- else:
494
- edge_index_batched = torch.tensor([[], []], dtype=torch.long)
495
- edge_attr_batched = torch.tensor([], dtype=torch.float).reshape(0, 3)
496
-
497
- return {
498
- "z": z_batch,
499
- "chirality": chir_batch,
500
- "formal_charge": fc_batch,
501
- "edge_index": edge_index_batched,
502
- "edge_attr": edge_attr_batched,
503
- "batch": batch_batch,
504
- "labels_z": labels_z_batch,
505
- "labels_dists": labels_dists_batch,
506
- "labels_dists_mask": labels_dists_mask_batch
507
- }
508
-
509
- train_dataset = PolymerDataset(train_atomic, train_chirality, train_charge, train_edge_index, train_edge_attr, train_num_nodes)
510
- val_dataset = PolymerDataset(val_atomic, val_chirality, val_charge, val_edge_index, val_edge_attr, val_num_nodes)
511
-
512
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_batch)
513
- val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_batch)
514
-
515
- # ---------------------------
516
- # 4. Model Definition (GINE-based masked model)
517
- # ---------------------------
518
- class GineBlock(nn.Module):
519
- def __init__(self, node_dim):
520
- super().__init__()
521
- # MLP used by GINEConv: map (node_dim) -> node_dim
522
- self.mlp = nn.Sequential(
523
- nn.Linear(node_dim, node_dim),
524
- nn.ReLU(),
525
- nn.Linear(node_dim, node_dim)
526
- )
527
- self.conv = GINEConv(self.mlp)
528
- self.bn = nn.BatchNorm1d(node_dim)
529
- self.act = nn.ReLU()
530
-
531
- def forward(self, x, edge_index, edge_attr):
532
- # GINEConv accepts edge_attr; edge_attr should be same dim as x (or handled in MLP inside)
533
- x = self.conv(x, edge_index, edge_attr)
534
- x = self.bn(x)
535
- x = self.act(x)
536
- return x
537
-
538
- class MaskedGINE(nn.Module):
539
- def __init__(self,
540
- node_emb_dim=NODE_EMB_DIM,
541
- edge_emb_dim=EDGE_EMB_DIM,
542
- num_layers=NUM_GNN_LAYERS,
543
- max_atomic_z=MAX_ATOMIC_Z,
544
- class_weights=None):
545
- super().__init__()
546
- self.node_emb_dim = node_emb_dim
547
- self.edge_emb_dim = edge_emb_dim
548
- self.max_atomic_z = max_atomic_z
549
-
550
- # Embedding for atomic numbers (including MASK token)
551
- num_embeddings = MASK_ATOM_ID + 1
552
- self.atom_emb = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=node_emb_dim, padding_idx=None)
553
-
554
- # Small MLP to map numeric node attributes (chirality, formal_charge) -> node_emb_dim
555
- self.node_attr_proj = nn.Sequential(
556
- nn.Linear(2, node_emb_dim),
557
- nn.ReLU(),
558
- nn.Linear(node_emb_dim, node_emb_dim)
559
- )
560
-
561
- # Edge encoder: maps 3-dim raw edge features -> edge_emb_dim
562
- self.edge_encoder = nn.Sequential(
563
- nn.Linear(3, edge_emb_dim),
564
- nn.ReLU(),
565
- nn.Linear(edge_emb_dim, edge_emb_dim)
566
- )
567
-
568
- # Project edge_emb -> node_emb_dim if needed (registered in __init__ to avoid dynamic creation)
569
- if edge_emb_dim != node_emb_dim:
570
- self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim)
571
- else:
572
- self._edge_to_node_proj = None
573
-
574
- # GINE layers
575
- self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
576
-
577
- # Heads
578
- num_classes_local = MASK_ATOM_ID + 1
579
- self.atom_head = nn.Linear(node_emb_dim, num_classes_local)
580
- self.coord_head = nn.Linear(node_emb_dim, K_ANCHORS)
581
-
582
- # Learned uncertainty weighting
583
- if USE_LEARNED_WEIGHTING:
584
- self.log_var_z = nn.Parameter(torch.zeros(1))
585
- self.log_var_pos = nn.Parameter(torch.zeros(1))
586
- else:
587
- self.log_var_z = None
588
- self.log_var_pos = None
589
-
590
- if class_weights is not None:
591
- self.register_buffer("class_weights", class_weights)
592
- else:
593
- self.class_weights = None
594
-
595
- def forward(self, z, chirality, formal_charge, edge_index, edge_attr,
596
- batch=None, labels_z=None, labels_dists=None, labels_dists_mask=None):
597
- """
598
- z: [N] long (atomic numbers or MASK_ATOM_ID)
599
- chirality: [N] float
600
- formal_charge: [N] float
601
- edge_index: [2, E] long (global batched indices)
602
- edge_attr: [E, 3] float
603
- batch: [N] long mapping nodes->graph idx
604
- labels_*: optional supervision targets as in collate_batch
605
- """
606
- if batch is None:
607
- batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
608
-
609
- # Node embedding
610
- atom_embedding = self.atom_emb(z) # [N, node_emb_dim]
611
- node_attr = torch.stack([chirality, formal_charge], dim=1) # [N,2]
612
- node_attr_emb = self.node_attr_proj(node_attr.to(atom_embedding.device)) # [N, node_emb_dim]
613
-
614
- x = atom_embedding + node_attr_emb # combine categorical and numeric node features
615
-
616
- # Edge embedding
617
- if edge_attr is None or edge_attr.numel() == 0:
618
- E = 0 if edge_attr is None else edge_attr.size(0)
619
- edge_emb = torch.zeros((E, self.edge_emb_dim), dtype=torch.float, device=x.device)
620
- else:
621
- edge_emb = self.edge_encoder(edge_attr.to(x.device)) # [E, edge_emb_dim]
622
-
623
- # For GINEConv, edge_attr should match node feature dim used inside GINE (GINE uses the provided nn to process x_j + edge_attr)
624
- # Project edge_emb -> node_emb_dim if dims differ (registered in __init__)
625
- if self._edge_to_node_proj is not None:
626
- edge_for_conv = self._edge_to_node_proj(edge_emb)
627
- else:
628
- edge_for_conv = edge_emb
629
-
630
- # Run GNN layers
631
- h = x
632
- for layer in self.gnn_layers:
633
- h = layer(h, edge_index.to(h.device), edge_for_conv)
634
-
635
- logits = self.atom_head(h) # [N, num_classes]
636
- dists_pred = self.coord_head(h) # [N, K_ANCHORS]
637
-
638
- # Compute loss if labels provided
639
- if labels_z is not None and labels_dists is not None and labels_dists_mask is not None:
640
- mask = labels_z != -100
641
- if mask.sum() == 0:
642
- return torch.tensor(0.0, device=z.device)
643
-
644
- logits_masked = logits[mask]
645
- dists_pred_masked = dists_pred[mask]
646
- labels_z_masked = labels_z[mask]
647
- labels_dists_masked = labels_dists[mask]
648
- labels_dists_mask_mask = labels_dists_mask[mask]
649
-
650
- # classification loss
651
- if self.class_weights is not None:
652
- loss_z = F.cross_entropy(logits_masked, labels_z_masked.to(logits_masked.device), weight=self.class_weights.to(logits_masked.device))
653
- else:
654
- loss_z = F.cross_entropy(logits_masked, labels_z_masked.to(logits_masked.device))
655
-
656
- # distance loss (only where mask true)
657
- if labels_dists_mask_mask.any():
658
- preds = dists_pred_masked[labels_dists_mask_mask]
659
- trues = labels_dists_masked[labels_dists_mask_mask].to(preds.device)
660
- loss_pos = F.mse_loss(preds, trues, reduction="mean")
661
- else:
662
- loss_pos = torch.tensor(0.0, device=z.device)
663
-
664
- if USE_LEARNED_WEIGHTING:
665
- lz = torch.exp(-self.log_var_z) * loss_z + self.log_var_z
666
- lp = torch.exp(-self.log_var_pos) * loss_pos + self.log_var_pos
667
- loss = 0.5 * (lz + lp)
668
- else:
669
- alpha = 1.0
670
- loss = loss_z + alpha * loss_pos
671
-
672
- return loss
673
-
674
- return logits, dists_pred
675
-
676
- # Instantiate model
677
- model = MaskedGINE(node_emb_dim=NODE_EMB_DIM,
678
- edge_emb_dim=EDGE_EMB_DIM,
679
- num_layers=NUM_GNN_LAYERS,
680
- max_atomic_z=MAX_ATOMIC_Z,
681
- class_weights=class_weights)
682
-
683
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
684
- model.to(device)
685
-
686
- # ---------------------------
687
- # 5. Training Setup (Hugging Face Trainer)
688
- # ---------------------------
689
- training_args = TrainingArguments(
690
- output_dir=OUTPUT_DIR,
691
- overwrite_output_dir=True,
692
- num_train_epochs=25,
693
- per_device_train_batch_size=16,
694
- per_device_eval_batch_size=8,
695
- gradient_accumulation_steps=4,
696
- eval_strategy="epoch",
697
- logging_steps=500,
698
- learning_rate=1e-4,
699
- weight_decay=0.01,
700
- fp16=torch.cuda.is_available(),
701
- save_strategy="no",
702
- disable_tqdm=False,
703
- logging_first_step=True,
704
- report_to=[],
705
- dataloader_num_workers=4,
706
- )
707
-
708
- class ValLossCallback(TrainerCallback):
709
- def __init__(self, trainer_ref=None):
710
- self.best_val_loss = float("inf")
711
- self.epochs_no_improve = 0
712
- self.patience = 10
713
- self.best_epoch = None
714
- self.trainer_ref = trainer_ref
715
-
716
- def on_epoch_end(self, args, state, control, **kwargs):
717
- epoch_num = int(state.epoch)
718
- train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None)
719
- print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===")
720
- if train_loss is not None:
721
- print(f"Train Loss: {train_loss:.4f}")
722
-
723
- def on_evaluate(self, args, state, control, metrics=None, **kwargs):
724
- epoch_num = int(state.epoch) + 1
725
- if self.trainer_ref is None:
726
- print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
727
- return
728
-
729
- metric_val_loss = None
730
- if metrics is not None:
731
- metric_val_loss = metrics.get("eval_loss")
732
-
733
- model_eval = self.trainer_ref.model
734
- model_eval.eval()
735
- device_local = next(model_eval.parameters()).device if any(p.numel() > 0 for p in model_eval.parameters()) else torch.device("cpu")
736
-
737
- preds_z_all = []
738
- true_z_all = []
739
- pred_dists_all = []
740
- true_dists_all = []
741
- total_loss = 0.0
742
- n_batches = 0
743
-
744
- logits_masked_list = []
745
- labels_masked_list = []
746
-
747
- with torch.no_grad():
748
- for batch in val_loader:
749
- z = batch["z"].to(device_local)
750
- chir = batch["chirality"].to(device_local)
751
- fc = batch["formal_charge"].to(device_local)
752
- edge_index = batch["edge_index"].to(device_local)
753
- edge_attr = batch["edge_attr"].to(device_local)
754
- batch_idx = batch["batch"].to(device_local)
755
- labels_z = batch["labels_z"].to(device_local)
756
- labels_dists = batch["labels_dists"].to(device_local)
757
- labels_dists_mask = batch["labels_dists_mask"].to(device_local)
758
-
759
- try:
760
- loss = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx, labels_z, labels_dists, labels_dists_mask)
761
- except Exception as e:
762
- loss = None
763
-
764
- if isinstance(loss, torch.Tensor):
765
- total_loss += loss.item()
766
- n_batches += 1
767
-
768
- logits, dists_pred = model_eval(z, chir, fc, edge_index, edge_attr, batch_idx)
769
-
770
- mask = labels_z != -100
771
- if mask.sum().item() == 0:
772
- continue
773
-
774
- logits_masked_list.append(logits[mask])
775
- labels_masked_list.append(labels_z[mask])
776
-
777
- pred_z = torch.argmax(logits[mask], dim=-1)
778
- true_z = labels_z[mask]
779
-
780
- # flatten valid distances
781
- pred_d = dists_pred[mask][labels_dists_mask[mask]]
782
- true_d = labels_dists[mask][labels_dists_mask[mask]]
783
-
784
- if pred_d.numel() > 0:
785
- pred_dists_all.extend(pred_d.cpu().tolist())
786
- true_dists_all.extend(true_d.cpu().tolist())
787
-
788
- preds_z_all.extend(pred_z.cpu().tolist())
789
- true_z_all.extend(true_z.cpu().tolist())
790
-
791
- avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
792
-
793
- accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
794
- f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
795
- rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
796
- mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
797
-
798
- if len(logits_masked_list) > 0:
799
- all_logits_masked = torch.cat(logits_masked_list, dim=0)
800
- all_labels_masked = torch.cat(labels_masked_list, dim=0)
801
- cw = getattr(model_eval, "class_weights", None)
802
- if cw is not None:
803
- cw_device = cw.to(device_local)
804
- try:
805
- loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked, weight=cw_device)
806
- except Exception:
807
- loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
808
- else:
809
- loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked)
810
- try:
811
- perplexity = float(torch.exp(loss_z_all).cpu().item())
812
- except Exception:
813
- perplexity = float(np.exp(float(loss_z_all.cpu().item())))
814
- else:
815
- perplexity = float("nan")
816
-
817
- print(f"\n--- Evaluation after Epoch {epoch_num} ---")
818
- print(f"Validation Loss: {avg_val_loss:.4f}")
819
- print(f"Validation Accuracy: {accuracy:.4f}")
820
- print(f"Validation F1 (weighted): {f1:.4f}")
821
- print(f"Validation RMSE (distances): {rmse:.4f}")
822
- print(f"Validation MAE (distances): {mae:.4f}")
823
- print(f"Validation Perplexity (classification head): {perplexity:.4f}")
824
-
825
- # Save best model by val loss
826
- if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
827
- self.best_val_loss = avg_val_loss
828
- self.best_epoch = int(state.epoch)
829
- self.epochs_no_improve = 0
830
- os.makedirs(BEST_MODEL_DIR, exist_ok=True)
831
- try:
832
- torch.save(self.trainer_ref.model.state_dict(), os.path.join(BEST_MODEL_DIR, "pytorch_model.bin"))
833
- print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(BEST_MODEL_DIR, 'pytorch_model.bin')}")
834
- except Exception as e:
835
- print(f"Failed to save best model at epoch {epoch_num}: {e}")
836
- else:
837
- self.epochs_no_improve += 1
838
-
839
- if self.epochs_no_improve >= self.patience:
840
- print(f"Early stopping after {self.patience} epochs with no improvement.")
841
- control.should_training_stop = True
842
-
843
- # Create callback and Trainer
844
- callback = ValLossCallback()
845
- trainer = Trainer(
846
- model=model,
847
- args=training_args,
848
- train_dataset=train_dataset,
849
- eval_dataset=val_dataset,
850
- data_collator=collate_batch,
851
- callbacks=[callback]
852
- )
853
- callback.trainer_ref = trainer
854
-
855
- # ---------------------------
856
- # 6. Run training
857
- # ---------------------------
858
- start_time = time.time()
859
- trainer.train()
860
- total_time = time.time() - start_time
861
-
862
- # ---------------------------
863
- # 7. Final Evaluation (on best saved model)
864
- # ---------------------------
865
- best_model_path = os.path.join(BEST_MODEL_DIR, "pytorch_model.bin")
866
- if os.path.exists(best_model_path):
867
- try:
868
- model.load_state_dict(torch.load(best_model_path, map_location=device))
869
- print(f"\nLoaded best model from {best_model_path}")
870
- except Exception as e:
871
- print(f"\nFailed to load best model from {best_model_path}: {e}")
872
-
873
- model.eval()
874
- preds_z_all = []
875
- true_z_all = []
876
- pred_dists_all = []
877
- true_dists_all = []
878
-
879
- logits_masked_list_final = []
880
- labels_masked_list_final = []
881
-
882
- with torch.no_grad():
883
- for batch in val_loader:
884
- z = batch["z"].to(device)
885
- chir = batch["chirality"].to(device)
886
- fc = batch["formal_charge"].to(device)
887
- edge_index = batch["edge_index"].to(device)
888
- edge_attr = batch["edge_attr"].to(device)
889
- batch_idx = batch["batch"].to(device)
890
- labels_z = batch["labels_z"].to(device)
891
- labels_dists = batch["labels_dists"].to(device)
892
- labels_dists_mask = batch["labels_dists_mask"].to(device)
893
-
894
- logits, dists_pred = model(z, chir, fc, edge_index, edge_attr, batch_idx)
895
-
896
- mask = labels_z != -100
897
- if mask.sum().item() == 0:
898
- continue
899
-
900
- logits_masked_list_final.append(logits[mask])
901
- labels_masked_list_final.append(labels_z[mask])
902
-
903
- pred_z = torch.argmax(logits[mask], dim=-1)
904
- true_z = labels_z[mask]
905
-
906
- pred_d = dists_pred[mask][labels_dists_mask[mask]]
907
- true_d = labels_dists[mask][labels_dists_mask[mask]]
908
-
909
- if pred_d.numel() > 0:
910
- pred_dists_all.extend(pred_d.cpu().tolist())
911
- true_dists_all.extend(true_d.cpu().tolist())
912
-
913
- preds_z_all.extend(pred_z.cpu().tolist())
914
- true_z_all.extend(true_z.cpu().tolist())
915
-
916
- accuracy = accuracy_score(true_z_all, preds_z_all) if len(true_z_all) > 0 else 0.0
917
- f1 = f1_score(true_z_all, preds_z_all, average="weighted") if len(true_z_all) > 0 else 0.0
918
- rmse = np.sqrt(mean_squared_error(true_dists_all, pred_dists_all)) if len(true_dists_all) > 0 else 0.0
919
- mae = mean_absolute_error(true_dists_all, pred_dists_all) if len(true_dists_all) > 0 else 0.0
920
-
921
- if len(logits_masked_list_final) > 0:
922
- all_logits_masked_final = torch.cat(logits_masked_list_final, dim=0)
923
- all_labels_masked_final = torch.cat(labels_masked_list_final, dim=0)
924
- cw_final = getattr(model, "class_weights", None)
925
- if cw_final is not None:
926
- try:
927
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final, weight=cw_final.to(device))
928
- except Exception:
929
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
930
- else:
931
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final)
932
- try:
933
- perplexity_final = float(torch.exp(loss_z_final).cpu().item())
934
- except Exception:
935
- perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
936
- else:
937
- perplexity_final = float("nan")
938
-
939
- best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
940
- best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
941
-
942
- print(f"\n=== Final Results (evaluated on best saved model) ===")
943
- print(f"Total Training Time (s): {total_time:.2f}")
944
- if best_epoch_num is not None:
945
- print(f"Best Epoch (1-based): {best_epoch_num}")
946
- else:
947
- print("Best Epoch: (none saved)")
948
-
949
- print(f"Best Validation Loss: {best_val_loss:.4f}")
950
- print(f"Validation Accuracy: {accuracy:.4f}")
951
- print(f"Validation F1 (weighted): {f1:.4f}")
952
- print(f"Validation RMSE (distances): {rmse:.4f}")
953
- print(f"Validation MAE (distances): {mae:.4f}")
954
- print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
955
-
956
- total_params = sum(p.numel() for p in model.parameters())
957
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
958
- non_trainable_params = total_params - trainable_params
959
- print(f"Total Parameters: {total_params}")
960
- print(f"Trainable Parameters: {trainable_params}")
961
- print(f"Non-trainable Parameters: {non_trainable_params}")