manpreet88 commited on
Commit
a5954f7
·
1 Parent(s): 231c5c8

Update Property_Prediction.py

Browse files
Files changed (1) hide show
  1. Downstream Tasks/Property_Prediction.py +331 -181
Downstream Tasks/Property_Prediction.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import random
4
  import time
@@ -20,66 +19,77 @@ import csv
20
  import copy
21
  from typing import List, Dict, Optional, Tuple, Any
22
 
 
23
  csv.field_size_limit(sys.maxsize)
24
 
25
- # Shared encoders/helpers from PolyFusion
 
 
26
  from PolyFusion.GINE import GineEncoder, match_edge_attr_to_index, safe_get
27
  from PolyFusion.SchNet import NodeSchNetWrapper
28
  from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder
29
  from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
30
 
31
- # -----------------------------
32
  # Configuration
33
- # -----------------------------
34
  BASE_DIR = "/path/to/Polymer_Foundational_Model"
35
  POLYINFO_PATH = "/path/to/polyinfo_with_modalities.csv"
36
 
37
- # Pretrained model directories
38
  PRETRAINED_MULTIMODAL_DIR = "/path/to/multimodal_output/best"
39
  BEST_GINE_DIR = "/path/to/gin_output/best"
40
  BEST_SCHNET_DIR = "/path/to/schnet_output/best"
41
  BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
42
  BEST_PSMILES_DIR = "/path/to/polybert_output/best"
43
 
 
44
  OUTPUT_RESULTS = "/path/to/multimodal_downstream_results.txt"
45
 
46
- # NEW: directory to save best-performing weights per property (best run/"fold")
47
  BEST_WEIGHTS_DIR = "/path/to/multimodal_downstream_bestweights"
48
 
49
- # Model hyperparameters (matching multimodal training)
 
 
50
  MAX_ATOMIC_Z = 85
51
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
52
 
53
- # GINE params
54
  NODE_EMB_DIM = 300
55
  EDGE_EMB_DIM = 300
56
  NUM_GNN_LAYERS = 5
57
 
58
- # SchNet params
59
  SCHNET_NUM_GAUSSIANS = 50
60
  SCHNET_NUM_INTERACTIONS = 6
61
  SCHNET_CUTOFF = 10.0
62
  SCHNET_MAX_NEIGHBORS = 64
63
  SCHNET_HIDDEN = 600
64
 
65
- # Fingerprint params
66
  FP_LENGTH = 2048
67
  MASK_TOKEN_ID_FP = 2
68
  VOCAB_SIZE_FP = 3
69
 
 
70
  CL_EMB_DIM = 600
71
 
72
- # PSMILES/Deberta params
73
  DEBERTA_HIDDEN = 600
74
  PSMILES_MAX_LEN = 128
75
 
76
- # Uni-Poly fusion params
77
- UNIPOLY_EMB_DIM = 600
78
- UNIPOLY_ATTN_HEADS = 8
79
- UNIPOLY_DROPOUT = 0.1
80
- UNIPOLY_FF_MULT = 4 # FFN hidden = 4*d (common transformer practice)
81
-
82
- # Training parameters (Uni-Poly fine-tuning)
 
 
 
 
83
  MAX_LEN = 128
84
  BATCH_SIZE = 32
85
  NUM_EPOCHS = 100
@@ -89,6 +99,7 @@ WEIGHT_DECAY = 0.0
89
 
90
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
91
 
 
92
  REQUESTED_PROPERTIES = [
93
  "density",
94
  "glass transition",
@@ -97,28 +108,31 @@ REQUESTED_PROPERTIES = [
97
  "thermal decomposition"
98
  ]
99
 
100
- # Uni-Poly reports stats over fivefold CV; we implement true 5-fold CV here
101
  NUM_RUNS = 5
102
  TEST_SIZE = 0.10
103
- VAL_SIZE_WITHIN_TRAINVAL = 0.10 # of trainval (i.e., 9% overall when TEST_SIZE=0.10)
104
 
105
- # Optional: PolyInfo duplicate aggregation key preference order
106
  AGG_KEYS_PREFERENCE = ["polymer_id", "PolymerID", "poly_id", "psmiles", "smiles", "canonical_smiles"]
107
 
108
- # -----------------------------
109
  # Utilities
110
- # -----------------------------
111
  def set_seed(seed: int):
 
112
  random.seed(seed)
113
  np.random.seed(seed)
114
  torch.manual_seed(seed)
115
  if torch.cuda.is_available():
116
  torch.cuda.manual_seed_all(seed)
 
117
  torch.backends.cudnn.deterministic = True
118
  torch.backends.cudnn.benchmark = False
119
 
120
 
121
  def make_json_serializable(obj):
 
122
  if isinstance(obj, dict):
123
  return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()}
124
  if isinstance(obj, (list, tuple, set)):
@@ -144,7 +158,15 @@ def make_json_serializable(obj):
144
  pass
145
  return obj
146
 
 
147
  def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_state: dict):
 
 
 
 
 
 
 
148
  n_ckpt = len(full_state)
149
  n_model = len(model_state)
150
  n_loaded = len(filtered_state)
@@ -160,13 +182,13 @@ def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_stat
160
  print(f" ckpt keys: {n_ckpt}")
161
  print(f" model keys: {n_model}")
162
  print(f" loaded keys: {n_loaded}")
163
- print(f" skipped (not in model): {len(missing_in_model)}")
164
- print(f" skipped (shape mismatch): {len(shape_mismatch)}")
165
 
166
  if missing_in_model:
167
- print(" examples missing_in_model:", missing_in_model[:10])
168
  if shape_mismatch:
169
- print(" examples shape_mismatch:")
170
  for k in shape_mismatch[:10]:
171
  print(f" {k}: ckpt={tuple(full_state[k].shape)} model={tuple(model_state[k].shape)}")
172
  print("")
@@ -174,10 +196,10 @@ def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_stat
174
 
175
  def find_property_columns(columns):
176
  """
177
- Safer property column matching:
178
- - Prefer exact token matches (word-level).
179
- - For 'density' explicitly avoid matching columns that contain 'cohesive' (e.g., 'Cohesive energy density').
180
- - Log candidates when ambiguity exists.
181
  """
182
  lowered = {c.lower(): c for c in columns}
183
  found = {}
@@ -186,6 +208,7 @@ def find_property_columns(columns):
186
  req_low = req.lower().strip()
187
  exact = None
188
 
 
189
  for c_low, c_orig in lowered.items():
190
  tokens = set(c_low.replace('_', ' ').split())
191
  if req_low in tokens or c_low == req_low:
@@ -198,6 +221,7 @@ def find_property_columns(columns):
198
  found[req] = exact
199
  continue
200
 
 
201
  candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low]
202
  if req_low == "density":
203
  candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()]
@@ -207,14 +231,16 @@ def find_property_columns(columns):
207
  else:
208
  chosen = candidates[0] if candidates else None
209
  found[req] = chosen
210
- print(f"[INFO] Requested property '{req}' -> chosen column: {chosen}")
211
  if candidates:
212
- print(f"[INFO] Candidates for '{req}': {candidates}")
213
  else:
214
- print(f"[WARN] No candidates found for '{req}' using substring search.")
215
  return found
216
 
 
217
  def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
 
218
  for k in AGG_KEYS_PREFERENCE:
219
  if k in df.columns:
220
  return k
@@ -223,38 +249,38 @@ def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
223
 
224
  def aggregate_polyinfo_duplicates(df: pd.DataFrame, modality_cols: List[str], property_cols: List[str]) -> pd.DataFrame:
225
  """
226
- Uni-Poly describes grouping samples for the same monomer and averaging properties to reduce noise.
227
- We apply that here if we can identify a stable key (polymer_id / psmiles / smiles).
 
 
228
  """
229
  key = choose_aggregation_key(df)
230
  if key is None:
231
- print("[INFO] No aggregation key found; skipping duplicate aggregation.")
232
  return df
233
 
234
- # Keep only rows with a non-empty key
235
  df2 = df.copy()
236
  df2[key] = df2[key].astype(str)
237
  df2 = df2[df2[key].str.strip() != ""].copy()
238
  if len(df2) == 0:
239
- print("[INFO] Aggregation key exists but is empty; skipping duplicate aggregation.")
240
  return df
241
 
242
  agg_dict = {}
243
- # modalities: take first non-null string (they should be identical per polymer id in your dataset)
244
  for mc in modality_cols:
245
  if mc in df2.columns:
246
  agg_dict[mc] = "first"
247
- # properties: mean
248
  for pc in property_cols:
249
  if pc in df2.columns:
250
  agg_dict[pc] = "mean"
251
 
252
  grouped = df2.groupby(key, as_index=False).agg(agg_dict)
253
- print(f"[INFO] Aggregated duplicates by '{key}': {len(df)} rows -> {len(grouped)} unique keys")
254
  return grouped
255
 
256
 
257
  def _sanitize_name(s: str) -> str:
 
258
  s2 = str(s).strip().lower()
259
  keep = []
260
  for ch in s2:
@@ -270,20 +296,29 @@ def _sanitize_name(s: str) -> str:
270
  out = out.strip("_")
271
  return out or "property"
272
 
 
 
 
 
273
  class MultimodalContrastiveModel(nn.Module):
274
  """
275
- Multimodal encoder wrapper that:
276
- - Takes pretrained modality encoders (GINE / SchNet / FP / PSMILES)
277
- - Projects each modality to emb_dim
278
- - L2-normalizes per modality embedding, then masked-mean combines across available modalities
279
- - L2-normalizes final combined embedding
280
-
281
- This version matches your downstream instantiation:
282
- MultimodalContrastiveModel(gine_encoder, schnet_encoder, fp_encoder, psmiles_encoder,
283
- emb_dim=..., modalities=[...])
284
-
285
- And your downstream forward usage:
286
- z = model(batch_mods, modality_mask=modality_mask)
 
 
 
 
 
287
  """
288
 
289
  def __init__(
@@ -310,7 +345,7 @@ class MultimodalContrastiveModel(nn.Module):
310
  self.out_dim = self.emb_dim
311
  self.dropout = nn.Dropout(float(dropout))
312
 
313
- # Decide which modalities are enabled
314
  if modalities is None:
315
  mods = []
316
  if self.gine is not None:
@@ -325,12 +360,12 @@ class MultimodalContrastiveModel(nn.Module):
325
  else:
326
  self.modalities = [m for m in modalities if m in ("gine", "schnet", "fp", "psmiles")]
327
 
328
- # Projection heads (match dims used elsewhere in your script)
329
  self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None
330
  self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None
331
  self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None
332
 
333
- # PSMILES encoder may not always be exactly DEBERTA_HIDDEN; try to infer
334
  psm_in = None
335
  if self.psmiles is not None:
336
  if hasattr(self.psmiles, "out_dim"):
@@ -349,7 +384,7 @@ class MultimodalContrastiveModel(nn.Module):
349
  self.proj_psmiles = nn.Linear(psm_in, self.emb_dim) if (self.psmiles is not None) else None
350
 
351
  def freeze_cl_encoders(self):
352
- """Optional helper to freeze modality encoders."""
353
  for enc in (self.gine, self.schnet, self.fp, self.psmiles):
354
  if enc is None:
355
  continue
@@ -359,9 +394,11 @@ class MultimodalContrastiveModel(nn.Module):
359
 
360
  def _masked_mean_combine(self, zs: List[torch.Tensor], masks: List[torch.Tensor]) -> torch.Tensor:
361
  """
362
- zs: list of (B,D)
363
- masks:list of (B,) boolean, True where modality is present
364
- returns: (B,D) masked mean
 
 
365
  """
366
  if not zs:
367
  device = next(self.parameters()).device
@@ -384,22 +421,21 @@ class MultimodalContrastiveModel(nn.Module):
384
  def forward(self, batch_mods: dict, modality_mask: Optional[dict] = None) -> torch.Tensor:
385
  """
386
  batch_mods keys: 'gine', 'schnet', 'fp', 'psmiles'
387
- modality_mask: dict with boolean tensors of shape (B,) for each modality
388
  """
389
  device = next(self.parameters()).device
390
 
391
  zs = []
392
  ms = []
393
 
394
- # Determine batch size B from any available modality tensor
395
  B = None
396
  if modality_mask is not None:
397
- for k, v in modality_mask.items():
398
  if isinstance(v, torch.Tensor) and v.numel() > 0:
399
  B = int(v.size(0))
400
  break
401
 
402
- # Fallback: infer B from fp/psmiles tensors
403
  if B is None:
404
  if "fp" in batch_mods and batch_mods["fp"] is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor):
405
  B = int(batch_mods["fp"]["input_ids"].size(0))
@@ -409,13 +445,14 @@ class MultimodalContrastiveModel(nn.Module):
409
  if B is None:
410
  return torch.zeros((1, self.emb_dim), device=device)
411
 
412
- # Helper to get modality present mask
413
  def _get_mask(name: str) -> torch.Tensor:
414
  if modality_mask is not None and name in modality_mask and isinstance(modality_mask[name], torch.Tensor):
415
  return modality_mask[name].to(device).bool()
416
  return torch.ones((B,), device=device, dtype=torch.bool)
417
 
418
- # GINE
 
 
419
  if "gine" in self.modalities and self.gine is not None and batch_mods.get("gine", None) is not None:
420
  g = batch_mods["gine"]
421
  if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0:
@@ -433,7 +470,9 @@ class MultimodalContrastiveModel(nn.Module):
433
  zs.append(z)
434
  ms.append(_get_mask("gine"))
435
 
436
- # SchNet
 
 
437
  if "schnet" in self.modalities and self.schnet is not None and batch_mods.get("schnet", None) is not None:
438
  s = batch_mods["schnet"]
439
  if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0:
@@ -448,7 +487,9 @@ class MultimodalContrastiveModel(nn.Module):
448
  zs.append(z)
449
  ms.append(_get_mask("schnet"))
450
 
451
- # FP
 
 
452
  if "fp" in self.modalities and self.fp is not None and batch_mods.get("fp", None) is not None:
453
  f = batch_mods["fp"]
454
  if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0:
@@ -462,7 +503,9 @@ class MultimodalContrastiveModel(nn.Module):
462
  zs.append(z)
463
  ms.append(_get_mask("fp"))
464
 
465
- # PSMILES
 
 
466
  if "psmiles" in self.modalities and self.psmiles is not None and batch_mods.get("psmiles", None) is not None:
467
  p = batch_mods["psmiles"]
468
  if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0:
@@ -476,7 +519,7 @@ class MultimodalContrastiveModel(nn.Module):
476
  zs.append(z)
477
  ms.append(_get_mask("psmiles"))
478
 
479
- # Combine (masked mean), then normalize
480
  if not zs:
481
  return torch.zeros((B, self.emb_dim), device=device)
482
 
@@ -484,7 +527,6 @@ class MultimodalContrastiveModel(nn.Module):
484
  z = F.normalize(z, dim=-1)
485
  return z
486
 
487
-
488
  @torch.no_grad()
489
  def encode_psmiles(
490
  self,
@@ -494,7 +536,7 @@ class MultimodalContrastiveModel(nn.Module):
494
  device: str = DEVICE
495
  ) -> np.ndarray:
496
  """
497
- PSMILES-only embeddings (used for generated candidates + oracle).
498
  """
499
  self.eval()
500
  if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None:
@@ -519,9 +561,9 @@ class MultimodalContrastiveModel(nn.Module):
519
  device: str = DEVICE
520
  ) -> np.ndarray:
521
  """
522
- Full multimodal embeddings for records that carry:
523
- graph, geometry, fingerprints, psmiles
524
- Missing modalities are skipped per-sample (combine is mean over available).
525
  """
526
  self.eval()
527
  dev = torch.device(device)
@@ -531,14 +573,13 @@ class MultimodalContrastiveModel(nn.Module):
531
  for i in range(0, len(records), batch_size):
532
  chunk = records[i:i + batch_size]
533
 
534
- # Build per-sample modality tensors; do minimal batching (stack where possible)
535
  # PSMILES batch
536
  psmiles_texts = [str(r.get("psmiles", "")) for r in chunk]
537
  p_enc = None
538
  if self.psm_tok is not None:
539
  p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt")
540
 
541
- # FP batch
542
  fp_ids, fp_attn = [], []
543
  for r in chunk:
544
  f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH)
@@ -547,8 +588,7 @@ class MultimodalContrastiveModel(nn.Module):
547
  fp_ids = torch.stack(fp_ids, dim=0)
548
  fp_attn = torch.stack(fp_attn, dim=0)
549
 
550
- # GINE and SchNet: variable-size; we keep as “packed with simple batch vectors
551
- # (mean pooling per sample already handled in encoder via batch vectors)
552
  gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []}
553
  node_offset = 0
554
  for bi, r in enumerate(chunk):
@@ -606,21 +646,32 @@ class MultimodalContrastiveModel(nn.Module):
606
  "psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None
607
  }
608
 
609
- z = self.forward_multimodal(batch_mods)
 
610
  outs.append(z.detach().cpu().numpy())
611
 
612
  return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
613
 
614
- # -----------------------------
 
615
  # Tokenizer setup
616
- # -----------------------------
617
  SPM_MODEL = "/path/to/spm.model"
618
  tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
619
 
620
- # -----------------------------
621
- # Dataset (single-task, Uni-Poly fine-tuning style)
622
- # -----------------------------
623
  class PolymerPropertyDataset(Dataset):
 
 
 
 
 
 
 
 
 
624
  def __init__(self, data_list, tokenizer, max_length=128):
625
  self.data_list = data_list
626
  self.tokenizer = tokenizer
@@ -632,7 +683,9 @@ class PolymerPropertyDataset(Dataset):
632
  def __getitem__(self, idx):
633
  data = self.data_list[idx]
634
 
635
- # Parse graph data for GINE
 
 
636
  gine_data = None
637
  if 'graph' in data and data['graph']:
638
  try:
@@ -661,6 +714,7 @@ class PolymerPropertyDataset(Dataset):
661
  edge_index = None
662
  edge_attr = None
663
 
 
664
  if edge_indices_raw is None:
665
  adj_mat = safe_get(graph_field, "adjacency_matrix", None)
666
  if adj_mat:
@@ -676,6 +730,7 @@ class PolymerPropertyDataset(Dataset):
676
  E = len(srcs)
677
  edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
678
  else:
 
679
  srcs, dsts = [], []
680
  if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0:
681
  if isinstance(edge_indices_raw[0], list):
@@ -702,6 +757,7 @@ class PolymerPropertyDataset(Dataset):
702
  if len(srcs) > 0:
703
  edge_index = [srcs, dsts]
704
 
 
705
  if edge_features_raw and isinstance(edge_features_raw, list):
706
  bond_types = []
707
  stereos = []
@@ -729,7 +785,9 @@ class PolymerPropertyDataset(Dataset):
729
  except Exception:
730
  gine_data = None
731
 
732
- # Parse geometry data for SchNet
 
 
733
  schnet_data = None
734
  if 'geometry' in data and data['geometry']:
735
  try:
@@ -746,7 +804,9 @@ class PolymerPropertyDataset(Dataset):
746
  except Exception:
747
  schnet_data = None
748
 
749
- # Parse fingerprints
 
 
750
  fp_data = None
751
  if 'fingerprints' in data and data['fingerprints']:
752
  try:
@@ -797,7 +857,9 @@ class PolymerPropertyDataset(Dataset):
797
  except Exception:
798
  fp_data = None
799
 
800
- # Parse PSMILES
 
 
801
  psmiles_data = None
802
  if 'psmiles' in data and data['psmiles'] and self.tokenizer is not None:
803
  try:
@@ -815,7 +877,9 @@ class PolymerPropertyDataset(Dataset):
815
  except Exception:
816
  psmiles_data = None
817
 
818
- # Defaults for missing modalities
 
 
819
  if gine_data is None:
820
  gine_data = {
821
  'z': torch.tensor([], dtype=torch.long),
@@ -843,7 +907,7 @@ class PolymerPropertyDataset(Dataset):
843
  'attention_mask': torch.zeros(PSMILES_MAX_LEN, dtype=torch.bool)
844
  }
845
 
846
- # Single-task target (scaled)
847
  target_scaled = float(data.get("target_scaled", 0.0))
848
 
849
  return {
@@ -855,10 +919,23 @@ class PolymerPropertyDataset(Dataset):
855
  }
856
 
857
 
 
 
 
858
  def multimodal_collate_fn(batch):
 
 
 
 
 
 
 
 
859
  B = len(batch)
860
 
861
- # GINE batching
 
 
862
  all_z = []
863
  all_ch = []
864
  all_fc = []
@@ -905,7 +982,9 @@ def multimodal_collate_fn(batch):
905
  edge_index_batched = torch.empty((2, 0), dtype=torch.long)
906
  edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
907
 
908
- # SchNet batching
 
 
909
  all_sz = []
910
  all_pos = []
911
  schnet_batch = []
@@ -930,12 +1009,16 @@ def multimodal_collate_fn(batch):
930
  s_pos_batch = torch.cat(all_pos, dim=0)
931
  s_batch_batch = torch.cat(schnet_batch, dim=0)
932
 
933
- # FP batching
 
 
934
  fp_ids = torch.stack([item["fp"]["input_ids"] for item in batch], dim=0)
935
  fp_attn = torch.stack([item["fp"]["attention_mask"] for item in batch], dim=0)
936
  fp_present = (fp_attn.sum(dim=1) > 0).cpu().numpy().tolist()
937
 
938
- # PSMILES batching
 
 
939
  p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch], dim=0)
940
  p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch], dim=0)
941
  psmiles_present = (p_attn.sum(dim=1) > 0).cpu().numpy().tolist()
@@ -943,6 +1026,7 @@ def multimodal_collate_fn(batch):
943
  # Target
944
  target = torch.stack([item["target"] for item in batch], dim=0) # (B,)
945
 
 
946
  modality_mask = {
947
  "gine": torch.tensor(gine_present, dtype=torch.bool),
948
  "schnet": torch.tensor(schnet_present, dtype=torch.bool),
@@ -977,13 +1061,17 @@ def multimodal_collate_fn(batch):
977
  }
978
 
979
 
980
- # -----------------------------
981
- # Property regression head (single task)
982
- # -----------------------------
983
- class UniPolyPropertyRegressor(nn.Module):
984
- def __init__(self, unipoly_model: MultimodalContrastiveModel, emb_dim: int = UNIPOLY_EMB_DIM, dropout: float = 0.1):
 
 
 
 
985
  super().__init__()
986
- self.unipoly = unipoly_model
987
  self.head = nn.Sequential(
988
  nn.Linear(emb_dim, emb_dim // 2),
989
  nn.ReLU(),
@@ -992,15 +1080,16 @@ class UniPolyPropertyRegressor(nn.Module):
992
  )
993
 
994
  def forward(self, batch_mods, modality_mask=None):
995
- emb = self.unipoly(batch_mods, modality_mask=modality_mask) # (B,d)
996
  y = self.head(emb).squeeze(-1) # (B,)
997
  return y
998
 
999
 
1000
- # -----------------------------
1001
- # Training / evaluation
1002
- # -----------------------------
1003
  def compute_metrics(y_true, y_pred):
 
1004
  mse = mean_squared_error(y_true, y_pred)
1005
  rmse = math.sqrt(mse)
1006
  mae = mean_absolute_error(y_true, y_pred)
@@ -1009,11 +1098,13 @@ def compute_metrics(y_true, y_pred):
1009
 
1010
 
1011
  def train_one_epoch(model, dataloader, optimizer, device):
 
1012
  model.train()
1013
  total_loss = 0.0
1014
  total_n = 0
1015
 
1016
  for batch in dataloader:
 
1017
  for k in batch:
1018
  if k == "target":
1019
  batch[k] = batch[k].to(device)
@@ -1046,6 +1137,10 @@ def train_one_epoch(model, dataloader, optimizer, device):
1046
 
1047
  @torch.no_grad()
1048
  def evaluate(model, dataloader, device):
 
 
 
 
1049
  model.eval()
1050
  preds = []
1051
  trues = []
@@ -1053,6 +1148,7 @@ def evaluate(model, dataloader, device):
1053
  total_n = 0
1054
 
1055
  for batch in dataloader:
 
1056
  for k in batch:
1057
  if k == "target":
1058
  batch[k] = batch[k].to(device)
@@ -1088,26 +1184,37 @@ def evaluate(model, dataloader, device):
1088
  return float(avg_loss), preds, trues
1089
 
1090
 
1091
- # -----------------------------
1092
- # Pretrained loading
1093
- # -----------------------------
1094
  def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveModel:
 
 
 
 
 
 
 
 
 
 
1095
  gine_encoder = GineEncoder(
1096
  node_emb_dim=NODE_EMB_DIM,
1097
  edge_emb_dim=EDGE_EMB_DIM,
1098
  num_layers=NUM_GNN_LAYERS,
1099
  max_atomic_z=MAX_ATOMIC_Z
1100
  )
1101
- if os.path.exists(os.path.join(BEST_GINE_DIR, "pytorch_model.bin")):
 
1102
  try:
1103
- gine_encoder.load_state_dict(
1104
- torch.load(os.path.join(BEST_GINE_DIR, "pytorch_model.bin"), map_location="cpu"),
1105
- strict=False
1106
- )
1107
- print(f"Loaded GINE weights from {BEST_GINE_DIR}")
1108
  except Exception as e:
1109
- print(f"Warning: Could not load GINE weights: {e}")
1110
 
 
 
 
1111
  schnet_encoder = NodeSchNetWrapper(
1112
  hidden_channels=SCHNET_HIDDEN,
1113
  num_interactions=SCHNET_NUM_INTERACTIONS,
@@ -1115,16 +1222,17 @@ def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveMod
1115
  cutoff=SCHNET_CUTOFF,
1116
  max_num_neighbors=SCHNET_MAX_NEIGHBORS
1117
  )
1118
- if os.path.exists(os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin")):
 
1119
  try:
1120
- schnet_encoder.load_state_dict(
1121
- torch.load(os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin"), map_location="cpu"),
1122
- strict=False
1123
- )
1124
- print(f"Loaded SchNet weights from {BEST_SCHNET_DIR}")
1125
  except Exception as e:
1126
- print(f"Warning: Could not load SchNet weights: {e}")
1127
 
 
 
 
1128
  fp_encoder = FingerprintEncoder(
1129
  vocab_size=VOCAB_SIZE_FP,
1130
  hidden_dim=256,
@@ -1134,42 +1242,49 @@ def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveMod
1134
  dim_feedforward=1024,
1135
  dropout=0.1
1136
  )
1137
- if os.path.exists(os.path.join(BEST_FP_DIR, "pytorch_model.bin")):
 
1138
  try:
1139
- fp_encoder.load_state_dict(
1140
- torch.load(os.path.join(BEST_FP_DIR, "pytorch_model.bin"), map_location="cpu"),
1141
- strict=False
1142
- )
1143
- print(f"Loaded fingerprint encoder weights from {BEST_FP_DIR}")
1144
  except Exception as e:
1145
- print(f"Warning: Could not load fingerprint weights: {e}")
1146
 
 
 
 
1147
  psmiles_encoder = None
1148
  if os.path.isdir(BEST_PSMILES_DIR):
1149
  try:
1150
  psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
1151
- print(f"Loaded PSMILES encoder from {BEST_PSMILES_DIR}")
1152
  except Exception as e:
1153
- print(f"Warning: Could not load PSMILES encoder: {e}")
1154
 
 
1155
  if psmiles_encoder is None:
1156
  try:
1157
  psmiles_encoder = PSMILESDebertaEncoder(
1158
  model_dir_or_name=None,
1159
  vocab_fallback=int(getattr(tokenizer, "vocab_size", 300))
1160
  )
 
1161
  except Exception as e:
1162
- print(f"Warning: Could not initialize PSMILES encoder: {e}")
1163
 
 
1164
  multimodal_model = MultimodalContrastiveModel(
1165
  gine_encoder,
1166
  schnet_encoder,
1167
  fp_encoder,
1168
  psmiles_encoder,
1169
- emb_dim=UNIPOLY_EMB_DIM,
1170
  modalities=["gine", "schnet", "fp", "psmiles"]
1171
  )
1172
 
 
 
 
1173
  ckpt_path = os.path.join(pretrained_path, "pytorch_model.bin")
1174
  if os.path.isfile(ckpt_path):
1175
  try:
@@ -1185,26 +1300,34 @@ def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveMod
1185
 
1186
  summarize_state_dict_load(state, model_state, filtered_state)
1187
  missing, unexpected = multimodal_model.load_state_dict(filtered_state, strict=False)
1188
- print(f"[INFO] Loaded multimodal pretrained weights from {ckpt_path}")
1189
- print(f"[INFO] load_state_dict() -> missing={len(missing)} unexpected={len(unexpected)}")
1190
  if missing:
1191
- print("[INFO] Missing keys (sample):", missing[:50])
1192
  if unexpected:
1193
- print("[INFO] Unexpected keys (sample):", unexpected[:50])
1194
 
1195
  except Exception as e:
1196
- print(f"Warning: Failed to load multimodal pretrained weights: {e}")
 
 
1197
 
1198
  return multimodal_model
1199
 
1200
 
1201
- # -----------------------------
1202
- # Downstream run (Uni-Poly matched)
1203
- # -----------------------------
1204
  def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
 
 
 
 
 
 
1205
  samples = []
1206
  for _, row in df.iterrows():
1207
- # Require at least one modality
1208
  has_modality = False
1209
  for col in ['graph', 'geometry', 'fingerprints', 'psmiles']:
1210
  if col in row and row[col] and str(row[col]).strip() != "":
@@ -1232,32 +1355,47 @@ def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
1232
  return samples
1233
 
1234
 
1235
- def run_unipoly_downstream(property_list: List[str], property_cols: List[str], df_raw: pd.DataFrame,
1236
  pretrained_path: str, output_file: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
  os.makedirs(pretrained_path, exist_ok=True)
1238
 
1239
- # Aggregate duplicates (Uni-Poly noise reduction) if possible
1240
  modality_cols = ["graph", "geometry", "fingerprints", "psmiles"]
1241
  df_proc = aggregate_polyinfo_duplicates(df_raw, modality_cols=modality_cols, property_cols=property_cols)
1242
 
1243
- all_results = {"per_property": {}, "mode": "UNIPOLY_MATCHED_SINGLE_TASK"}
1244
 
1245
- # Load base pretrained model once per property-run (we reload inside each run to avoid leakage across runs)
1246
  for pname, pcol in zip(property_list, property_cols):
1247
- print(f"\n=== Uni-Poly matched fine-tuning: {pname} (col='{pcol}') ===")
1248
  samples = build_samples_for_property(df_proc, pcol)
 
 
1249
  if len(samples) < 200:
1250
- print(f"[WARN] Too few samples for '{pname}': {len(samples)}. Results may be unstable.")
1251
  if len(samples) < 50:
1252
- print(f"[WARN] Skipping '{pname}' due to insufficient samples.")
1253
  continue
1254
 
1255
  run_metrics = []
1256
  run_records = []
1257
 
1258
- # NEW: track which run ("fold") gives best performance and save its weights
1259
  best_overall_r2 = -1e18
1260
- best_overall_payload = None # contains model_state_dict (cpu), scaler stats, and metadata
1261
 
1262
  idxs = np.arange(len(samples))
1263
  cv = KFold(n_splits=NUM_RUNS, shuffle=True, random_state=42)
@@ -1266,10 +1404,12 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1266
  seed = 42 + run_idx
1267
  set_seed(seed)
1268
 
 
 
1269
  trainval = [copy.deepcopy(samples[i]) for i in trainval_idx]
1270
  test = [copy.deepcopy(samples[i]) for i in test_idx]
1271
 
1272
- # train/val split within trainval
1273
  tr_idx, va_idx = train_test_split(
1274
  np.arange(len(trainval)),
1275
  test_size=VAL_SIZE_WITHIN_TRAINVAL,
@@ -1279,7 +1419,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1279
  train = [copy.deepcopy(trainval[i]) for i in tr_idx]
1280
  val = [copy.deepcopy(trainval[i]) for i in va_idx]
1281
 
1282
- # Z-score normalization (StandardScaler) fit on train
1283
  sc = StandardScaler()
1284
  sc.fit(np.array([s["target_raw"] for s in train]).reshape(-1, 1))
1285
 
@@ -1299,9 +1439,11 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1299
  dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
1300
  dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
1301
 
1302
- # Load pretrained base, then attach Uni-Poly head; fine-tune end-to-end
1303
- unipoly_base = load_pretrained_multimodal(pretrained_path)
1304
- model = UniPolyPropertyRegressor(unipoly_base, emb_dim=UNIPOLY_EMB_DIM, dropout=UNIPOLY_DROPOUT).to(DEVICE)
 
 
1305
 
1306
  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
1307
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
@@ -1310,6 +1452,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1310
  best_state = None
1311
  no_improve = 0
1312
 
 
1313
  for epoch in range(1, NUM_EPOCHS + 1):
1314
  tr_loss = train_one_epoch(model, dl_train, optimizer, DEVICE)
1315
  va_loss, _, _ = evaluate(model, dl_val, DEVICE)
@@ -1317,7 +1460,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1317
 
1318
  scheduler.step()
1319
 
1320
- print(f"[{pname}] fold {run_idx+1}/{NUM_RUNS} epoch {epoch:03d} train={tr_loss:.6f} val={va_loss:.6f}")
1321
 
1322
  if va_loss < best_val - 1e-8:
1323
  best_val = va_loss
@@ -1326,25 +1469,29 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1326
  else:
1327
  no_improve += 1
1328
  if no_improve >= PATIENCE:
1329
- print(f"[{pname}] Early stopping at epoch {epoch} (patience={PATIENCE}).")
1330
  break
1331
 
1332
  if best_state is None:
1333
- print(f"[WARN] No best state saved for {pname} fold {run_idx+1}; skipping.")
1334
  continue
1335
 
 
1336
  model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=True)
1337
  _, pred_scaled, true_scaled = evaluate(model, dl_test, DEVICE)
1338
  if pred_scaled is None:
 
1339
  continue
1340
 
1341
- # Inverse transform to original units
1342
  pred = sc.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()
1343
  true = sc.inverse_transform(true_scaled.reshape(-1, 1)).ravel()
1344
 
1345
  m = compute_metrics(true, pred)
1346
  run_metrics.append(m)
1347
 
 
 
1348
  record = {
1349
  "property": pname,
1350
  "property_col": pcol,
@@ -1361,7 +1508,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1361
  with open(output_file, "a") as fh:
1362
  fh.write(json.dumps(make_json_serializable(record)) + "\n")
1363
 
1364
- # NEW: if this run is best for the property (by test R2), keep its weights + scaler
1365
  if float(m.get("r2", -1e18)) > float(best_overall_r2):
1366
  best_overall_r2 = float(m.get("r2", -1e18))
1367
  best_overall_payload = {
@@ -1378,20 +1525,16 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1378
  "scaler_scale": make_json_serializable(getattr(sc, "scale_", None)),
1379
  "scaler_var": make_json_serializable(getattr(sc, "var_", None)),
1380
  "scaler_n_samples_seen": make_json_serializable(getattr(sc, "n_samples_seen_", None)),
1381
- "model_state_dict": best_state, # already CPU tensors
1382
  }
1383
 
1384
- # NEW: save best run ("fold") weights for this property into a new directory
1385
  if best_overall_payload is not None and "model_state_dict" in best_overall_payload:
1386
  os.makedirs(BEST_WEIGHTS_DIR, exist_ok=True)
1387
  prop_dir = os.path.join(BEST_WEIGHTS_DIR, _sanitize_name(pname))
1388
  os.makedirs(prop_dir, exist_ok=True)
1389
 
1390
- # Save a single checkpoint bundle + a lightweight metadata json
1391
- ckpt_bundle = {
1392
- k: v for k, v in best_overall_payload.items()
1393
- if k != "test_metrics"
1394
- }
1395
  ckpt_bundle["test_metrics"] = best_overall_payload["test_metrics"]
1396
 
1397
  torch.save(ckpt_bundle, os.path.join(prop_dir, "best_run_checkpoint.pt"))
@@ -1400,10 +1543,10 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1400
  with open(os.path.join(prop_dir, "best_run_metadata.json"), "w") as fh:
1401
  fh.write(json.dumps(make_json_serializable(meta), indent=2))
1402
 
1403
- print(f"[INFO] Saved best weights for '{pname}' to: {prop_dir}")
1404
- print(f"[INFO] best_run={best_overall_payload['best_run']} best_test_r2={best_overall_payload['test_metrics'].get('r2', None)}")
1405
 
1406
- # Aggregate across runs
1407
  if run_metrics:
1408
  r2s = [x["r2"] for x in run_metrics]
1409
  maes = [x["mae"] for x in run_metrics]
@@ -1415,8 +1558,10 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1415
  "rmse": {"mean": float(np.mean(rmses)), "std": float(np.std(rmses, ddof=0))},
1416
  "mse": {"mean": float(np.mean(mses)), "std": float(np.std(mses, ddof=0))},
1417
  }
 
1418
  else:
1419
  agg = None
 
1420
 
1421
  all_results["per_property"][pname] = {
1422
  "property_col": pcol,
@@ -1431,38 +1576,42 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
1431
  return all_results
1432
 
1433
 
1434
- # -----------------------------
1435
  # Main
1436
- # -----------------------------
1437
  def main():
 
1438
  if os.path.exists(OUTPUT_RESULTS):
1439
  backup = OUTPUT_RESULTS + ".bak"
1440
  shutil.copy(OUTPUT_RESULTS, backup)
1441
- print(f"[INFO] Existing {OUTPUT_RESULTS} backed up to {backup}")
1442
  open(OUTPUT_RESULTS, "w").close()
 
1443
 
 
1444
  if not os.path.isfile(POLYINFO_PATH):
1445
  raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}")
1446
  polyinfo_raw = pd.read_csv(POLYINFO_PATH, engine="python")
 
1447
 
 
1448
  found = find_property_columns(polyinfo_raw.columns)
1449
  prop_map = {req: col for req, col in found.items()}
1450
- print(f"[INFO] Property-to-column map: {prop_map}")
1451
 
1452
  property_list = []
1453
  property_cols = []
1454
  for req in REQUESTED_PROPERTIES:
1455
  col = prop_map.get(req)
1456
  if col is None:
1457
- print(f"[WARN] Could not find a column for requested property '{req}'. Skipping.")
1458
  continue
1459
  property_list.append(req)
1460
  property_cols.append(col)
1461
 
1462
- print(f"\n=== Running Uni-Poly MATCHED downstream fine-tuning for properties: {property_list} ===")
1463
- overall = run_unipoly_downstream(property_list, property_cols, polyinfo_raw, PRETRAINED_MULTIMODAL_DIR, OUTPUT_RESULTS)
1464
 
1465
- # Write final summary (aggregated per property)
1466
  final_agg = {}
1467
  if overall and "per_property" in overall:
1468
  for pname, info in overall["per_property"].items():
@@ -1473,7 +1622,8 @@ def main():
1473
  fh.write(json.dumps(make_json_serializable(final_agg), indent=2))
1474
  fh.write("\n")
1475
 
1476
- print(f"\n[DONE] Results appended to {OUTPUT_RESULTS}")
 
1477
 
1478
 
1479
  if __name__ == "__main__":
 
 
1
  import os
2
  import random
3
  import time
 
19
  import copy
20
  from typing import List, Dict, Optional, Tuple, Any
21
 
22
+ # Increase CSV field size limit (PolyInfo modality JSON fields can be large).
23
  csv.field_size_limit(sys.maxsize)
24
 
25
+ # =============================================================================
26
+ # Imports: Shared encoders/helpers from PolyFusion
27
+ # =============================================================================
28
  from PolyFusion.GINE import GineEncoder, match_edge_attr_to_index, safe_get
29
  from PolyFusion.SchNet import NodeSchNetWrapper
30
  from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder
31
  from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
32
 
33
+ # =============================================================================
34
  # Configuration
35
+ # =============================================================================
36
  BASE_DIR = "/path/to/Polymer_Foundational_Model"
37
  POLYINFO_PATH = "/path/to/polyinfo_with_modalities.csv"
38
 
39
+ # Pretrained encoder directories (update these placeholders)
40
  PRETRAINED_MULTIMODAL_DIR = "/path/to/multimodal_output/best"
41
  BEST_GINE_DIR = "/path/to/gin_output/best"
42
  BEST_SCHNET_DIR = "/path/to/schnet_output/best"
43
  BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
44
  BEST_PSMILES_DIR = "/path/to/polybert_output/best"
45
 
46
+ # Output log file (per-run json lines + per-property aggregated summary)
47
  OUTPUT_RESULTS = "/path/to/multimodal_downstream_results.txt"
48
 
49
+ # Directory to save best-performing checkpoint bundle per property (best CV run)
50
  BEST_WEIGHTS_DIR = "/path/to/multimodal_downstream_bestweights"
51
 
52
+ # -----------------------------------------------------------------------------
53
+ # Model sizes / dims (must match your pretrained multimodal encoder settings)
54
+ # -----------------------------------------------------------------------------
55
  MAX_ATOMIC_Z = 85
56
  MASK_ATOM_ID = MAX_ATOMIC_Z + 1
57
 
58
+ # GINE
59
  NODE_EMB_DIM = 300
60
  EDGE_EMB_DIM = 300
61
  NUM_GNN_LAYERS = 5
62
 
63
+ # SchNet
64
  SCHNET_NUM_GAUSSIANS = 50
65
  SCHNET_NUM_INTERACTIONS = 6
66
  SCHNET_CUTOFF = 10.0
67
  SCHNET_MAX_NEIGHBORS = 64
68
  SCHNET_HIDDEN = 600
69
 
70
+ # Fingerprints
71
  FP_LENGTH = 2048
72
  MASK_TOKEN_ID_FP = 2
73
  VOCAB_SIZE_FP = 3
74
 
75
+ # Contrastive embedding dim
76
  CL_EMB_DIM = 600
77
 
78
+ # PSMILES/DeBERTa
79
  DEBERTA_HIDDEN = 600
80
  PSMILES_MAX_LEN = 128
81
 
82
+ # -----------------------------------------------------------------------------
83
+ # Uni-Poly style fusion + regression head hyperparameters
84
+ # -----------------------------------------------------------------------------
85
+ POLYF_EMB_DIM = 600
86
+ POLYF_ATTN_HEADS = 8
87
+ POLYF_DROPOUT = 0.1
88
+ POLYF_FF_MULT = 4 # FFN hidden = 4*d (common transformer practice)
89
+
90
+ # -----------------------------------------------------------------------------
91
+ # Fine-tuning parameters (single-task per property)
92
+ # -----------------------------------------------------------------------------
93
  MAX_LEN = 128
94
  BATCH_SIZE = 32
95
  NUM_EPOCHS = 100
 
99
 
100
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
101
 
102
+ # Properties to evaluate (order preserved)
103
  REQUESTED_PROPERTIES = [
104
  "density",
105
  "glass transition",
 
108
  "thermal decomposition"
109
  ]
110
 
111
+ # True K-fold evaluation to match "fivefold per property"
112
  NUM_RUNS = 5
113
  TEST_SIZE = 0.10
114
+ VAL_SIZE_WITHIN_TRAINVAL = 0.10 # fraction of trainval reserved for val split
115
 
116
+ # Duplicate aggregation (noise reduction) key preference order
117
  AGG_KEYS_PREFERENCE = ["polymer_id", "PolymerID", "poly_id", "psmiles", "smiles", "canonical_smiles"]
118
 
119
+ # =============================================================================
120
  # Utilities
121
+ # =============================================================================
122
  def set_seed(seed: int):
123
+ """Set all relevant RNG seeds for reproducible folds."""
124
  random.seed(seed)
125
  np.random.seed(seed)
126
  torch.manual_seed(seed)
127
  if torch.cuda.is_available():
128
  torch.cuda.manual_seed_all(seed)
129
+ # Deterministic settings: reproducible but may reduce throughput.
130
  torch.backends.cudnn.deterministic = True
131
  torch.backends.cudnn.benchmark = False
132
 
133
 
134
  def make_json_serializable(obj):
135
+ """Convert common numpy/torch/pandas objects into JSON-safe Python types."""
136
  if isinstance(obj, dict):
137
  return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()}
138
  if isinstance(obj, (list, tuple, set)):
 
158
  pass
159
  return obj
160
 
161
+
162
  def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_state: dict):
163
+ """
164
+ Print a concise load report:
165
+ - how many checkpoint keys exist
166
+ - how many model keys exist
167
+ - how many keys will be loaded (intersection with matching shapes)
168
+ - common reasons for skipped keys
169
+ """
170
  n_ckpt = len(full_state)
171
  n_model = len(model_state)
172
  n_loaded = len(filtered_state)
 
182
  print(f" ckpt keys: {n_ckpt}")
183
  print(f" model keys: {n_model}")
184
  print(f" loaded keys: {n_loaded}")
185
+ print(f" skipped (not in model): {len(missing_in_model)}")
186
+ print(f" skipped (shape mismatch): {len(shape_mismatch)}")
187
 
188
  if missing_in_model:
189
+ print(" examples skipped (not in model):", missing_in_model[:10])
190
  if shape_mismatch:
191
+ print(" examples skipped (shape mismatch):")
192
  for k in shape_mismatch[:10]:
193
  print(f" {k}: ckpt={tuple(full_state[k].shape)} model={tuple(model_state[k].shape)}")
194
  print("")
 
196
 
197
  def find_property_columns(columns):
198
  """
199
+ Robust property column matching with guardrails:
200
+ - Prefer word-level (token) matches over substring matches.
201
+ - For 'density', avoid confusing with 'cohesive energy density' columns.
202
+ - Log chosen column and competing candidates when ambiguous.
203
  """
204
  lowered = {c.lower(): c for c in columns}
205
  found = {}
 
208
  req_low = req.lower().strip()
209
  exact = None
210
 
211
+ # Pass 1: token-level exactness (safer than substring match)
212
  for c_low, c_orig in lowered.items():
213
  tokens = set(c_low.replace('_', ' ').split())
214
  if req_low in tokens or c_low == req_low:
 
221
  found[req] = exact
222
  continue
223
 
224
+ # Pass 2: substring match as fallback
225
  candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low]
226
  if req_low == "density":
227
  candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()]
 
231
  else:
232
  chosen = candidates[0] if candidates else None
233
  found[req] = chosen
234
+ print(f"[COLMAP] Requested '{req}' -> chosen column: {chosen}")
235
  if candidates:
236
+ print(f"[COLMAP] Candidates for '{req}': {candidates}")
237
  else:
238
+ print(f"[COLMAP][WARN] No candidates found for '{req}' using substring search.")
239
  return found
240
 
241
+
242
  def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
243
+ """Pick the most stable identifier available for duplicate aggregation."""
244
  for k in AGG_KEYS_PREFERENCE:
245
  if k in df.columns:
246
  return k
 
249
 
250
  def aggregate_polyinfo_duplicates(df: pd.DataFrame, modality_cols: List[str], property_cols: List[str]) -> pd.DataFrame:
251
  """
252
+ Optional noise reduction: group duplicate polymer entries and average properties.
253
+
254
+ - Modalities are taken as "first" (they should be consistent per polymer key).
255
+ - Properties are averaged (mean).
256
  """
257
  key = choose_aggregation_key(df)
258
  if key is None:
259
+ print("[AGG] No aggregation key found; skipping duplicate aggregation.")
260
  return df
261
 
 
262
  df2 = df.copy()
263
  df2[key] = df2[key].astype(str)
264
  df2 = df2[df2[key].str.strip() != ""].copy()
265
  if len(df2) == 0:
266
+ print("[AGG] Aggregation key exists but is empty; skipping duplicate aggregation.")
267
  return df
268
 
269
  agg_dict = {}
 
270
  for mc in modality_cols:
271
  if mc in df2.columns:
272
  agg_dict[mc] = "first"
 
273
  for pc in property_cols:
274
  if pc in df2.columns:
275
  agg_dict[pc] = "mean"
276
 
277
  grouped = df2.groupby(key, as_index=False).agg(agg_dict)
278
+ print(f"[AGG] Grouped by '{key}': {len(df)} rows -> {len(grouped)} unique keys")
279
  return grouped
280
 
281
 
282
  def _sanitize_name(s: str) -> str:
283
+ """Create a filesystem-safe name for property directories."""
284
  s2 = str(s).strip().lower()
285
  keep = []
286
  for ch in s2:
 
296
  out = out.strip("_")
297
  return out or "property"
298
 
299
+
300
+ # =============================================================================
301
+ # Multimodal backbone: encode + project + modality-aware fusion
302
+ # =============================================================================
303
  class MultimodalContrastiveModel(nn.Module):
304
  """
305
+ Multimodal encoder wrapper:
306
+
307
+ 1) Runs each available modality encoder:
308
+ - GINE (graph)
309
+ - SchNet (3D geometry)
310
+ - Transformer FP encoder (Morgan bit sequence)
311
+ - DeBERTa-based PSMILES encoder (sequence)
312
+
313
+ 2) Projects each modality embedding to a shared dim (emb_dim).
314
+
315
+ 3) Normalizes each modality embedding (L2), drops out, then fuses via
316
+ a masked mean across modalities that are present for each sample.
317
+
318
+ 4) Normalizes the final fused embedding (L2).
319
+
320
+ Expected downstream usage:
321
+ z = model(batch_mods, modality_mask=modality_mask) # (B, emb_dim)
322
  """
323
 
324
  def __init__(
 
345
  self.out_dim = self.emb_dim
346
  self.dropout = nn.Dropout(float(dropout))
347
 
348
+ # Determine which modalities are enabled
349
  if modalities is None:
350
  mods = []
351
  if self.gine is not None:
 
360
  else:
361
  self.modalities = [m for m in modalities if m in ("gine", "schnet", "fp", "psmiles")]
362
 
363
+ # Projection heads into shared embedding space
364
  self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None
365
  self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None
366
  self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None
367
 
368
+ # Infer PSMILES hidden size if possible; fallback to DEBERTA_HIDDEN
369
  psm_in = None
370
  if self.psmiles is not None:
371
  if hasattr(self.psmiles, "out_dim"):
 
384
  self.proj_psmiles = nn.Linear(psm_in, self.emb_dim) if (self.psmiles is not None) else None
385
 
386
  def freeze_cl_encoders(self):
387
+ """Freeze all modality encoders (optional for evaluation-only usage)."""
388
  for enc in (self.gine, self.schnet, self.fp, self.psmiles):
389
  if enc is None:
390
  continue
 
394
 
395
  def _masked_mean_combine(self, zs: List[torch.Tensor], masks: List[torch.Tensor]) -> torch.Tensor:
396
  """
397
+ Compute sample-wise mean over available modalities.
398
+
399
+ zs: list of modality embeddings, each (B,D)
400
+ masks: list of modality presence masks, each (B,) bool
401
+ returns: (B,D)
402
  """
403
  if not zs:
404
  device = next(self.parameters()).device
 
421
  def forward(self, batch_mods: dict, modality_mask: Optional[dict] = None) -> torch.Tensor:
422
  """
423
  batch_mods keys: 'gine', 'schnet', 'fp', 'psmiles'
424
+ modality_mask: dict {modality_name: (B,) bool} describing presence.
425
  """
426
  device = next(self.parameters()).device
427
 
428
  zs = []
429
  ms = []
430
 
431
+ # Infer batch size B
432
  B = None
433
  if modality_mask is not None:
434
+ for _, v in modality_mask.items():
435
  if isinstance(v, torch.Tensor) and v.numel() > 0:
436
  B = int(v.size(0))
437
  break
438
 
 
439
  if B is None:
440
  if "fp" in batch_mods and batch_mods["fp"] is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor):
441
  B = int(batch_mods["fp"]["input_ids"].size(0))
 
445
  if B is None:
446
  return torch.zeros((1, self.emb_dim), device=device)
447
 
 
448
  def _get_mask(name: str) -> torch.Tensor:
449
  if modality_mask is not None and name in modality_mask and isinstance(modality_mask[name], torch.Tensor):
450
  return modality_mask[name].to(device).bool()
451
  return torch.ones((B,), device=device, dtype=torch.bool)
452
 
453
+ # -------------------------
454
+ # GINE (graph modality)
455
+ # -------------------------
456
  if "gine" in self.modalities and self.gine is not None and batch_mods.get("gine", None) is not None:
457
  g = batch_mods["gine"]
458
  if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0:
 
470
  zs.append(z)
471
  ms.append(_get_mask("gine"))
472
 
473
+ # -------------------------
474
+ # SchNet (3D geometry)
475
+ # -------------------------
476
  if "schnet" in self.modalities and self.schnet is not None and batch_mods.get("schnet", None) is not None:
477
  s = batch_mods["schnet"]
478
  if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0:
 
487
  zs.append(z)
488
  ms.append(_get_mask("schnet"))
489
 
490
+ # -------------------------
491
+ # Fingerprint modality
492
+ # -------------------------
493
  if "fp" in self.modalities and self.fp is not None and batch_mods.get("fp", None) is not None:
494
  f = batch_mods["fp"]
495
  if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0:
 
503
  zs.append(z)
504
  ms.append(_get_mask("fp"))
505
 
506
+ # -------------------------
507
+ # PSMILES text modality
508
+ # -------------------------
509
  if "psmiles" in self.modalities and self.psmiles is not None and batch_mods.get("psmiles", None) is not None:
510
  p = batch_mods["psmiles"]
511
  if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0:
 
519
  zs.append(z)
520
  ms.append(_get_mask("psmiles"))
521
 
522
+ # Fuse and normalize
523
  if not zs:
524
  return torch.zeros((B, self.emb_dim), device=device)
525
 
 
527
  z = F.normalize(z, dim=-1)
528
  return z
529
 
 
530
  @torch.no_grad()
531
  def encode_psmiles(
532
  self,
 
536
  device: str = DEVICE
537
  ) -> np.ndarray:
538
  """
539
+ Convenience: PSMILES-only embeddings (used for fast bulk encoding tasks).
540
  """
541
  self.eval()
542
  if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None:
 
561
  device: str = DEVICE
562
  ) -> np.ndarray:
563
  """
564
+ Convenience: multimodal embedding for records carrying:
565
+ - graph, geometry, fingerprints, psmiles
566
+ Missing modalities are handled sample-wise via modality masking.
567
  """
568
  self.eval()
569
  dev = torch.device(device)
 
573
  for i in range(0, len(records), batch_size):
574
  chunk = records[i:i + batch_size]
575
 
 
576
  # PSMILES batch
577
  psmiles_texts = [str(r.get("psmiles", "")) for r in chunk]
578
  p_enc = None
579
  if self.psm_tok is not None:
580
  p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt")
581
 
582
+ # FP batch (always stack; missing handled by attention_mask downstream)
583
  fp_ids, fp_attn = [], []
584
  for r in chunk:
585
  f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH)
 
588
  fp_ids = torch.stack(fp_ids, dim=0)
589
  fp_attn = torch.stack(fp_attn, dim=0)
590
 
591
+ # GINE + SchNet packed batching
 
592
  gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []}
593
  node_offset = 0
594
  for bi, r in enumerate(chunk):
 
646
  "psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None
647
  }
648
 
649
+ # NOTE: This script uses forward() as the encoder entry point.
650
+ z = self.forward(batch_mods, modality_mask=None)
651
  outs.append(z.detach().cpu().numpy())
652
 
653
  return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
654
 
655
+
656
+ # =============================================================================
657
  # Tokenizer setup
658
+ # =============================================================================
659
  SPM_MODEL = "/path/to/spm.model"
660
  tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
661
 
662
+ # =============================================================================
663
+ # Dataset: single-task property prediction (with modality parsing)
664
+ # =============================================================================
665
  class PolymerPropertyDataset(Dataset):
666
+ """
667
+ Dataset that prepares one sample with up to four modalities:
668
+ - graph (for GINE)
669
+ - geometry (for SchNet)
670
+ - fingerprints (for FP transformer)
671
+ - psmiles text (for DeBERTa encoder)
672
+
673
+ Target is a single scalar per sample (already scaled externally).
674
+ """
675
  def __init__(self, data_list, tokenizer, max_length=128):
676
  self.data_list = data_list
677
  self.tokenizer = tokenizer
 
683
  def __getitem__(self, idx):
684
  data = self.data_list[idx]
685
 
686
+ # ---------------------------------------------------------------------
687
+ # Graph -> GINE tensors (robust parsing of stored JSON fields)
688
+ # ---------------------------------------------------------------------
689
  gine_data = None
690
  if 'graph' in data and data['graph']:
691
  try:
 
714
  edge_index = None
715
  edge_attr = None
716
 
717
+ # Fallback: adjacency matrix if edge_indices missing
718
  if edge_indices_raw is None:
719
  adj_mat = safe_get(graph_field, "adjacency_matrix", None)
720
  if adj_mat:
 
730
  E = len(srcs)
731
  edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
732
  else:
733
+ # edge_indices can be [[srcs],[dsts]] or list of pairs
734
  srcs, dsts = [], []
735
  if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0:
736
  if isinstance(edge_indices_raw[0], list):
 
757
  if len(srcs) > 0:
758
  edge_index = [srcs, dsts]
759
 
760
+ # edge_features: attempt to map known fields; otherwise zeros
761
  if edge_features_raw and isinstance(edge_features_raw, list):
762
  bond_types = []
763
  stereos = []
 
785
  except Exception:
786
  gine_data = None
787
 
788
+ # ---------------------------------------------------------------------
789
+ # Geometry -> SchNet tensors (best conformer)
790
+ # ---------------------------------------------------------------------
791
  schnet_data = None
792
  if 'geometry' in data and data['geometry']:
793
  try:
 
804
  except Exception:
805
  schnet_data = None
806
 
807
+ # ---------------------------------------------------------------------
808
+ # Fingerprints -> FP transformer inputs (bit sequence)
809
+ # ---------------------------------------------------------------------
810
  fp_data = None
811
  if 'fingerprints' in data and data['fingerprints']:
812
  try:
 
857
  except Exception:
858
  fp_data = None
859
 
860
+ # ---------------------------------------------------------------------
861
+ # PSMILES -> DeBERTa tokenizer inputs
862
+ # ---------------------------------------------------------------------
863
  psmiles_data = None
864
  if 'psmiles' in data and data['psmiles'] and self.tokenizer is not None:
865
  try:
 
877
  except Exception:
878
  psmiles_data = None
879
 
880
+ # ---------------------------------------------------------------------
881
+ # Fill defaults for missing modalities (keeps collate simpler)
882
+ # ---------------------------------------------------------------------
883
  if gine_data is None:
884
  gine_data = {
885
  'z': torch.tensor([], dtype=torch.long),
 
907
  'attention_mask': torch.zeros(PSMILES_MAX_LEN, dtype=torch.bool)
908
  }
909
 
910
+ # Single-task regression target (already scaled)
911
  target_scaled = float(data.get("target_scaled", 0.0))
912
 
913
  return {
 
919
  }
920
 
921
 
922
+ # =============================================================================
923
+ # Collate: pack variable-sized graph/3D into batch tensors + modality masks
924
+ # =============================================================================
925
  def multimodal_collate_fn(batch):
926
+ """
927
+ Collate samples into a single minibatch.
928
+
929
+ - GINE: concatenate nodes across samples and build a `batch` vector.
930
+ - SchNet: concatenate atoms/coords across samples and build a `batch` vector.
931
+ - FP/PSMILES: stack to (B, L).
932
+ - modality_mask: per-sample boolean flags indicating availability.
933
+ """
934
  B = len(batch)
935
 
936
+ # -------------------------
937
+ # GINE packing
938
+ # -------------------------
939
  all_z = []
940
  all_ch = []
941
  all_fc = []
 
982
  edge_index_batched = torch.empty((2, 0), dtype=torch.long)
983
  edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
984
 
985
+ # -------------------------
986
+ # SchNet packing
987
+ # -------------------------
988
  all_sz = []
989
  all_pos = []
990
  schnet_batch = []
 
1009
  s_pos_batch = torch.cat(all_pos, dim=0)
1010
  s_batch_batch = torch.cat(schnet_batch, dim=0)
1011
 
1012
+ # -------------------------
1013
+ # FP stacking
1014
+ # -------------------------
1015
  fp_ids = torch.stack([item["fp"]["input_ids"] for item in batch], dim=0)
1016
  fp_attn = torch.stack([item["fp"]["attention_mask"] for item in batch], dim=0)
1017
  fp_present = (fp_attn.sum(dim=1) > 0).cpu().numpy().tolist()
1018
 
1019
+ # -------------------------
1020
+ # PSMILES stacking
1021
+ # -------------------------
1022
  p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch], dim=0)
1023
  p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch], dim=0)
1024
  psmiles_present = (p_attn.sum(dim=1) > 0).cpu().numpy().tolist()
 
1026
  # Target
1027
  target = torch.stack([item["target"] for item in batch], dim=0) # (B,)
1028
 
1029
+ # Presence mask for fusion (per-sample modality availability)
1030
  modality_mask = {
1031
  "gine": torch.tensor(gine_present, dtype=torch.bool),
1032
  "schnet": torch.tensor(schnet_present, dtype=torch.bool),
 
1061
  }
1062
 
1063
 
1064
+ # =============================================================================
1065
+ # Single-task regressor head (Uni-Poly-style fine-tuning)
1066
+ # =============================================================================
1067
+ class PolyFPropertyRegressor(nn.Module):
1068
+ """
1069
+ Simple MLP head on top of the multimodal fused embedding.
1070
+ Predicts one scalar (scaled target) per sample.
1071
+ """
1072
+ def __init__(self, polyf_model: MultimodalContrastiveModel, emb_dim: int = POLYF_EMB_DIM, dropout: float = 0.1):
1073
  super().__init__()
1074
+ self.polyf = polyf_model
1075
  self.head = nn.Sequential(
1076
  nn.Linear(emb_dim, emb_dim // 2),
1077
  nn.ReLU(),
 
1080
  )
1081
 
1082
  def forward(self, batch_mods, modality_mask=None):
1083
+ emb = self.polyf(batch_mods, modality_mask=modality_mask) # (B,d)
1084
  y = self.head(emb).squeeze(-1) # (B,)
1085
  return y
1086
 
1087
 
1088
+ # =============================================================================
1089
+ # Training / evaluation helpers
1090
+ # =============================================================================
1091
  def compute_metrics(y_true, y_pred):
1092
+ """Compute standard regression metrics in original units."""
1093
  mse = mean_squared_error(y_true, y_pred)
1094
  rmse = math.sqrt(mse)
1095
  mae = mean_absolute_error(y_true, y_pred)
 
1098
 
1099
 
1100
  def train_one_epoch(model, dataloader, optimizer, device):
1101
+ """One epoch of supervised regression training (MSE loss)."""
1102
  model.train()
1103
  total_loss = 0.0
1104
  total_n = 0
1105
 
1106
  for batch in dataloader:
1107
+ # Move nested batch dict to device
1108
  for k in batch:
1109
  if k == "target":
1110
  batch[k] = batch[k].to(device)
 
1137
 
1138
  @torch.no_grad()
1139
  def evaluate(model, dataloader, device):
1140
+ """
1141
+ Evaluate on a dataloader:
1142
+ - returns avg loss, predicted scaled values, true scaled values
1143
+ """
1144
  model.eval()
1145
  preds = []
1146
  trues = []
 
1148
  total_n = 0
1149
 
1150
  for batch in dataloader:
1151
+ # Move nested batch dict to device
1152
  for k in batch:
1153
  if k == "target":
1154
  batch[k] = batch[k].to(device)
 
1184
  return float(avg_loss), preds, trues
1185
 
1186
 
1187
+ # =============================================================================
1188
+ # Pretrained loading helpers
1189
+ # =============================================================================
1190
  def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveModel:
1191
+ """
1192
+ Construct modality encoders and load any available pretrained weights:
1193
+ - modality-specific checkpoints (BEST_*_DIR)
1194
+ - full multimodal checkpoint from `pretrained_path/pytorch_model.bin`
1195
+
1196
+ Returns a ready-to-fine-tune MultimodalContrastiveModel.
1197
+ """
1198
+ # -------------------------
1199
+ # GINE encoder
1200
+ # -------------------------
1201
  gine_encoder = GineEncoder(
1202
  node_emb_dim=NODE_EMB_DIM,
1203
  edge_emb_dim=EDGE_EMB_DIM,
1204
  num_layers=NUM_GNN_LAYERS,
1205
  max_atomic_z=MAX_ATOMIC_Z
1206
  )
1207
+ gine_ckpt = os.path.join(BEST_GINE_DIR, "pytorch_model.bin")
1208
+ if os.path.exists(gine_ckpt):
1209
  try:
1210
+ gine_encoder.load_state_dict(torch.load(gine_ckpt, map_location="cpu"), strict=False)
1211
+ print(f"[LOAD] GINE weights: {gine_ckpt}")
 
 
 
1212
  except Exception as e:
1213
+ print(f"[LOAD][WARN] Could not load GINE weights: {e}")
1214
 
1215
+ # -------------------------
1216
+ # SchNet encoder
1217
+ # -------------------------
1218
  schnet_encoder = NodeSchNetWrapper(
1219
  hidden_channels=SCHNET_HIDDEN,
1220
  num_interactions=SCHNET_NUM_INTERACTIONS,
 
1222
  cutoff=SCHNET_CUTOFF,
1223
  max_num_neighbors=SCHNET_MAX_NEIGHBORS
1224
  )
1225
+ sch_ckpt = os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin")
1226
+ if os.path.exists(sch_ckpt):
1227
  try:
1228
+ schnet_encoder.load_state_dict(torch.load(sch_ckpt, map_location="cpu"), strict=False)
1229
+ print(f"[LOAD] SchNet weights: {sch_ckpt}")
 
 
 
1230
  except Exception as e:
1231
+ print(f"[LOAD][WARN] Could not load SchNet weights: {e}")
1232
 
1233
+ # -------------------------
1234
+ # Fingerprint encoder
1235
+ # -------------------------
1236
  fp_encoder = FingerprintEncoder(
1237
  vocab_size=VOCAB_SIZE_FP,
1238
  hidden_dim=256,
 
1242
  dim_feedforward=1024,
1243
  dropout=0.1
1244
  )
1245
+ fp_ckpt = os.path.join(BEST_FP_DIR, "pytorch_model.bin")
1246
+ if os.path.exists(fp_ckpt):
1247
  try:
1248
+ fp_encoder.load_state_dict(torch.load(fp_ckpt, map_location="cpu"), strict=False)
1249
+ print(f"[LOAD] FP encoder weights: {fp_ckpt}")
 
 
 
1250
  except Exception as e:
1251
+ print(f"[LOAD][WARN] Could not load fingerprint weights: {e}")
1252
 
1253
+ # -------------------------
1254
+ # PSMILES encoder
1255
+ # -------------------------
1256
  psmiles_encoder = None
1257
  if os.path.isdir(BEST_PSMILES_DIR):
1258
  try:
1259
  psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
1260
+ print(f"[LOAD] PSMILES encoder: {BEST_PSMILES_DIR}")
1261
  except Exception as e:
1262
+ print(f"[LOAD][WARN] Could not load PSMILES encoder from dir: {e}")
1263
 
1264
+ # Fallback: initialize with vocab fallback (still functional, but not your finetuned weights)
1265
  if psmiles_encoder is None:
1266
  try:
1267
  psmiles_encoder = PSMILESDebertaEncoder(
1268
  model_dir_or_name=None,
1269
  vocab_fallback=int(getattr(tokenizer, "vocab_size", 300))
1270
  )
1271
+ print("[LOAD] PSMILES encoder: initialized fallback (no pretrained dir).")
1272
  except Exception as e:
1273
+ print(f"[LOAD][WARN] Could not initialize PSMILES encoder: {e}")
1274
 
1275
+ # Build multimodal wrapper
1276
  multimodal_model = MultimodalContrastiveModel(
1277
  gine_encoder,
1278
  schnet_encoder,
1279
  fp_encoder,
1280
  psmiles_encoder,
1281
+ emb_dim=POLYF_EMB_DIM,
1282
  modalities=["gine", "schnet", "fp", "psmiles"]
1283
  )
1284
 
1285
+ # -------------------------
1286
+ # Optional: load full multimodal checkpoint (projects/fusion, etc.)
1287
+ # -------------------------
1288
  ckpt_path = os.path.join(pretrained_path, "pytorch_model.bin")
1289
  if os.path.isfile(ckpt_path):
1290
  try:
 
1300
 
1301
  summarize_state_dict_load(state, model_state, filtered_state)
1302
  missing, unexpected = multimodal_model.load_state_dict(filtered_state, strict=False)
1303
+ print(f"[LOAD] Multimodal checkpoint: {ckpt_path}")
1304
+ print(f"[LOAD] load_state_dict -> missing={len(missing)} unexpected={len(unexpected)}")
1305
  if missing:
1306
+ print("[LOAD] Missing keys (sample):", missing[:50])
1307
  if unexpected:
1308
+ print("[LOAD] Unexpected keys (sample):", unexpected[:50])
1309
 
1310
  except Exception as e:
1311
+ print(f"[LOAD][WARN] Failed to load multimodal pretrained weights: {e}")
1312
+ else:
1313
+ print(f"[LOAD] No multimodal checkpoint found at: {ckpt_path}")
1314
 
1315
  return multimodal_model
1316
 
1317
 
1318
+ # =============================================================================
1319
+ # Downstream: sample construction + CV training loop
1320
+ # =============================================================================
1321
  def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
1322
+ """
1323
+ Construct training samples for a single property:
1324
+ - Keep rows that have at least one modality present.
1325
+ - Keep rows with a finite property value in `prop_col`.
1326
+ - Store raw target (will be scaled per fold).
1327
+ """
1328
  samples = []
1329
  for _, row in df.iterrows():
1330
+ # Require at least one modality present
1331
  has_modality = False
1332
  for col in ['graph', 'geometry', 'fingerprints', 'psmiles']:
1333
  if col in row and row[col] and str(row[col]).strip() != "":
 
1355
  return samples
1356
 
1357
 
1358
+ def run_polyf_downstream(property_list: List[str], property_cols: List[str], df_raw: pd.DataFrame,
1359
  pretrained_path: str, output_file: str):
1360
+ """
1361
+ Uni-Poly-matched downstream evaluation:
1362
+
1363
+ For each property:
1364
+ - Build samples from PolyInfo
1365
+ - 5-fold CV:
1366
+ - Split into trainval/test (by KFold)
1367
+ - Split trainval into train/val
1368
+ - Fit StandardScaler on train targets
1369
+ - Fine-tune encoder+head end-to-end with early stopping by val loss
1370
+ - Evaluate on held-out test fold in original units
1371
+ - Save per-fold results and per-property mean±std
1372
+ - Save best fold checkpoint bundle (by test R2) for later reuse
1373
+ """
1374
  os.makedirs(pretrained_path, exist_ok=True)
1375
 
1376
+ # Optional duplicate aggregation (noise reduction)
1377
  modality_cols = ["graph", "geometry", "fingerprints", "psmiles"]
1378
  df_proc = aggregate_polyinfo_duplicates(df_raw, modality_cols=modality_cols, property_cols=property_cols)
1379
 
1380
+ all_results = {"per_property": {}, "mode": "POLYF_MATCHED_SINGLE_TASK"}
1381
 
 
1382
  for pname, pcol in zip(property_list, property_cols):
1383
+ print(f"\n=== [DOWNSTREAM] Uni-Poly matched fine-tuning: {pname} (col='{pcol}') ===")
1384
  samples = build_samples_for_property(df_proc, pcol)
1385
+
1386
+ print(f"[DATA] {pname}: n_samples={len(samples)}")
1387
  if len(samples) < 200:
1388
+ print(f"[DATA][WARN] '{pname}' has <200 samples; results may be noisy.")
1389
  if len(samples) < 50:
1390
+ print(f"[DATA][WARN] Skipping '{pname}' (insufficient samples).")
1391
  continue
1392
 
1393
  run_metrics = []
1394
  run_records = []
1395
 
1396
+ # Track best-performing fold for this property (by test R2)
1397
  best_overall_r2 = -1e18
1398
+ best_overall_payload = None
1399
 
1400
  idxs = np.arange(len(samples))
1401
  cv = KFold(n_splits=NUM_RUNS, shuffle=True, random_state=42)
 
1404
  seed = 42 + run_idx
1405
  set_seed(seed)
1406
 
1407
+ print(f"\n--- [CV] {pname}: fold {run_idx+1}/{NUM_RUNS} | seed={seed} ---")
1408
+
1409
  trainval = [copy.deepcopy(samples[i]) for i in trainval_idx]
1410
  test = [copy.deepcopy(samples[i]) for i in test_idx]
1411
 
1412
+ # Split trainval into train/val
1413
  tr_idx, va_idx = train_test_split(
1414
  np.arange(len(trainval)),
1415
  test_size=VAL_SIZE_WITHIN_TRAINVAL,
 
1419
  train = [copy.deepcopy(trainval[i]) for i in tr_idx]
1420
  val = [copy.deepcopy(trainval[i]) for i in va_idx]
1421
 
1422
+ # Standardize target using training fold only (prevents leakage)
1423
  sc = StandardScaler()
1424
  sc.fit(np.array([s["target_raw"] for s in train]).reshape(-1, 1))
1425
 
 
1439
  dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
1440
  dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
1441
 
1442
+ print(f"[SPLIT] train={len(ds_train)} val={len(ds_val)} test={len(ds_test)}")
1443
+
1444
+ # Fresh base model per fold to avoid any cross-fold leakage
1445
+ polyf_base = load_pretrained_multimodal(pretrained_path)
1446
+ model = PolyFPropertyRegressor(polyf_base, emb_dim=POLYF_EMB_DIM, dropout=POLYF_DROPOUT).to(DEVICE)
1447
 
1448
  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
1449
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
 
1452
  best_state = None
1453
  no_improve = 0
1454
 
1455
+ # Train with early stopping on validation loss
1456
  for epoch in range(1, NUM_EPOCHS + 1):
1457
  tr_loss = train_one_epoch(model, dl_train, optimizer, DEVICE)
1458
  va_loss, _, _ = evaluate(model, dl_val, DEVICE)
 
1460
 
1461
  scheduler.step()
1462
 
1463
+ print(f"[{pname}] fold {run_idx+1}/{NUM_RUNS} epoch {epoch:03d} | train={tr_loss:.6f} | val={va_loss:.6f}")
1464
 
1465
  if va_loss < best_val - 1e-8:
1466
  best_val = va_loss
 
1469
  else:
1470
  no_improve += 1
1471
  if no_improve >= PATIENCE:
1472
+ print(f"[{pname}] fold {run_idx+1}: early stopping (patience={PATIENCE}) at epoch {epoch}.")
1473
  break
1474
 
1475
  if best_state is None:
1476
+ print(f"[{pname}][WARN] No best checkpoint captured for fold {run_idx+1}; skipping fold.")
1477
  continue
1478
 
1479
+ # Restore best state and evaluate on test fold
1480
  model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=True)
1481
  _, pred_scaled, true_scaled = evaluate(model, dl_test, DEVICE)
1482
  if pred_scaled is None:
1483
+ print(f"[{pname}][WARN] Test evaluation returned empty predictions for fold {run_idx+1}.")
1484
  continue
1485
 
1486
+ # Convert from scaled space back to original units
1487
  pred = sc.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()
1488
  true = sc.inverse_transform(true_scaled.reshape(-1, 1)).ravel()
1489
 
1490
  m = compute_metrics(true, pred)
1491
  run_metrics.append(m)
1492
 
1493
+ print(f"[{pname}] fold {run_idx+1} TEST | r2={m['r2']:.4f} mae={m['mae']:.4f} rmse={m['rmse']:.4f}")
1494
+
1495
  record = {
1496
  "property": pname,
1497
  "property_col": pcol,
 
1508
  with open(output_file, "a") as fh:
1509
  fh.write(json.dumps(make_json_serializable(record)) + "\n")
1510
 
1511
+ # Update best fold bundle (by test R2)
1512
  if float(m.get("r2", -1e18)) > float(best_overall_r2):
1513
  best_overall_r2 = float(m.get("r2", -1e18))
1514
  best_overall_payload = {
 
1525
  "scaler_scale": make_json_serializable(getattr(sc, "scale_", None)),
1526
  "scaler_var": make_json_serializable(getattr(sc, "var_", None)),
1527
  "scaler_n_samples_seen": make_json_serializable(getattr(sc, "n_samples_seen_", None)),
1528
+ "model_state_dict": best_state, # CPU tensors
1529
  }
1530
 
1531
+ # Save best fold weights + metadata per property
1532
  if best_overall_payload is not None and "model_state_dict" in best_overall_payload:
1533
  os.makedirs(BEST_WEIGHTS_DIR, exist_ok=True)
1534
  prop_dir = os.path.join(BEST_WEIGHTS_DIR, _sanitize_name(pname))
1535
  os.makedirs(prop_dir, exist_ok=True)
1536
 
1537
+ ckpt_bundle = {k: v for k, v in best_overall_payload.items() if k != "test_metrics"}
 
 
 
 
1538
  ckpt_bundle["test_metrics"] = best_overall_payload["test_metrics"]
1539
 
1540
  torch.save(ckpt_bundle, os.path.join(prop_dir, "best_run_checkpoint.pt"))
 
1543
  with open(os.path.join(prop_dir, "best_run_metadata.json"), "w") as fh:
1544
  fh.write(json.dumps(make_json_serializable(meta), indent=2))
1545
 
1546
+ print(f"[BEST] Saved best fold for '{pname}' -> {prop_dir}")
1547
+ print(f"[BEST] best_run={best_overall_payload['best_run']} best_test_r2={best_overall_payload['test_metrics'].get('r2', None)}")
1548
 
1549
+ # Aggregate metrics across folds
1550
  if run_metrics:
1551
  r2s = [x["r2"] for x in run_metrics]
1552
  maes = [x["mae"] for x in run_metrics]
 
1558
  "rmse": {"mean": float(np.mean(rmses)), "std": float(np.std(rmses, ddof=0))},
1559
  "mse": {"mean": float(np.mean(mses)), "std": float(np.std(mses, ddof=0))},
1560
  }
1561
+ print(f"[AGG] {pname} | r2={agg['r2']['mean']:.4f}±{agg['r2']['std']:.4f} mae={agg['mae']['mean']:.4f}±{agg['mae']['std']:.4f}")
1562
  else:
1563
  agg = None
1564
+ print(f"[AGG][WARN] No successful folds for '{pname}' (no aggregate computed).")
1565
 
1566
  all_results["per_property"][pname] = {
1567
  "property_col": pcol,
 
1576
  return all_results
1577
 
1578
 
1579
+ # =============================================================================
1580
  # Main
1581
+ # =============================================================================
1582
  def main():
1583
+ # Start a fresh results file (back up old results if present)
1584
  if os.path.exists(OUTPUT_RESULTS):
1585
  backup = OUTPUT_RESULTS + ".bak"
1586
  shutil.copy(OUTPUT_RESULTS, backup)
1587
+ print(f"[INIT] Existing results backed up: {backup}")
1588
  open(OUTPUT_RESULTS, "w").close()
1589
+ print(f"[INIT] Writing results to: {OUTPUT_RESULTS}")
1590
 
1591
+ # Load PolyInfo
1592
  if not os.path.isfile(POLYINFO_PATH):
1593
  raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}")
1594
  polyinfo_raw = pd.read_csv(POLYINFO_PATH, engine="python")
1595
+ print(f"[DATA] Loaded PolyInfo: n_rows={len(polyinfo_raw)} n_cols={len(polyinfo_raw.columns)}")
1596
 
1597
+ # Map requested properties to dataframe columns
1598
  found = find_property_columns(polyinfo_raw.columns)
1599
  prop_map = {req: col for req, col in found.items()}
1600
+ print(f"[COLMAP] Property-to-column map: {prop_map}")
1601
 
1602
  property_list = []
1603
  property_cols = []
1604
  for req in REQUESTED_PROPERTIES:
1605
  col = prop_map.get(req)
1606
  if col is None:
1607
+ print(f"[COLMAP][WARN] Could not find a column for '{req}'; skipping.")
1608
  continue
1609
  property_list.append(req)
1610
  property_cols.append(col)
1611
 
1612
+ overall = run_polyf_downstream(property_list, property_cols, polyinfo_raw, PRETRAINED_MULTIMODAL_DIR, OUTPUT_RESULTS)
 
1613
 
1614
+ # Write final summary (aggregated per property)
1615
  final_agg = {}
1616
  if overall and "per_property" in overall:
1617
  for pname, info in overall["per_property"].items():
 
1622
  fh.write(json.dumps(make_json_serializable(final_agg), indent=2))
1623
  fh.write("\n")
1624
 
1625
+ print(f"\n Results appended to: {OUTPUT_RESULTS}")
1626
+ print(f" Best checkpoints saved under: {BEST_WEIGHTS_DIR}")
1627
 
1628
 
1629
  if __name__ == "__main__":