manpreet88 commited on
Commit
21839e7
·
1 Parent(s): 6a3741a

Update CL.py

Browse files
Files changed (1) hide show
  1. PolyFusion/CL.py +826 -511
PolyFusion/CL.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import os
2
  import sys
3
  import csv
@@ -22,7 +27,7 @@ import torch.nn as nn
22
  import torch.nn.functional as F
23
  from torch.utils.data import Dataset, DataLoader
24
 
25
- # Shared model utilities
26
  from GINE import GineEncoder, match_edge_attr_to_index, safe_get
27
  from SchNet import NodeSchNetWrapper
28
  from Transformer import PooledFingerprintEncoder as FingerprintEncoder
@@ -30,15 +35,15 @@ from DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
30
 
31
  # HF Trainer & Transformers
32
  from transformers import TrainingArguments, Trainer
33
- from transformers import DataCollatorForLanguageModeling
34
  from transformers.trainer_callback import TrainerCallback
35
 
36
  from sklearn.model_selection import train_test_split
37
- from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, mean_absolute_error
 
 
 
 
38
 
39
- # ---------------------------
40
- # Config / Hyperparams (paths are placeholders; update for your environment)
41
- # ---------------------------
42
  P_MASK = 0.15
43
  MAX_ATOMIC_Z = 85
44
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
@@ -48,104 +53,112 @@ NODE_EMB_DIM = 300
48
  EDGE_EMB_DIM = 300
49
  NUM_GNN_LAYERS = 5
50
 
51
- # SchNet params
52
  SCHNET_NUM_GAUSSIANS = 50
53
  SCHNET_NUM_INTERACTIONS = 6
54
  SCHNET_CUTOFF = 10.0
55
  SCHNET_MAX_NEIGHBORS = 64
56
  SCHNET_HIDDEN = 600
57
 
58
- # Transformer params
59
  FP_LENGTH = 2048
60
- MASK_TOKEN_ID_FP = 2
61
  VOCAB_SIZE_FP = 3
62
 
63
- # DeBERTav2 params
64
  DEBERTA_HIDDEN = 600
65
  PSMILES_MAX_LEN = 128
66
 
67
  # Contrastive params
68
  TEMPERATURE = 0.07
 
 
 
 
 
 
 
69
 
70
- # Reconstruction loss weight (balance between contrastive and reconstruction objectives)
71
- REC_LOSS_WEIGHT = 1.0 # could be tuned (e.g., 0.5, 1.0)
72
 
73
- # Training args
74
  OUTPUT_DIR = "/path/to/multimodal_output"
75
- os.makedirs(OUTPUT_DIR, exist_ok=True)
76
  BEST_GINE_DIR = "/path/to/gin_output/best"
77
  BEST_SCHNET_DIR = "/path/to/schnet_output/best"
78
  BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
79
  BEST_PSMILES_DIR = "/path/to/polybert_output/best"
80
 
81
- training_args = TrainingArguments(
82
- output_dir=OUTPUT_DIR,
83
- overwrite_output_dir=True,
84
- num_train_epochs=25,
85
- per_device_train_batch_size=16,
86
- per_device_eval_batch_size=8,
87
- gradient_accumulation_steps=4,
88
- eval_strategy="epoch",
89
- logging_steps=100,
90
- learning_rate=1e-4,
91
- weight_decay=0.01,
92
- eval_accumulation_steps=1000,
93
- fp16=torch.cuda.is_available(),
94
- save_strategy="epoch",
95
- save_steps=500,
96
- disable_tqdm=False,
97
- logging_first_step=True,
98
- report_to=[],
99
- dataloader_num_workers=0,
100
- load_best_model_at_end=True,
101
- metric_for_best_model="eval_loss",
102
- greater_is_better=False,
103
- )
104
-
105
- # ======== robust device selection =========
106
- USE_CUDA = torch.cuda.is_available()
107
- if USE_CUDA:
108
- device = torch.device("cuda") # respects CUDA_VISIBLE_DEVICES
109
- else:
110
- device = torch.device("cpu")
111
- print("Device:", device)
112
-
113
- # ======== deterministic seeds =========
114
- SEED = 42
115
- random.seed(SEED)
116
- np.random.seed(SEED)
117
- torch.manual_seed(SEED)
118
- if USE_CUDA:
119
- torch.cuda.manual_seed_all(SEED)
120
-
121
- # ---------------------------
122
- # Utility / small helpers
123
- # ---------------------------
124
-
125
- # Optimized BFS to compute distances to visible anchors
126
- def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_idx: np.ndarray, visible_idx: np.ndarray, k_anchors: int):
127
- INF = num_nodes + 1
128
  selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
129
  selected_mask = np.zeros((num_nodes, k_anchors), dtype=np.bool_)
 
130
  if edge_index is None or edge_index.numel() == 0:
131
  return selected_dists, selected_mask
 
132
  src = edge_index[0].tolist()
133
  dst = edge_index[1].tolist()
 
134
  adj = [[] for _ in range(num_nodes)]
135
  for u, v in zip(src, dst):
136
  if 0 <= u < num_nodes and 0 <= v < num_nodes:
137
  adj[u].append(v)
138
- visible_set = set(visible_idx.tolist()) if isinstance(visible_idx, (np.ndarray, list)) else set(visible_idx.cpu().tolist())
 
 
 
 
 
 
139
  for a in np.atleast_1d(masked_idx).tolist():
140
  if a < 0 or a >= num_nodes:
141
  continue
 
142
  q = [a]
143
  visited = [-1] * num_nodes
144
  visited[a] = 0
145
  head = 0
146
  found = []
 
147
  while head < len(q) and len(found) < k_anchors:
148
- u = q[head]; head += 1
 
149
  for v in adj[u]:
150
  if visited[v] == -1:
151
  visited[v] = visited[u] + 1
@@ -154,137 +167,162 @@ def bfs_distances_to_visible(edge_index: torch.Tensor, num_nodes: int, masked_id
154
  found.append((visited[v], v))
155
  if len(found) >= k_anchors:
156
  break
 
157
  if len(found) > 0:
158
  found.sort(key=lambda x: x[0])
159
  k = min(k_anchors, len(found))
160
  for i in range(k):
161
  selected_dists[a, i] = float(found[i][0])
162
  selected_mask[a, i] = True
 
163
  return selected_dists, selected_mask
164
 
165
- # ---------------------------
166
- # Data loading / preprocessing
167
- # ---------------------------
168
- CSV_PATH = "/path/to/polymer_structures_unified_processed.csv"
169
- TARGET_ROWS = 2000000
170
- CHUNKSIZE = 50000
171
 
172
- PREPROC_DIR = "/path/to/preprocessed_samples"
173
- os.makedirs(PREPROC_DIR, exist_ok=True)
174
-
175
- # The per-sample file format: torch.save(sample_dict, sample_path)
176
- # sample_dict keys: 'gine', 'schnet', 'fp', 'psmiles_raw'
177
- # 'gine' -> dict with: node_atomic (list/int tensor), chirality (list/float tensor), formal_charge (list/float tensor), edge_index (2xE list), edge_attr (E x 3 list)
178
- # 'schnet' -> dict with: atomic (list), coords (list of [x,y,z])
179
- # 'fp' -> list of length FP_LENGTH (0/1 ints)
180
- # 'psmiles_raw' -> raw psmiles string
181
-
182
- def prepare_or_load_data_streaming():
183
- # If PREPROC_DIR already contains per-sample files, reuse them
184
- existing = sorted([p for p in Path(PREPROC_DIR).glob("sample_*.pt")])
 
 
 
 
 
 
 
 
 
 
 
185
  if len(existing) > 0:
186
- print(f"Found {len(existing)} preprocessed sample files in {PREPROC_DIR}; reusing those (no reparse).")
187
  return [str(p) for p in existing]
188
 
189
  print("No existing per-sample preprocessed folder found. Parsing CSV chunked and writing per-sample files (streaming).")
190
- rows_read = 0
191
  sample_idx = 0
192
 
193
- for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
194
- # Pre-extract columns presence
195
  has_graph = "graph" in chunk.columns
196
  has_geometry = "geometry" in chunk.columns
197
  has_fp = "fingerprints" in chunk.columns
198
  has_psmiles = "psmiles" in chunk.columns
199
 
200
  for i_row in range(len(chunk)):
201
- if rows_read >= TARGET_ROWS:
202
  break
 
203
  row = chunk.iloc[i_row]
204
 
205
- # Prepare placeholders
206
  gine_sample = None
207
  schnet_sample = None
208
  fp_sample = None
209
  psmiles_raw = None
210
 
211
- # Parse graph
212
  if has_graph:
213
  val = row.get("graph", "")
214
  try:
215
- graph_field = json.loads(val) if isinstance(val, str) and val.strip() != "" else (val if not isinstance(val, str) else None)
 
 
 
 
216
  except Exception:
217
  graph_field = None
 
218
  if graph_field:
219
  node_features = safe_get(graph_field, "node_features", None)
220
  if node_features:
221
  atomic_nums = []
222
  chirality_vals = []
223
  formal_charges = []
 
224
  for nf in node_features:
225
  an = safe_get(nf, "atomic_num", None)
226
  if an is None:
227
  an = safe_get(nf, "atomic_number", 0)
228
  ch = safe_get(nf, "chirality", 0)
229
  fc = safe_get(nf, "formal_charge", 0)
 
230
  try:
231
  atomic_nums.append(int(an))
232
  except Exception:
233
  atomic_nums.append(0)
 
234
  chirality_vals.append(float(ch))
235
  formal_charges.append(float(fc))
236
- n_nodes = len(atomic_nums)
237
  edge_indices_raw = safe_get(graph_field, "edge_indices", None)
238
  edge_features_raw = safe_get(graph_field, "edge_features", None)
 
239
  edge_index = None
240
  edge_attr = None
 
 
241
  if edge_indices_raw is None:
242
  adj_mat = safe_get(graph_field, "adjacency_matrix", None)
243
  if adj_mat:
244
- srcs = []
245
- dsts = []
246
  for i_r, row_adj in enumerate(adj_mat):
247
  for j, val2 in enumerate(row_adj):
248
  if val2:
249
- srcs.append(i_r); dsts.append(j)
 
250
  if len(srcs) > 0:
251
  edge_index = [srcs, dsts]
252
  E = len(srcs)
253
  edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
254
  else:
 
 
 
255
  srcs, dsts = [], []
256
- # handle multiple formats
257
  if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
258
- # either list of pairs or two lists
259
  first = edge_indices_raw[0]
260
  if len(first) == 2 and isinstance(first[0], int):
261
- # maybe list of pairs
262
  try:
263
  srcs = [int(p[0]) for p in edge_indices_raw]
264
  dsts = [int(p[1]) for p in edge_indices_raw]
265
  except Exception:
266
  srcs, dsts = [], []
267
  else:
268
- # maybe [[srcs],[dsts]]
269
  try:
270
  srcs = [int(x) for x in edge_indices_raw[0]]
271
  dsts = [int(x) for x in edge_indices_raw[1]]
272
  except Exception:
273
  srcs, dsts = [], []
274
- if len(srcs) == 0 and isinstance(edge_indices_raw, list) and all(isinstance(p, (list, tuple)) and len(p) == 2 for p in edge_indices_raw):
 
 
 
275
  srcs = [int(p[0]) for p in edge_indices_raw]
276
  dsts = [int(p[1]) for p in edge_indices_raw]
 
277
  if len(srcs) > 0:
278
  edge_index = [srcs, dsts]
 
279
  if edge_features_raw and isinstance(edge_features_raw, list):
280
- bond_types = []
281
- stereos = []
282
- is_conjs = []
283
  for ef in edge_features_raw:
284
  bt = safe_get(ef, "bond_type", 0)
285
  st = safe_get(ef, "stereo", 0)
286
  ic = safe_get(ef, "is_conjugated", False)
287
- bond_types.append(float(bt)); stereos.append(float(st)); is_conjs.append(float(1.0 if ic else 0.0))
 
 
288
  edge_attr = list(zip(bond_types, stereos, is_conjs))
289
  else:
290
  E = len(srcs)
@@ -299,11 +337,15 @@ def prepare_or_load_data_streaming():
299
  "edge_attr": edge_attr,
300
  }
301
 
302
- # Parse geometry for SchNet
303
  if has_geometry and schnet_sample is None:
304
  val = row.get("geometry", "")
305
  try:
306
- geom = json.loads(val) if isinstance(val, str) and val.strip() != "" else (val if not isinstance(val, str) else None)
 
 
 
 
307
  conf = geom.get("best_conformer") if isinstance(geom, dict) else None
308
  if conf:
309
  atomic = conf.get("atomic_numbers", [])
@@ -313,12 +355,13 @@ def prepare_or_load_data_streaming():
313
  except Exception:
314
  schnet_sample = None
315
 
316
- # Parse fingerprints
317
  if has_fp:
318
  fpval = row.get("fingerprints", "")
319
  if fpval is None or (isinstance(fpval, str) and fpval.strip() == ""):
320
  fp_sample = [0] * FP_LENGTH
321
  else:
 
322
  try:
323
  fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval
324
  except Exception:
@@ -330,8 +373,13 @@ def prepare_or_load_data_streaming():
330
  if len(bits) < FP_LENGTH:
331
  bits += [0] * (FP_LENGTH - len(bits))
332
  fp_sample = bits
 
333
  if fp_sample is None:
334
- bits = safe_get(fp_json, "morgan_r3_bits", None) if isinstance(fp_json, dict) else (fp_json if isinstance(fp_json, list) else None)
 
 
 
 
335
  if bits is None:
336
  fp_sample = [0] * FP_LENGTH
337
  else:
@@ -344,107 +392,122 @@ def prepare_or_load_data_streaming():
344
  normalized.append(1 if int(b) != 0 else 0)
345
  else:
346
  normalized.append(0)
 
347
  if len(normalized) >= FP_LENGTH:
348
  break
 
349
  if len(normalized) < FP_LENGTH:
350
  normalized.extend([0] * (FP_LENGTH - len(normalized)))
351
  fp_sample = normalized[:FP_LENGTH]
352
 
353
- # Parse psmiles
354
  if has_psmiles:
355
  s = row.get("psmiles", "")
356
- if s is None:
357
- psmiles_raw = ""
358
- else:
359
- psmiles_raw = str(s)
360
 
361
- # If we have at least two modalities (prefer all four), write the sample
362
- # For safety, we require psmiles and fp at minimum OR graph+psmiles etc.
363
- modalities_present = sum([1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]])
 
364
  if modalities_present >= 2:
365
  sample = {
366
  "gine": gine_sample,
367
  "schnet": schnet_sample,
368
  "fp": fp_sample,
369
- "psmiles_raw": psmiles_raw
370
  }
371
- sample_path = os.path.join(PREPROC_DIR, f"sample_{sample_idx:08d}.pt")
 
372
  try:
373
  torch.save(sample, sample_path)
374
  except Exception as save_e:
375
  print("Warning: failed to torch.save sample:", save_e)
376
- # fallback to json write small dict (safe)
377
  try:
378
  with open(sample_path + ".json", "w") as fjson:
379
  json.dump(sample, fjson)
380
- # indicate via filename with .json
381
- sample_path = sample_path + ".json"
382
  except Exception:
383
  pass
384
 
385
  sample_idx += 1
386
- rows_read += 1
387
 
388
- # continue to next row
389
- if rows_read >= TARGET_ROWS:
390
  break
391
 
392
- print(f"Wrote {sample_idx} sample files to {PREPROC_DIR}.")
393
- return [str(p) for p in sorted(Path(PREPROC_DIR).glob("sample_*.pt"))]
394
 
395
- sample_files = prepare_or_load_data_streaming()
396
 
397
- # ---------------------------
398
- # Prepare tokenizer for psmiles
399
- # ---------------------------
400
- SPM_MODEL = "/path/to/spm.model"
401
- tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
402
 
403
- # ---------------------------
404
- # Lazy dataset: loads per-sample file on demand and tokenizes psmiles on-the-fly
405
- # ---------------------------
406
  class LazyMultimodalDataset(Dataset):
407
- def __init__(self, sample_file_list: List[str], tokenizer, fp_length=FP_LENGTH, psmiles_max_len=PSMILES_MAX_LEN):
 
 
 
 
 
 
 
 
 
 
408
  self.files = sample_file_list
409
  self.tokenizer = tokenizer
410
  self.fp_length = fp_length
411
  self.psmiles_max_len = psmiles_max_len
412
 
413
- def __len__(self):
414
  return len(self.files)
415
 
416
- def __getitem__(self, idx):
417
  sample_path = self.files[idx]
418
- # prefer torch.load if .pt, else try json
 
419
  if sample_path.endswith(".pt"):
420
  sample = torch.load(sample_path, map_location="cpu")
421
  else:
422
- # fallback json load
423
  with open(sample_path, "r") as f:
424
  sample = json.load(f)
425
 
426
- # GINE: convert lists to tensors (still on CPU)
427
  gine_raw = sample.get("gine", None)
428
- gine_item = None
429
  if gine_raw:
430
  node_atomic = torch.tensor(gine_raw.get("node_atomic", []), dtype=torch.long)
431
  node_chirality = torch.tensor(gine_raw.get("node_chirality", []), dtype=torch.float)
432
  node_charge = torch.tensor(gine_raw.get("node_charge", []), dtype=torch.float)
 
433
  if gine_raw.get("edge_index", None) is not None:
434
- ei = gine_raw["edge_index"]
435
- edge_index = torch.tensor(ei, dtype=torch.long)
436
  else:
437
  edge_index = torch.tensor([[], []], dtype=torch.long)
 
438
  ea_raw = gine_raw.get("edge_attr", None)
439
  if ea_raw:
440
  edge_attr = torch.tensor(ea_raw, dtype=torch.float)
441
  else:
442
  edge_attr = torch.zeros((edge_index.size(1), 3), dtype=torch.float)
443
- gine_item = {"z": node_atomic, "chirality": node_chirality, "formal_charge": node_charge, "edge_index": edge_index, "edge_attr": edge_attr}
444
- else:
445
- gine_item = {"z": torch.tensor([], dtype=torch.long), "chirality": torch.tensor([], dtype=torch.float), "formal_charge": torch.tensor([], dtype=torch.float), "edge_index": torch.tensor([[], []], dtype=torch.long), "edge_attr": torch.zeros((0, 3), dtype=torch.float)}
446
 
447
- # SchNet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  schnet_raw = sample.get("schnet", None)
449
  if schnet_raw:
450
  s_z = torch.tensor(schnet_raw.get("atomic", []), dtype=torch.long)
@@ -453,12 +516,11 @@ class LazyMultimodalDataset(Dataset):
453
  else:
454
  schnet_item = {"z": torch.tensor([], dtype=torch.long), "pos": torch.tensor([], dtype=torch.float)}
455
 
456
- # Fingerprint stored as list of ints; convert to tensor here
457
  fp_raw = sample.get("fp", None)
458
  if fp_raw is None:
459
  fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
460
  else:
461
- # if fp_raw is already tensor-like, handle it
462
  if isinstance(fp_raw, (list, tuple)):
463
  arr = list(fp_raw)[:self.fp_length]
464
  if len(arr) < self.fp_length:
@@ -467,85 +529,87 @@ class LazyMultimodalDataset(Dataset):
467
  elif isinstance(fp_raw, torch.Tensor):
468
  fp_vec = fp_raw.clone().to(torch.long)
469
  else:
470
- # fallback
471
  fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
472
 
473
- # PSMILES: raw string, tokenize now
474
- psm_raw = sample.get("psmiles_raw", "")
475
- if psm_raw is None:
476
- psm_raw = ""
477
  enc = self.tokenizer(psm_raw, truncation=True, padding="max_length", max_length=self.psmiles_max_len)
478
  p_input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
479
  p_attn = torch.tensor(enc["attention_mask"], dtype=torch.bool)
480
 
481
  return {
482
- "gine": {"z": gine_item["z"], "chirality": gine_item["chirality"], "formal_charge": gine_item["formal_charge"], "edge_index": gine_item["edge_index"], "edge_attr": gine_item["edge_attr"], "num_nodes": int(gine_item["z"].size(0)) if gine_item["z"].numel() > 0 else 0},
 
 
 
 
 
 
 
483
  "schnet": {"z": schnet_item["z"], "pos": schnet_item["pos"]},
484
  "fp": {"input_ids": fp_vec},
485
- "psmiles": {"input_ids": p_input_ids, "attention_mask": p_attn}
486
  }
487
 
488
- # instantiate dataset lazily
489
- dataset = LazyMultimodalDataset(sample_files, tokenizer, fp_length=FP_LENGTH, psmiles_max_len=PSMILES_MAX_LEN)
490
 
491
- # train/val split
492
- train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)
493
- train_subset = torch.utils.data.Subset(dataset, train_idx)
494
- val_subset = torch.utils.data.Subset(dataset, val_idx)
495
-
496
- # For manual evaluation (used by evaluate_multimodal), create a DataLoader with num_workers=0
497
- def multimodal_collate(batch_list):
498
  """
499
- Given a list of items as returned by MultimodalDataset.__getitem__, build a batched mini-batch
500
- that the encoders accept.
 
 
 
 
 
501
  """
502
- B = len(batch_list)
503
- # GINE batching
504
- all_z = []
505
- all_ch = []
506
- all_fc = []
507
- all_edge_index = []
508
- all_edge_attr = []
509
  batch_mapping = []
510
  node_offset = 0
 
511
  for i, item in enumerate(batch_list):
512
  g = item["gine"]
513
  z = g["z"]
514
  n = z.size(0)
 
515
  all_z.append(z)
516
  all_ch.append(g["chirality"])
517
  all_fc.append(g["formal_charge"])
518
  batch_mapping.append(torch.full((n,), i, dtype=torch.long))
 
519
  if g["edge_index"] is not None and g["edge_index"].numel() > 0:
520
  ei_offset = g["edge_index"] + node_offset
521
  all_edge_index.append(ei_offset)
 
 
522
  ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
523
  all_edge_attr.append(ea)
 
524
  node_offset += n
 
525
  if len(all_z) == 0:
526
- # create zero-length placeholders for empty batch
527
  z_batch = torch.tensor([], dtype=torch.long)
528
  ch_batch = torch.tensor([], dtype=torch.float)
529
  fc_batch = torch.tensor([], dtype=torch.float)
530
  batch_batch = torch.tensor([], dtype=torch.long)
531
- edge_index_batched = torch.empty((2,0), dtype=torch.long)
532
- edge_attr_batched = torch.zeros((0,3), dtype=torch.float)
533
  else:
534
  z_batch = torch.cat(all_z, dim=0)
535
  ch_batch = torch.cat(all_ch, dim=0)
536
  fc_batch = torch.cat(all_fc, dim=0)
537
  batch_batch = torch.cat(batch_mapping, dim=0)
 
538
  if len(all_edge_index) > 0:
539
  edge_index_batched = torch.cat(all_edge_index, dim=1)
540
  edge_attr_batched = torch.cat(all_edge_attr, dim=0)
541
  else:
542
- edge_index_batched = torch.empty((2,0), dtype=torch.long)
543
- edge_attr_batched = torch.zeros((0,3), dtype=torch.float)
544
 
545
- # SchNet batching: concat nodes and create batch indices
546
- all_sz = []
547
- all_pos = []
548
- schnet_batch = []
549
  for i, item in enumerate(batch_list):
550
  s = item["schnet"]
551
  s_z = s["z"]
@@ -555,6 +619,7 @@ def multimodal_collate(batch_list):
555
  all_sz.append(s_z)
556
  all_pos.append(s_pos)
557
  schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long))
 
558
  if len(all_sz) == 0:
559
  s_z_batch = torch.tensor([], dtype=torch.long)
560
  s_pos_batch = torch.tensor([], dtype=torch.float)
@@ -564,375 +629,467 @@ def multimodal_collate(batch_list):
564
  s_pos_batch = torch.cat(all_pos, dim=0)
565
  s_batch_batch = torch.cat(schnet_batch, dim=0)
566
 
567
- # FP batching: each fp is vector [L] (long 0/1). We make attention_mask all ones.
568
- fp_ids = torch.stack([item["fp"]["input_ids"] if isinstance(item["fp"]["input_ids"], torch.Tensor) else torch.tensor(item["fp"]["input_ids"], dtype=torch.long) for item in batch_list], dim=0)
 
 
 
 
 
 
 
569
  fp_attn = torch.ones_like(fp_ids, dtype=torch.bool)
570
 
571
- # PSMILES
572
  p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch_list], dim=0)
573
  p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch_list], dim=0)
574
 
575
  return {
576
- "gine": {"z": z_batch, "chirality": ch_batch, "formal_charge": fc_batch, "edge_index": edge_index_batched, "edge_attr": edge_attr_batched, "batch": batch_batch},
 
 
 
 
 
 
 
577
  "schnet": {"z": s_z_batch, "pos": s_pos_batch, "batch": s_batch_batch},
578
  "fp": {"input_ids": fp_ids, "attention_mask": fp_attn},
579
- "psmiles": {"input_ids": p_ids, "attention_mask": p_attn}
580
  }
581
 
582
- train_loader = DataLoader(train_subset, batch_size=training_args.per_device_train_batch_size, shuffle=True, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
583
- val_loader = DataLoader(val_subset, batch_size=training_args.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate, num_workers=0, drop_last=False)
584
 
585
- # ---------------------------
586
- # Multimodal wrapper & loss
587
- # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  class MultimodalContrastiveModel(nn.Module):
589
- def __init__(self,
590
- gine_encoder: Optional[GineEncoder],
591
- schnet_encoder: Optional[NodeSchNetWrapper],
592
- fp_encoder: Optional[FingerprintEncoder],
593
- psmiles_encoder: Optional[PSMILESDebertaEncoder],
594
- emb_dim: int = 600):
 
 
 
 
 
 
 
 
595
  super().__init__()
596
  self.gine = gine_encoder
597
  self.schnet = schnet_encoder
598
  self.fp = fp_encoder
599
  self.psmiles = psmiles_encoder
 
600
  self.proj_gine = nn.Linear(getattr(self.gine, "pool_proj").out_features if self.gine is not None else emb_dim, emb_dim) if self.gine is not None else None
601
  self.proj_schnet = nn.Linear(getattr(self.schnet, "pool_proj").out_features if self.schnet is not None else emb_dim, emb_dim) if self.schnet is not None else None
602
  self.proj_fp = nn.Linear(getattr(self.fp, "pool_proj").out_features if self.fp is not None else emb_dim, emb_dim) if self.fp is not None else None
603
  self.proj_psmiles = nn.Linear(getattr(self.psmiles, "pool_proj").out_features if self.psmiles is not None else emb_dim, emb_dim) if self.psmiles is not None else None
 
604
  self.temperature = TEMPERATURE
605
- self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
606
 
607
  def encode(self, batch_mods: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
608
- device = next(self.parameters()).device
609
  embs = {}
610
- B = None
611
- if 'gine' in batch_mods and self.gine is not None:
612
- g = batch_mods['gine']
613
- emb_g = self.gine(g['z'], g['chirality'], g['formal_charge'], g['edge_index'], g['edge_attr'], g.get('batch', None))
614
- embs['gine'] = F.normalize(self.proj_gine(emb_g), dim=-1)
615
- B = embs['gine'].size(0) if B is None else B
616
- if 'schnet' in batch_mods and self.schnet is not None:
617
- s = batch_mods['schnet']
618
- emb_s = self.schnet(s['z'], s['pos'], s.get('batch', None))
619
- embs['schnet'] = F.normalize(self.proj_schnet(emb_s), dim=-1)
620
- B = embs['schnet'].size(0) if B is None else B
621
- if 'fp' in batch_mods and self.fp is not None:
622
- f = batch_mods['fp']
623
- emb_f = self.fp(f['input_ids'], f.get('attention_mask', None))
624
- embs['fp'] = F.normalize(self.proj_fp(emb_f), dim=-1)
625
- B = embs['fp'].size(0) if B is None else B
626
- if 'psmiles' in batch_mods and self.psmiles is not None:
627
- p = batch_mods['psmiles']
628
- emb_p = self.psmiles(p['input_ids'], p.get('attention_mask', None))
629
- embs['psmiles'] = F.normalize(self.proj_psmiles(emb_p), dim=-1)
630
- B = embs['psmiles'].size(0) if B is None else B
631
  return embs
632
 
633
  def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
 
 
 
634
  device = next(self.parameters()).device
635
  embs = self.encode(batch_mods)
636
  info = {}
 
637
  if mask_target not in embs:
638
  return torch.tensor(0.0, device=device), {"batch_size": 0}
 
639
  target = embs[mask_target]
640
  other_keys = [k for k in embs.keys() if k != mask_target]
641
  if len(other_keys) == 0:
642
  return torch.tensor(0.0, device=device), {"batch_size": target.size(0)}
 
643
  anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
644
  logits = torch.matmul(anchor, target.T) / self.temperature
645
  B = logits.size(0)
646
  labels = torch.arange(B, device=logits.device)
647
  info_nce_loss = F.cross_entropy(logits, labels)
648
- info['info_nce_loss'] = float(info_nce_loss.detach().cpu().item())
649
 
 
650
  rec_losses = []
651
  rec_details = {}
652
 
 
653
  try:
654
- if 'gine' in batch_mods and self.gine is not None:
655
- gm = batch_mods['gine']
656
- labels_nodes = gm.get('labels', None)
657
  if labels_nodes is not None:
658
- node_logits = self.gine.node_logits(gm['z'], gm['chirality'], gm['formal_charge'], gm['edge_index'], gm['edge_attr'])
659
  if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
660
  loss_gine = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
661
  rec_losses.append(loss_gine)
662
- rec_details['gine_rec_loss'] = float(loss_gine.detach().cpu().item())
663
  except Exception as e:
664
  print("Warning: GINE reconstruction loss computation failed:", e)
665
 
 
666
  try:
667
- if 'schnet' in batch_mods and self.schnet is not None:
668
- sm = batch_mods['schnet']
669
- labels_nodes = sm.get('labels', None)
670
  if labels_nodes is not None:
671
- node_logits = self.schnet.node_logits(sm['z'], sm['pos'], sm.get('batch', None))
672
  if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
673
  loss_schnet = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
674
  rec_losses.append(loss_schnet)
675
- rec_details['schnet_rec_loss'] = float(loss_schnet.detach().cpu().item())
676
  except Exception as e:
677
  print("Warning: SchNet reconstruction loss computation failed:", e)
678
 
 
679
  try:
680
- if 'fp' in batch_mods and self.fp is not None:
681
- fm = batch_mods['fp']
682
- labels_fp = fm.get('labels', None)
683
  if labels_fp is not None:
684
- token_logits = self.fp.token_logits(fm['input_ids'], fm.get('attention_mask', None))
685
  Bf, Lf, V = token_logits.shape
686
  logits2 = token_logits.view(-1, V)
687
  labels2 = labels_fp.view(-1).to(logits2.device)
688
  loss_fp = self.ce_loss(logits2, labels2)
689
  rec_losses.append(loss_fp)
690
- rec_details['fp_rec_loss'] = float(loss_fp.detach().cpu().item())
691
  except Exception as e:
692
  print("Warning: FP reconstruction loss computation failed:", e)
693
 
 
694
  try:
695
- if 'psmiles' in batch_mods and self.psmiles is not None:
696
- pm = batch_mods['psmiles']
697
- labels_ps = pm.get('labels', None)
698
- if labels_ps is not None and tokenizer is not None:
699
- loss_ps = self.psmiles.token_logits(pm['input_ids'], pm.get('attention_mask', None), labels=labels_ps)
700
  if isinstance(loss_ps, torch.Tensor):
701
  rec_losses.append(loss_ps)
702
- rec_details['psmiles_mlm_loss'] = float(loss_ps.detach().cpu().item())
703
  except Exception as e:
704
  print("Warning: PSMILES MLM loss computation failed:", e)
705
 
706
  if len(rec_losses) > 0:
707
  rec_loss_total = sum(rec_losses) / len(rec_losses)
708
- info['reconstruction_loss'] = float(rec_loss_total.detach().cpu().item())
709
  total_loss = info_nce_loss + REC_LOSS_WEIGHT * rec_loss_total
710
- info['total_loss'] = float(total_loss.detach().cpu().item())
711
  info.update(rec_details)
712
  else:
713
  total_loss = info_nce_loss
714
- info['reconstruction_loss'] = 0.0
715
- info['total_loss'] = float(total_loss.detach().cpu().item())
716
 
717
  return total_loss, info
718
 
719
- # ---------------------------
720
- # Instantiate encoders (load weights if available) and move to device with .to(device)
721
- gine_encoder = GineEncoder(node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z)
722
- if os.path.exists(os.path.join(BEST_GINE_DIR, "pytorch_model.bin")):
723
- try:
724
- gine_encoder.load_state_dict(torch.load(os.path.join(BEST_GINE_DIR, "pytorch_model.bin"), map_location="cpu"), strict=False)
725
- print("Loaded GINE best weights from", BEST_GINE_DIR)
726
- except Exception as e:
727
- print("Could not load GINE best weights:", e)
728
- gine_encoder.to(device)
729
 
730
- schnet_encoder = NodeSchNetWrapper(hidden_channels=SCHNET_HIDDEN, num_interactions=SCHNET_NUM_INTERACTIONS, num_gaussians=SCHNET_NUM_GAUSSIANS, cutoff=SCHNET_CUTOFF, max_num_neighbors=SCHNET_MAX_NEIGHBORS)
731
- if os.path.exists(os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin")):
732
- try:
733
- schnet_encoder.load_state_dict(torch.load(os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin"), map_location="cpu"), strict=False)
734
- print("Loaded SchNet best weights from", BEST_SCHNET_DIR)
735
- except Exception as e:
736
- print("Could not load SchNet best weights:", e)
737
- schnet_encoder.to(device)
738
-
739
- fp_encoder = FingerprintEncoder(vocab_size=VOCAB_SIZE_FP, hidden_dim=256, seq_len=FP_LENGTH, num_layers=4, nhead=8, dim_feedforward=1024, dropout=0.1)
740
- if os.path.exists(os.path.join(BEST_FP_DIR, "pytorch_model.bin")):
741
- try:
742
- fp_encoder.load_state_dict(torch.load(os.path.join(BEST_FP_DIR, "pytorch_model.bin"), map_location="cpu"), strict=False)
743
- print("Loaded fingerprint encoder best weights from", BEST_FP_DIR)
744
- except Exception as e:
745
- print("Could not load fingerprint best weights:", e)
746
- fp_encoder.to(device)
747
 
748
- psmiles_encoder = None
749
- if os.path.isdir(BEST_PSMILES_DIR):
750
- try:
751
- psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
752
- print("Loaded Deberta (PSMILES) from", BEST_PSMILES_DIR)
753
- except Exception as e:
754
- print("Failed to load Deberta from saved directory:", e)
755
- if psmiles_encoder is None:
756
- try:
757
- psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=None)
758
- except Exception as e:
759
- print("Failed to instantiate Deberta encoder:", e)
760
- psmiles_encoder.to(device)
761
 
762
- multimodal_model = MultimodalContrastiveModel(gine_encoder, schnet_encoder, fp_encoder, psmiles_encoder, emb_dim=600)
763
- multimodal_model.to(device)
 
 
 
 
 
 
764
 
765
- # ---------------------------
766
- # Helper to sample masked variant for modalities: (kept same, device-safe)
767
- def mask_batch_for_modality(batch: dict, modality: str, p_mask: float = P_MASK):
768
- b = {}
769
- # GINE:
770
- if 'gine' in batch:
771
- z = batch['gine']['z'].clone()
772
- chir = batch['gine']['chirality'].clone()
773
- fc = batch['gine']['formal_charge'].clone()
774
- edge_index = batch['gine']['edge_index']
775
- edge_attr = batch['gine']['edge_attr']
776
- batch_map = batch['gine'].get('batch', None)
777
  n_nodes = z.size(0)
778
  dev = z.device
779
  is_selected = torch.rand(n_nodes, device=dev) < p_mask
780
  if is_selected.numel() > 0 and is_selected.all():
781
  is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
 
782
  labels_z = torch.full_like(z, fill_value=-100)
783
  if is_selected.any():
784
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
785
  if sel_idx.dim() == 0:
786
  sel_idx = sel_idx.unsqueeze(0)
 
787
  labels_z[is_selected] = z[is_selected]
788
- rand_atomic = torch.randint(1, MAX_ATOMIC_Z+1, (sel_idx.size(0),), dtype=torch.long, device=dev)
 
789
  probs = torch.rand(sel_idx.size(0), device=dev)
790
  mask_choice = probs < 0.8
791
  rand_choice = (probs >= 0.8) & (probs < 0.9)
 
792
  if mask_choice.any():
793
  z[sel_idx[mask_choice]] = MASK_ATOM_ID
794
  if rand_choice.any():
795
  z[sel_idx[rand_choice]] = rand_atomic[rand_choice]
796
- b['gine'] = {"z": z, "chirality": chir, "formal_charge": fc, "edge_index": edge_index, "edge_attr": edge_attr, "batch": batch_map, "labels": labels_z}
797
 
798
- # SchNet:
799
- if 'schnet' in batch:
800
- z = batch['schnet']['z'].clone()
801
- pos = batch['schnet']['pos'].clone()
802
- batch_map = batch['schnet'].get('batch', None)
 
 
 
 
 
 
 
 
 
 
 
803
  n_nodes = z.size(0)
804
  dev = z.device
805
  is_selected = torch.rand(n_nodes, device=dev) < p_mask
806
  if is_selected.numel() > 0 and is_selected.all():
807
  is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
 
808
  labels_z = torch.full((n_nodes,), -100, dtype=torch.long, device=dev)
809
  if is_selected.any():
810
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
811
  if sel_idx.dim() == 0:
812
  sel_idx = sel_idx.unsqueeze(0)
 
813
  labels_z[is_selected] = z[is_selected]
814
  probs_c = torch.rand(sel_idx.size(0), device=dev)
815
  noisy_choice = probs_c < 0.8
816
  randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
 
817
  if noisy_choice.any():
818
  idx = sel_idx[noisy_choice]
819
  noise = torch.randn((idx.size(0), 3), device=pos.device) * 0.5
820
  pos[idx] = pos[idx] + noise
 
821
  if randpos_choice.any():
822
  idx = sel_idx[randpos_choice]
823
  mins = pos.min(dim=0).values
824
  maxs = pos.max(dim=0).values
825
  randpos = (torch.rand((idx.size(0), 3), device=pos.device) * (maxs - mins)) + mins
826
  pos[idx] = randpos
827
- b['schnet'] = {"z": z, "pos": pos, "batch": batch_map, "labels": labels_z}
828
 
829
- # FP:
830
- if 'fp' in batch:
831
- input_ids = batch['fp']['input_ids'].clone()
832
- attn = batch['fp'].get('attention_mask', torch.ones_like(input_ids, dtype=torch.bool))
 
 
 
833
  B, L = input_ids.shape
834
  dev = input_ids.device
835
  labels_z = torch.full_like(input_ids, -100)
 
836
  for i in range(B):
837
  sel = torch.rand(L, device=dev) < p_mask
838
  if sel.numel() > 0 and sel.all():
839
  sel[torch.randint(0, L, (1,), device=dev)] = False
 
840
  sel_idx = torch.nonzero(sel).squeeze(-1)
841
  if sel_idx.numel() > 0:
842
  if sel_idx.dim() == 0:
843
  sel_idx = sel_idx.unsqueeze(0)
 
844
  labels_z[i, sel_idx] = input_ids[i, sel_idx]
 
845
  probs = torch.rand(sel_idx.size(0), device=dev)
846
  mask_choice = probs < 0.8
847
  rand_choice = (probs >= 0.8) & (probs < 0.9)
 
848
  if mask_choice.any():
849
  input_ids[i, sel_idx[mask_choice]] = MASK_TOKEN_ID_FP
850
  if rand_choice.any():
851
  rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long, device=dev)
852
  input_ids[i, sel_idx[rand_choice]] = rand_bits
853
- b['fp'] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
854
 
855
- # PSMILES:
856
- if 'psmiles' in batch:
857
- input_ids = batch['psmiles']['input_ids'].clone()
858
- attn = batch['psmiles']['attention_mask'].clone()
 
 
 
859
  B, L = input_ids.shape
860
  dev = input_ids.device
861
  labels_z = torch.full_like(input_ids, -100)
 
 
862
  if tokenizer is None:
863
- b['psmiles'] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
864
  else:
865
  mask_token_id = tokenizer.mask_token_id if getattr(tokenizer, "mask_token_id", None) is not None else getattr(tokenizer, "vocab", {}).get("<mask>", 1)
 
866
  for i in range(B):
867
  sel = torch.rand(L, device=dev) < p_mask
868
  if sel.numel() > 0 and sel.all():
869
  sel[torch.randint(0, L, (1,), device=dev)] = False
 
870
  sel_idx = torch.nonzero(sel).squeeze(-1)
871
  if sel_idx.numel() > 0:
872
  if sel_idx.dim() == 0:
873
  sel_idx = sel_idx.unsqueeze(0)
 
874
  labels_z[i, sel_idx] = input_ids[i, sel_idx]
 
875
  probs = torch.rand(sel_idx.size(0), device=dev)
876
  mask_choice = probs < 0.8
877
  rand_choice = (probs >= 0.8) & (probs < 0.9)
 
878
  if mask_choice.any():
879
  input_ids[i, sel_idx[mask_choice]] = mask_token_id
880
  if rand_choice.any():
881
  rand_ids = torch.randint(0, getattr(tokenizer, "vocab_size", 300), (rand_choice.sum().item(),), dtype=torch.long, device=dev)
882
  input_ids[i, sel_idx[rand_choice]] = rand_ids
883
- b['psmiles'] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
 
884
 
885
  return b
886
 
887
- def mm_batch_to_model_input(masked_batch):
 
 
 
 
 
888
  mm = {}
889
- if 'gine' in masked_batch:
890
- gm = masked_batch['gine']
891
- mm['gine'] = {"z": gm['z'], "chirality": gm['chirality'], "formal_charge": gm['formal_charge'], "edge_index": gm['edge_index'], "edge_attr": gm['edge_attr'], "batch": gm.get('batch', None), "labels": gm.get('labels', None)}
892
- if 'schnet' in masked_batch:
893
- sm = masked_batch['schnet']
894
- mm['schnet'] = {"z": sm['z'], "pos": sm['pos'], "batch": sm.get('batch', None), "labels": sm.get('labels', None)}
895
- if 'fp' in masked_batch:
896
- fm = masked_batch['fp']
897
- mm['fp'] = {"input_ids": fm['input_ids'], "attention_mask": fm.get('attention_mask', None), "labels": fm.get('labels', None)}
898
- if 'psmiles' in masked_batch:
899
- pm = masked_batch['psmiles']
900
- mm['psmiles'] = {"input_ids": pm['input_ids'], "attention_mask": pm.get('attention_mask', None), "labels": pm.get('labels', None)}
 
 
 
 
 
 
 
 
901
  return mm
902
 
903
- # ---------------------------
904
- # Evaluation function (keeps same semantics)
905
- def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader, device, mask_target="fp"):
 
 
 
 
 
 
 
 
 
906
  model.eval()
907
  total_loss = 0.0
908
  total_examples = 0
909
  acc_sum = 0.0
910
- top5_sum = 0.0
911
- mrr_sum = 0.0
912
- mean_pos_logit_sum = 0.0
913
- mean_neg_logit_sum = 0.0
914
  f1_sum = 0.0
915
 
916
  with torch.no_grad():
917
  for batch in val_loader:
918
- masked_batch = mask_batch_for_modality(batch, mask_target, p_mask=P_MASK)
919
- # move to device
 
920
  for k in masked_batch:
921
  for subk in masked_batch[k]:
922
  if isinstance(masked_batch[k][subk], torch.Tensor):
923
  masked_batch[k][subk] = masked_batch[k][subk].to(device)
 
924
  mm_in = mm_batch_to_model_input(masked_batch)
925
  embs = model.encode(mm_in)
 
926
  if mask_target not in embs:
927
  continue
 
928
  target = embs[mask_target]
929
  other_keys = [k for k in embs.keys() if k != mask_target]
930
  if len(other_keys) == 0:
931
  continue
 
932
  anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
933
  logits = torch.matmul(anchor, target.T) / model.temperature
 
934
  B = logits.size(0)
935
  labels = torch.arange(B, device=logits.device)
 
936
  loss = F.cross_entropy(logits, labels)
937
  total_loss += loss.item() * B
938
  total_examples += B
@@ -941,46 +1098,14 @@ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader, device, m
941
  acc = (preds == labels).float().mean().item()
942
  acc_sum += acc * B
943
 
944
- if B >= 5:
945
- topk = min(5, B)
946
- topk_indices = torch.topk(logits, k=topk, dim=1).indices
947
- hits_topk = (topk_indices == labels.unsqueeze(1)).any(dim=1).float().mean().item()
948
- top5_sum += hits_topk * B
949
- else:
950
- top5_sum += acc * B
951
-
952
- sorted_desc = torch.argsort(logits, dim=1, descending=True)
953
- positions = (sorted_desc == labels.unsqueeze(1)).nonzero(as_tuple=False)
954
- ranks = torch.zeros(B, device=logits.device).float()
955
- if positions.numel() > 0:
956
- for p in positions:
957
- i, pos = int(p[0].item()), int(p[1].item())
958
- ranks[i] = pos + 1.0
959
- ranks_nonzero = ranks.clone()
960
- ranks_nonzero[ranks_nonzero == 0] = float('inf')
961
- mrr = (1.0 / ranks_nonzero).mean().item()
962
- mrr_sum += mrr * B
963
-
964
- pos_logits = logits[torch.arange(B), labels]
965
- neg_logits = logits.clone()
966
- neg_logits[torch.arange(B), labels] = float('-inf')
967
- neg_mask = neg_logits != float('-inf')
968
- if neg_mask.any():
969
- row_counts = neg_mask.sum(dim=1).clamp(min=1).float()
970
- sum_neg_per_row = neg_logits.masked_fill(~neg_mask, 0.0).sum(dim=1)
971
- mean_neg = (sum_neg_per_row / row_counts).mean().item()
972
- else:
973
- mean_neg = 0.0
974
- mean_pos_logit_sum += pos_logits.mean().item() * B
975
- mean_neg_logit_sum += mean_neg * B
976
-
977
  try:
978
  labels_np = labels.cpu().numpy()
979
  preds_np = preds.cpu().numpy()
980
  if len(np.unique(labels_np)) < 2:
981
  batch_f1 = float(acc)
982
  else:
983
- batch_f1 = f1_score(labels_np, preds_np, average='weighted')
984
  except Exception:
985
  batch_f1 = float(acc)
986
  f1_sum += batch_f1 * B
@@ -988,18 +1113,29 @@ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader, device, m
988
  if total_examples == 0:
989
  return {"eval_loss": float("nan"), "eval_accuracy": 0.0, "eval_f1_weighted": 0.0}
990
 
991
- avg_loss = total_loss / total_examples
992
- accuracy = acc_sum / total_examples
993
- f1_weighted = f1_sum / total_examples
 
 
 
994
 
995
- return {"eval_loss": avg_loss, "eval_accuracy": accuracy, "eval_f1_weighted": f1_weighted}
 
 
996
 
997
- # ---------------------------
998
- # HF wrapper / Trainer integration (kept same as your part 2, uses lazy loaders)
999
  class HFMultimodalModule(nn.Module):
1000
- def __init__(self, mm_model: MultimodalContrastiveModel):
 
 
 
 
 
 
 
1001
  super().__init__()
1002
  self.mm = mm_model
 
1003
 
1004
  def forward(self, **kwargs):
1005
  if "batch" in kwargs:
@@ -1012,16 +1148,23 @@ class HFMultimodalModule(nn.Module):
1012
  batch = {k: found[k] for k in found}
1013
  mask_target = kwargs.get("mask_target", "fp")
1014
  else:
1015
- raise ValueError("HFMultimodalModule.forward could not find 'batch' nor modality keys in inputs. Inputs keys: {}".format(list(kwargs.keys())))
1016
- masked_batch = mask_batch_for_modality(batch, mask_target, p_mask=P_MASK)
 
 
 
 
 
1017
  device = next(self.parameters()).device
1018
  for k in masked_batch:
1019
  for subk in list(masked_batch[k].keys()):
1020
  val = masked_batch[k][subk]
1021
  if isinstance(val, torch.Tensor):
1022
  masked_batch[k][subk] = val.to(device)
 
1023
  mm_in = mm_batch_to_model_input(masked_batch)
1024
  loss, info = self.mm(mm_in, mask_target)
 
1025
  logits = None
1026
  labels = None
1027
  try:
@@ -1039,6 +1182,7 @@ class HFMultimodalModule(nn.Module):
1039
  print("Warning: failed to compute logits/labels inside HFMultimodalModule.forward:", e)
1040
  logits = None
1041
  labels = None
 
1042
  eval_loss = loss.detach() if isinstance(loss, torch.Tensor) else torch.tensor(float(loss), device=device)
1043
  out = {"loss": loss, "eval_loss": eval_loss}
1044
  if logits is not None:
@@ -1048,11 +1192,15 @@ class HFMultimodalModule(nn.Module):
1048
  out["mm_info"] = info
1049
  return out
1050
 
1051
- hf_model = HFMultimodalModule(multimodal_model)
1052
- hf_model.to(device)
1053
 
1054
  class ContrastiveDataCollator:
1055
- def __init__(self, mask_prob=P_MASK, modalities: Optional[List[str]] = None):
 
 
 
 
 
 
1056
  self.mask_prob = mask_prob
1057
  self.modalities = modalities if modalities is not None else ["gine", "schnet", "fp", "psmiles"]
1058
 
@@ -1061,22 +1209,30 @@ class ContrastiveDataCollator:
1061
  collated = features
1062
  mask_target = random.choice([m for m in self.modalities if m in collated])
1063
  return {"batch": collated, "mask_target": mask_target}
 
1064
  if isinstance(features, (list, tuple)) and len(features) > 0:
1065
  first = features[0]
1066
- if isinstance(first, dict) and 'gine' in first:
1067
  collated = multimodal_collate(list(features))
1068
  mask_target = random.choice([m for m in self.modalities if m in collated])
1069
  return {"batch": collated, "mask_target": mask_target}
1070
- if isinstance(first, dict) and 'batch' in first:
1071
- collated = first['batch']
 
1072
  mask_target = first.get("mask_target", random.choice([m for m in self.modalities if m in collated]))
1073
  return {"batch": collated, "mask_target": mask_target}
 
1074
  print("ContrastiveDataCollator received unexpected 'features' shape/type.")
1075
  raise ValueError("ContrastiveDataCollator could not collate input. Expected list[dict] with 'gine' key or already-collated dict.")
1076
 
1077
- data_collator = ContrastiveDataCollator(mask_prob=P_MASK)
1078
 
1079
  class VerboseTrainingCallback(TrainerCallback):
 
 
 
 
 
 
1080
  def __init__(self, patience: int = 10):
1081
  self.start_time = time.time()
1082
  self.epoch_start_time = time.time()
@@ -1104,42 +1260,45 @@ class VerboseTrainingCallback(TrainerCallback):
1104
 
1105
  def on_train_begin(self, args, state, control, **kwargs):
1106
  self.start_time = time.time()
1107
- print("" + "="*80)
1108
- print("🚀 STARTING MULTIMODAL CONTRASTIVE LEARNING TRAINING")
1109
- print("="*80)
1110
- model = kwargs.get('model')
 
1111
  if model is not None:
1112
  total_params = sum(p.numel() for p in model.parameters())
1113
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1114
  non_trainable_params = total_params - trainable_params
1115
- print(f"📊 MODEL PARAMETERS:")
1116
  print(f" Total Parameters: {total_params:,}")
1117
  print(f" Trainable Parameters: {trainable_params:,}")
1118
  print(f" Non-trainable Parameters: {non_trainable_params:,}")
1119
  print(f" Training Progress: 0/{args.num_train_epochs} epochs")
1120
- print("="*80)
 
1121
 
1122
  def on_epoch_begin(self, args, state, control, **kwargs):
1123
  self.epoch_start_time = time.time()
1124
  current_epoch = state.epoch if state is not None else 0.0
1125
- print(f"📈 Epoch {current_epoch + 1:.1f}/{args.num_train_epochs} Starting...")
1126
 
1127
  def on_epoch_end(self, args, state, control, **kwargs):
1128
  train_loss = None
1129
  for log in reversed(state.log_history):
1130
- if isinstance(log, dict) and 'loss' in log and float(log.get('loss', 0)) != 0.0:
1131
- train_loss = log['loss']
1132
  break
1133
  self._last_train_loss = train_loss
1134
 
1135
  def on_log(self, args, state, control, logs=None, **kwargs):
1136
- if logs is not None and 'loss' in logs:
1137
  current_step = state.global_step
1138
  current_epoch = state.epoch
1139
  try:
1140
  steps_per_epoch = max(1, len(train_loader) // args.gradient_accumulation_steps)
1141
  except Exception:
1142
  steps_per_epoch = 1
 
1143
  if current_step % max(1, steps_per_epoch // 10) == 0:
1144
  progress = current_epoch + (current_step % steps_per_epoch) / steps_per_epoch
1145
  print(f" Step {current_step:4d} | Epoch {progress:.1f} | Train Loss: {logs['loss']:.6f}")
@@ -1147,54 +1306,63 @@ class VerboseTrainingCallback(TrainerCallback):
1147
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
1148
  current_epoch = state.epoch if state is not None else 0.0
1149
  epoch_time = time.time() - self.epoch_start_time
1150
- hf_metrics = metrics if metrics is not None else kwargs.get('metrics', None)
 
1151
  hf_eval_loss = None
1152
  hf_train_loss = self._last_train_loss
 
1153
  if hf_metrics is not None:
1154
- hf_eval_loss = hf_metrics.get('eval_loss', hf_metrics.get('loss', None))
1155
  if hf_train_loss is None:
1156
- hf_train_loss = hf_metrics.get('train_loss', hf_train_loss)
 
1157
  cl_metrics = {}
1158
  try:
1159
- model = kwargs.get('model', None)
1160
  if model is not None:
1161
  cl_model = model.mm if hasattr(model, "mm") else model
1162
- cl_metrics = evaluate_multimodal(cl_model, val_loader, device, mask_target="fp")
1163
  else:
1164
- cl_metrics = evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
1165
  except Exception as e:
1166
  print("Warning: evaluate_multimodal inside callback failed:", e)
 
1167
  if hf_eval_loss is None:
1168
- hf_eval_loss = cl_metrics.get('eval_loss', None)
1169
- val_acc = cl_metrics.get('eval_accuracy', 'N/A')
1170
- val_f1 = cl_metrics.get('eval_f1_weighted', 'N/A')
1171
- print(f"🔍 EPOCH {current_epoch + 1:.1f} RESULTS:")
 
 
1172
  if hf_train_loss is not None:
1173
  try:
1174
  print(f" Train Loss (HF reported): {hf_train_loss:.6f}")
1175
  except Exception:
1176
  print(f" Train Loss (HF reported): {hf_train_loss}")
1177
  else:
1178
- print(f" Train Loss (HF reported): N/A")
 
1179
  if hf_eval_loss is not None:
1180
  try:
1181
  print(f" Eval Loss (HF reported): {hf_eval_loss:.6f}")
1182
  except Exception:
1183
  print(f" Eval Loss (HF reported): {hf_eval_loss}")
1184
  else:
1185
- print(f" Eval Loss (HF reported): N/A")
 
1186
  if isinstance(val_acc, float):
1187
  print(f" Eval Acc (CL evaluator): {val_acc:.6f}")
1188
  else:
1189
  print(f" Eval Acc (CL evaluator): {val_acc}")
 
1190
  if isinstance(val_f1, float):
1191
  print(f" Eval F1 Weighted (CL evaluator): {val_f1:.6f}")
1192
  else:
1193
  print(f" Eval F1 Weighted (CL evaluator): {val_f1}")
1194
- current_val = hf_eval_loss if hf_eval_loss is not None else float('inf')
1195
- improved = False
 
1196
  if current_val < self.best_val_loss - 1e-6:
1197
- improved = True
1198
  self.best_val_loss = current_val
1199
  self.best_epoch = current_epoch
1200
  self.epochs_no_improve = 0
@@ -1204,9 +1372,11 @@ class VerboseTrainingCallback(TrainerCallback):
1204
  print("Warning: saving best model failed:", e)
1205
  else:
1206
  self.epochs_no_improve += 1
 
1207
  if self.epochs_no_improve >= self.patience:
1208
  print(f"Early stopping: no improvement in val_loss for {self.patience} epochs.")
1209
  control.should_training_stop = True
 
1210
  print(f" Epoch Training Time: {epoch_time:.2f}s")
1211
  print(f" Best Val Loss so far: {self.best_val_loss}")
1212
  print(f" Epochs since improvement: {self.epochs_no_improve}/{self.patience}")
@@ -1214,20 +1384,26 @@ class VerboseTrainingCallback(TrainerCallback):
1214
 
1215
  def on_train_end(self, args, state, control, **kwargs):
1216
  total_time = time.time() - self.start_time
1217
- print("" + "="*80)
1218
- print("🏁 TRAINING COMPLETED")
1219
- print("="*80)
1220
  print(f" Total Training Time: {total_time:.2f}s")
1221
  if state is not None:
1222
  try:
1223
  print(f" Total Epochs Completed: {state.epoch + 1:.1f}")
1224
  except Exception:
1225
  pass
1226
- print("="*80)
 
1227
 
1228
- from transformers import Trainer as HfTrainer
 
 
 
 
 
 
1229
 
1230
- class CLTrainer(HfTrainer):
1231
  def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
1232
  try:
1233
  metrics = super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) or {}
@@ -1237,7 +1413,7 @@ class CLTrainer(HfTrainer):
1237
  traceback.print_exc()
1238
  try:
1239
  cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
1240
- cl_metrics = evaluate_multimodal(cl_model, val_loader, device, mask_target="fp")
1241
  metrics = {k: float(v) if isinstance(v, (float, int, np.floating, np.integer)) else v for k, v in cl_metrics.items()}
1242
  metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
1243
  except Exception as e2:
@@ -1245,32 +1421,39 @@ class CLTrainer(HfTrainer):
1245
  traceback.print_exc()
1246
  metrics = {"eval_loss": float("nan"), "epoch": float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else 0.0}
1247
  return metrics
 
1248
  try:
1249
  cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
1250
- cl_metrics = evaluate_multimodal(cl_model, val_loader, device, mask_target="fp")
1251
  except Exception as e:
1252
  print("Warning: evaluate_multimodal failed inside CLTrainer.evaluate():", e)
1253
  cl_metrics = {}
 
1254
  for k, v in cl_metrics.items():
1255
  try:
1256
  metrics[k] = float(v)
1257
  except Exception:
1258
  metrics[k] = v
1259
- if 'eval_loss' not in metrics and 'eval_loss' in cl_metrics:
 
1260
  try:
1261
- metrics['eval_loss'] = float(cl_metrics['eval_loss'])
1262
  except Exception:
1263
- metrics['eval_loss'] = cl_metrics['eval_loss']
 
1264
  if "epoch" not in metrics:
1265
  metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
 
1266
  return metrics
1267
 
1268
  def _save(self, output_dir: str):
1269
  os.makedirs(output_dir, exist_ok=True)
 
1270
  try:
1271
  self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
1272
  except Exception:
1273
  pass
 
1274
  try:
1275
  model_to_save = self.model.mm if hasattr(self.model, "mm") else self.model
1276
  torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
@@ -1282,10 +1465,12 @@ class CLTrainer(HfTrainer):
1282
  raise e
1283
  except Exception as e2:
1284
  print("Warning: failed to save model state_dict:", e2)
 
1285
  try:
1286
  torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
1287
  except Exception:
1288
  pass
 
1289
  try:
1290
  torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
1291
  except Exception:
@@ -1295,14 +1480,17 @@ class CLTrainer(HfTrainer):
1295
  best_ckpt = self.state.best_model_checkpoint
1296
  if not best_ckpt:
1297
  return
 
1298
  candidate = os.path.join(best_ckpt, "pytorch_model.bin")
1299
  if not os.path.exists(candidate):
1300
  candidate = os.path.join(best_ckpt, "model.bin")
1301
  if not os.path.exists(candidate):
1302
  candidate = None
 
1303
  if candidate is None:
1304
  print(f"CLTrainer._load_best_model(): no compatible pytorch_model.bin found in {best_ckpt}; skipping load.")
1305
  return
 
1306
  try:
1307
  state_dict = torch.load(candidate, map_location=self.args.device)
1308
  model_to_load = self.model.mm if hasattr(self.model, "mm") else self.model
@@ -1312,97 +1500,224 @@ class CLTrainer(HfTrainer):
1312
  print("CLTrainer._load_best_model: failed to load state_dict using torch.load:", e)
1313
  return
1314
 
1315
- callback = VerboseTrainingCallback(patience=10)
1316
 
1317
- trainer = CLTrainer(
1318
- model=hf_model,
1319
- args=training_args,
1320
- train_dataset=train_subset,
1321
- eval_dataset=val_subset,
1322
- data_collator=data_collator,
1323
- callbacks=[callback],
1324
- )
1325
 
1326
- callback.trainer_ref = trainer
1327
-
1328
- # Force HF Trainer to use our prebuilt PyTorch DataLoaders
1329
- trainer.get_train_dataloader = lambda dataset=None: train_loader
1330
- trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1331
 
1332
- training_args.metric_for_best_model = "eval_loss"
1333
- training_args.greater_is_better = False
1334
 
1335
- optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1336
 
1337
- total_params = sum(p.numel() for p in multimodal_model.parameters())
1338
- trainable_params = sum(p.numel() for p in multimodal_model.parameters() if p.requires_grad)
1339
- non_trainable_params = total_params - trainable_params
 
1340
 
1341
- print(f"\n📊 MODEL PARAMETERS:")
1342
- print(f" Total Parameters: {total_params:,}")
1343
- print(f" Trainable Parameters: {trainable_params:,}")
1344
- print(f" Non-trainable Parameters: {non_trainable_params:,}")
1345
 
1346
- def compute_metrics_wrapper(eval_pred):
1347
- return evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
 
 
 
 
 
 
 
1348
 
1349
- # ---------------------------
1350
- # Clear any cached GPU memory before starting (helpful)
1351
- if USE_CUDA:
1352
  try:
1353
- torch.cuda.empty_cache()
1354
- except Exception:
1355
- pass
 
 
 
 
 
1356
 
1357
- # ---------------------------
1358
- # Start training
1359
- training_start_time = time.time()
1360
- trainer.train()
1361
- training_end_time = time.time()
1362
 
1363
- # Save best
1364
- BEST_MULTIMODAL_DIR = os.path.join(OUTPUT_DIR, "best")
1365
- os.makedirs(BEST_MULTIMODAL_DIR, exist_ok=True)
1366
 
1367
- try:
1368
- best_ckpt = trainer.state.best_model_checkpoint
1369
- if best_ckpt:
1370
- multimodal_model.load_state_dict(torch.load(os.path.join(best_ckpt, "pytorch_model.bin"), map_location=device), strict=False)
1371
- print(f"Loaded best checkpoint from {best_ckpt} into multimodal_model for final evaluation.")
1372
- torch.save(multimodal_model.state_dict(), os.path.join(BEST_MULTIMODAL_DIR, "pytorch_model.bin"))
1373
- print(f" Saved best multimodal model to {os.path.join(BEST_MULTIMODAL_DIR, 'pytorch_model.bin')}")
1374
- except Exception as e:
1375
- print("Warning: failed to load/save best model from Trainer:", e)
1376
-
1377
- # Final evaluation
1378
- final_metrics = {}
1379
- try:
1380
- if trainer.state.best_model_checkpoint:
1381
- trainer._load_best_model()
1382
- final_metrics = trainer.evaluate(eval_dataset=val_subset)
1383
- else:
1384
- final_metrics = evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
1385
- except Exception as e:
1386
- print("Warning: final evaluation via trainer.evaluate failed, falling back to direct evaluate_multimodal:", e)
1387
- final_metrics = evaluate_multimodal(multimodal_model, val_loader, device, mask_target="fp")
1388
-
1389
- print("\n" + "="*80)
1390
- print("🏁 FINAL TRAINING RESULTS")
1391
- print("="*80)
1392
- training_time = training_end_time - training_start_time
1393
- print(f"Total Training Time: {training_time:.2f}s")
1394
- best_ckpt = trainer.state.best_model_checkpoint if hasattr(trainer.state, 'best_model_checkpoint') else None
1395
- if best_ckpt:
1396
- print(f"Best Checkpoint: {best_ckpt}")
1397
- else:
1398
- print("Best Checkpoint: (none saved)")
1399
-
1400
- hf_eval_loss = final_metrics.get('eval_loss', float('nan'))
1401
- hf_eval_acc = final_metrics.get('eval_accuracy', 0.0)
1402
- hf_eval_f1 = final_metrics.get('eval_f1_weighted', 0.0)
1403
- print(f"Val Loss (HF reported / trainer.evaluate): {hf_eval_loss:.4f}")
1404
- print(f"Val Acc (CL evaluator): {hf_eval_acc:.4f}")
1405
- print(f"Val F1 Weighted (CL evaluator): {hf_eval_f1:.4f}")
1406
- print(f"Total Trainable Params: {trainable_params:,}")
1407
- print(f"Total Non-trainable Params: {non_trainable_params:,}")
1408
- print("="*80)
 
1
+ """
2
+ PolyFusion - CL.py
3
+ Multimodal contrastive pretraining script (DeBERTaV2 + GINE + SchNet + Transformer).
4
+ """
5
+
6
  import os
7
  import sys
8
  import csv
 
27
  import torch.nn.functional as F
28
  from torch.utils.data import Dataset, DataLoader
29
 
30
+ # Shared model utilities
31
  from GINE import GineEncoder, match_edge_attr_to_index, safe_get
32
  from SchNet import NodeSchNetWrapper
33
  from Transformer import PooledFingerprintEncoder as FingerprintEncoder
 
35
 
36
  # HF Trainer & Transformers
37
  from transformers import TrainingArguments, Trainer
 
38
  from transformers.trainer_callback import TrainerCallback
39
 
40
  from sklearn.model_selection import train_test_split
41
+ from sklearn.metrics import f1_score
42
+
43
+ # =============================================================================
44
+ # Configuration (paths are placeholders; update for your environment)
45
+ # =============================================================================
46
 
 
 
 
47
  P_MASK = 0.15
48
  MAX_ATOMIC_Z = 85
49
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
 
53
  EDGE_EMB_DIM = 300
54
  NUM_GNN_LAYERS = 5
55
 
56
+ # SchNet params
57
  SCHNET_NUM_GAUSSIANS = 50
58
  SCHNET_NUM_INTERACTIONS = 6
59
  SCHNET_CUTOFF = 10.0
60
  SCHNET_MAX_NEIGHBORS = 64
61
  SCHNET_HIDDEN = 600
62
 
63
+ # Fingerprint Transformer params
64
  FP_LENGTH = 2048
65
+ MASK_TOKEN_ID_FP = 2
66
  VOCAB_SIZE_FP = 3
67
 
68
+ # DeBERTaV2 params
69
  DEBERTA_HIDDEN = 600
70
  PSMILES_MAX_LEN = 128
71
 
72
  # Contrastive params
73
  TEMPERATURE = 0.07
74
+ REC_LOSS_WEIGHT = 1.0 # Reconstruction loss weight
75
+
76
+ # Data / preprocessing
77
+ CSV_PATH = "/path/to/polymer_structures_unified_processed.csv"
78
+ TARGET_ROWS = 2000000
79
+ CHUNKSIZE = 50000
80
+ PREPROC_DIR = "/path/to/preprocessed_samples"
81
 
82
+ # Tokenizer assets
83
+ SPM_MODEL = "/path/to/spm.model"
84
 
85
+ # Outputs / checkpoints
86
  OUTPUT_DIR = "/path/to/multimodal_output"
 
87
  BEST_GINE_DIR = "/path/to/gin_output/best"
88
  BEST_SCHNET_DIR = "/path/to/schnet_output/best"
89
  BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
90
  BEST_PSMILES_DIR = "/path/to/polybert_output/best"
91
 
92
+
93
+ # =============================================================================
94
+ # Reproducibility + device
95
+ # =============================================================================
96
+
97
+ def get_device() -> torch.device:
98
+ """Select CUDA if available (respects CUDA_VISIBLE_DEVICES), else CPU."""
99
+ return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
100
+
101
+
102
+ def set_seed(seed: int = 42) -> None:
103
+ """Set Python/Numpy/Torch seeds for deterministic-ish behavior."""
104
+ random.seed(seed)
105
+ np.random.seed(seed)
106
+ torch.manual_seed(seed)
107
+ if torch.cuda.is_available():
108
+ torch.cuda.manual_seed_all(seed)
109
+
110
+
111
+ # =============================================================================
112
+ # Optional helper
113
+ # =============================================================================
114
+
115
+ def bfs_distances_to_visible(
116
+ edge_index: torch.Tensor,
117
+ num_nodes: int,
118
+ masked_idx: np.ndarray,
119
+ visible_idx: np.ndarray,
120
+ k_anchors: int
121
+ ):
122
+ """
123
+ Compute shortest-path distances from masked nodes to nearest visible anchors.
124
+
125
+ This helper was present in your original script. It remains here unchanged in
126
+ behavior (and may be useful for future masking/anchor variants), but it is
127
+ not required for the current CL objective.
128
+ """
 
 
 
 
 
 
 
 
 
 
129
  selected_dists = np.zeros((num_nodes, k_anchors), dtype=np.float32)
130
  selected_mask = np.zeros((num_nodes, k_anchors), dtype=np.bool_)
131
+
132
  if edge_index is None or edge_index.numel() == 0:
133
  return selected_dists, selected_mask
134
+
135
  src = edge_index[0].tolist()
136
  dst = edge_index[1].tolist()
137
+
138
  adj = [[] for _ in range(num_nodes)]
139
  for u, v in zip(src, dst):
140
  if 0 <= u < num_nodes and 0 <= v < num_nodes:
141
  adj[u].append(v)
142
+
143
+ visible_set = (
144
+ set(visible_idx.tolist())
145
+ if isinstance(visible_idx, (np.ndarray, list))
146
+ else set(visible_idx.cpu().tolist())
147
+ )
148
+
149
  for a in np.atleast_1d(masked_idx).tolist():
150
  if a < 0 or a >= num_nodes:
151
  continue
152
+
153
  q = [a]
154
  visited = [-1] * num_nodes
155
  visited[a] = 0
156
  head = 0
157
  found = []
158
+
159
  while head < len(q) and len(found) < k_anchors:
160
+ u = q[head]
161
+ head += 1
162
  for v in adj[u]:
163
  if visited[v] == -1:
164
  visited[v] = visited[u] + 1
 
167
  found.append((visited[v], v))
168
  if len(found) >= k_anchors:
169
  break
170
+
171
  if len(found) > 0:
172
  found.sort(key=lambda x: x[0])
173
  k = min(k_anchors, len(found))
174
  for i in range(k):
175
  selected_dists[a, i] = float(found[i][0])
176
  selected_mask[a, i] = True
177
+
178
  return selected_dists, selected_mask
179
 
 
 
 
 
 
 
180
 
181
+ # =============================================================================
182
+ # Preprocessing (streaming to disk to avoid large memory spikes)
183
+ # =============================================================================
184
+
185
+ def ensure_dir(path: str) -> None:
186
+ """Create a directory if it doesn't exist."""
187
+ os.makedirs(path, exist_ok=True)
188
+
189
+
190
+ def prepare_or_load_data_streaming(
191
+ csv_path: str,
192
+ preproc_dir: str,
193
+ target_rows: int = TARGET_ROWS,
194
+ chunksize: int = CHUNKSIZE
195
+ ) -> List[str]:
196
+ """
197
+ Prepare per-sample serialized files (torch .pt) for lazy loading.
198
+
199
+ - If `preproc_dir` already contains `sample_*.pt`, reuse them.
200
+ - Else stream the CSV in chunks and write `sample_{idx:08d}.pt` files.
201
+ """
202
+ ensure_dir(preproc_dir)
203
+
204
+ existing = sorted([p for p in Path(preproc_dir).glob("sample_*.pt")])
205
  if len(existing) > 0:
206
+ print(f"Found {len(existing)} preprocessed sample files in {preproc_dir}; reusing those (no reparse).")
207
  return [str(p) for p in existing]
208
 
209
  print("No existing per-sample preprocessed folder found. Parsing CSV chunked and writing per-sample files (streaming).")
210
+ rows_written = 0
211
  sample_idx = 0
212
 
213
+ for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize):
 
214
  has_graph = "graph" in chunk.columns
215
  has_geometry = "geometry" in chunk.columns
216
  has_fp = "fingerprints" in chunk.columns
217
  has_psmiles = "psmiles" in chunk.columns
218
 
219
  for i_row in range(len(chunk)):
220
+ if rows_written >= target_rows:
221
  break
222
+
223
  row = chunk.iloc[i_row]
224
 
225
+ # Per-row modality payloads (None if missing)
226
  gine_sample = None
227
  schnet_sample = None
228
  fp_sample = None
229
  psmiles_raw = None
230
 
231
+ # -------- Graph / GINE modality --------
232
  if has_graph:
233
  val = row.get("graph", "")
234
  try:
235
+ graph_field = (
236
+ json.loads(val)
237
+ if isinstance(val, str) and val.strip() != ""
238
+ else (val if not isinstance(val, str) else None)
239
+ )
240
  except Exception:
241
  graph_field = None
242
+
243
  if graph_field:
244
  node_features = safe_get(graph_field, "node_features", None)
245
  if node_features:
246
  atomic_nums = []
247
  chirality_vals = []
248
  formal_charges = []
249
+
250
  for nf in node_features:
251
  an = safe_get(nf, "atomic_num", None)
252
  if an is None:
253
  an = safe_get(nf, "atomic_number", 0)
254
  ch = safe_get(nf, "chirality", 0)
255
  fc = safe_get(nf, "formal_charge", 0)
256
+
257
  try:
258
  atomic_nums.append(int(an))
259
  except Exception:
260
  atomic_nums.append(0)
261
+
262
  chirality_vals.append(float(ch))
263
  formal_charges.append(float(fc))
264
+
265
  edge_indices_raw = safe_get(graph_field, "edge_indices", None)
266
  edge_features_raw = safe_get(graph_field, "edge_features", None)
267
+
268
  edge_index = None
269
  edge_attr = None
270
+
271
+ # Handle missing edge_indices via adjacency_matrix
272
  if edge_indices_raw is None:
273
  adj_mat = safe_get(graph_field, "adjacency_matrix", None)
274
  if adj_mat:
275
+ srcs, dsts = [], []
 
276
  for i_r, row_adj in enumerate(adj_mat):
277
  for j, val2 in enumerate(row_adj):
278
  if val2:
279
+ srcs.append(i_r)
280
+ dsts.append(j)
281
  if len(srcs) > 0:
282
  edge_index = [srcs, dsts]
283
  E = len(srcs)
284
  edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
285
  else:
286
+ # edge_indices_raw can be:
287
+ # - list of pairs [[u,v], ...]
288
+ # - two lists [[srcs], [dsts]]
289
  srcs, dsts = [], []
290
+
291
  if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0 and isinstance(edge_indices_raw[0], list):
 
292
  first = edge_indices_raw[0]
293
  if len(first) == 2 and isinstance(first[0], int):
294
+ # list of pairs
295
  try:
296
  srcs = [int(p[0]) for p in edge_indices_raw]
297
  dsts = [int(p[1]) for p in edge_indices_raw]
298
  except Exception:
299
  srcs, dsts = [], []
300
  else:
301
+ # two lists
302
  try:
303
  srcs = [int(x) for x in edge_indices_raw[0]]
304
  dsts = [int(x) for x in edge_indices_raw[1]]
305
  except Exception:
306
  srcs, dsts = [], []
307
+
308
+ if len(srcs) == 0 and isinstance(edge_indices_raw, list) and all(
309
+ isinstance(p, (list, tuple)) and len(p) == 2 for p in edge_indices_raw
310
+ ):
311
  srcs = [int(p[0]) for p in edge_indices_raw]
312
  dsts = [int(p[1]) for p in edge_indices_raw]
313
+
314
  if len(srcs) > 0:
315
  edge_index = [srcs, dsts]
316
+
317
  if edge_features_raw and isinstance(edge_features_raw, list):
318
+ bond_types, stereos, is_conjs = [], [], []
 
 
319
  for ef in edge_features_raw:
320
  bt = safe_get(ef, "bond_type", 0)
321
  st = safe_get(ef, "stereo", 0)
322
  ic = safe_get(ef, "is_conjugated", False)
323
+ bond_types.append(float(bt))
324
+ stereos.append(float(st))
325
+ is_conjs.append(float(1.0 if ic else 0.0))
326
  edge_attr = list(zip(bond_types, stereos, is_conjs))
327
  else:
328
  E = len(srcs)
 
337
  "edge_attr": edge_attr,
338
  }
339
 
340
+ # -------- Geometry / SchNet modality --------
341
  if has_geometry and schnet_sample is None:
342
  val = row.get("geometry", "")
343
  try:
344
+ geom = (
345
+ json.loads(val)
346
+ if isinstance(val, str) and val.strip() != ""
347
+ else (val if not isinstance(val, str) else None)
348
+ )
349
  conf = geom.get("best_conformer") if isinstance(geom, dict) else None
350
  if conf:
351
  atomic = conf.get("atomic_numbers", [])
 
355
  except Exception:
356
  schnet_sample = None
357
 
358
+ # -------- Fingerprint modality --------
359
  if has_fp:
360
  fpval = row.get("fingerprints", "")
361
  if fpval is None or (isinstance(fpval, str) and fpval.strip() == ""):
362
  fp_sample = [0] * FP_LENGTH
363
  else:
364
+ fp_json = None
365
  try:
366
  fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval
367
  except Exception:
 
373
  if len(bits) < FP_LENGTH:
374
  bits += [0] * (FP_LENGTH - len(bits))
375
  fp_sample = bits
376
+
377
  if fp_sample is None:
378
+ bits = (
379
+ safe_get(fp_json, "morgan_r3_bits", None)
380
+ if isinstance(fp_json, dict)
381
+ else (fp_json if isinstance(fp_json, list) else None)
382
+ )
383
  if bits is None:
384
  fp_sample = [0] * FP_LENGTH
385
  else:
 
392
  normalized.append(1 if int(b) != 0 else 0)
393
  else:
394
  normalized.append(0)
395
+
396
  if len(normalized) >= FP_LENGTH:
397
  break
398
+
399
  if len(normalized) < FP_LENGTH:
400
  normalized.extend([0] * (FP_LENGTH - len(normalized)))
401
  fp_sample = normalized[:FP_LENGTH]
402
 
403
+ # -------- PSMILES modality --------
404
  if has_psmiles:
405
  s = row.get("psmiles", "")
406
+ psmiles_raw = "" if s is None else str(s)
 
 
 
407
 
408
+ # Require at least 2 modalities to keep sample (same logic as your original)
409
+ modalities_present = sum(
410
+ [1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
411
+ )
412
  if modalities_present >= 2:
413
  sample = {
414
  "gine": gine_sample,
415
  "schnet": schnet_sample,
416
  "fp": fp_sample,
417
+ "psmiles_raw": psmiles_raw,
418
  }
419
+
420
+ sample_path = os.path.join(preproc_dir, f"sample_{sample_idx:08d}.pt")
421
  try:
422
  torch.save(sample, sample_path)
423
  except Exception as save_e:
424
  print("Warning: failed to torch.save sample:", save_e)
425
+ # fallback JSON for debugging (kept from your original)
426
  try:
427
  with open(sample_path + ".json", "w") as fjson:
428
  json.dump(sample, fjson)
 
 
429
  except Exception:
430
  pass
431
 
432
  sample_idx += 1
433
+ rows_written += 1
434
 
435
+ if rows_written >= target_rows:
 
436
  break
437
 
438
+ print(f"Wrote {sample_idx} sample files to {preproc_dir}.")
439
+ return [str(p) for p in sorted(Path(preproc_dir).glob("sample_*.pt"))]
440
 
 
441
 
442
+ # =============================================================================
443
+ # Dataset + collate
444
+ # =============================================================================
 
 
445
 
 
 
 
446
  class LazyMultimodalDataset(Dataset):
447
+ """
448
+ Lazily loads per-sample files from disk and converts them into tensors.
449
+
450
+ Each sample file is expected to contain:
451
+ - gine: dict or None
452
+ - schnet: dict or None
453
+ - fp: list[int] or tensor
454
+ - psmiles_raw: str
455
+ """
456
+
457
+ def __init__(self, sample_file_list: List[str], tokenizer, fp_length: int = FP_LENGTH, psmiles_max_len: int = PSMILES_MAX_LEN):
458
  self.files = sample_file_list
459
  self.tokenizer = tokenizer
460
  self.fp_length = fp_length
461
  self.psmiles_max_len = psmiles_max_len
462
 
463
+ def __len__(self) -> int:
464
  return len(self.files)
465
 
466
+ def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]:
467
  sample_path = self.files[idx]
468
+
469
+ # prefer torch.load if .pt, else try json (kept behavior)
470
  if sample_path.endswith(".pt"):
471
  sample = torch.load(sample_path, map_location="cpu")
472
  else:
 
473
  with open(sample_path, "r") as f:
474
  sample = json.load(f)
475
 
476
+ # ---- GINE tensors ----
477
  gine_raw = sample.get("gine", None)
 
478
  if gine_raw:
479
  node_atomic = torch.tensor(gine_raw.get("node_atomic", []), dtype=torch.long)
480
  node_chirality = torch.tensor(gine_raw.get("node_chirality", []), dtype=torch.float)
481
  node_charge = torch.tensor(gine_raw.get("node_charge", []), dtype=torch.float)
482
+
483
  if gine_raw.get("edge_index", None) is not None:
484
+ edge_index = torch.tensor(gine_raw["edge_index"], dtype=torch.long)
 
485
  else:
486
  edge_index = torch.tensor([[], []], dtype=torch.long)
487
+
488
  ea_raw = gine_raw.get("edge_attr", None)
489
  if ea_raw:
490
  edge_attr = torch.tensor(ea_raw, dtype=torch.float)
491
  else:
492
  edge_attr = torch.zeros((edge_index.size(1), 3), dtype=torch.float)
 
 
 
493
 
494
+ gine_item = {
495
+ "z": node_atomic,
496
+ "chirality": node_chirality,
497
+ "formal_charge": node_charge,
498
+ "edge_index": edge_index,
499
+ "edge_attr": edge_attr,
500
+ }
501
+ else:
502
+ gine_item = {
503
+ "z": torch.tensor([], dtype=torch.long),
504
+ "chirality": torch.tensor([], dtype=torch.float),
505
+ "formal_charge": torch.tensor([], dtype=torch.float),
506
+ "edge_index": torch.tensor([[], []], dtype=torch.long),
507
+ "edge_attr": torch.zeros((0, 3), dtype=torch.float),
508
+ }
509
+
510
+ # ---- SchNet tensors ----
511
  schnet_raw = sample.get("schnet", None)
512
  if schnet_raw:
513
  s_z = torch.tensor(schnet_raw.get("atomic", []), dtype=torch.long)
 
516
  else:
517
  schnet_item = {"z": torch.tensor([], dtype=torch.long), "pos": torch.tensor([], dtype=torch.float)}
518
 
519
+ # ---- Fingerprint tensors ----
520
  fp_raw = sample.get("fp", None)
521
  if fp_raw is None:
522
  fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
523
  else:
 
524
  if isinstance(fp_raw, (list, tuple)):
525
  arr = list(fp_raw)[:self.fp_length]
526
  if len(arr) < self.fp_length:
 
529
  elif isinstance(fp_raw, torch.Tensor):
530
  fp_vec = fp_raw.clone().to(torch.long)
531
  else:
 
532
  fp_vec = torch.zeros((self.fp_length,), dtype=torch.long)
533
 
534
+ # ---- PSMILES tensors ----
535
+ psm_raw = sample.get("psmiles_raw", "") or ""
 
 
536
  enc = self.tokenizer(psm_raw, truncation=True, padding="max_length", max_length=self.psmiles_max_len)
537
  p_input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
538
  p_attn = torch.tensor(enc["attention_mask"], dtype=torch.bool)
539
 
540
  return {
541
+ "gine": {
542
+ "z": gine_item["z"],
543
+ "chirality": gine_item["chirality"],
544
+ "formal_charge": gine_item["formal_charge"],
545
+ "edge_index": gine_item["edge_index"],
546
+ "edge_attr": gine_item["edge_attr"],
547
+ "num_nodes": int(gine_item["z"].size(0)) if gine_item["z"].numel() > 0 else 0,
548
+ },
549
  "schnet": {"z": schnet_item["z"], "pos": schnet_item["pos"]},
550
  "fp": {"input_ids": fp_vec},
551
+ "psmiles": {"input_ids": p_input_ids, "attention_mask": p_attn},
552
  }
553
 
 
 
554
 
555
+ def multimodal_collate(batch_list: List[Dict[str, Dict[str, torch.Tensor]]]) -> Dict[str, Dict[str, torch.Tensor]]:
 
 
 
 
 
 
556
  """
557
+ Collate a list of LazyMultimodalDataset samples into a single multimodal batch.
558
+
559
+ Output keys:
560
+ - gine: {z, chirality, formal_charge, edge_index, edge_attr, batch}
561
+ - schnet: {z, pos, batch}
562
+ - fp: {input_ids, attention_mask}
563
+ - psmiles: {input_ids, attention_mask}
564
  """
565
+ # ---- GINE batching ----
566
+ all_z, all_ch, all_fc = [], [], []
567
+ all_edge_index, all_edge_attr = [], []
 
 
 
 
568
  batch_mapping = []
569
  node_offset = 0
570
+
571
  for i, item in enumerate(batch_list):
572
  g = item["gine"]
573
  z = g["z"]
574
  n = z.size(0)
575
+
576
  all_z.append(z)
577
  all_ch.append(g["chirality"])
578
  all_fc.append(g["formal_charge"])
579
  batch_mapping.append(torch.full((n,), i, dtype=torch.long))
580
+
581
  if g["edge_index"] is not None and g["edge_index"].numel() > 0:
582
  ei_offset = g["edge_index"] + node_offset
583
  all_edge_index.append(ei_offset)
584
+
585
+ # REUSED helper from GINE.py
586
  ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
587
  all_edge_attr.append(ea)
588
+
589
  node_offset += n
590
+
591
  if len(all_z) == 0:
 
592
  z_batch = torch.tensor([], dtype=torch.long)
593
  ch_batch = torch.tensor([], dtype=torch.float)
594
  fc_batch = torch.tensor([], dtype=torch.float)
595
  batch_batch = torch.tensor([], dtype=torch.long)
596
+ edge_index_batched = torch.empty((2, 0), dtype=torch.long)
597
+ edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
598
  else:
599
  z_batch = torch.cat(all_z, dim=0)
600
  ch_batch = torch.cat(all_ch, dim=0)
601
  fc_batch = torch.cat(all_fc, dim=0)
602
  batch_batch = torch.cat(batch_mapping, dim=0)
603
+
604
  if len(all_edge_index) > 0:
605
  edge_index_batched = torch.cat(all_edge_index, dim=1)
606
  edge_attr_batched = torch.cat(all_edge_attr, dim=0)
607
  else:
608
+ edge_index_batched = torch.empty((2, 0), dtype=torch.long)
609
+ edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
610
 
611
+ # ---- SchNet batching ----
612
+ all_sz, all_pos, schnet_batch = [], [], []
 
 
613
  for i, item in enumerate(batch_list):
614
  s = item["schnet"]
615
  s_z = s["z"]
 
619
  all_sz.append(s_z)
620
  all_pos.append(s_pos)
621
  schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long))
622
+
623
  if len(all_sz) == 0:
624
  s_z_batch = torch.tensor([], dtype=torch.long)
625
  s_pos_batch = torch.tensor([], dtype=torch.float)
 
629
  s_pos_batch = torch.cat(all_pos, dim=0)
630
  s_batch_batch = torch.cat(schnet_batch, dim=0)
631
 
632
+ # ---- FP batching ----
633
+ fp_ids = torch.stack(
634
+ [
635
+ item["fp"]["input_ids"] if isinstance(item["fp"]["input_ids"], torch.Tensor)
636
+ else torch.tensor(item["fp"]["input_ids"], dtype=torch.long)
637
+ for item in batch_list
638
+ ],
639
+ dim=0
640
+ )
641
  fp_attn = torch.ones_like(fp_ids, dtype=torch.bool)
642
 
643
+ # ---- PSMILES batching ----
644
  p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch_list], dim=0)
645
  p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch_list], dim=0)
646
 
647
  return {
648
+ "gine": {
649
+ "z": z_batch,
650
+ "chirality": ch_batch,
651
+ "formal_charge": fc_batch,
652
+ "edge_index": edge_index_batched,
653
+ "edge_attr": edge_attr_batched,
654
+ "batch": batch_batch,
655
+ },
656
  "schnet": {"z": s_z_batch, "pos": s_pos_batch, "batch": s_batch_batch},
657
  "fp": {"input_ids": fp_ids, "attention_mask": fp_attn},
658
+ "psmiles": {"input_ids": p_ids, "attention_mask": p_attn},
659
  }
660
 
 
 
661
 
662
+ def build_dataloaders(
663
+ sample_files: List[str],
664
+ tokenizer,
665
+ train_batch_size: int,
666
+ eval_batch_size: int,
667
+ seed: int = 42
668
+ ) -> Tuple[DataLoader, DataLoader, torch.utils.data.Subset, torch.utils.data.Subset]:
669
+ """
670
+ Create train/val subsets and corresponding DataLoaders.
671
+ """
672
+ dataset = LazyMultimodalDataset(sample_files, tokenizer, fp_length=FP_LENGTH, psmiles_max_len=PSMILES_MAX_LEN)
673
+
674
+ train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=seed)
675
+ train_subset = torch.utils.data.Subset(dataset, train_idx)
676
+ val_subset = torch.utils.data.Subset(dataset, val_idx)
677
+
678
+ train_loader = DataLoader(
679
+ train_subset,
680
+ batch_size=train_batch_size,
681
+ shuffle=True,
682
+ collate_fn=multimodal_collate,
683
+ num_workers=0,
684
+ drop_last=False,
685
+ )
686
+ val_loader = DataLoader(
687
+ val_subset,
688
+ batch_size=eval_batch_size,
689
+ shuffle=False,
690
+ collate_fn=multimodal_collate,
691
+ num_workers=0,
692
+ drop_last=False,
693
+ )
694
+ return train_loader, val_loader, train_subset, val_subset
695
+
696
+
697
+ # =============================================================================
698
+ # Multimodal contrastive model
699
+ # =============================================================================
700
+
701
  class MultimodalContrastiveModel(nn.Module):
702
+ """
703
+ Wraps unimodal encoders and computes:
704
+ - InfoNCE between masked modality embedding vs mean anchor of other modalities
705
+ - Optional reconstruction losses for masked tokens/atoms when labels are present
706
+ """
707
+
708
+ def __init__(
709
+ self,
710
+ gine_encoder: Optional[GineEncoder],
711
+ schnet_encoder: Optional[NodeSchNetWrapper],
712
+ fp_encoder: Optional[FingerprintEncoder],
713
+ psmiles_encoder: Optional[PSMILESDebertaEncoder],
714
+ emb_dim: int = 600,
715
+ ):
716
  super().__init__()
717
  self.gine = gine_encoder
718
  self.schnet = schnet_encoder
719
  self.fp = fp_encoder
720
  self.psmiles = psmiles_encoder
721
+
722
  self.proj_gine = nn.Linear(getattr(self.gine, "pool_proj").out_features if self.gine is not None else emb_dim, emb_dim) if self.gine is not None else None
723
  self.proj_schnet = nn.Linear(getattr(self.schnet, "pool_proj").out_features if self.schnet is not None else emb_dim, emb_dim) if self.schnet is not None else None
724
  self.proj_fp = nn.Linear(getattr(self.fp, "pool_proj").out_features if self.fp is not None else emb_dim, emb_dim) if self.fp is not None else None
725
  self.proj_psmiles = nn.Linear(getattr(self.psmiles, "pool_proj").out_features if self.psmiles is not None else emb_dim, emb_dim) if self.psmiles is not None else None
726
+
727
  self.temperature = TEMPERATURE
728
+ self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
729
 
730
  def encode(self, batch_mods: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
731
+ """Compute normalized projected embeddings for available modalities."""
732
  embs = {}
733
+
734
+ if "gine" in batch_mods and self.gine is not None:
735
+ g = batch_mods["gine"]
736
+ emb_g = self.gine(g["z"], g["chirality"], g["formal_charge"], g["edge_index"], g["edge_attr"], g.get("batch", None))
737
+ embs["gine"] = F.normalize(self.proj_gine(emb_g), dim=-1)
738
+
739
+ if "schnet" in batch_mods and self.schnet is not None:
740
+ s = batch_mods["schnet"]
741
+ emb_s = self.schnet(s["z"], s["pos"], s.get("batch", None))
742
+ embs["schnet"] = F.normalize(self.proj_schnet(emb_s), dim=-1)
743
+
744
+ if "fp" in batch_mods and self.fp is not None:
745
+ f = batch_mods["fp"]
746
+ emb_f = self.fp(f["input_ids"], f.get("attention_mask", None))
747
+ embs["fp"] = F.normalize(self.proj_fp(emb_f), dim=-1)
748
+
749
+ if "psmiles" in batch_mods and self.psmiles is not None:
750
+ p = batch_mods["psmiles"]
751
+ emb_p = self.psmiles(p["input_ids"], p.get("attention_mask", None))
752
+ embs["psmiles"] = F.normalize(self.proj_psmiles(emb_p), dim=-1)
753
+
754
  return embs
755
 
756
  def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
757
+ """
758
+ Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss (if any labels exist).
759
+ """
760
  device = next(self.parameters()).device
761
  embs = self.encode(batch_mods)
762
  info = {}
763
+
764
  if mask_target not in embs:
765
  return torch.tensor(0.0, device=device), {"batch_size": 0}
766
+
767
  target = embs[mask_target]
768
  other_keys = [k for k in embs.keys() if k != mask_target]
769
  if len(other_keys) == 0:
770
  return torch.tensor(0.0, device=device), {"batch_size": target.size(0)}
771
+
772
  anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
773
  logits = torch.matmul(anchor, target.T) / self.temperature
774
  B = logits.size(0)
775
  labels = torch.arange(B, device=logits.device)
776
  info_nce_loss = F.cross_entropy(logits, labels)
777
+ info["info_nce_loss"] = float(info_nce_loss.detach().cpu().item())
778
 
779
+ # Optional reconstruction terms
780
  rec_losses = []
781
  rec_details = {}
782
 
783
+ # GINE node reconstruction (atomic ids) if labels present
784
  try:
785
+ if "gine" in batch_mods and self.gine is not None:
786
+ gm = batch_mods["gine"]
787
+ labels_nodes = gm.get("labels", None)
788
  if labels_nodes is not None:
789
+ node_logits = self.gine.node_logits(gm["z"], gm["chirality"], gm["formal_charge"], gm["edge_index"], gm["edge_attr"])
790
  if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
791
  loss_gine = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
792
  rec_losses.append(loss_gine)
793
+ rec_details["gine_rec_loss"] = float(loss_gine.detach().cpu().item())
794
  except Exception as e:
795
  print("Warning: GINE reconstruction loss computation failed:", e)
796
 
797
+ # SchNet node reconstruction if labels present
798
  try:
799
+ if "schnet" in batch_mods and self.schnet is not None:
800
+ sm = batch_mods["schnet"]
801
+ labels_nodes = sm.get("labels", None)
802
  if labels_nodes is not None:
803
+ node_logits = self.schnet.node_logits(sm["z"], sm["pos"], sm.get("batch", None))
804
  if labels_nodes.dim() == 1 and node_logits.size(0) == labels_nodes.size(0):
805
  loss_schnet = self.ce_loss(node_logits, labels_nodes.to(node_logits.device))
806
  rec_losses.append(loss_schnet)
807
+ rec_details["schnet_rec_loss"] = float(loss_schnet.detach().cpu().item())
808
  except Exception as e:
809
  print("Warning: SchNet reconstruction loss computation failed:", e)
810
 
811
+ # FP token reconstruction if labels present
812
  try:
813
+ if "fp" in batch_mods and self.fp is not None:
814
+ fm = batch_mods["fp"]
815
+ labels_fp = fm.get("labels", None)
816
  if labels_fp is not None:
817
+ token_logits = self.fp.token_logits(fm["input_ids"], fm.get("attention_mask", None))
818
  Bf, Lf, V = token_logits.shape
819
  logits2 = token_logits.view(-1, V)
820
  labels2 = labels_fp.view(-1).to(logits2.device)
821
  loss_fp = self.ce_loss(logits2, labels2)
822
  rec_losses.append(loss_fp)
823
+ rec_details["fp_rec_loss"] = float(loss_fp.detach().cpu().item())
824
  except Exception as e:
825
  print("Warning: FP reconstruction loss computation failed:", e)
826
 
827
+ # PSMILES MLM loss if labels present
828
  try:
829
+ if "psmiles" in batch_mods and self.psmiles is not None:
830
+ pm = batch_mods["psmiles"]
831
+ labels_ps = pm.get("labels", None)
832
+ if labels_ps is not None:
833
+ loss_ps = self.psmiles.token_logits(pm["input_ids"], pm.get("attention_mask", None), labels=labels_ps)
834
  if isinstance(loss_ps, torch.Tensor):
835
  rec_losses.append(loss_ps)
836
+ rec_details["psmiles_mlm_loss"] = float(loss_ps.detach().cpu().item())
837
  except Exception as e:
838
  print("Warning: PSMILES MLM loss computation failed:", e)
839
 
840
  if len(rec_losses) > 0:
841
  rec_loss_total = sum(rec_losses) / len(rec_losses)
842
+ info["reconstruction_loss"] = float(rec_loss_total.detach().cpu().item())
843
  total_loss = info_nce_loss + REC_LOSS_WEIGHT * rec_loss_total
844
+ info["total_loss"] = float(total_loss.detach().cpu().item())
845
  info.update(rec_details)
846
  else:
847
  total_loss = info_nce_loss
848
+ info["reconstruction_loss"] = 0.0
849
+ info["total_loss"] = float(total_loss.detach().cpu().item())
850
 
851
  return total_loss, info
852
 
 
 
 
 
 
 
 
 
 
 
853
 
854
+ # =============================================================================
855
+ # Masking utilities
856
+ # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
858
+ def mask_batch_for_modality(batch: dict, modality: str, tokenizer, p_mask: float = P_MASK) -> dict:
859
+ """
860
+ Apply BERT-style masking to the selected modality and attach `labels`.
861
+ Other modalities are passed through unchanged (but cloned where mutated).
862
+ """
863
+ b = {}
 
 
 
 
 
 
 
864
 
865
+ # ---------------- GINE ----------------
866
+ if "gine" in batch:
867
+ z = batch["gine"]["z"].clone()
868
+ chir = batch["gine"]["chirality"].clone()
869
+ fc = batch["gine"]["formal_charge"].clone()
870
+ edge_index = batch["gine"]["edge_index"]
871
+ edge_attr = batch["gine"]["edge_attr"]
872
+ batch_map = batch["gine"].get("batch", None)
873
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  n_nodes = z.size(0)
875
  dev = z.device
876
  is_selected = torch.rand(n_nodes, device=dev) < p_mask
877
  if is_selected.numel() > 0 and is_selected.all():
878
  is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
879
+
880
  labels_z = torch.full_like(z, fill_value=-100)
881
  if is_selected.any():
882
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
883
  if sel_idx.dim() == 0:
884
  sel_idx = sel_idx.unsqueeze(0)
885
+
886
  labels_z[is_selected] = z[is_selected]
887
+ rand_atomic = torch.randint(1, MAX_ATOMIC_Z + 1, (sel_idx.size(0),), dtype=torch.long, device=dev)
888
+
889
  probs = torch.rand(sel_idx.size(0), device=dev)
890
  mask_choice = probs < 0.8
891
  rand_choice = (probs >= 0.8) & (probs < 0.9)
892
+
893
  if mask_choice.any():
894
  z[sel_idx[mask_choice]] = MASK_ATOM_ID
895
  if rand_choice.any():
896
  z[sel_idx[rand_choice]] = rand_atomic[rand_choice]
 
897
 
898
+ b["gine"] = {
899
+ "z": z,
900
+ "chirality": chir,
901
+ "formal_charge": fc,
902
+ "edge_index": edge_index,
903
+ "edge_attr": edge_attr,
904
+ "batch": batch_map,
905
+ "labels": labels_z,
906
+ }
907
+
908
+ # ---------------- SchNet ----------------
909
+ if "schnet" in batch:
910
+ z = batch["schnet"]["z"].clone()
911
+ pos = batch["schnet"]["pos"].clone()
912
+ batch_map = batch["schnet"].get("batch", None)
913
+
914
  n_nodes = z.size(0)
915
  dev = z.device
916
  is_selected = torch.rand(n_nodes, device=dev) < p_mask
917
  if is_selected.numel() > 0 and is_selected.all():
918
  is_selected[torch.randint(0, n_nodes, (1,), device=dev)] = False
919
+
920
  labels_z = torch.full((n_nodes,), -100, dtype=torch.long, device=dev)
921
  if is_selected.any():
922
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
923
  if sel_idx.dim() == 0:
924
  sel_idx = sel_idx.unsqueeze(0)
925
+
926
  labels_z[is_selected] = z[is_selected]
927
  probs_c = torch.rand(sel_idx.size(0), device=dev)
928
  noisy_choice = probs_c < 0.8
929
  randpos_choice = (probs_c >= 0.8) & (probs_c < 0.9)
930
+
931
  if noisy_choice.any():
932
  idx = sel_idx[noisy_choice]
933
  noise = torch.randn((idx.size(0), 3), device=pos.device) * 0.5
934
  pos[idx] = pos[idx] + noise
935
+
936
  if randpos_choice.any():
937
  idx = sel_idx[randpos_choice]
938
  mins = pos.min(dim=0).values
939
  maxs = pos.max(dim=0).values
940
  randpos = (torch.rand((idx.size(0), 3), device=pos.device) * (maxs - mins)) + mins
941
  pos[idx] = randpos
 
942
 
943
+ b["schnet"] = {"z": z, "pos": pos, "batch": batch_map, "labels": labels_z}
944
+
945
+ # ---------------- FP ----------------
946
+ if "fp" in batch:
947
+ input_ids = batch["fp"]["input_ids"].clone()
948
+ attn = batch["fp"].get("attention_mask", torch.ones_like(input_ids, dtype=torch.bool))
949
+
950
  B, L = input_ids.shape
951
  dev = input_ids.device
952
  labels_z = torch.full_like(input_ids, -100)
953
+
954
  for i in range(B):
955
  sel = torch.rand(L, device=dev) < p_mask
956
  if sel.numel() > 0 and sel.all():
957
  sel[torch.randint(0, L, (1,), device=dev)] = False
958
+
959
  sel_idx = torch.nonzero(sel).squeeze(-1)
960
  if sel_idx.numel() > 0:
961
  if sel_idx.dim() == 0:
962
  sel_idx = sel_idx.unsqueeze(0)
963
+
964
  labels_z[i, sel_idx] = input_ids[i, sel_idx]
965
+
966
  probs = torch.rand(sel_idx.size(0), device=dev)
967
  mask_choice = probs < 0.8
968
  rand_choice = (probs >= 0.8) & (probs < 0.9)
969
+
970
  if mask_choice.any():
971
  input_ids[i, sel_idx[mask_choice]] = MASK_TOKEN_ID_FP
972
  if rand_choice.any():
973
  rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long, device=dev)
974
  input_ids[i, sel_idx[rand_choice]] = rand_bits
 
975
 
976
+ b["fp"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
977
+
978
+ # ---------------- PSMILES ----------------
979
+ if "psmiles" in batch:
980
+ input_ids = batch["psmiles"]["input_ids"].clone()
981
+ attn = batch["psmiles"]["attention_mask"].clone()
982
+
983
  B, L = input_ids.shape
984
  dev = input_ids.device
985
  labels_z = torch.full_like(input_ids, -100)
986
+
987
+ # If tokenizer is unavailable, keep labels=-100 (no MLM loss)
988
  if tokenizer is None:
989
+ b["psmiles"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
990
  else:
991
  mask_token_id = tokenizer.mask_token_id if getattr(tokenizer, "mask_token_id", None) is not None else getattr(tokenizer, "vocab", {}).get("<mask>", 1)
992
+
993
  for i in range(B):
994
  sel = torch.rand(L, device=dev) < p_mask
995
  if sel.numel() > 0 and sel.all():
996
  sel[torch.randint(0, L, (1,), device=dev)] = False
997
+
998
  sel_idx = torch.nonzero(sel).squeeze(-1)
999
  if sel_idx.numel() > 0:
1000
  if sel_idx.dim() == 0:
1001
  sel_idx = sel_idx.unsqueeze(0)
1002
+
1003
  labels_z[i, sel_idx] = input_ids[i, sel_idx]
1004
+
1005
  probs = torch.rand(sel_idx.size(0), device=dev)
1006
  mask_choice = probs < 0.8
1007
  rand_choice = (probs >= 0.8) & (probs < 0.9)
1008
+
1009
  if mask_choice.any():
1010
  input_ids[i, sel_idx[mask_choice]] = mask_token_id
1011
  if rand_choice.any():
1012
  rand_ids = torch.randint(0, getattr(tokenizer, "vocab_size", 300), (rand_choice.sum().item(),), dtype=torch.long, device=dev)
1013
  input_ids[i, sel_idx[rand_choice]] = rand_ids
1014
+
1015
+ b["psmiles"] = {"input_ids": input_ids, "attention_mask": attn, "labels": labels_z}
1016
 
1017
  return b
1018
 
1019
+
1020
+ def mm_batch_to_model_input(masked_batch: dict) -> dict:
1021
+ """
1022
+ Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel.
1023
+ (Kept identical semantics.)
1024
+ """
1025
  mm = {}
1026
+ if "gine" in masked_batch:
1027
+ gm = masked_batch["gine"]
1028
+ mm["gine"] = {
1029
+ "z": gm["z"],
1030
+ "chirality": gm["chirality"],
1031
+ "formal_charge": gm["formal_charge"],
1032
+ "edge_index": gm["edge_index"],
1033
+ "edge_attr": gm["edge_attr"],
1034
+ "batch": gm.get("batch", None),
1035
+ "labels": gm.get("labels", None),
1036
+ }
1037
+ if "schnet" in masked_batch:
1038
+ sm = masked_batch["schnet"]
1039
+ mm["schnet"] = {"z": sm["z"], "pos": sm["pos"], "batch": sm.get("batch", None), "labels": sm.get("labels", None)}
1040
+ if "fp" in masked_batch:
1041
+ fm = masked_batch["fp"]
1042
+ mm["fp"] = {"input_ids": fm["input_ids"], "attention_mask": fm.get("attention_mask", None), "labels": fm.get("labels", None)}
1043
+ if "psmiles" in masked_batch:
1044
+ pm = masked_batch["psmiles"]
1045
+ mm["psmiles"] = {"input_ids": pm["input_ids"], "attention_mask": pm.get("attention_mask", None), "labels": pm.get("labels", None)}
1046
  return mm
1047
 
1048
+
1049
+ # =============================================================================
1050
+ # Evaluation
1051
+ # =============================================================================
1052
+
1053
+ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader: DataLoader, device: torch.device, tokenizer, mask_target: str = "fp") -> Dict[str, float]:
1054
+ """
1055
+ Contrastive-only evaluation:
1056
+ - masks one modality
1057
+ - computes InfoNCE logits = anchor·target / T
1058
+ - reports eval_loss, top1 acc, weighted F1
1059
+ """
1060
  model.eval()
1061
  total_loss = 0.0
1062
  total_examples = 0
1063
  acc_sum = 0.0
 
 
 
 
1064
  f1_sum = 0.0
1065
 
1066
  with torch.no_grad():
1067
  for batch in val_loader:
1068
+ masked_batch = mask_batch_for_modality(batch, mask_target, tokenizer=tokenizer, p_mask=P_MASK)
1069
+
1070
+ # Move tensors to device
1071
  for k in masked_batch:
1072
  for subk in masked_batch[k]:
1073
  if isinstance(masked_batch[k][subk], torch.Tensor):
1074
  masked_batch[k][subk] = masked_batch[k][subk].to(device)
1075
+
1076
  mm_in = mm_batch_to_model_input(masked_batch)
1077
  embs = model.encode(mm_in)
1078
+
1079
  if mask_target not in embs:
1080
  continue
1081
+
1082
  target = embs[mask_target]
1083
  other_keys = [k for k in embs.keys() if k != mask_target]
1084
  if len(other_keys) == 0:
1085
  continue
1086
+
1087
  anchor = torch.stack([embs[k] for k in other_keys], dim=0).mean(dim=0)
1088
  logits = torch.matmul(anchor, target.T) / model.temperature
1089
+
1090
  B = logits.size(0)
1091
  labels = torch.arange(B, device=logits.device)
1092
+
1093
  loss = F.cross_entropy(logits, labels)
1094
  total_loss += loss.item() * B
1095
  total_examples += B
 
1098
  acc = (preds == labels).float().mean().item()
1099
  acc_sum += acc * B
1100
 
1101
+ # Weighted F1 over instance IDs (kept as in your prior logic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
  try:
1103
  labels_np = labels.cpu().numpy()
1104
  preds_np = preds.cpu().numpy()
1105
  if len(np.unique(labels_np)) < 2:
1106
  batch_f1 = float(acc)
1107
  else:
1108
+ batch_f1 = f1_score(labels_np, preds_np, average="weighted")
1109
  except Exception:
1110
  batch_f1 = float(acc)
1111
  f1_sum += batch_f1 * B
 
1113
  if total_examples == 0:
1114
  return {"eval_loss": float("nan"), "eval_accuracy": 0.0, "eval_f1_weighted": 0.0}
1115
 
1116
+ return {
1117
+ "eval_loss": total_loss / total_examples,
1118
+ "eval_accuracy": acc_sum / total_examples,
1119
+ "eval_f1_weighted": f1_sum / total_examples,
1120
+ }
1121
+
1122
 
1123
+ # =============================================================================
1124
+ # HF wrapper + collator + trainer
1125
+ # =============================================================================
1126
 
 
 
1127
  class HFMultimodalModule(nn.Module):
1128
+ """
1129
+ HuggingFace Trainer-facing wrapper:
1130
+ - Receives a full multimodal batch
1131
+ - Randomly masks one modality (provided by collator) inside forward
1132
+ - Returns a dict compatible with Trainer (loss, logits, labels)
1133
+ """
1134
+
1135
+ def __init__(self, mm_model: MultimodalContrastiveModel, tokenizer):
1136
  super().__init__()
1137
  self.mm = mm_model
1138
+ self._tokenizer = tokenizer
1139
 
1140
  def forward(self, **kwargs):
1141
  if "batch" in kwargs:
 
1148
  batch = {k: found[k] for k in found}
1149
  mask_target = kwargs.get("mask_target", "fp")
1150
  else:
1151
+ raise ValueError(
1152
+ "HFMultimodalModule.forward could not find 'batch' nor modality keys in inputs. "
1153
+ f"Inputs keys: {list(kwargs.keys())}"
1154
+ )
1155
+
1156
+ masked_batch = mask_batch_for_modality(batch, mask_target, tokenizer=self._tokenizer, p_mask=P_MASK)
1157
+
1158
  device = next(self.parameters()).device
1159
  for k in masked_batch:
1160
  for subk in list(masked_batch[k].keys()):
1161
  val = masked_batch[k][subk]
1162
  if isinstance(val, torch.Tensor):
1163
  masked_batch[k][subk] = val.to(device)
1164
+
1165
  mm_in = mm_batch_to_model_input(masked_batch)
1166
  loss, info = self.mm(mm_in, mask_target)
1167
+
1168
  logits = None
1169
  labels = None
1170
  try:
 
1182
  print("Warning: failed to compute logits/labels inside HFMultimodalModule.forward:", e)
1183
  logits = None
1184
  labels = None
1185
+
1186
  eval_loss = loss.detach() if isinstance(loss, torch.Tensor) else torch.tensor(float(loss), device=device)
1187
  out = {"loss": loss, "eval_loss": eval_loss}
1188
  if logits is not None:
 
1192
  out["mm_info"] = info
1193
  return out
1194
 
 
 
1195
 
1196
  class ContrastiveDataCollator:
1197
+ """
1198
+ Collator used by Trainer:
1199
+ - If given raw samples (list of dicts), it calls multimodal_collate
1200
+ - Then selects a random modality to mask (mask_target)
1201
+ """
1202
+
1203
+ def __init__(self, mask_prob: float = P_MASK, modalities: Optional[List[str]] = None):
1204
  self.mask_prob = mask_prob
1205
  self.modalities = modalities if modalities is not None else ["gine", "schnet", "fp", "psmiles"]
1206
 
 
1209
  collated = features
1210
  mask_target = random.choice([m for m in self.modalities if m in collated])
1211
  return {"batch": collated, "mask_target": mask_target}
1212
+
1213
  if isinstance(features, (list, tuple)) and len(features) > 0:
1214
  first = features[0]
1215
+ if isinstance(first, dict) and "gine" in first:
1216
  collated = multimodal_collate(list(features))
1217
  mask_target = random.choice([m for m in self.modalities if m in collated])
1218
  return {"batch": collated, "mask_target": mask_target}
1219
+
1220
+ if isinstance(first, dict) and "batch" in first:
1221
+ collated = first["batch"]
1222
  mask_target = first.get("mask_target", random.choice([m for m in self.modalities if m in collated]))
1223
  return {"batch": collated, "mask_target": mask_target}
1224
+
1225
  print("ContrastiveDataCollator received unexpected 'features' shape/type.")
1226
  raise ValueError("ContrastiveDataCollator could not collate input. Expected list[dict] with 'gine' key or already-collated dict.")
1227
 
 
1228
 
1229
  class VerboseTrainingCallback(TrainerCallback):
1230
+ """
1231
+ Console-first training callback with early stopping on eval_loss.
1232
+
1233
+ Behavior is kept consistent with your original callback; changes are comment/structure only.
1234
+ """
1235
+
1236
  def __init__(self, patience: int = 10):
1237
  self.start_time = time.time()
1238
  self.epoch_start_time = time.time()
 
1260
 
1261
  def on_train_begin(self, args, state, control, **kwargs):
1262
  self.start_time = time.time()
1263
+ print("=" * 80)
1264
+ print(" STARTING MULTIMODAL CONTRASTIVE LEARNING TRAINING")
1265
+ print("=" * 80)
1266
+
1267
+ model = kwargs.get("model")
1268
  if model is not None:
1269
  total_params = sum(p.numel() for p in model.parameters())
1270
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1271
  non_trainable_params = total_params - trainable_params
1272
+ print(" MODEL PARAMETERS:")
1273
  print(f" Total Parameters: {total_params:,}")
1274
  print(f" Trainable Parameters: {trainable_params:,}")
1275
  print(f" Non-trainable Parameters: {non_trainable_params:,}")
1276
  print(f" Training Progress: 0/{args.num_train_epochs} epochs")
1277
+
1278
+ print("=" * 80)
1279
 
1280
  def on_epoch_begin(self, args, state, control, **kwargs):
1281
  self.epoch_start_time = time.time()
1282
  current_epoch = state.epoch if state is not None else 0.0
1283
+ print(f" Epoch {current_epoch + 1:.1f}/{args.num_train_epochs} Starting...")
1284
 
1285
  def on_epoch_end(self, args, state, control, **kwargs):
1286
  train_loss = None
1287
  for log in reversed(state.log_history):
1288
+ if isinstance(log, dict) and "loss" in log and float(log.get("loss", 0)) != 0.0:
1289
+ train_loss = log["loss"]
1290
  break
1291
  self._last_train_loss = train_loss
1292
 
1293
  def on_log(self, args, state, control, logs=None, **kwargs):
1294
+ if logs is not None and "loss" in logs:
1295
  current_step = state.global_step
1296
  current_epoch = state.epoch
1297
  try:
1298
  steps_per_epoch = max(1, len(train_loader) // args.gradient_accumulation_steps)
1299
  except Exception:
1300
  steps_per_epoch = 1
1301
+
1302
  if current_step % max(1, steps_per_epoch // 10) == 0:
1303
  progress = current_epoch + (current_step % steps_per_epoch) / steps_per_epoch
1304
  print(f" Step {current_step:4d} | Epoch {progress:.1f} | Train Loss: {logs['loss']:.6f}")
 
1306
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
1307
  current_epoch = state.epoch if state is not None else 0.0
1308
  epoch_time = time.time() - self.epoch_start_time
1309
+
1310
+ hf_metrics = metrics if metrics is not None else kwargs.get("metrics", None)
1311
  hf_eval_loss = None
1312
  hf_train_loss = self._last_train_loss
1313
+
1314
  if hf_metrics is not None:
1315
+ hf_eval_loss = hf_metrics.get("eval_loss", hf_metrics.get("loss", None))
1316
  if hf_train_loss is None:
1317
+ hf_train_loss = hf_metrics.get("train_loss", hf_train_loss)
1318
+
1319
  cl_metrics = {}
1320
  try:
1321
+ model = kwargs.get("model", None)
1322
  if model is not None:
1323
  cl_model = model.mm if hasattr(model, "mm") else model
1324
+ cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp")
1325
  else:
1326
+ cl_metrics = evaluate_multimodal(multimodal_model, val_loader, device, tokenizer, mask_target="fp")
1327
  except Exception as e:
1328
  print("Warning: evaluate_multimodal inside callback failed:", e)
1329
+
1330
  if hf_eval_loss is None:
1331
+ hf_eval_loss = cl_metrics.get("eval_loss", None)
1332
+
1333
+ val_acc = cl_metrics.get("eval_accuracy", "N/A")
1334
+ val_f1 = cl_metrics.get("eval_f1_weighted", "N/A")
1335
+
1336
+ print(f" EPOCH {current_epoch + 1:.1f} RESULTS:")
1337
  if hf_train_loss is not None:
1338
  try:
1339
  print(f" Train Loss (HF reported): {hf_train_loss:.6f}")
1340
  except Exception:
1341
  print(f" Train Loss (HF reported): {hf_train_loss}")
1342
  else:
1343
+ print(" Train Loss (HF reported): N/A")
1344
+
1345
  if hf_eval_loss is not None:
1346
  try:
1347
  print(f" Eval Loss (HF reported): {hf_eval_loss:.6f}")
1348
  except Exception:
1349
  print(f" Eval Loss (HF reported): {hf_eval_loss}")
1350
  else:
1351
+ print(" Eval Loss (HF reported): N/A")
1352
+
1353
  if isinstance(val_acc, float):
1354
  print(f" Eval Acc (CL evaluator): {val_acc:.6f}")
1355
  else:
1356
  print(f" Eval Acc (CL evaluator): {val_acc}")
1357
+
1358
  if isinstance(val_f1, float):
1359
  print(f" Eval F1 Weighted (CL evaluator): {val_f1:.6f}")
1360
  else:
1361
  print(f" Eval F1 Weighted (CL evaluator): {val_f1}")
1362
+
1363
+ current_val = hf_eval_loss if hf_eval_loss is not None else float("inf")
1364
+
1365
  if current_val < self.best_val_loss - 1e-6:
 
1366
  self.best_val_loss = current_val
1367
  self.best_epoch = current_epoch
1368
  self.epochs_no_improve = 0
 
1372
  print("Warning: saving best model failed:", e)
1373
  else:
1374
  self.epochs_no_improve += 1
1375
+
1376
  if self.epochs_no_improve >= self.patience:
1377
  print(f"Early stopping: no improvement in val_loss for {self.patience} epochs.")
1378
  control.should_training_stop = True
1379
+
1380
  print(f" Epoch Training Time: {epoch_time:.2f}s")
1381
  print(f" Best Val Loss so far: {self.best_val_loss}")
1382
  print(f" Epochs since improvement: {self.epochs_no_improve}/{self.patience}")
 
1384
 
1385
  def on_train_end(self, args, state, control, **kwargs):
1386
  total_time = time.time() - self.start_time
1387
+ print("=" * 80)
1388
+ print(" TRAINING COMPLETED")
1389
+ print("=" * 80)
1390
  print(f" Total Training Time: {total_time:.2f}s")
1391
  if state is not None:
1392
  try:
1393
  print(f" Total Epochs Completed: {state.epoch + 1:.1f}")
1394
  except Exception:
1395
  pass
1396
+ print("=" * 80)
1397
+
1398
 
1399
+ class CLTrainer(Trainer):
1400
+ """
1401
+ Custom Trainer:
1402
+ - evaluate(): merges HF eval with contrastive evaluator (same behavior)
1403
+ - _save(): saves a state_dict under pytorch_model.bin
1404
+ - _load_best_model(): loads best pytorch_model.bin
1405
+ """
1406
 
 
1407
  def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
1408
  try:
1409
  metrics = super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) or {}
 
1413
  traceback.print_exc()
1414
  try:
1415
  cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
1416
+ cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp")
1417
  metrics = {k: float(v) if isinstance(v, (float, int, np.floating, np.integer)) else v for k, v in cl_metrics.items()}
1418
  metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
1419
  except Exception as e2:
 
1421
  traceback.print_exc()
1422
  metrics = {"eval_loss": float("nan"), "epoch": float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else 0.0}
1423
  return metrics
1424
+
1425
  try:
1426
  cl_model = self.model.mm if hasattr(self.model, "mm") else self.model
1427
+ cl_metrics = evaluate_multimodal(cl_model, val_loader, device, tokenizer, mask_target="fp")
1428
  except Exception as e:
1429
  print("Warning: evaluate_multimodal failed inside CLTrainer.evaluate():", e)
1430
  cl_metrics = {}
1431
+
1432
  for k, v in cl_metrics.items():
1433
  try:
1434
  metrics[k] = float(v)
1435
  except Exception:
1436
  metrics[k] = v
1437
+
1438
+ if "eval_loss" not in metrics and "eval_loss" in cl_metrics:
1439
  try:
1440
+ metrics["eval_loss"] = float(cl_metrics["eval_loss"])
1441
  except Exception:
1442
+ metrics["eval_loss"] = cl_metrics["eval_loss"]
1443
+
1444
  if "epoch" not in metrics:
1445
  metrics["epoch"] = float(self.state.epoch) if getattr(self.state, "epoch", None) is not None else metrics.get("epoch", 0.0)
1446
+
1447
  return metrics
1448
 
1449
  def _save(self, output_dir: str):
1450
  os.makedirs(output_dir, exist_ok=True)
1451
+
1452
  try:
1453
  self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
1454
  except Exception:
1455
  pass
1456
+
1457
  try:
1458
  model_to_save = self.model.mm if hasattr(self.model, "mm") else self.model
1459
  torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
 
1465
  raise e
1466
  except Exception as e2:
1467
  print("Warning: failed to save model state_dict:", e2)
1468
+
1469
  try:
1470
  torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
1471
  except Exception:
1472
  pass
1473
+
1474
  try:
1475
  torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
1476
  except Exception:
 
1480
  best_ckpt = self.state.best_model_checkpoint
1481
  if not best_ckpt:
1482
  return
1483
+
1484
  candidate = os.path.join(best_ckpt, "pytorch_model.bin")
1485
  if not os.path.exists(candidate):
1486
  candidate = os.path.join(best_ckpt, "model.bin")
1487
  if not os.path.exists(candidate):
1488
  candidate = None
1489
+
1490
  if candidate is None:
1491
  print(f"CLTrainer._load_best_model(): no compatible pytorch_model.bin found in {best_ckpt}; skipping load.")
1492
  return
1493
+
1494
  try:
1495
  state_dict = torch.load(candidate, map_location=self.args.device)
1496
  model_to_load = self.model.mm if hasattr(self.model, "mm") else self.model
 
1500
  print("CLTrainer._load_best_model: failed to load state_dict using torch.load:", e)
1501
  return
1502
 
 
1503
 
1504
+ # =============================================================================
1505
+ # Model construction + weight loading
1506
+ # =============================================================================
 
 
 
 
 
1507
 
1508
+ def load_state_dict_if_present(model: nn.Module, ckpt_dir: str, filename: str = "pytorch_model.bin") -> None:
1509
+ """Load model weights if the checkpoint file exists."""
1510
+ path = os.path.join(ckpt_dir, filename)
1511
+ if os.path.exists(path):
1512
+ try:
1513
+ model.load_state_dict(torch.load(path, map_location="cpu"), strict=False)
1514
+ print(f"Loaded weights from {path}")
1515
+ except Exception as e:
1516
+ print(f"Could not load weights from {path}: {e}")
1517
+
1518
+
1519
+ def build_models(device: torch.device) -> Tuple[MultimodalContrastiveModel, PSMILESDebertaEncoder]:
1520
+ """Instantiate unimodal encoders, optionally load best checkpoints, and assemble the multimodal model."""
1521
+ # GINE
1522
+ gine_encoder = GineEncoder(node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z)
1523
+ load_state_dict_if_present(gine_encoder, BEST_GINE_DIR)
1524
+ gine_encoder.to(device)
1525
+
1526
+ # SchNet
1527
+ schnet_encoder = NodeSchNetWrapper(
1528
+ hidden_channels=SCHNET_HIDDEN,
1529
+ num_interactions=SCHNET_NUM_INTERACTIONS,
1530
+ num_gaussians=SCHNET_NUM_GAUSSIANS,
1531
+ cutoff=SCHNET_CUTOFF,
1532
+ max_num_neighbors=SCHNET_MAX_NEIGHBORS,
1533
+ )
1534
+ load_state_dict_if_present(schnet_encoder, BEST_SCHNET_DIR)
1535
+ schnet_encoder.to(device)
1536
+
1537
+ # Fingerprint encoder
1538
+ fp_encoder = FingerprintEncoder(
1539
+ vocab_size=VOCAB_SIZE_FP,
1540
+ hidden_dim=256,
1541
+ seq_len=FP_LENGTH,
1542
+ num_layers=4,
1543
+ nhead=8,
1544
+ dim_feedforward=1024,
1545
+ dropout=0.1,
1546
+ )
1547
+ load_state_dict_if_present(fp_encoder, BEST_FP_DIR)
1548
+ fp_encoder.to(device)
1549
+
1550
+ # PSMILES / DeBERTa
1551
+ psmiles_encoder = None
1552
+ if os.path.isdir(BEST_PSMILES_DIR):
1553
+ try:
1554
+ psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
1555
+ print("Loaded Deberta (PSMILES) from", BEST_PSMILES_DIR)
1556
+ except Exception as e:
1557
+ print("Failed to load Deberta from saved directory:", e)
1558
 
1559
+ if psmiles_encoder is None:
1560
+ psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=None)
1561
 
1562
+ psmiles_encoder.to(device)
1563
+
1564
+ multimodal_model = MultimodalContrastiveModel(gine_encoder, schnet_encoder, fp_encoder, psmiles_encoder, emb_dim=600)
1565
+ multimodal_model.to(device)
1566
+
1567
+ return multimodal_model, psmiles_encoder
1568
+
1569
+
1570
+ # =============================================================================
1571
+ # Main execution
1572
+ # =============================================================================
1573
+
1574
+ def main():
1575
+ # ---- setup ----
1576
+ ensure_dir(OUTPUT_DIR)
1577
+ ensure_dir(PREPROC_DIR)
1578
+
1579
+ device_local = get_device()
1580
+ print("Device:", device_local)
1581
+
1582
+ set_seed(42)
1583
+
1584
+ training_args = TrainingArguments(
1585
+ output_dir=OUTPUT_DIR,
1586
+ overwrite_output_dir=True,
1587
+ num_train_epochs=25,
1588
+ per_device_train_batch_size=16,
1589
+ per_device_eval_batch_size=8,
1590
+ gradient_accumulation_steps=4,
1591
+ eval_strategy="epoch",
1592
+ logging_steps=100,
1593
+ learning_rate=1e-4,
1594
+ weight_decay=0.01,
1595
+ eval_accumulation_steps=1000,
1596
+ fp16=torch.cuda.is_available(),
1597
+ save_strategy="epoch",
1598
+ save_steps=500,
1599
+ disable_tqdm=False,
1600
+ logging_first_step=True,
1601
+ report_to=[],
1602
+ dataloader_num_workers=0,
1603
+ load_best_model_at_end=True,
1604
+ metric_for_best_model="eval_loss",
1605
+ greater_is_better=False,
1606
+ )
1607
+
1608
+ # ---- data ----
1609
+ sample_files = prepare_or_load_data_streaming(
1610
+ csv_path=CSV_PATH,
1611
+ preproc_dir=PREPROC_DIR,
1612
+ target_rows=TARGET_ROWS,
1613
+ chunksize=CHUNKSIZE,
1614
+ )
1615
+
1616
+ tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
1617
+
1618
+ global train_loader, val_loader, multimodal_model, device, tokenizer # kept for callback references (same behavior)
1619
+ tokenizer = tokenizer_local
1620
+ device = device_local
1621
+
1622
+ train_loader, val_loader, train_subset, val_subset = build_dataloaders(
1623
+ sample_files=sample_files,
1624
+ tokenizer=tokenizer_local,
1625
+ train_batch_size=training_args.per_device_train_batch_size,
1626
+ eval_batch_size=training_args.per_device_eval_batch_size,
1627
+ seed=42,
1628
+ )
1629
+
1630
+ # ---- models ----
1631
+ multimodal_model, _psmiles_encoder = build_models(device_local)
1632
+
1633
+ hf_model = HFMultimodalModule(multimodal_model, tokenizer_local).to(device_local)
1634
+ data_collator = ContrastiveDataCollator(mask_prob=P_MASK)
1635
+
1636
+ callback = VerboseTrainingCallback(patience=10)
1637
+
1638
+ trainer = CLTrainer(
1639
+ model=hf_model,
1640
+ args=training_args,
1641
+ train_dataset=train_subset,
1642
+ eval_dataset=val_subset,
1643
+ data_collator=data_collator,
1644
+ callbacks=[callback],
1645
+ )
1646
+ callback.trainer_ref = trainer
1647
+
1648
+ # Force HF Trainer to use our prebuilt PyTorch DataLoaders
1649
+ trainer.get_train_dataloader = lambda dataset=None: train_loader
1650
+ trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
1651
+
1652
+ # Optimizer (kept as in original script)
1653
+ _optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
1654
+
1655
+ total_params = sum(p.numel() for p in multimodal_model.parameters())
1656
+ trainable_params = sum(p.numel() for p in multimodal_model.parameters() if p.requires_grad)
1657
+ non_trainable_params = total_params - trainable_params
1658
+
1659
+ print("\n MODEL PARAMETERS:")
1660
+ print(f" Total Parameters: {total_params:,}")
1661
+ print(f" Trainable Parameters: {trainable_params:,}")
1662
+ print(f" Non-trainable Parameters: {non_trainable_params:,}")
1663
+
1664
+ # Clear GPU cache
1665
+ if torch.cuda.is_available():
1666
+ try:
1667
+ torch.cuda.empty_cache()
1668
+ except Exception:
1669
+ pass
1670
 
1671
+ # ---- train ----
1672
+ training_start_time = time.time()
1673
+ trainer.train()
1674
+ training_end_time = time.time()
1675
 
1676
+ # ---- save best ----
1677
+ best_dir = os.path.join(OUTPUT_DIR, "best")
1678
+ os.makedirs(best_dir, exist_ok=True)
 
1679
 
1680
+ try:
1681
+ best_ckpt = trainer.state.best_model_checkpoint
1682
+ if best_ckpt:
1683
+ multimodal_model.load_state_dict(torch.load(os.path.join(best_ckpt, "pytorch_model.bin"), map_location=device_local), strict=False)
1684
+ print(f"Loaded best checkpoint from {best_ckpt} into multimodal_model for final evaluation.")
1685
+ torch.save(multimodal_model.state_dict(), os.path.join(best_dir, "pytorch_model.bin"))
1686
+ print(f" Saved best multimodal model to {os.path.join(best_dir, 'pytorch_model.bin')}")
1687
+ except Exception as e:
1688
+ print("Warning: failed to load/save best model from Trainer:", e)
1689
 
1690
+ # ---- final evaluation ----
1691
+ final_metrics = {}
 
1692
  try:
1693
+ if trainer.state.best_model_checkpoint:
1694
+ trainer._load_best_model()
1695
+ final_metrics = trainer.evaluate(eval_dataset=val_subset)
1696
+ else:
1697
+ final_metrics = evaluate_multimodal(multimodal_model, val_loader, device_local, tokenizer_local, mask_target="fp")
1698
+ except Exception as e:
1699
+ print("Warning: final evaluation via trainer.evaluate failed, falling back to direct evaluate_multimodal:", e)
1700
+ final_metrics = evaluate_multimodal(multimodal_model, val_loader, device_local, tokenizer_local, mask_target="fp")
1701
 
1702
+ print("\n" + "=" * 80)
1703
+ print(" FINAL TRAINING RESULTS")
1704
+ print("=" * 80)
1705
+ print(f"Total Training Time: {training_end_time - training_start_time:.2f}s")
 
1706
 
1707
+ best_ckpt = trainer.state.best_model_checkpoint if hasattr(trainer.state, "best_model_checkpoint") else None
1708
+ print(f"Best Checkpoint: {best_ckpt if best_ckpt else '(none saved)'}")
 
1709
 
1710
+ hf_eval_loss = final_metrics.get("eval_loss", float("nan"))
1711
+ hf_eval_acc = final_metrics.get("eval_accuracy", 0.0)
1712
+ hf_eval_f1 = final_metrics.get("eval_f1_weighted", 0.0)
1713
+
1714
+ print(f"Val Loss (HF reported / trainer.evaluate): {hf_eval_loss:.4f}")
1715
+ print(f"Val Acc (CL evaluator): {hf_eval_acc:.4f}")
1716
+ print(f"Val F1 Weighted (CL evaluator): {hf_eval_f1:.4f}")
1717
+ print(f"Total Trainable Params: {trainable_params:,}")
1718
+ print(f"Total Non-trainable Params: {non_trainable_params:,}")
1719
+ print("=" * 80)
1720
+
1721
+
1722
+ if __name__ == "__main__":
1723
+ main()