Spaces:
Running
Running
| import os | |
| import random | |
| import time | |
| from pathlib import Path | |
| import math | |
| import json | |
| import shutil | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error | |
| from sklearn.model_selection import train_test_split, KFold | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| import sys | |
| import csv | |
| import copy | |
| from typing import List, Dict, Optional, Tuple, Any | |
| # Increase CSV field size limit | |
| csv.field_size_limit(sys.maxsize) | |
| # ============================================================================= | |
| # Imports: Shared encoders/helpers from PolyFusion | |
| # ============================================================================= | |
| from PolyFusion.GINE import GineEncoder, match_edge_attr_to_index, safe_get | |
| from PolyFusion.SchNet import NodeSchNetWrapper | |
| from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder | |
| from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| BASE_DIR = "/path/to/Polymer_Foundational_Model" | |
| POLYINFO_PATH = "/path/to/polyinfo_with_modalities.csv" | |
| # Pretrained encoder directories | |
| PRETRAINED_MULTIMODAL_DIR = "/path/to/multimodal_output/best" | |
| BEST_GINE_DIR = "/path/to/gin_output/best" | |
| BEST_SCHNET_DIR = "/path/to/schnet_output/best" | |
| BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best" | |
| BEST_PSMILES_DIR = "/path/to/polybert_output/best" | |
| # Output log file (per-run json lines + per-property aggregated summary) | |
| OUTPUT_RESULTS = "/path/to/multimodal_downstream_results.txt" | |
| # Directory to save best-performing checkpoint bundle per property (best CV run) | |
| BEST_WEIGHTS_DIR = "/path/to/multimodal_downstream_bestweights" | |
| # ----------------------------------------------------------------------------- | |
| # Model sizes / dims | |
| # ----------------------------------------------------------------------------- | |
| MAX_ATOMIC_Z = 85 | |
| MASK_ATOM_ID = MAX_ATOMIC_Z + 1 | |
| # GINE | |
| NODE_EMB_DIM = 300 | |
| EDGE_EMB_DIM = 300 | |
| NUM_GNN_LAYERS = 5 | |
| # SchNet | |
| SCHNET_NUM_GAUSSIANS = 50 | |
| SCHNET_NUM_INTERACTIONS = 6 | |
| SCHNET_CUTOFF = 10.0 | |
| SCHNET_MAX_NEIGHBORS = 64 | |
| SCHNET_HIDDEN = 600 | |
| # Fingerprints | |
| FP_LENGTH = 2048 | |
| MASK_TOKEN_ID_FP = 2 | |
| VOCAB_SIZE_FP = 3 | |
| # Contrastive embedding dim | |
| CL_EMB_DIM = 600 | |
| # PSMILES/DeBERTa | |
| DEBERTA_HIDDEN = 600 | |
| PSMILES_MAX_LEN = 128 | |
| # ----------------------------------------------------------------------------- | |
| # Fusion + regression head hyperparameters | |
| # ----------------------------------------------------------------------------- | |
| POLYF_EMB_DIM = 600 | |
| POLYF_ATTN_HEADS = 8 | |
| POLYF_DROPOUT = 0.1 | |
| POLYF_FF_MULT = 4 # FFN hidden = 4*d | |
| # ----------------------------------------------------------------------------- | |
| # Fine-tuning parameters (single-task per property) | |
| # ----------------------------------------------------------------------------- | |
| MAX_LEN = 128 | |
| BATCH_SIZE = 32 | |
| NUM_EPOCHS = 100 | |
| PATIENCE = 10 | |
| LEARNING_RATE = 1e-4 | |
| WEIGHT_DECAY = 0.0 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Properties to evaluate | |
| REQUESTED_PROPERTIES = [ | |
| "density", | |
| "glass transition", | |
| "melting", | |
| "thermal decomposition" | |
| ] | |
| # True K-fold evaluation to match "fivefold per property" | |
| NUM_RUNS = 5 | |
| TEST_SIZE = 0.10 | |
| VAL_SIZE_WITHIN_TRAINVAL = 0.10 # fraction of trainval reserved for val split | |
| # Duplicate aggregation (noise reduction) key preference order | |
| AGG_KEYS_PREFERENCE = ["polymer_id", "PolymerID", "poly_id", "psmiles", "smiles", "canonical_smiles"] | |
| # ============================================================================= | |
| # Utilities | |
| # ============================================================================= | |
| def set_seed(seed: int): | |
| """Set all relevant RNG seeds for reproducible folds.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| # Deterministic settings: reproducible but may reduce throughput. | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def make_json_serializable(obj): | |
| """Convert common numpy/torch/pandas objects into JSON-safe Python types.""" | |
| if isinstance(obj, dict): | |
| return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple, set)): | |
| return [make_json_serializable(x) for x in obj] | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, np.generic): | |
| try: | |
| return obj.item() | |
| except Exception: | |
| return float(obj) | |
| if isinstance(obj, torch.Tensor): | |
| try: | |
| return obj.detach().cpu().tolist() | |
| except Exception: | |
| return None | |
| if isinstance(obj, (pd.Timestamp, pd.Timedelta)): | |
| return str(obj) | |
| try: | |
| if isinstance(obj, (float, int, str, bool, type(None))): | |
| return obj | |
| except Exception: | |
| pass | |
| return obj | |
| def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_state: dict): | |
| """ | |
| Print a concise load report: | |
| - how many checkpoint keys exist | |
| - how many model keys exist | |
| - how many keys will be loaded (intersection with matching shapes) | |
| - common reasons for skipped keys | |
| """ | |
| n_ckpt = len(full_state) | |
| n_model = len(model_state) | |
| n_loaded = len(filtered_state) | |
| missing_in_model = [k for k in full_state.keys() if k not in model_state] | |
| shape_mismatch = [ | |
| k for k in full_state.keys() | |
| if k in model_state and hasattr(full_state[k], "shape") | |
| and tuple(full_state[k].shape) != tuple(model_state[k].shape) | |
| ] | |
| print("\n[CKPT LOAD SUMMARY]") | |
| print(f" ckpt keys: {n_ckpt}") | |
| print(f" model keys: {n_model}") | |
| print(f" loaded keys: {n_loaded}") | |
| print(f" skipped (not in model): {len(missing_in_model)}") | |
| print(f" skipped (shape mismatch): {len(shape_mismatch)}") | |
| if missing_in_model: | |
| print(" examples skipped (not in model):", missing_in_model[:10]) | |
| if shape_mismatch: | |
| print(" examples skipped (shape mismatch):") | |
| for k in shape_mismatch[:10]: | |
| print(f" {k}: ckpt={tuple(full_state[k].shape)} model={tuple(model_state[k].shape)}") | |
| print("") | |
| def find_property_columns(columns): | |
| """ | |
| Robust property column matching with guardrails: | |
| - Prefer word-level (token) matches over substring matches. | |
| - For 'density', avoid confusing with 'cohesive energy density' columns. | |
| - Log chosen column and competing candidates when ambiguous. | |
| """ | |
| lowered = {c.lower(): c for c in columns} | |
| found = {} | |
| for req in REQUESTED_PROPERTIES: | |
| req_low = req.lower().strip() | |
| exact = None | |
| # Pass 1: token-level exactness (safer than substring match) | |
| for c_low, c_orig in lowered.items(): | |
| tokens = set(c_low.replace('_', ' ').split()) | |
| if req_low in tokens or c_low == req_low: | |
| if req_low == "density" and ("cohesive" in c_low or "cohesive energy" in c_low): | |
| continue | |
| exact = c_orig | |
| break | |
| if exact is not None: | |
| found[req] = exact | |
| continue | |
| # Pass 2: substring match as fallback | |
| candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low] | |
| if req_low == "density": | |
| candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()] | |
| if len(candidates) == 1: | |
| found[req] = candidates[0] | |
| else: | |
| chosen = candidates[0] if candidates else None | |
| found[req] = chosen | |
| print(f"[COLMAP] Requested '{req}' -> chosen column: {chosen}") | |
| if candidates: | |
| print(f"[COLMAP] Candidates for '{req}': {candidates}") | |
| else: | |
| print(f"[COLMAP][WARN] No candidates found for '{req}' using substring search.") | |
| return found | |
| def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]: | |
| """Pick the most stable identifier available for duplicate aggregation.""" | |
| for k in AGG_KEYS_PREFERENCE: | |
| if k in df.columns: | |
| return k | |
| return None | |
| def aggregate_polyinfo_duplicates(df: pd.DataFrame, modality_cols: List[str], property_cols: List[str]) -> pd.DataFrame: | |
| """ | |
| Optional noise reduction: group duplicate polymer entries and average properties. | |
| - Modalities are taken as "first" (they should be consistent per polymer key). | |
| - Properties are averaged (mean). | |
| """ | |
| key = choose_aggregation_key(df) | |
| if key is None: | |
| print("[AGG] No aggregation key found; skipping duplicate aggregation.") | |
| return df | |
| df2 = df.copy() | |
| df2[key] = df2[key].astype(str) | |
| df2 = df2[df2[key].str.strip() != ""].copy() | |
| if len(df2) == 0: | |
| print("[AGG] Aggregation key exists but is empty; skipping duplicate aggregation.") | |
| return df | |
| agg_dict = {} | |
| for mc in modality_cols: | |
| if mc in df2.columns: | |
| agg_dict[mc] = "first" | |
| for pc in property_cols: | |
| if pc in df2.columns: | |
| agg_dict[pc] = "mean" | |
| grouped = df2.groupby(key, as_index=False).agg(agg_dict) | |
| print(f"[AGG] Grouped by '{key}': {len(df)} rows -> {len(grouped)} unique keys") | |
| return grouped | |
| def _sanitize_name(s: str) -> str: | |
| """Create a filesystem-safe name for property directories.""" | |
| s2 = str(s).strip().lower() | |
| keep = [] | |
| for ch in s2: | |
| if ch.isalnum(): | |
| keep.append(ch) | |
| elif ch in (" ", "-", "_", "."): | |
| keep.append("_") | |
| else: | |
| keep.append("_") | |
| out = "".join(keep) | |
| while "__" in out: | |
| out = out.replace("__", "_") | |
| out = out.strip("_") | |
| return out or "property" | |
| # ============================================================================= | |
| # Multimodal backbone: encode + project + modality-aware fusion | |
| # ============================================================================= | |
| class MultimodalContrastiveModel(nn.Module): | |
| """ | |
| Multimodal encoder wrapper: | |
| 1) Runs each available modality encoder: | |
| - GINE (graph) | |
| - SchNet (3D geometry) | |
| - Transformer FP encoder (Morgan bit sequence) | |
| - DeBERTa-based PSMILES encoder (sequence) | |
| 2) Projects each modality embedding to a shared dim (emb_dim). | |
| 3) Normalizes each modality embedding (L2), drops out, then fuses via | |
| a masked mean across modalities that are present for each sample. | |
| 4) Normalizes the final fused embedding (L2). | |
| Expected downstream usage: | |
| z = model(batch_mods, modality_mask=modality_mask) # (B, emb_dim) | |
| """ | |
| def __init__( | |
| self, | |
| gine_encoder: Optional[nn.Module] = None, | |
| schnet_encoder: Optional[nn.Module] = None, | |
| fp_encoder: Optional[nn.Module] = None, | |
| psmiles_encoder: Optional[nn.Module] = None, | |
| *, | |
| emb_dim: int = CL_EMB_DIM, | |
| modalities: Optional[List[str]] = None, | |
| dropout: float = 0.1, | |
| psmiles_tokenizer: Optional[Any] = None, | |
| ): | |
| super().__init__() | |
| self.gine = gine_encoder | |
| self.schnet = schnet_encoder | |
| self.fp = fp_encoder | |
| self.psmiles = psmiles_encoder | |
| self.psm_tok = psmiles_tokenizer | |
| self.emb_dim = int(emb_dim) | |
| self.out_dim = self.emb_dim | |
| self.dropout = nn.Dropout(float(dropout)) | |
| # Determine which modalities are enabled | |
| if modalities is None: | |
| mods = [] | |
| if self.gine is not None: | |
| mods.append("gine") | |
| if self.schnet is not None: | |
| mods.append("schnet") | |
| if self.fp is not None: | |
| mods.append("fp") | |
| if self.psmiles is not None: | |
| mods.append("psmiles") | |
| self.modalities = mods | |
| else: | |
| self.modalities = [m for m in modalities if m in ("gine", "schnet", "fp", "psmiles")] | |
| # Projection heads into shared embedding space | |
| self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None | |
| self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None | |
| self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None | |
| # Infer PSMILES hidden size if possible; fallback to DEBERTA_HIDDEN | |
| psm_in = None | |
| if self.psmiles is not None: | |
| if hasattr(self.psmiles, "out_dim"): | |
| try: | |
| psm_in = int(self.psmiles.out_dim) | |
| except Exception: | |
| psm_in = None | |
| if psm_in is None and hasattr(self.psmiles, "model") and hasattr(self.psmiles.model, "config"): | |
| try: | |
| psm_in = int(self.psmiles.model.config.hidden_size) | |
| except Exception: | |
| psm_in = None | |
| if psm_in is None: | |
| psm_in = int(DEBERTA_HIDDEN) | |
| self.proj_psmiles = nn.Linear(psm_in, self.emb_dim) if (self.psmiles is not None) else None | |
| def freeze_cl_encoders(self): | |
| """Freeze all modality encoders (optional for evaluation-only usage).""" | |
| for enc in (self.gine, self.schnet, self.fp, self.psmiles): | |
| if enc is None: | |
| continue | |
| enc.eval() | |
| for p in enc.parameters(): | |
| p.requires_grad = False | |
| def _masked_mean_combine(self, zs: List[torch.Tensor], masks: List[torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Compute sample-wise mean over available modalities. | |
| zs: list of modality embeddings, each (B,D) | |
| masks: list of modality presence masks, each (B,) bool | |
| returns: (B,D) | |
| """ | |
| if not zs: | |
| device = next(self.parameters()).device | |
| return torch.zeros((1, self.emb_dim), device=device) | |
| device = zs[0].device | |
| B = zs[0].size(0) | |
| sum_z = torch.zeros((B, self.emb_dim), device=device) | |
| count = torch.zeros((B, 1), device=device) | |
| for z, m in zip(zs, masks): | |
| m = m.to(device).view(B, 1).float() | |
| sum_z = sum_z + z * m | |
| count = count + m | |
| count = count.clamp(min=1.0) | |
| return sum_z / count | |
| def forward(self, batch_mods: dict, modality_mask: Optional[dict] = None) -> torch.Tensor: | |
| """ | |
| batch_mods keys: 'gine', 'schnet', 'fp', 'psmiles' | |
| modality_mask: dict {modality_name: (B,) bool} describing presence. | |
| """ | |
| device = next(self.parameters()).device | |
| zs = [] | |
| ms = [] | |
| # Infer batch size B | |
| B = None | |
| if modality_mask is not None: | |
| for _, v in modality_mask.items(): | |
| if isinstance(v, torch.Tensor) and v.numel() > 0: | |
| B = int(v.size(0)) | |
| break | |
| if B is None: | |
| if "fp" in batch_mods and batch_mods["fp"] is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor): | |
| B = int(batch_mods["fp"]["input_ids"].size(0)) | |
| elif "psmiles" in batch_mods and batch_mods["psmiles"] is not None and isinstance(batch_mods["psmiles"].get("input_ids", None), torch.Tensor): | |
| B = int(batch_mods["psmiles"]["input_ids"].size(0)) | |
| if B is None: | |
| return torch.zeros((1, self.emb_dim), device=device) | |
| def _get_mask(name: str) -> torch.Tensor: | |
| if modality_mask is not None and name in modality_mask and isinstance(modality_mask[name], torch.Tensor): | |
| return modality_mask[name].to(device).bool() | |
| return torch.ones((B,), device=device, dtype=torch.bool) | |
| # ------------------------- | |
| # GINE (graph modality) | |
| # ------------------------- | |
| if "gine" in self.modalities and self.gine is not None and batch_mods.get("gine", None) is not None: | |
| g = batch_mods["gine"] | |
| if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0: | |
| emb_g = self.gine( | |
| g["z"].to(device), | |
| g.get("chirality", None).to(device) if isinstance(g.get("chirality", None), torch.Tensor) else None, | |
| g.get("formal_charge", None).to(device) if isinstance(g.get("formal_charge", None), torch.Tensor) else None, | |
| g.get("edge_index", torch.empty((2, 0), dtype=torch.long)).to(device) if isinstance(g.get("edge_index", None), torch.Tensor) else torch.empty((2, 0), dtype=torch.long, device=device), | |
| g.get("edge_attr", torch.zeros((0, 3), dtype=torch.float)).to(device) if isinstance(g.get("edge_attr", None), torch.Tensor) else torch.zeros((0, 3), dtype=torch.float, device=device), | |
| g.get("batch", None).to(device) if isinstance(g.get("batch", None), torch.Tensor) else None | |
| ) | |
| z = self.proj_gine(emb_g) if self.proj_gine is not None else emb_g | |
| z = F.normalize(z, dim=-1) | |
| z = self.dropout(z) | |
| zs.append(z) | |
| ms.append(_get_mask("gine")) | |
| # ------------------------- | |
| # SchNet (3D geometry) | |
| # ------------------------- | |
| if "schnet" in self.modalities and self.schnet is not None and batch_mods.get("schnet", None) is not None: | |
| s = batch_mods["schnet"] | |
| if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0: | |
| emb_s = self.schnet( | |
| s["z"].to(device), | |
| s["pos"].to(device) if isinstance(s.get("pos", None), torch.Tensor) else torch.zeros((0, 3), device=device), | |
| s.get("batch", None).to(device) if isinstance(s.get("batch", None), torch.Tensor) else None | |
| ) | |
| z = self.proj_schnet(emb_s) if self.proj_schnet is not None else emb_s | |
| z = F.normalize(z, dim=-1) | |
| z = self.dropout(z) | |
| zs.append(z) | |
| ms.append(_get_mask("schnet")) | |
| # ------------------------- | |
| # Fingerprint modality | |
| # ------------------------- | |
| if "fp" in self.modalities and self.fp is not None and batch_mods.get("fp", None) is not None: | |
| f = batch_mods["fp"] | |
| if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0: | |
| emb_f = self.fp( | |
| f["input_ids"].to(device), | |
| f.get("attention_mask", None).to(device) if isinstance(f.get("attention_mask", None), torch.Tensor) else None | |
| ) | |
| z = self.proj_fp(emb_f) if self.proj_fp is not None else emb_f | |
| z = F.normalize(z, dim=-1) | |
| z = self.dropout(z) | |
| zs.append(z) | |
| ms.append(_get_mask("fp")) | |
| # ------------------------- | |
| # PSMILES text modality | |
| # ------------------------- | |
| if "psmiles" in self.modalities and self.psmiles is not None and batch_mods.get("psmiles", None) is not None: | |
| p = batch_mods["psmiles"] | |
| if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0: | |
| emb_p = self.psmiles( | |
| p["input_ids"].to(device), | |
| p.get("attention_mask", None).to(device) if isinstance(p.get("attention_mask", None), torch.Tensor) else None | |
| ) | |
| z = self.proj_psmiles(emb_p) if self.proj_psmiles is not None else emb_p | |
| z = F.normalize(z, dim=-1) | |
| z = self.dropout(z) | |
| zs.append(z) | |
| ms.append(_get_mask("psmiles")) | |
| # Fuse and normalize | |
| if not zs: | |
| return torch.zeros((B, self.emb_dim), device=device) | |
| z = self._masked_mean_combine(zs, ms) | |
| z = F.normalize(z, dim=-1) | |
| return z | |
| def encode_psmiles( | |
| self, | |
| psmiles_list: List[str], | |
| max_len: int = PSMILES_MAX_LEN, | |
| batch_size: int = 64, | |
| device: str = DEVICE | |
| ) -> np.ndarray: | |
| """ | |
| PSMILES embeddings | |
| """ | |
| self.eval() | |
| if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None: | |
| raise RuntimeError("PSMILES tokenizer/encoder/projection not available.") | |
| outs = [] | |
| for i in range(0, len(psmiles_list), batch_size): | |
| chunk = [str(x) for x in psmiles_list[i:i + batch_size]] | |
| enc = self.psm_tok(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt") | |
| input_ids = enc["input_ids"].to(device) | |
| attn = enc["attention_mask"].to(device).bool() | |
| emb_p = self.psmiles(input_ids, attn) | |
| z = F.normalize(self.proj_psmiles(emb_p), dim=-1) | |
| outs.append(z.detach().cpu().numpy()) | |
| return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32) | |
| def encode_multimodal( | |
| self, | |
| records: List[dict], | |
| batch_size: int = 32, | |
| device: str = DEVICE | |
| ) -> np.ndarray: | |
| """ | |
| Convenience: multimodal embedding for records carrying: | |
| - graph, geometry, fingerprints, psmiles | |
| Missing modalities are handled sample-wise via modality masking. | |
| """ | |
| self.eval() | |
| dev = torch.device(device) | |
| self.to(dev) | |
| outs = [] | |
| for i in range(0, len(records), batch_size): | |
| chunk = records[i:i + batch_size] | |
| # PSMILES batch | |
| psmiles_texts = [str(r.get("psmiles", "")) for r in chunk] | |
| p_enc = None | |
| if self.psm_tok is not None: | |
| p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt") | |
| # FP batch (always stack; missing handled by attention_mask downstream) | |
| fp_ids, fp_attn = [], [] | |
| for r in chunk: | |
| f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH) | |
| fp_ids.append(f["input_ids"]) | |
| fp_attn.append(f["attention_mask"]) | |
| fp_ids = torch.stack(fp_ids, dim=0) | |
| fp_attn = torch.stack(fp_attn, dim=0) | |
| # GINE + SchNet packed batching | |
| gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []} | |
| node_offset = 0 | |
| for bi, r in enumerate(chunk): | |
| g = _parse_graph_for_gine(r.get("graph", None)) | |
| if g is None or g["z"].numel() == 0: | |
| continue | |
| n = g["z"].size(0) | |
| gine_all["z"].append(g["z"]) | |
| gine_all["chirality"].append(g["chirality"]) | |
| gine_all["formal_charge"].append(g["formal_charge"]) | |
| gine_all["batch"].append(torch.full((n,), bi, dtype=torch.long)) | |
| ei = g["edge_index"] | |
| ea = g["edge_attr"] | |
| if ei is not None and ei.numel() > 0: | |
| gine_all["edge_index"].append(ei + node_offset) | |
| gine_all["edge_attr"].append(ea) | |
| node_offset += n | |
| gine_batch = None | |
| if len(gine_all["z"]) > 0: | |
| z_b = torch.cat(gine_all["z"], dim=0) | |
| ch_b = torch.cat(gine_all["chirality"], dim=0) | |
| fc_b = torch.cat(gine_all["formal_charge"], dim=0) | |
| b_b = torch.cat(gine_all["batch"], dim=0) | |
| if len(gine_all["edge_index"]) > 0: | |
| ei_b = torch.cat(gine_all["edge_index"], dim=1) | |
| ea_b = torch.cat(gine_all["edge_attr"], dim=0) | |
| else: | |
| ei_b = torch.empty((2, 0), dtype=torch.long) | |
| ea_b = torch.zeros((0, 3), dtype=torch.float) | |
| gine_batch = {"z": z_b, "chirality": ch_b, "formal_charge": fc_b, "edge_index": ei_b, "edge_attr": ea_b, "batch": b_b} | |
| sch_all_z, sch_all_pos, sch_all_batch = [], [], [] | |
| for bi, r in enumerate(chunk): | |
| s = _parse_geometry_for_schnet(r.get("geometry", None)) | |
| if s is None or s["z"].numel() == 0: | |
| continue | |
| n = s["z"].size(0) | |
| sch_all_z.append(s["z"]) | |
| sch_all_pos.append(s["pos"]) | |
| sch_all_batch.append(torch.full((n,), bi, dtype=torch.long)) | |
| schnet_batch = None | |
| if len(sch_all_z) > 0: | |
| schnet_batch = { | |
| "z": torch.cat(sch_all_z, dim=0), | |
| "pos": torch.cat(sch_all_pos, dim=0), | |
| "batch": torch.cat(sch_all_batch, dim=0), | |
| } | |
| batch_mods = { | |
| "gine": gine_batch, | |
| "schnet": schnet_batch, | |
| "fp": {"input_ids": fp_ids, "attention_mask": fp_attn}, | |
| "psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None | |
| } | |
| # NOTE: This script uses forward() as the encoder entry point. | |
| z = self.forward(batch_mods, modality_mask=None) | |
| outs.append(z.detach().cpu().numpy()) | |
| return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32) | |
| # ============================================================================= | |
| # Tokenizer setup | |
| # ============================================================================= | |
| SPM_MODEL = "/path/to/spm.model" | |
| tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN) | |
| # ============================================================================= | |
| # Dataset: single-task property prediction (with modality parsing) | |
| # ============================================================================= | |
| class PolymerPropertyDataset(Dataset): | |
| """ | |
| Dataset that prepares one sample with up to four modalities: | |
| - graph (for GINE) | |
| - geometry (for SchNet) | |
| - fingerprints (for FP transformer) | |
| - psmiles text (for DeBERTa encoder) | |
| Target is a single scalar per sample (already scaled externally). | |
| """ | |
| def __init__(self, data_list, tokenizer, max_length=128): | |
| self.data_list = data_list | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| data = self.data_list[idx] | |
| # --------------------------------------------------------------------- | |
| # Graph -> GINE tensors (robust parsing of stored JSON fields) | |
| # --------------------------------------------------------------------- | |
| gine_data = None | |
| if 'graph' in data and data['graph']: | |
| try: | |
| graph_field = json.loads(data['graph']) if isinstance(data['graph'], str) else data['graph'] | |
| node_features = safe_get(graph_field, "node_features", None) | |
| if node_features: | |
| atomic_nums = [] | |
| chirality_vals = [] | |
| formal_charges = [] | |
| for nf in node_features: | |
| an = safe_get(nf, "atomic_num", None) | |
| if an is None: | |
| an = safe_get(nf, "atomic_number", 0) | |
| ch = safe_get(nf, "chirality", 0) | |
| fc = safe_get(nf, "formal_charge", 0) | |
| try: | |
| atomic_nums.append(int(an)) | |
| except Exception: | |
| atomic_nums.append(0) | |
| chirality_vals.append(float(ch)) | |
| formal_charges.append(float(fc)) | |
| edge_indices_raw = safe_get(graph_field, "edge_indices", None) | |
| edge_features_raw = safe_get(graph_field, "edge_features", None) | |
| edge_index = None | |
| edge_attr = None | |
| # Fallback: adjacency matrix if edge_indices missing | |
| if edge_indices_raw is None: | |
| adj_mat = safe_get(graph_field, "adjacency_matrix", None) | |
| if adj_mat: | |
| srcs = [] | |
| dsts = [] | |
| for i_r, row_adj in enumerate(adj_mat): | |
| for j, val2 in enumerate(row_adj): | |
| if val2: | |
| srcs.append(i_r) | |
| dsts.append(j) | |
| if len(srcs) > 0: | |
| edge_index = [srcs, dsts] | |
| E = len(srcs) | |
| edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)] | |
| else: | |
| # edge_indices can be [[srcs],[dsts]] or list of pairs | |
| srcs, dsts = [], [] | |
| if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0: | |
| if isinstance(edge_indices_raw[0], list): | |
| first = edge_indices_raw[0] | |
| if len(first) == 2 and isinstance(first[0], int): | |
| try: | |
| srcs = [int(p[0]) for p in edge_indices_raw] | |
| dsts = [int(p[1]) for p in edge_indices_raw] | |
| except Exception: | |
| srcs, dsts = [], [] | |
| else: | |
| try: | |
| srcs = [int(x) for x in edge_indices_raw[0]] | |
| dsts = [int(x) for x in edge_indices_raw[1]] | |
| except Exception: | |
| srcs, dsts = [], [] | |
| else: | |
| try: | |
| srcs = [int(x) for x in edge_indices_raw[0]] | |
| dsts = [int(x) for x in edge_indices_raw[1]] | |
| except Exception: | |
| srcs, dsts = [], [] | |
| if len(srcs) > 0: | |
| edge_index = [srcs, dsts] | |
| # edge_features: attempt to map known fields; otherwise zeros | |
| if edge_features_raw and isinstance(edge_features_raw, list): | |
| bond_types = [] | |
| stereos = [] | |
| is_conjs = [] | |
| for ef in edge_features_raw: | |
| bt = safe_get(ef, "bond_type", 0) | |
| st = safe_get(ef, "stereo", 0) | |
| ic = safe_get(ef, "is_conjugated", False) | |
| bond_types.append(float(bt)) | |
| stereos.append(float(st)) | |
| is_conjs.append(float(1.0 if ic else 0.0)) | |
| edge_attr = list(zip(bond_types, stereos, is_conjs)) | |
| else: | |
| E = len(srcs) | |
| edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)] | |
| if edge_index is not None: | |
| gine_data = { | |
| 'z': torch.tensor(atomic_nums, dtype=torch.long), | |
| 'chirality': torch.tensor(chirality_vals, dtype=torch.float), | |
| 'formal_charge': torch.tensor(formal_charges, dtype=torch.float), | |
| 'edge_index': torch.tensor(edge_index, dtype=torch.long), | |
| 'edge_attr': torch.tensor(edge_attr, dtype=torch.float) | |
| } | |
| except Exception: | |
| gine_data = None | |
| # --------------------------------------------------------------------- | |
| # Geometry -> SchNet tensors (best conformer) | |
| # --------------------------------------------------------------------- | |
| schnet_data = None | |
| if 'geometry' in data and data['geometry']: | |
| try: | |
| geom = json.loads(data['geometry']) if isinstance(data['geometry'], str) else data['geometry'] | |
| conf = geom.get("best_conformer") if isinstance(geom, dict) else None | |
| if conf: | |
| atomic = conf.get("atomic_numbers", []) | |
| coords = conf.get("coordinates", []) | |
| if len(atomic) == len(coords) and len(atomic) > 0: | |
| schnet_data = { | |
| 'z': torch.tensor(atomic, dtype=torch.long), | |
| 'pos': torch.tensor(coords, dtype=torch.float) | |
| } | |
| except Exception: | |
| schnet_data = None | |
| # --------------------------------------------------------------------- | |
| # Fingerprints -> FP transformer inputs (bit sequence) | |
| # --------------------------------------------------------------------- | |
| fp_data = None | |
| if 'fingerprints' in data and data['fingerprints']: | |
| try: | |
| fpval = data['fingerprints'] | |
| if fpval is not None and not (isinstance(fpval, str) and fpval.strip() == ""): | |
| try: | |
| fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval | |
| except Exception: | |
| try: | |
| fp_json = json.loads(str(fpval).replace("'", '"')) | |
| except Exception: | |
| parts = [p.strip().strip('"').strip("'") for p in str(fpval).split(",")] | |
| bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]] | |
| if len(bits) < FP_LENGTH: | |
| bits += [0] * (FP_LENGTH - len(bits)) | |
| fp_json = bits | |
| if isinstance(fp_json, dict): | |
| bits = safe_get(fp_json, "morgan_r3_bits", None) | |
| if bits is None: | |
| bits = [0] * FP_LENGTH | |
| else: | |
| normalized = [] | |
| for b in bits: | |
| if isinstance(b, str): | |
| b_clean = b.strip().strip('"').strip("'") | |
| normalized.append(1 if b_clean in ("1", "True", "true") else 0) | |
| elif isinstance(b, (int, np.integer)): | |
| normalized.append(1 if int(b) != 0 else 0) | |
| else: | |
| normalized.append(0) | |
| if len(normalized) >= FP_LENGTH: | |
| break | |
| if len(normalized) < FP_LENGTH: | |
| normalized.extend([0] * (FP_LENGTH - len(normalized))) | |
| bits = normalized[:FP_LENGTH] | |
| elif isinstance(fp_json, list): | |
| bits = fp_json[:FP_LENGTH] | |
| if len(bits) < FP_LENGTH: | |
| bits += [0] * (FP_LENGTH - len(bits)) | |
| else: | |
| bits = [0] * FP_LENGTH | |
| fp_data = { | |
| 'input_ids': torch.tensor(bits, dtype=torch.long), | |
| 'attention_mask': torch.ones(FP_LENGTH, dtype=torch.bool) | |
| } | |
| except Exception: | |
| fp_data = None | |
| # --------------------------------------------------------------------- | |
| # PSMILES -> DeBERTa tokenizer inputs | |
| # --------------------------------------------------------------------- | |
| psmiles_data = None | |
| if 'psmiles' in data and data['psmiles'] and self.tokenizer is not None: | |
| try: | |
| s = str(data['psmiles']) | |
| enc = self.tokenizer( | |
| s, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=PSMILES_MAX_LEN | |
| ) | |
| psmiles_data = { | |
| 'input_ids': torch.tensor(enc["input_ids"], dtype=torch.long), | |
| 'attention_mask': torch.tensor(enc["attention_mask"], dtype=torch.bool) | |
| } | |
| except Exception: | |
| psmiles_data = None | |
| # --------------------------------------------------------------------- | |
| # Fill defaults for missing modalities | |
| # --------------------------------------------------------------------- | |
| if gine_data is None: | |
| gine_data = { | |
| 'z': torch.tensor([], dtype=torch.long), | |
| 'chirality': torch.tensor([], dtype=torch.float), | |
| 'formal_charge': torch.tensor([], dtype=torch.float), | |
| 'edge_index': torch.tensor([[], []], dtype=torch.long), | |
| 'edge_attr': torch.zeros((0, 3), dtype=torch.float) | |
| } | |
| if schnet_data is None: | |
| schnet_data = { | |
| 'z': torch.tensor([], dtype=torch.long), | |
| 'pos': torch.tensor([], dtype=torch.float) | |
| } | |
| if fp_data is None: | |
| fp_data = { | |
| 'input_ids': torch.zeros(FP_LENGTH, dtype=torch.long), | |
| 'attention_mask': torch.zeros(FP_LENGTH, dtype=torch.bool) | |
| } | |
| if psmiles_data is None: | |
| psmiles_data = { | |
| 'input_ids': torch.zeros(PSMILES_MAX_LEN, dtype=torch.long), | |
| 'attention_mask': torch.zeros(PSMILES_MAX_LEN, dtype=torch.bool) | |
| } | |
| # Single-task regression target (already scaled) | |
| target_scaled = float(data.get("target_scaled", 0.0)) | |
| return { | |
| 'gine': gine_data, | |
| 'schnet': schnet_data, | |
| 'fp': fp_data, | |
| 'psmiles': psmiles_data, | |
| 'target': torch.tensor(target_scaled, dtype=torch.float32), | |
| } | |
| # ============================================================================= | |
| # Collate: pack variable-sized graph/3D into batch tensors + modality masks | |
| # ============================================================================= | |
| def multimodal_collate_fn(batch): | |
| """ | |
| Collate samples into a single minibatch. | |
| - GINE: concatenate nodes across samples and build a `batch` vector. | |
| - SchNet: concatenate atoms/coords across samples and build a `batch` vector. | |
| - FP/PSMILES: stack to (B, L). | |
| - modality_mask: per-sample boolean flags indicating availability. | |
| """ | |
| B = len(batch) | |
| # ------------------------- | |
| # GINE packing | |
| # ------------------------- | |
| all_z = [] | |
| all_ch = [] | |
| all_fc = [] | |
| all_edge_index = [] | |
| all_edge_attr = [] | |
| batch_mapping = [] | |
| node_offset = 0 | |
| gine_present = [] | |
| for i, item in enumerate(batch): | |
| g = item["gine"] | |
| z = g["z"] | |
| n = z.size(0) | |
| gine_present.append(bool(n > 0)) | |
| all_z.append(z) | |
| all_ch.append(g["chirality"]) | |
| all_fc.append(g["formal_charge"]) | |
| batch_mapping.append(torch.full((n,), i, dtype=torch.long)) | |
| if g["edge_index"] is not None and g["edge_index"].numel() > 0: | |
| ei_offset = g["edge_index"] + node_offset | |
| all_edge_index.append(ei_offset) | |
| ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3) | |
| all_edge_attr.append(ea) | |
| node_offset += n | |
| if len(all_z) == 0: | |
| z_batch = torch.tensor([], dtype=torch.long) | |
| ch_batch = torch.tensor([], dtype=torch.float) | |
| fc_batch = torch.tensor([], dtype=torch.float) | |
| batch_batch = torch.tensor([], dtype=torch.long) | |
| edge_index_batched = torch.empty((2, 0), dtype=torch.long) | |
| edge_attr_batched = torch.zeros((0, 3), dtype=torch.float) | |
| else: | |
| z_batch = torch.cat(all_z, dim=0) | |
| ch_batch = torch.cat(all_ch, dim=0) | |
| fc_batch = torch.cat(all_fc, dim=0) | |
| batch_batch = torch.cat(batch_mapping, dim=0) if len(batch_mapping) > 0 else torch.tensor([], dtype=torch.long) | |
| if len(all_edge_index) > 0: | |
| edge_index_batched = torch.cat(all_edge_index, dim=1) | |
| edge_attr_batched = torch.cat(all_edge_attr, dim=0) if len(all_edge_attr) > 0 else torch.zeros((0, 3), dtype=torch.float) | |
| else: | |
| edge_index_batched = torch.empty((2, 0), dtype=torch.long) | |
| edge_attr_batched = torch.zeros((0, 3), dtype=torch.float) | |
| # ------------------------- | |
| # SchNet packing | |
| # ------------------------- | |
| all_sz = [] | |
| all_pos = [] | |
| schnet_batch = [] | |
| schnet_present = [False] * B | |
| for i, item in enumerate(batch): | |
| s = item["schnet"] | |
| s_z = s["z"] | |
| s_pos = s["pos"] | |
| if s_z.numel() == 0: | |
| continue | |
| schnet_present[i] = True | |
| all_sz.append(s_z) | |
| all_pos.append(s_pos) | |
| schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long)) | |
| if len(all_sz) == 0: | |
| s_z_batch = torch.tensor([], dtype=torch.long) | |
| s_pos_batch = torch.tensor([], dtype=torch.float) | |
| s_batch_batch = torch.tensor([], dtype=torch.long) | |
| else: | |
| s_z_batch = torch.cat(all_sz, dim=0) | |
| s_pos_batch = torch.cat(all_pos, dim=0) | |
| s_batch_batch = torch.cat(schnet_batch, dim=0) | |
| # ------------------------- | |
| # FP stacking | |
| # ------------------------- | |
| fp_ids = torch.stack([item["fp"]["input_ids"] for item in batch], dim=0) | |
| fp_attn = torch.stack([item["fp"]["attention_mask"] for item in batch], dim=0) | |
| fp_present = (fp_attn.sum(dim=1) > 0).cpu().numpy().tolist() | |
| # ------------------------- | |
| # PSMILES stacking | |
| # ------------------------- | |
| p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch], dim=0) | |
| p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch], dim=0) | |
| psmiles_present = (p_attn.sum(dim=1) > 0).cpu().numpy().tolist() | |
| # Target | |
| target = torch.stack([item["target"] for item in batch], dim=0) # (B,) | |
| # Presence mask for fusion (per-sample modality availability) | |
| modality_mask = { | |
| "gine": torch.tensor(gine_present, dtype=torch.bool), | |
| "schnet": torch.tensor(schnet_present, dtype=torch.bool), | |
| "fp": torch.tensor(fp_present, dtype=torch.bool), | |
| "psmiles": torch.tensor(psmiles_present, dtype=torch.bool), | |
| } | |
| return { | |
| "gine": { | |
| "z": z_batch, | |
| "chirality": ch_batch, | |
| "formal_charge": fc_batch, | |
| "edge_index": edge_index_batched, | |
| "edge_attr": edge_attr_batched, | |
| "batch": batch_batch | |
| }, | |
| "schnet": { | |
| "z": s_z_batch, | |
| "pos": s_pos_batch, | |
| "batch": s_batch_batch | |
| }, | |
| "fp": { | |
| "input_ids": fp_ids, | |
| "attention_mask": fp_attn | |
| }, | |
| "psmiles": { | |
| "input_ids": p_ids, | |
| "attention_mask": p_attn | |
| }, | |
| "target": target, | |
| "modality_mask": modality_mask | |
| } | |
| # ============================================================================= | |
| # Single-task regressor head | |
| # ============================================================================= | |
| class PolyFPropertyRegressor(nn.Module): | |
| """ | |
| Simple MLP head on top of the multimodal fused embedding. | |
| Predicts one scalar (scaled target) per sample. | |
| """ | |
| def __init__(self, polyf_model: MultimodalContrastiveModel, emb_dim: int = POLYF_EMB_DIM, dropout: float = 0.1): | |
| super().__init__() | |
| self.polyf = polyf_model | |
| self.head = nn.Sequential( | |
| nn.Linear(emb_dim, emb_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(emb_dim // 2, 1) | |
| ) | |
| def forward(self, batch_mods, modality_mask=None): | |
| emb = self.polyf(batch_mods, modality_mask=modality_mask) # (B,d) | |
| y = self.head(emb).squeeze(-1) # (B,) | |
| return y | |
| # ============================================================================= | |
| # Training / evaluation helpers | |
| # ============================================================================= | |
| def compute_metrics(y_true, y_pred): | |
| """Compute standard regression metrics in original units.""" | |
| mse = mean_squared_error(y_true, y_pred) | |
| rmse = math.sqrt(mse) | |
| mae = mean_absolute_error(y_true, y_pred) | |
| r2 = r2_score(y_true, y_pred) | |
| return {"mse": float(mse), "rmse": float(rmse), "mae": float(mae), "r2": float(r2)} | |
| def train_one_epoch(model, dataloader, optimizer, device): | |
| """One epoch of supervised regression training (MSE loss).""" | |
| model.train() | |
| total_loss = 0.0 | |
| total_n = 0 | |
| for batch in dataloader: | |
| # Move nested batch dict to device | |
| for k in batch: | |
| if k == "target": | |
| batch[k] = batch[k].to(device) | |
| elif k == "modality_mask": | |
| for mk in batch[k]: | |
| if isinstance(batch[k][mk], torch.Tensor): | |
| batch[k][mk] = batch[k][mk].to(device) | |
| else: | |
| for subk in batch[k]: | |
| if isinstance(batch[k][subk], torch.Tensor): | |
| batch[k][subk] = batch[k][subk].to(device) | |
| y = batch["target"] # (B,) | |
| modality_mask = batch.get("modality_mask", None) | |
| batch_mods = {k: v for k, v in batch.items() if k not in ("target", "modality_mask")} | |
| pred = model(batch_mods, modality_mask=modality_mask) | |
| loss = F.mse_loss(pred, y) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| bs = int(y.size(0)) | |
| total_loss += float(loss.item()) * bs | |
| total_n += bs | |
| return total_loss / max(1, total_n) | |
| def evaluate(model, dataloader, device): | |
| """ | |
| Evaluate on a dataloader: | |
| - returns avg loss, predicted scaled values, true scaled values | |
| """ | |
| model.eval() | |
| preds = [] | |
| trues = [] | |
| total_loss = 0.0 | |
| total_n = 0 | |
| for batch in dataloader: | |
| # Move nested batch dict to device | |
| for k in batch: | |
| if k == "target": | |
| batch[k] = batch[k].to(device) | |
| elif k == "modality_mask": | |
| for mk in batch[k]: | |
| if isinstance(batch[k][mk], torch.Tensor): | |
| batch[k][mk] = batch[k][mk].to(device) | |
| else: | |
| for subk in batch[k]: | |
| if isinstance(batch[k][subk], torch.Tensor): | |
| batch[k][subk] = batch[k][subk].to(device) | |
| y = batch["target"] | |
| modality_mask = batch.get("modality_mask", None) | |
| batch_mods = {k: v for k, v in batch.items() if k not in ("target", "modality_mask")} | |
| pred = model(batch_mods, modality_mask=modality_mask) | |
| loss = F.mse_loss(pred, y) | |
| bs = int(y.size(0)) | |
| total_loss += float(loss.item()) * bs | |
| total_n += bs | |
| preds.append(pred.detach().cpu().numpy()) | |
| trues.append(y.detach().cpu().numpy()) | |
| if total_n == 0: | |
| return None, None, None | |
| preds = np.concatenate(preds, axis=0) | |
| trues = np.concatenate(trues, axis=0) | |
| avg_loss = total_loss / max(1, total_n) | |
| return float(avg_loss), preds, trues | |
| # ============================================================================= | |
| # Pretrained loading helpers | |
| # ============================================================================= | |
| def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveModel: | |
| """ | |
| Construct modality encoders and load any available pretrained weights: | |
| - modality-specific checkpoints (BEST_*_DIR) | |
| - full multimodal checkpoint from `pretrained_path/pytorch_model.bin` | |
| Returns a ready-to-fine-tune MultimodalContrastiveModel. | |
| """ | |
| # ------------------------- | |
| # GINE encoder | |
| # ------------------------- | |
| gine_encoder = GineEncoder( | |
| node_emb_dim=NODE_EMB_DIM, | |
| edge_emb_dim=EDGE_EMB_DIM, | |
| num_layers=NUM_GNN_LAYERS, | |
| max_atomic_z=MAX_ATOMIC_Z | |
| ) | |
| gine_ckpt = os.path.join(BEST_GINE_DIR, "pytorch_model.bin") | |
| if os.path.exists(gine_ckpt): | |
| try: | |
| gine_encoder.load_state_dict(torch.load(gine_ckpt, map_location="cpu"), strict=False) | |
| print(f"[LOAD] GINE weights: {gine_ckpt}") | |
| except Exception as e: | |
| print(f"[LOAD][WARN] Could not load GINE weights: {e}") | |
| # ------------------------- | |
| # SchNet encoder | |
| # ------------------------- | |
| schnet_encoder = NodeSchNetWrapper( | |
| hidden_channels=SCHNET_HIDDEN, | |
| num_interactions=SCHNET_NUM_INTERACTIONS, | |
| num_gaussians=SCHNET_NUM_GAUSSIANS, | |
| cutoff=SCHNET_CUTOFF, | |
| max_num_neighbors=SCHNET_MAX_NEIGHBORS | |
| ) | |
| sch_ckpt = os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin") | |
| if os.path.exists(sch_ckpt): | |
| try: | |
| schnet_encoder.load_state_dict(torch.load(sch_ckpt, map_location="cpu"), strict=False) | |
| print(f"[LOAD] SchNet weights: {sch_ckpt}") | |
| except Exception as e: | |
| print(f"[LOAD][WARN] Could not load SchNet weights: {e}") | |
| # ------------------------- | |
| # Fingerprint encoder | |
| # ------------------------- | |
| fp_encoder = FingerprintEncoder( | |
| vocab_size=VOCAB_SIZE_FP, | |
| hidden_dim=256, | |
| seq_len=FP_LENGTH, | |
| num_layers=4, | |
| nhead=8, | |
| dim_feedforward=1024, | |
| dropout=0.1 | |
| ) | |
| fp_ckpt = os.path.join(BEST_FP_DIR, "pytorch_model.bin") | |
| if os.path.exists(fp_ckpt): | |
| try: | |
| fp_encoder.load_state_dict(torch.load(fp_ckpt, map_location="cpu"), strict=False) | |
| print(f"[LOAD] FP encoder weights: {fp_ckpt}") | |
| except Exception as e: | |
| print(f"[LOAD][WARN] Could not load fingerprint weights: {e}") | |
| # ------------------------- | |
| # PSMILES encoder | |
| # ------------------------- | |
| psmiles_encoder = None | |
| if os.path.isdir(BEST_PSMILES_DIR): | |
| try: | |
| psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR) | |
| print(f"[LOAD] PSMILES encoder: {BEST_PSMILES_DIR}") | |
| except Exception as e: | |
| print(f"[LOAD][WARN] Could not load PSMILES encoder from dir: {e}") | |
| # Fallback: initialize with vocab fallback | |
| if psmiles_encoder is None: | |
| try: | |
| psmiles_encoder = PSMILESDebertaEncoder( | |
| model_dir_or_name=None, | |
| vocab_fallback=int(getattr(tokenizer, "vocab_size", 300)) | |
| ) | |
| print("[LOAD] PSMILES encoder: initialized fallback (no pretrained dir).") | |
| except Exception as e: | |
| print(f"[LOAD][WARN] Could not initialize PSMILES encoder: {e}") | |
| # Build multimodal wrapper | |
| multimodal_model = MultimodalContrastiveModel( | |
| gine_encoder, | |
| schnet_encoder, | |
| fp_encoder, | |
| psmiles_encoder, | |
| emb_dim=POLYF_EMB_DIM, | |
| modalities=["gine", "schnet", "fp", "psmiles"] | |
| ) | |
| # ------------------------- | |
| # Optional: load full multimodal checkpoint | |
| # ------------------------- | |
| ckpt_path = os.path.join(pretrained_path, "pytorch_model.bin") | |
| if os.path.isfile(ckpt_path): | |
| try: | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| model_state = multimodal_model.state_dict() | |
| filtered_state = {} | |
| for k, v in state.items(): | |
| if k not in model_state: | |
| continue | |
| if model_state[k].shape != v.shape: | |
| continue | |
| filtered_state[k] = v | |
| summarize_state_dict_load(state, model_state, filtered_state) | |
| missing, unexpected = multimodal_model.load_state_dict(filtered_state, strict=False) | |
| print(f"[LOAD] Multimodal checkpoint: {ckpt_path}") | |
| print(f"[LOAD] load_state_dict -> missing={len(missing)} unexpected={len(unexpected)}") | |
| if missing: | |
| print("[LOAD] Missing keys (sample):", missing[:50]) | |
| if unexpected: | |
| print("[LOAD] Unexpected keys (sample):", unexpected[:50]) | |
| except Exception as e: | |
| print(f"[LOAD][WARN] Failed to load multimodal pretrained weights: {e}") | |
| else: | |
| print(f"[LOAD] No multimodal checkpoint found at: {ckpt_path}") | |
| return multimodal_model | |
| # ============================================================================= | |
| # Downstream: sample construction + CV training loop | |
| # ============================================================================= | |
| def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]: | |
| """ | |
| Construct training samples for a single property: | |
| - Keep rows that have at least one modality present. | |
| - Keep rows with a finite property value in `prop_col`. | |
| - Store raw target (will be scaled per fold). | |
| """ | |
| samples = [] | |
| for _, row in df.iterrows(): | |
| # Require at least one modality present | |
| has_modality = False | |
| for col in ['graph', 'geometry', 'fingerprints', 'psmiles']: | |
| if col in row and row[col] and str(row[col]).strip() != "": | |
| has_modality = True | |
| break | |
| if not has_modality: | |
| continue | |
| val = row.get(prop_col, np.nan) | |
| if val is None or (isinstance(val, float) and np.isnan(val)): | |
| continue | |
| try: | |
| y = float(val) | |
| except Exception: | |
| continue | |
| samples.append({ | |
| 'graph': row.get('graph', ''), | |
| 'geometry': row.get('geometry', ''), | |
| 'fingerprints': row.get('fingerprints', ''), | |
| 'psmiles': row.get('psmiles', ''), | |
| 'target_raw': y | |
| }) | |
| return samples | |
| def run_polyf_downstream(property_list: List[str], property_cols: List[str], df_raw: pd.DataFrame, | |
| pretrained_path: str, output_file: str): | |
| """ | |
| Downstream evaluation: | |
| For each property: | |
| - Build samples from PolyInfo | |
| - 5-fold CV: | |
| - Split into trainval/test (by KFold) | |
| - Split trainval into train/val | |
| - Fit StandardScaler on train targets | |
| - Fine-tune encoder+head end-to-end with early stopping by val loss | |
| - Evaluate on held-out test fold in original units | |
| - Save per-fold results and per-property mean±std | |
| - Save best fold checkpoint bundle (by test R2) for later reuse | |
| """ | |
| os.makedirs(pretrained_path, exist_ok=True) | |
| # Optional duplicate aggregation (noise reduction) | |
| modality_cols = ["graph", "geometry", "fingerprints", "psmiles"] | |
| df_proc = aggregate_polyinfo_duplicates(df_raw, modality_cols=modality_cols, property_cols=property_cols) | |
| all_results = {"per_property": {}, "mode": "POLYF_MATCHED_SINGLE_TASK"} | |
| for pname, pcol in zip(property_list, property_cols): | |
| samples = build_samples_for_property(df_proc, pcol) | |
| print(f"[DATA] {pname}: n_samples={len(samples)}") | |
| if len(samples) < 200: | |
| print(f"[DATA][WARN] '{pname}' has <200 samples; results may be noisy.") | |
| if len(samples) < 50: | |
| print(f"[DATA][WARN] Skipping '{pname}' (insufficient samples).") | |
| continue | |
| run_metrics = [] | |
| run_records = [] | |
| # Track best-performing fold for this property (by test R2) | |
| best_overall_r2 = -1e18 | |
| best_overall_payload = None | |
| idxs = np.arange(len(samples)) | |
| cv = KFold(n_splits=NUM_RUNS, shuffle=True, random_state=42) | |
| for run_idx, (trainval_idx, test_idx) in enumerate(cv.split(idxs)): | |
| seed = 42 + run_idx | |
| set_seed(seed) | |
| print(f"\n--- [CV] {pname}: fold {run_idx+1}/{NUM_RUNS} | seed={seed} ---") | |
| trainval = [copy.deepcopy(samples[i]) for i in trainval_idx] | |
| test = [copy.deepcopy(samples[i]) for i in test_idx] | |
| # Split trainval into train/val | |
| tr_idx, va_idx = train_test_split( | |
| np.arange(len(trainval)), | |
| test_size=VAL_SIZE_WITHIN_TRAINVAL, | |
| random_state=seed, | |
| shuffle=True | |
| ) | |
| train = [copy.deepcopy(trainval[i]) for i in tr_idx] | |
| val = [copy.deepcopy(trainval[i]) for i in va_idx] | |
| # Standardize target using training fold only (prevents leakage) | |
| sc = StandardScaler() | |
| sc.fit(np.array([s["target_raw"] for s in train]).reshape(-1, 1)) | |
| def _apply_scale(lst): | |
| for s in lst: | |
| s["target_scaled"] = float(sc.transform(np.array([[s["target_raw"]]])).ravel()[0]) | |
| _apply_scale(train) | |
| _apply_scale(val) | |
| _apply_scale(test) | |
| ds_train = PolymerPropertyDataset(train, tokenizer, max_length=MAX_LEN) | |
| ds_val = PolymerPropertyDataset(val, tokenizer, max_length=MAX_LEN) | |
| ds_test = PolymerPropertyDataset(test, tokenizer, max_length=MAX_LEN) | |
| dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn) | |
| dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn) | |
| dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn) | |
| print(f"[SPLIT] train={len(ds_train)} val={len(ds_val)} test={len(ds_test)}") | |
| # Fresh base model per fold to avoid any cross-fold leakage | |
| polyf_base = load_pretrained_multimodal(pretrained_path) | |
| model = PolyFPropertyRegressor(polyf_base, emb_dim=POLYF_EMB_DIM, dropout=POLYF_DROPOUT).to(DEVICE) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS) | |
| best_val = float("inf") | |
| best_state = None | |
| no_improve = 0 | |
| # Train with early stopping on validation loss | |
| for epoch in range(1, NUM_EPOCHS + 1): | |
| tr_loss = train_one_epoch(model, dl_train, optimizer, DEVICE) | |
| va_loss, _, _ = evaluate(model, dl_val, DEVICE) | |
| va_loss = va_loss if va_loss is not None else float("inf") | |
| scheduler.step() | |
| print(f"[{pname}] fold {run_idx+1}/{NUM_RUNS} epoch {epoch:03d} | train={tr_loss:.6f} | val={va_loss:.6f}") | |
| if va_loss < best_val - 1e-8: | |
| best_val = va_loss | |
| no_improve = 0 | |
| best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} | |
| else: | |
| no_improve += 1 | |
| if no_improve >= PATIENCE: | |
| print(f"[{pname}] fold {run_idx+1}: early stopping (patience={PATIENCE}) at epoch {epoch}.") | |
| break | |
| if best_state is None: | |
| print(f"[{pname}][WARN] No best checkpoint captured for fold {run_idx+1}; skipping fold.") | |
| continue | |
| # Restore best state and evaluate on test fold | |
| model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=True) | |
| _, pred_scaled, true_scaled = evaluate(model, dl_test, DEVICE) | |
| if pred_scaled is None: | |
| print(f"[{pname}][WARN] Test evaluation returned empty predictions for fold {run_idx+1}.") | |
| continue | |
| # Convert from scaled space back to original units | |
| pred = sc.inverse_transform(pred_scaled.reshape(-1, 1)).ravel() | |
| true = sc.inverse_transform(true_scaled.reshape(-1, 1)).ravel() | |
| m = compute_metrics(true, pred) | |
| run_metrics.append(m) | |
| print(f"[{pname}] fold {run_idx+1} TEST | r2={m['r2']:.4f} mae={m['mae']:.4f} rmse={m['rmse']:.4f}") | |
| record = { | |
| "property": pname, | |
| "property_col": pcol, | |
| "run": run_idx + 1, | |
| "seed": seed, | |
| "n_train": len(ds_train), | |
| "n_val": len(ds_val), | |
| "n_test": len(ds_test), | |
| "best_val_loss": float(best_val), | |
| "test_metrics": m | |
| } | |
| run_records.append(record) | |
| with open(output_file, "a") as fh: | |
| fh.write(json.dumps(make_json_serializable(record)) + "\n") | |
| # Update best fold bundle (by test R2) | |
| if float(m.get("r2", -1e18)) > float(best_overall_r2): | |
| best_overall_r2 = float(m.get("r2", -1e18)) | |
| best_overall_payload = { | |
| "property": pname, | |
| "property_col": pcol, | |
| "best_run": int(run_idx + 1), | |
| "seed": int(seed), | |
| "n_train": int(len(ds_train)), | |
| "n_val": int(len(ds_val)), | |
| "n_test": int(len(ds_test)), | |
| "best_val_loss": float(best_val), | |
| "test_metrics": make_json_serializable(m), | |
| "scaler_mean": make_json_serializable(getattr(sc, "mean_", None)), | |
| "scaler_scale": make_json_serializable(getattr(sc, "scale_", None)), | |
| "scaler_var": make_json_serializable(getattr(sc, "var_", None)), | |
| "scaler_n_samples_seen": make_json_serializable(getattr(sc, "n_samples_seen_", None)), | |
| "model_state_dict": best_state, # CPU tensors | |
| } | |
| # Save best fold weights + metadata per property | |
| if best_overall_payload is not None and "model_state_dict" in best_overall_payload: | |
| os.makedirs(BEST_WEIGHTS_DIR, exist_ok=True) | |
| prop_dir = os.path.join(BEST_WEIGHTS_DIR, _sanitize_name(pname)) | |
| os.makedirs(prop_dir, exist_ok=True) | |
| ckpt_bundle = {k: v for k, v in best_overall_payload.items() if k != "test_metrics"} | |
| ckpt_bundle["test_metrics"] = best_overall_payload["test_metrics"] | |
| torch.save(ckpt_bundle, os.path.join(prop_dir, "best_run_checkpoint.pt")) | |
| meta = {k: v for k, v in best_overall_payload.items() if k != "model_state_dict"} | |
| with open(os.path.join(prop_dir, "best_run_metadata.json"), "w") as fh: | |
| fh.write(json.dumps(make_json_serializable(meta), indent=2)) | |
| print(f"[BEST] Saved best fold for '{pname}' -> {prop_dir}") | |
| print(f"[BEST] best_run={best_overall_payload['best_run']} best_test_r2={best_overall_payload['test_metrics'].get('r2', None)}") | |
| # Aggregate metrics across folds | |
| if run_metrics: | |
| r2s = [x["r2"] for x in run_metrics] | |
| maes = [x["mae"] for x in run_metrics] | |
| rmses = [x["rmse"] for x in run_metrics] | |
| mses = [x["mse"] for x in run_metrics] | |
| agg = { | |
| "r2": {"mean": float(np.mean(r2s)), "std": float(np.std(r2s, ddof=0))}, | |
| "mae": {"mean": float(np.mean(maes)), "std": float(np.std(maes, ddof=0))}, | |
| "rmse": {"mean": float(np.mean(rmses)), "std": float(np.std(rmses, ddof=0))}, | |
| "mse": {"mean": float(np.mean(mses)), "std": float(np.std(mses, ddof=0))}, | |
| } | |
| print(f"[AGG] {pname} | r2={agg['r2']['mean']:.4f}±{agg['r2']['std']:.4f} mae={agg['mae']['mean']:.4f}±{agg['mae']['std']:.4f}") | |
| else: | |
| agg = None | |
| print(f"[AGG][WARN] No successful folds for '{pname}' (no aggregate computed).") | |
| all_results["per_property"][pname] = { | |
| "property_col": pcol, | |
| "n_samples": len(samples), | |
| "runs": run_records, | |
| "agg": agg | |
| } | |
| with open(output_file, "a") as fh: | |
| fh.write("AGG_PROPERTY: " + json.dumps(make_json_serializable({pname: agg})) + "\n") | |
| return all_results | |
| # ============================================================================= | |
| # Main | |
| # ============================================================================= | |
| def main(): | |
| # Start a fresh results file (back up old results if present) | |
| if os.path.exists(OUTPUT_RESULTS): | |
| backup = OUTPUT_RESULTS + ".bak" | |
| shutil.copy(OUTPUT_RESULTS, backup) | |
| print(f"[INIT] Existing results backed up: {backup}") | |
| open(OUTPUT_RESULTS, "w").close() | |
| print(f"[INIT] Writing results to: {OUTPUT_RESULTS}") | |
| # Load PolyInfo | |
| if not os.path.isfile(POLYINFO_PATH): | |
| raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}") | |
| polyinfo_raw = pd.read_csv(POLYINFO_PATH, engine="python") | |
| print(f"[DATA] Loaded PolyInfo: n_rows={len(polyinfo_raw)} n_cols={len(polyinfo_raw.columns)}") | |
| # Map requested properties to dataframe columns | |
| found = find_property_columns(polyinfo_raw.columns) | |
| prop_map = {req: col for req, col in found.items()} | |
| print(f"[COLMAP] Property-to-column map: {prop_map}") | |
| property_list = [] | |
| property_cols = [] | |
| for req in REQUESTED_PROPERTIES: | |
| col = prop_map.get(req) | |
| if col is None: | |
| print(f"[COLMAP][WARN] Could not find a column for '{req}'; skipping.") | |
| continue | |
| property_list.append(req) | |
| property_cols.append(col) | |
| overall = run_polyf_downstream(property_list, property_cols, polyinfo_raw, PRETRAINED_MULTIMODAL_DIR, OUTPUT_RESULTS) | |
| # Write final summary (aggregated per property) | |
| final_agg = {} | |
| if overall and "per_property" in overall: | |
| for pname, info in overall["per_property"].items(): | |
| final_agg[pname] = info.get("agg", None) | |
| with open(OUTPUT_RESULTS, "a") as fh: | |
| fh.write("\nFINAL_SUMMARY\n") | |
| fh.write(json.dumps(make_json_serializable(final_agg), indent=2)) | |
| fh.write("\n") | |
| print(f"\n Results appended to: {OUTPUT_RESULTS}") | |
| print(f" Best checkpoints saved under: {BEST_WEIGHTS_DIR}") | |
| if __name__ == "__main__": | |
| main() | |