PolyFusionAgent / Downstream Tasks /Property_Prediction.py
kaurm43's picture
Update Downstream Tasks/Property_Prediction.py
11e4265 verified
import os
import random
import time
from pathlib import Path
import math
import json
import shutil
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split, KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import sys
import csv
import copy
from typing import List, Dict, Optional, Tuple, Any
# Increase CSV field size limit
csv.field_size_limit(sys.maxsize)
# =============================================================================
# Imports: Shared encoders/helpers from PolyFusion
# =============================================================================
from PolyFusion.GINE import GineEncoder, match_edge_attr_to_index, safe_get
from PolyFusion.SchNet import NodeSchNetWrapper
from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder
from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
# =============================================================================
# Configuration
# =============================================================================
BASE_DIR = "/path/to/Polymer_Foundational_Model"
POLYINFO_PATH = "/path/to/polyinfo_with_modalities.csv"
# Pretrained encoder directories
PRETRAINED_MULTIMODAL_DIR = "/path/to/multimodal_output/best"
BEST_GINE_DIR = "/path/to/gin_output/best"
BEST_SCHNET_DIR = "/path/to/schnet_output/best"
BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
BEST_PSMILES_DIR = "/path/to/polybert_output/best"
# Output log file (per-run json lines + per-property aggregated summary)
OUTPUT_RESULTS = "/path/to/multimodal_downstream_results.txt"
# Directory to save best-performing checkpoint bundle per property (best CV run)
BEST_WEIGHTS_DIR = "/path/to/multimodal_downstream_bestweights"
# -----------------------------------------------------------------------------
# Model sizes / dims
# -----------------------------------------------------------------------------
MAX_ATOMIC_Z = 85
MASK_ATOM_ID = MAX_ATOMIC_Z + 1
# GINE
NODE_EMB_DIM = 300
EDGE_EMB_DIM = 300
NUM_GNN_LAYERS = 5
# SchNet
SCHNET_NUM_GAUSSIANS = 50
SCHNET_NUM_INTERACTIONS = 6
SCHNET_CUTOFF = 10.0
SCHNET_MAX_NEIGHBORS = 64
SCHNET_HIDDEN = 600
# Fingerprints
FP_LENGTH = 2048
MASK_TOKEN_ID_FP = 2
VOCAB_SIZE_FP = 3
# Contrastive embedding dim
CL_EMB_DIM = 600
# PSMILES/DeBERTa
DEBERTA_HIDDEN = 600
PSMILES_MAX_LEN = 128
# -----------------------------------------------------------------------------
# Fusion + regression head hyperparameters
# -----------------------------------------------------------------------------
POLYF_EMB_DIM = 600
POLYF_ATTN_HEADS = 8
POLYF_DROPOUT = 0.1
POLYF_FF_MULT = 4 # FFN hidden = 4*d
# -----------------------------------------------------------------------------
# Fine-tuning parameters (single-task per property)
# -----------------------------------------------------------------------------
MAX_LEN = 128
BATCH_SIZE = 32
NUM_EPOCHS = 100
PATIENCE = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Properties to evaluate
REQUESTED_PROPERTIES = [
"density",
"glass transition",
"melting",
"thermal decomposition"
]
# True K-fold evaluation to match "fivefold per property"
NUM_RUNS = 5
TEST_SIZE = 0.10
VAL_SIZE_WITHIN_TRAINVAL = 0.10 # fraction of trainval reserved for val split
# Duplicate aggregation (noise reduction) key preference order
AGG_KEYS_PREFERENCE = ["polymer_id", "PolymerID", "poly_id", "psmiles", "smiles", "canonical_smiles"]
# =============================================================================
# Utilities
# =============================================================================
def set_seed(seed: int):
"""Set all relevant RNG seeds for reproducible folds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Deterministic settings: reproducible but may reduce throughput.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def make_json_serializable(obj):
"""Convert common numpy/torch/pandas objects into JSON-safe Python types."""
if isinstance(obj, dict):
return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
return [make_json_serializable(x) for x in obj]
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
try:
return obj.item()
except Exception:
return float(obj)
if isinstance(obj, torch.Tensor):
try:
return obj.detach().cpu().tolist()
except Exception:
return None
if isinstance(obj, (pd.Timestamp, pd.Timedelta)):
return str(obj)
try:
if isinstance(obj, (float, int, str, bool, type(None))):
return obj
except Exception:
pass
return obj
def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_state: dict):
"""
Print a concise load report:
- how many checkpoint keys exist
- how many model keys exist
- how many keys will be loaded (intersection with matching shapes)
- common reasons for skipped keys
"""
n_ckpt = len(full_state)
n_model = len(model_state)
n_loaded = len(filtered_state)
missing_in_model = [k for k in full_state.keys() if k not in model_state]
shape_mismatch = [
k for k in full_state.keys()
if k in model_state and hasattr(full_state[k], "shape")
and tuple(full_state[k].shape) != tuple(model_state[k].shape)
]
print("\n[CKPT LOAD SUMMARY]")
print(f" ckpt keys: {n_ckpt}")
print(f" model keys: {n_model}")
print(f" loaded keys: {n_loaded}")
print(f" skipped (not in model): {len(missing_in_model)}")
print(f" skipped (shape mismatch): {len(shape_mismatch)}")
if missing_in_model:
print(" examples skipped (not in model):", missing_in_model[:10])
if shape_mismatch:
print(" examples skipped (shape mismatch):")
for k in shape_mismatch[:10]:
print(f" {k}: ckpt={tuple(full_state[k].shape)} model={tuple(model_state[k].shape)}")
print("")
def find_property_columns(columns):
"""
Robust property column matching with guardrails:
- Prefer word-level (token) matches over substring matches.
- For 'density', avoid confusing with 'cohesive energy density' columns.
- Log chosen column and competing candidates when ambiguous.
"""
lowered = {c.lower(): c for c in columns}
found = {}
for req in REQUESTED_PROPERTIES:
req_low = req.lower().strip()
exact = None
# Pass 1: token-level exactness (safer than substring match)
for c_low, c_orig in lowered.items():
tokens = set(c_low.replace('_', ' ').split())
if req_low in tokens or c_low == req_low:
if req_low == "density" and ("cohesive" in c_low or "cohesive energy" in c_low):
continue
exact = c_orig
break
if exact is not None:
found[req] = exact
continue
# Pass 2: substring match as fallback
candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low]
if req_low == "density":
candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()]
if len(candidates) == 1:
found[req] = candidates[0]
else:
chosen = candidates[0] if candidates else None
found[req] = chosen
print(f"[COLMAP] Requested '{req}' -> chosen column: {chosen}")
if candidates:
print(f"[COLMAP] Candidates for '{req}': {candidates}")
else:
print(f"[COLMAP][WARN] No candidates found for '{req}' using substring search.")
return found
def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
"""Pick the most stable identifier available for duplicate aggregation."""
for k in AGG_KEYS_PREFERENCE:
if k in df.columns:
return k
return None
def aggregate_polyinfo_duplicates(df: pd.DataFrame, modality_cols: List[str], property_cols: List[str]) -> pd.DataFrame:
"""
Optional noise reduction: group duplicate polymer entries and average properties.
- Modalities are taken as "first" (they should be consistent per polymer key).
- Properties are averaged (mean).
"""
key = choose_aggregation_key(df)
if key is None:
print("[AGG] No aggregation key found; skipping duplicate aggregation.")
return df
df2 = df.copy()
df2[key] = df2[key].astype(str)
df2 = df2[df2[key].str.strip() != ""].copy()
if len(df2) == 0:
print("[AGG] Aggregation key exists but is empty; skipping duplicate aggregation.")
return df
agg_dict = {}
for mc in modality_cols:
if mc in df2.columns:
agg_dict[mc] = "first"
for pc in property_cols:
if pc in df2.columns:
agg_dict[pc] = "mean"
grouped = df2.groupby(key, as_index=False).agg(agg_dict)
print(f"[AGG] Grouped by '{key}': {len(df)} rows -> {len(grouped)} unique keys")
return grouped
def _sanitize_name(s: str) -> str:
"""Create a filesystem-safe name for property directories."""
s2 = str(s).strip().lower()
keep = []
for ch in s2:
if ch.isalnum():
keep.append(ch)
elif ch in (" ", "-", "_", "."):
keep.append("_")
else:
keep.append("_")
out = "".join(keep)
while "__" in out:
out = out.replace("__", "_")
out = out.strip("_")
return out or "property"
# =============================================================================
# Multimodal backbone: encode + project + modality-aware fusion
# =============================================================================
class MultimodalContrastiveModel(nn.Module):
"""
Multimodal encoder wrapper:
1) Runs each available modality encoder:
- GINE (graph)
- SchNet (3D geometry)
- Transformer FP encoder (Morgan bit sequence)
- DeBERTa-based PSMILES encoder (sequence)
2) Projects each modality embedding to a shared dim (emb_dim).
3) Normalizes each modality embedding (L2), drops out, then fuses via
a masked mean across modalities that are present for each sample.
4) Normalizes the final fused embedding (L2).
Expected downstream usage:
z = model(batch_mods, modality_mask=modality_mask) # (B, emb_dim)
"""
def __init__(
self,
gine_encoder: Optional[nn.Module] = None,
schnet_encoder: Optional[nn.Module] = None,
fp_encoder: Optional[nn.Module] = None,
psmiles_encoder: Optional[nn.Module] = None,
*,
emb_dim: int = CL_EMB_DIM,
modalities: Optional[List[str]] = None,
dropout: float = 0.1,
psmiles_tokenizer: Optional[Any] = None,
):
super().__init__()
self.gine = gine_encoder
self.schnet = schnet_encoder
self.fp = fp_encoder
self.psmiles = psmiles_encoder
self.psm_tok = psmiles_tokenizer
self.emb_dim = int(emb_dim)
self.out_dim = self.emb_dim
self.dropout = nn.Dropout(float(dropout))
# Determine which modalities are enabled
if modalities is None:
mods = []
if self.gine is not None:
mods.append("gine")
if self.schnet is not None:
mods.append("schnet")
if self.fp is not None:
mods.append("fp")
if self.psmiles is not None:
mods.append("psmiles")
self.modalities = mods
else:
self.modalities = [m for m in modalities if m in ("gine", "schnet", "fp", "psmiles")]
# Projection heads into shared embedding space
self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None
self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None
self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None
# Infer PSMILES hidden size if possible; fallback to DEBERTA_HIDDEN
psm_in = None
if self.psmiles is not None:
if hasattr(self.psmiles, "out_dim"):
try:
psm_in = int(self.psmiles.out_dim)
except Exception:
psm_in = None
if psm_in is None and hasattr(self.psmiles, "model") and hasattr(self.psmiles.model, "config"):
try:
psm_in = int(self.psmiles.model.config.hidden_size)
except Exception:
psm_in = None
if psm_in is None:
psm_in = int(DEBERTA_HIDDEN)
self.proj_psmiles = nn.Linear(psm_in, self.emb_dim) if (self.psmiles is not None) else None
def freeze_cl_encoders(self):
"""Freeze all modality encoders (optional for evaluation-only usage)."""
for enc in (self.gine, self.schnet, self.fp, self.psmiles):
if enc is None:
continue
enc.eval()
for p in enc.parameters():
p.requires_grad = False
def _masked_mean_combine(self, zs: List[torch.Tensor], masks: List[torch.Tensor]) -> torch.Tensor:
"""
Compute sample-wise mean over available modalities.
zs: list of modality embeddings, each (B,D)
masks: list of modality presence masks, each (B,) bool
returns: (B,D)
"""
if not zs:
device = next(self.parameters()).device
return torch.zeros((1, self.emb_dim), device=device)
device = zs[0].device
B = zs[0].size(0)
sum_z = torch.zeros((B, self.emb_dim), device=device)
count = torch.zeros((B, 1), device=device)
for z, m in zip(zs, masks):
m = m.to(device).view(B, 1).float()
sum_z = sum_z + z * m
count = count + m
count = count.clamp(min=1.0)
return sum_z / count
def forward(self, batch_mods: dict, modality_mask: Optional[dict] = None) -> torch.Tensor:
"""
batch_mods keys: 'gine', 'schnet', 'fp', 'psmiles'
modality_mask: dict {modality_name: (B,) bool} describing presence.
"""
device = next(self.parameters()).device
zs = []
ms = []
# Infer batch size B
B = None
if modality_mask is not None:
for _, v in modality_mask.items():
if isinstance(v, torch.Tensor) and v.numel() > 0:
B = int(v.size(0))
break
if B is None:
if "fp" in batch_mods and batch_mods["fp"] is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor):
B = int(batch_mods["fp"]["input_ids"].size(0))
elif "psmiles" in batch_mods and batch_mods["psmiles"] is not None and isinstance(batch_mods["psmiles"].get("input_ids", None), torch.Tensor):
B = int(batch_mods["psmiles"]["input_ids"].size(0))
if B is None:
return torch.zeros((1, self.emb_dim), device=device)
def _get_mask(name: str) -> torch.Tensor:
if modality_mask is not None and name in modality_mask and isinstance(modality_mask[name], torch.Tensor):
return modality_mask[name].to(device).bool()
return torch.ones((B,), device=device, dtype=torch.bool)
# -------------------------
# GINE (graph modality)
# -------------------------
if "gine" in self.modalities and self.gine is not None and batch_mods.get("gine", None) is not None:
g = batch_mods["gine"]
if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0:
emb_g = self.gine(
g["z"].to(device),
g.get("chirality", None).to(device) if isinstance(g.get("chirality", None), torch.Tensor) else None,
g.get("formal_charge", None).to(device) if isinstance(g.get("formal_charge", None), torch.Tensor) else None,
g.get("edge_index", torch.empty((2, 0), dtype=torch.long)).to(device) if isinstance(g.get("edge_index", None), torch.Tensor) else torch.empty((2, 0), dtype=torch.long, device=device),
g.get("edge_attr", torch.zeros((0, 3), dtype=torch.float)).to(device) if isinstance(g.get("edge_attr", None), torch.Tensor) else torch.zeros((0, 3), dtype=torch.float, device=device),
g.get("batch", None).to(device) if isinstance(g.get("batch", None), torch.Tensor) else None
)
z = self.proj_gine(emb_g) if self.proj_gine is not None else emb_g
z = F.normalize(z, dim=-1)
z = self.dropout(z)
zs.append(z)
ms.append(_get_mask("gine"))
# -------------------------
# SchNet (3D geometry)
# -------------------------
if "schnet" in self.modalities and self.schnet is not None and batch_mods.get("schnet", None) is not None:
s = batch_mods["schnet"]
if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0:
emb_s = self.schnet(
s["z"].to(device),
s["pos"].to(device) if isinstance(s.get("pos", None), torch.Tensor) else torch.zeros((0, 3), device=device),
s.get("batch", None).to(device) if isinstance(s.get("batch", None), torch.Tensor) else None
)
z = self.proj_schnet(emb_s) if self.proj_schnet is not None else emb_s
z = F.normalize(z, dim=-1)
z = self.dropout(z)
zs.append(z)
ms.append(_get_mask("schnet"))
# -------------------------
# Fingerprint modality
# -------------------------
if "fp" in self.modalities and self.fp is not None and batch_mods.get("fp", None) is not None:
f = batch_mods["fp"]
if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0:
emb_f = self.fp(
f["input_ids"].to(device),
f.get("attention_mask", None).to(device) if isinstance(f.get("attention_mask", None), torch.Tensor) else None
)
z = self.proj_fp(emb_f) if self.proj_fp is not None else emb_f
z = F.normalize(z, dim=-1)
z = self.dropout(z)
zs.append(z)
ms.append(_get_mask("fp"))
# -------------------------
# PSMILES text modality
# -------------------------
if "psmiles" in self.modalities and self.psmiles is not None and batch_mods.get("psmiles", None) is not None:
p = batch_mods["psmiles"]
if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0:
emb_p = self.psmiles(
p["input_ids"].to(device),
p.get("attention_mask", None).to(device) if isinstance(p.get("attention_mask", None), torch.Tensor) else None
)
z = self.proj_psmiles(emb_p) if self.proj_psmiles is not None else emb_p
z = F.normalize(z, dim=-1)
z = self.dropout(z)
zs.append(z)
ms.append(_get_mask("psmiles"))
# Fuse and normalize
if not zs:
return torch.zeros((B, self.emb_dim), device=device)
z = self._masked_mean_combine(zs, ms)
z = F.normalize(z, dim=-1)
return z
@torch.no_grad()
def encode_psmiles(
self,
psmiles_list: List[str],
max_len: int = PSMILES_MAX_LEN,
batch_size: int = 64,
device: str = DEVICE
) -> np.ndarray:
"""
PSMILES embeddings
"""
self.eval()
if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None:
raise RuntimeError("PSMILES tokenizer/encoder/projection not available.")
outs = []
for i in range(0, len(psmiles_list), batch_size):
chunk = [str(x) for x in psmiles_list[i:i + batch_size]]
enc = self.psm_tok(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
input_ids = enc["input_ids"].to(device)
attn = enc["attention_mask"].to(device).bool()
emb_p = self.psmiles(input_ids, attn)
z = F.normalize(self.proj_psmiles(emb_p), dim=-1)
outs.append(z.detach().cpu().numpy())
return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
@torch.no_grad()
def encode_multimodal(
self,
records: List[dict],
batch_size: int = 32,
device: str = DEVICE
) -> np.ndarray:
"""
Convenience: multimodal embedding for records carrying:
- graph, geometry, fingerprints, psmiles
Missing modalities are handled sample-wise via modality masking.
"""
self.eval()
dev = torch.device(device)
self.to(dev)
outs = []
for i in range(0, len(records), batch_size):
chunk = records[i:i + batch_size]
# PSMILES batch
psmiles_texts = [str(r.get("psmiles", "")) for r in chunk]
p_enc = None
if self.psm_tok is not None:
p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt")
# FP batch (always stack; missing handled by attention_mask downstream)
fp_ids, fp_attn = [], []
for r in chunk:
f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH)
fp_ids.append(f["input_ids"])
fp_attn.append(f["attention_mask"])
fp_ids = torch.stack(fp_ids, dim=0)
fp_attn = torch.stack(fp_attn, dim=0)
# GINE + SchNet packed batching
gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []}
node_offset = 0
for bi, r in enumerate(chunk):
g = _parse_graph_for_gine(r.get("graph", None))
if g is None or g["z"].numel() == 0:
continue
n = g["z"].size(0)
gine_all["z"].append(g["z"])
gine_all["chirality"].append(g["chirality"])
gine_all["formal_charge"].append(g["formal_charge"])
gine_all["batch"].append(torch.full((n,), bi, dtype=torch.long))
ei = g["edge_index"]
ea = g["edge_attr"]
if ei is not None and ei.numel() > 0:
gine_all["edge_index"].append(ei + node_offset)
gine_all["edge_attr"].append(ea)
node_offset += n
gine_batch = None
if len(gine_all["z"]) > 0:
z_b = torch.cat(gine_all["z"], dim=0)
ch_b = torch.cat(gine_all["chirality"], dim=0)
fc_b = torch.cat(gine_all["formal_charge"], dim=0)
b_b = torch.cat(gine_all["batch"], dim=0)
if len(gine_all["edge_index"]) > 0:
ei_b = torch.cat(gine_all["edge_index"], dim=1)
ea_b = torch.cat(gine_all["edge_attr"], dim=0)
else:
ei_b = torch.empty((2, 0), dtype=torch.long)
ea_b = torch.zeros((0, 3), dtype=torch.float)
gine_batch = {"z": z_b, "chirality": ch_b, "formal_charge": fc_b, "edge_index": ei_b, "edge_attr": ea_b, "batch": b_b}
sch_all_z, sch_all_pos, sch_all_batch = [], [], []
for bi, r in enumerate(chunk):
s = _parse_geometry_for_schnet(r.get("geometry", None))
if s is None or s["z"].numel() == 0:
continue
n = s["z"].size(0)
sch_all_z.append(s["z"])
sch_all_pos.append(s["pos"])
sch_all_batch.append(torch.full((n,), bi, dtype=torch.long))
schnet_batch = None
if len(sch_all_z) > 0:
schnet_batch = {
"z": torch.cat(sch_all_z, dim=0),
"pos": torch.cat(sch_all_pos, dim=0),
"batch": torch.cat(sch_all_batch, dim=0),
}
batch_mods = {
"gine": gine_batch,
"schnet": schnet_batch,
"fp": {"input_ids": fp_ids, "attention_mask": fp_attn},
"psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None
}
# NOTE: This script uses forward() as the encoder entry point.
z = self.forward(batch_mods, modality_mask=None)
outs.append(z.detach().cpu().numpy())
return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
# =============================================================================
# Tokenizer setup
# =============================================================================
SPM_MODEL = "/path/to/spm.model"
tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
# =============================================================================
# Dataset: single-task property prediction (with modality parsing)
# =============================================================================
class PolymerPropertyDataset(Dataset):
"""
Dataset that prepares one sample with up to four modalities:
- graph (for GINE)
- geometry (for SchNet)
- fingerprints (for FP transformer)
- psmiles text (for DeBERTa encoder)
Target is a single scalar per sample (already scaled externally).
"""
def __init__(self, data_list, tokenizer, max_length=128):
self.data_list = data_list
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
data = self.data_list[idx]
# ---------------------------------------------------------------------
# Graph -> GINE tensors (robust parsing of stored JSON fields)
# ---------------------------------------------------------------------
gine_data = None
if 'graph' in data and data['graph']:
try:
graph_field = json.loads(data['graph']) if isinstance(data['graph'], str) else data['graph']
node_features = safe_get(graph_field, "node_features", None)
if node_features:
atomic_nums = []
chirality_vals = []
formal_charges = []
for nf in node_features:
an = safe_get(nf, "atomic_num", None)
if an is None:
an = safe_get(nf, "atomic_number", 0)
ch = safe_get(nf, "chirality", 0)
fc = safe_get(nf, "formal_charge", 0)
try:
atomic_nums.append(int(an))
except Exception:
atomic_nums.append(0)
chirality_vals.append(float(ch))
formal_charges.append(float(fc))
edge_indices_raw = safe_get(graph_field, "edge_indices", None)
edge_features_raw = safe_get(graph_field, "edge_features", None)
edge_index = None
edge_attr = None
# Fallback: adjacency matrix if edge_indices missing
if edge_indices_raw is None:
adj_mat = safe_get(graph_field, "adjacency_matrix", None)
if adj_mat:
srcs = []
dsts = []
for i_r, row_adj in enumerate(adj_mat):
for j, val2 in enumerate(row_adj):
if val2:
srcs.append(i_r)
dsts.append(j)
if len(srcs) > 0:
edge_index = [srcs, dsts]
E = len(srcs)
edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
else:
# edge_indices can be [[srcs],[dsts]] or list of pairs
srcs, dsts = [], []
if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0:
if isinstance(edge_indices_raw[0], list):
first = edge_indices_raw[0]
if len(first) == 2 and isinstance(first[0], int):
try:
srcs = [int(p[0]) for p in edge_indices_raw]
dsts = [int(p[1]) for p in edge_indices_raw]
except Exception:
srcs, dsts = [], []
else:
try:
srcs = [int(x) for x in edge_indices_raw[0]]
dsts = [int(x) for x in edge_indices_raw[1]]
except Exception:
srcs, dsts = [], []
else:
try:
srcs = [int(x) for x in edge_indices_raw[0]]
dsts = [int(x) for x in edge_indices_raw[1]]
except Exception:
srcs, dsts = [], []
if len(srcs) > 0:
edge_index = [srcs, dsts]
# edge_features: attempt to map known fields; otherwise zeros
if edge_features_raw and isinstance(edge_features_raw, list):
bond_types = []
stereos = []
is_conjs = []
for ef in edge_features_raw:
bt = safe_get(ef, "bond_type", 0)
st = safe_get(ef, "stereo", 0)
ic = safe_get(ef, "is_conjugated", False)
bond_types.append(float(bt))
stereos.append(float(st))
is_conjs.append(float(1.0 if ic else 0.0))
edge_attr = list(zip(bond_types, stereos, is_conjs))
else:
E = len(srcs)
edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
if edge_index is not None:
gine_data = {
'z': torch.tensor(atomic_nums, dtype=torch.long),
'chirality': torch.tensor(chirality_vals, dtype=torch.float),
'formal_charge': torch.tensor(formal_charges, dtype=torch.float),
'edge_index': torch.tensor(edge_index, dtype=torch.long),
'edge_attr': torch.tensor(edge_attr, dtype=torch.float)
}
except Exception:
gine_data = None
# ---------------------------------------------------------------------
# Geometry -> SchNet tensors (best conformer)
# ---------------------------------------------------------------------
schnet_data = None
if 'geometry' in data and data['geometry']:
try:
geom = json.loads(data['geometry']) if isinstance(data['geometry'], str) else data['geometry']
conf = geom.get("best_conformer") if isinstance(geom, dict) else None
if conf:
atomic = conf.get("atomic_numbers", [])
coords = conf.get("coordinates", [])
if len(atomic) == len(coords) and len(atomic) > 0:
schnet_data = {
'z': torch.tensor(atomic, dtype=torch.long),
'pos': torch.tensor(coords, dtype=torch.float)
}
except Exception:
schnet_data = None
# ---------------------------------------------------------------------
# Fingerprints -> FP transformer inputs (bit sequence)
# ---------------------------------------------------------------------
fp_data = None
if 'fingerprints' in data and data['fingerprints']:
try:
fpval = data['fingerprints']
if fpval is not None and not (isinstance(fpval, str) and fpval.strip() == ""):
try:
fp_json = json.loads(fpval) if isinstance(fpval, str) else fpval
except Exception:
try:
fp_json = json.loads(str(fpval).replace("'", '"'))
except Exception:
parts = [p.strip().strip('"').strip("'") for p in str(fpval).split(",")]
bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]]
if len(bits) < FP_LENGTH:
bits += [0] * (FP_LENGTH - len(bits))
fp_json = bits
if isinstance(fp_json, dict):
bits = safe_get(fp_json, "morgan_r3_bits", None)
if bits is None:
bits = [0] * FP_LENGTH
else:
normalized = []
for b in bits:
if isinstance(b, str):
b_clean = b.strip().strip('"').strip("'")
normalized.append(1 if b_clean in ("1", "True", "true") else 0)
elif isinstance(b, (int, np.integer)):
normalized.append(1 if int(b) != 0 else 0)
else:
normalized.append(0)
if len(normalized) >= FP_LENGTH:
break
if len(normalized) < FP_LENGTH:
normalized.extend([0] * (FP_LENGTH - len(normalized)))
bits = normalized[:FP_LENGTH]
elif isinstance(fp_json, list):
bits = fp_json[:FP_LENGTH]
if len(bits) < FP_LENGTH:
bits += [0] * (FP_LENGTH - len(bits))
else:
bits = [0] * FP_LENGTH
fp_data = {
'input_ids': torch.tensor(bits, dtype=torch.long),
'attention_mask': torch.ones(FP_LENGTH, dtype=torch.bool)
}
except Exception:
fp_data = None
# ---------------------------------------------------------------------
# PSMILES -> DeBERTa tokenizer inputs
# ---------------------------------------------------------------------
psmiles_data = None
if 'psmiles' in data and data['psmiles'] and self.tokenizer is not None:
try:
s = str(data['psmiles'])
enc = self.tokenizer(
s,
truncation=True,
padding="max_length",
max_length=PSMILES_MAX_LEN
)
psmiles_data = {
'input_ids': torch.tensor(enc["input_ids"], dtype=torch.long),
'attention_mask': torch.tensor(enc["attention_mask"], dtype=torch.bool)
}
except Exception:
psmiles_data = None
# ---------------------------------------------------------------------
# Fill defaults for missing modalities
# ---------------------------------------------------------------------
if gine_data is None:
gine_data = {
'z': torch.tensor([], dtype=torch.long),
'chirality': torch.tensor([], dtype=torch.float),
'formal_charge': torch.tensor([], dtype=torch.float),
'edge_index': torch.tensor([[], []], dtype=torch.long),
'edge_attr': torch.zeros((0, 3), dtype=torch.float)
}
if schnet_data is None:
schnet_data = {
'z': torch.tensor([], dtype=torch.long),
'pos': torch.tensor([], dtype=torch.float)
}
if fp_data is None:
fp_data = {
'input_ids': torch.zeros(FP_LENGTH, dtype=torch.long),
'attention_mask': torch.zeros(FP_LENGTH, dtype=torch.bool)
}
if psmiles_data is None:
psmiles_data = {
'input_ids': torch.zeros(PSMILES_MAX_LEN, dtype=torch.long),
'attention_mask': torch.zeros(PSMILES_MAX_LEN, dtype=torch.bool)
}
# Single-task regression target (already scaled)
target_scaled = float(data.get("target_scaled", 0.0))
return {
'gine': gine_data,
'schnet': schnet_data,
'fp': fp_data,
'psmiles': psmiles_data,
'target': torch.tensor(target_scaled, dtype=torch.float32),
}
# =============================================================================
# Collate: pack variable-sized graph/3D into batch tensors + modality masks
# =============================================================================
def multimodal_collate_fn(batch):
"""
Collate samples into a single minibatch.
- GINE: concatenate nodes across samples and build a `batch` vector.
- SchNet: concatenate atoms/coords across samples and build a `batch` vector.
- FP/PSMILES: stack to (B, L).
- modality_mask: per-sample boolean flags indicating availability.
"""
B = len(batch)
# -------------------------
# GINE packing
# -------------------------
all_z = []
all_ch = []
all_fc = []
all_edge_index = []
all_edge_attr = []
batch_mapping = []
node_offset = 0
gine_present = []
for i, item in enumerate(batch):
g = item["gine"]
z = g["z"]
n = z.size(0)
gine_present.append(bool(n > 0))
all_z.append(z)
all_ch.append(g["chirality"])
all_fc.append(g["formal_charge"])
batch_mapping.append(torch.full((n,), i, dtype=torch.long))
if g["edge_index"] is not None and g["edge_index"].numel() > 0:
ei_offset = g["edge_index"] + node_offset
all_edge_index.append(ei_offset)
ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
all_edge_attr.append(ea)
node_offset += n
if len(all_z) == 0:
z_batch = torch.tensor([], dtype=torch.long)
ch_batch = torch.tensor([], dtype=torch.float)
fc_batch = torch.tensor([], dtype=torch.float)
batch_batch = torch.tensor([], dtype=torch.long)
edge_index_batched = torch.empty((2, 0), dtype=torch.long)
edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
else:
z_batch = torch.cat(all_z, dim=0)
ch_batch = torch.cat(all_ch, dim=0)
fc_batch = torch.cat(all_fc, dim=0)
batch_batch = torch.cat(batch_mapping, dim=0) if len(batch_mapping) > 0 else torch.tensor([], dtype=torch.long)
if len(all_edge_index) > 0:
edge_index_batched = torch.cat(all_edge_index, dim=1)
edge_attr_batched = torch.cat(all_edge_attr, dim=0) if len(all_edge_attr) > 0 else torch.zeros((0, 3), dtype=torch.float)
else:
edge_index_batched = torch.empty((2, 0), dtype=torch.long)
edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
# -------------------------
# SchNet packing
# -------------------------
all_sz = []
all_pos = []
schnet_batch = []
schnet_present = [False] * B
for i, item in enumerate(batch):
s = item["schnet"]
s_z = s["z"]
s_pos = s["pos"]
if s_z.numel() == 0:
continue
schnet_present[i] = True
all_sz.append(s_z)
all_pos.append(s_pos)
schnet_batch.append(torch.full((s_z.size(0),), i, dtype=torch.long))
if len(all_sz) == 0:
s_z_batch = torch.tensor([], dtype=torch.long)
s_pos_batch = torch.tensor([], dtype=torch.float)
s_batch_batch = torch.tensor([], dtype=torch.long)
else:
s_z_batch = torch.cat(all_sz, dim=0)
s_pos_batch = torch.cat(all_pos, dim=0)
s_batch_batch = torch.cat(schnet_batch, dim=0)
# -------------------------
# FP stacking
# -------------------------
fp_ids = torch.stack([item["fp"]["input_ids"] for item in batch], dim=0)
fp_attn = torch.stack([item["fp"]["attention_mask"] for item in batch], dim=0)
fp_present = (fp_attn.sum(dim=1) > 0).cpu().numpy().tolist()
# -------------------------
# PSMILES stacking
# -------------------------
p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch], dim=0)
p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch], dim=0)
psmiles_present = (p_attn.sum(dim=1) > 0).cpu().numpy().tolist()
# Target
target = torch.stack([item["target"] for item in batch], dim=0) # (B,)
# Presence mask for fusion (per-sample modality availability)
modality_mask = {
"gine": torch.tensor(gine_present, dtype=torch.bool),
"schnet": torch.tensor(schnet_present, dtype=torch.bool),
"fp": torch.tensor(fp_present, dtype=torch.bool),
"psmiles": torch.tensor(psmiles_present, dtype=torch.bool),
}
return {
"gine": {
"z": z_batch,
"chirality": ch_batch,
"formal_charge": fc_batch,
"edge_index": edge_index_batched,
"edge_attr": edge_attr_batched,
"batch": batch_batch
},
"schnet": {
"z": s_z_batch,
"pos": s_pos_batch,
"batch": s_batch_batch
},
"fp": {
"input_ids": fp_ids,
"attention_mask": fp_attn
},
"psmiles": {
"input_ids": p_ids,
"attention_mask": p_attn
},
"target": target,
"modality_mask": modality_mask
}
# =============================================================================
# Single-task regressor head
# =============================================================================
class PolyFPropertyRegressor(nn.Module):
"""
Simple MLP head on top of the multimodal fused embedding.
Predicts one scalar (scaled target) per sample.
"""
def __init__(self, polyf_model: MultimodalContrastiveModel, emb_dim: int = POLYF_EMB_DIM, dropout: float = 0.1):
super().__init__()
self.polyf = polyf_model
self.head = nn.Sequential(
nn.Linear(emb_dim, emb_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(emb_dim // 2, 1)
)
def forward(self, batch_mods, modality_mask=None):
emb = self.polyf(batch_mods, modality_mask=modality_mask) # (B,d)
y = self.head(emb).squeeze(-1) # (B,)
return y
# =============================================================================
# Training / evaluation helpers
# =============================================================================
def compute_metrics(y_true, y_pred):
"""Compute standard regression metrics in original units."""
mse = mean_squared_error(y_true, y_pred)
rmse = math.sqrt(mse)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
return {"mse": float(mse), "rmse": float(rmse), "mae": float(mae), "r2": float(r2)}
def train_one_epoch(model, dataloader, optimizer, device):
"""One epoch of supervised regression training (MSE loss)."""
model.train()
total_loss = 0.0
total_n = 0
for batch in dataloader:
# Move nested batch dict to device
for k in batch:
if k == "target":
batch[k] = batch[k].to(device)
elif k == "modality_mask":
for mk in batch[k]:
if isinstance(batch[k][mk], torch.Tensor):
batch[k][mk] = batch[k][mk].to(device)
else:
for subk in batch[k]:
if isinstance(batch[k][subk], torch.Tensor):
batch[k][subk] = batch[k][subk].to(device)
y = batch["target"] # (B,)
modality_mask = batch.get("modality_mask", None)
batch_mods = {k: v for k, v in batch.items() if k not in ("target", "modality_mask")}
pred = model(batch_mods, modality_mask=modality_mask)
loss = F.mse_loss(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
bs = int(y.size(0))
total_loss += float(loss.item()) * bs
total_n += bs
return total_loss / max(1, total_n)
@torch.no_grad()
def evaluate(model, dataloader, device):
"""
Evaluate on a dataloader:
- returns avg loss, predicted scaled values, true scaled values
"""
model.eval()
preds = []
trues = []
total_loss = 0.0
total_n = 0
for batch in dataloader:
# Move nested batch dict to device
for k in batch:
if k == "target":
batch[k] = batch[k].to(device)
elif k == "modality_mask":
for mk in batch[k]:
if isinstance(batch[k][mk], torch.Tensor):
batch[k][mk] = batch[k][mk].to(device)
else:
for subk in batch[k]:
if isinstance(batch[k][subk], torch.Tensor):
batch[k][subk] = batch[k][subk].to(device)
y = batch["target"]
modality_mask = batch.get("modality_mask", None)
batch_mods = {k: v for k, v in batch.items() if k not in ("target", "modality_mask")}
pred = model(batch_mods, modality_mask=modality_mask)
loss = F.mse_loss(pred, y)
bs = int(y.size(0))
total_loss += float(loss.item()) * bs
total_n += bs
preds.append(pred.detach().cpu().numpy())
trues.append(y.detach().cpu().numpy())
if total_n == 0:
return None, None, None
preds = np.concatenate(preds, axis=0)
trues = np.concatenate(trues, axis=0)
avg_loss = total_loss / max(1, total_n)
return float(avg_loss), preds, trues
# =============================================================================
# Pretrained loading helpers
# =============================================================================
def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveModel:
"""
Construct modality encoders and load any available pretrained weights:
- modality-specific checkpoints (BEST_*_DIR)
- full multimodal checkpoint from `pretrained_path/pytorch_model.bin`
Returns a ready-to-fine-tune MultimodalContrastiveModel.
"""
# -------------------------
# GINE encoder
# -------------------------
gine_encoder = GineEncoder(
node_emb_dim=NODE_EMB_DIM,
edge_emb_dim=EDGE_EMB_DIM,
num_layers=NUM_GNN_LAYERS,
max_atomic_z=MAX_ATOMIC_Z
)
gine_ckpt = os.path.join(BEST_GINE_DIR, "pytorch_model.bin")
if os.path.exists(gine_ckpt):
try:
gine_encoder.load_state_dict(torch.load(gine_ckpt, map_location="cpu"), strict=False)
print(f"[LOAD] GINE weights: {gine_ckpt}")
except Exception as e:
print(f"[LOAD][WARN] Could not load GINE weights: {e}")
# -------------------------
# SchNet encoder
# -------------------------
schnet_encoder = NodeSchNetWrapper(
hidden_channels=SCHNET_HIDDEN,
num_interactions=SCHNET_NUM_INTERACTIONS,
num_gaussians=SCHNET_NUM_GAUSSIANS,
cutoff=SCHNET_CUTOFF,
max_num_neighbors=SCHNET_MAX_NEIGHBORS
)
sch_ckpt = os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin")
if os.path.exists(sch_ckpt):
try:
schnet_encoder.load_state_dict(torch.load(sch_ckpt, map_location="cpu"), strict=False)
print(f"[LOAD] SchNet weights: {sch_ckpt}")
except Exception as e:
print(f"[LOAD][WARN] Could not load SchNet weights: {e}")
# -------------------------
# Fingerprint encoder
# -------------------------
fp_encoder = FingerprintEncoder(
vocab_size=VOCAB_SIZE_FP,
hidden_dim=256,
seq_len=FP_LENGTH,
num_layers=4,
nhead=8,
dim_feedforward=1024,
dropout=0.1
)
fp_ckpt = os.path.join(BEST_FP_DIR, "pytorch_model.bin")
if os.path.exists(fp_ckpt):
try:
fp_encoder.load_state_dict(torch.load(fp_ckpt, map_location="cpu"), strict=False)
print(f"[LOAD] FP encoder weights: {fp_ckpt}")
except Exception as e:
print(f"[LOAD][WARN] Could not load fingerprint weights: {e}")
# -------------------------
# PSMILES encoder
# -------------------------
psmiles_encoder = None
if os.path.isdir(BEST_PSMILES_DIR):
try:
psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
print(f"[LOAD] PSMILES encoder: {BEST_PSMILES_DIR}")
except Exception as e:
print(f"[LOAD][WARN] Could not load PSMILES encoder from dir: {e}")
# Fallback: initialize with vocab fallback
if psmiles_encoder is None:
try:
psmiles_encoder = PSMILESDebertaEncoder(
model_dir_or_name=None,
vocab_fallback=int(getattr(tokenizer, "vocab_size", 300))
)
print("[LOAD] PSMILES encoder: initialized fallback (no pretrained dir).")
except Exception as e:
print(f"[LOAD][WARN] Could not initialize PSMILES encoder: {e}")
# Build multimodal wrapper
multimodal_model = MultimodalContrastiveModel(
gine_encoder,
schnet_encoder,
fp_encoder,
psmiles_encoder,
emb_dim=POLYF_EMB_DIM,
modalities=["gine", "schnet", "fp", "psmiles"]
)
# -------------------------
# Optional: load full multimodal checkpoint
# -------------------------
ckpt_path = os.path.join(pretrained_path, "pytorch_model.bin")
if os.path.isfile(ckpt_path):
try:
state = torch.load(ckpt_path, map_location="cpu")
model_state = multimodal_model.state_dict()
filtered_state = {}
for k, v in state.items():
if k not in model_state:
continue
if model_state[k].shape != v.shape:
continue
filtered_state[k] = v
summarize_state_dict_load(state, model_state, filtered_state)
missing, unexpected = multimodal_model.load_state_dict(filtered_state, strict=False)
print(f"[LOAD] Multimodal checkpoint: {ckpt_path}")
print(f"[LOAD] load_state_dict -> missing={len(missing)} unexpected={len(unexpected)}")
if missing:
print("[LOAD] Missing keys (sample):", missing[:50])
if unexpected:
print("[LOAD] Unexpected keys (sample):", unexpected[:50])
except Exception as e:
print(f"[LOAD][WARN] Failed to load multimodal pretrained weights: {e}")
else:
print(f"[LOAD] No multimodal checkpoint found at: {ckpt_path}")
return multimodal_model
# =============================================================================
# Downstream: sample construction + CV training loop
# =============================================================================
def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
"""
Construct training samples for a single property:
- Keep rows that have at least one modality present.
- Keep rows with a finite property value in `prop_col`.
- Store raw target (will be scaled per fold).
"""
samples = []
for _, row in df.iterrows():
# Require at least one modality present
has_modality = False
for col in ['graph', 'geometry', 'fingerprints', 'psmiles']:
if col in row and row[col] and str(row[col]).strip() != "":
has_modality = True
break
if not has_modality:
continue
val = row.get(prop_col, np.nan)
if val is None or (isinstance(val, float) and np.isnan(val)):
continue
try:
y = float(val)
except Exception:
continue
samples.append({
'graph': row.get('graph', ''),
'geometry': row.get('geometry', ''),
'fingerprints': row.get('fingerprints', ''),
'psmiles': row.get('psmiles', ''),
'target_raw': y
})
return samples
def run_polyf_downstream(property_list: List[str], property_cols: List[str], df_raw: pd.DataFrame,
pretrained_path: str, output_file: str):
"""
Downstream evaluation:
For each property:
- Build samples from PolyInfo
- 5-fold CV:
- Split into trainval/test (by KFold)
- Split trainval into train/val
- Fit StandardScaler on train targets
- Fine-tune encoder+head end-to-end with early stopping by val loss
- Evaluate on held-out test fold in original units
- Save per-fold results and per-property mean±std
- Save best fold checkpoint bundle (by test R2) for later reuse
"""
os.makedirs(pretrained_path, exist_ok=True)
# Optional duplicate aggregation (noise reduction)
modality_cols = ["graph", "geometry", "fingerprints", "psmiles"]
df_proc = aggregate_polyinfo_duplicates(df_raw, modality_cols=modality_cols, property_cols=property_cols)
all_results = {"per_property": {}, "mode": "POLYF_MATCHED_SINGLE_TASK"}
for pname, pcol in zip(property_list, property_cols):
samples = build_samples_for_property(df_proc, pcol)
print(f"[DATA] {pname}: n_samples={len(samples)}")
if len(samples) < 200:
print(f"[DATA][WARN] '{pname}' has <200 samples; results may be noisy.")
if len(samples) < 50:
print(f"[DATA][WARN] Skipping '{pname}' (insufficient samples).")
continue
run_metrics = []
run_records = []
# Track best-performing fold for this property (by test R2)
best_overall_r2 = -1e18
best_overall_payload = None
idxs = np.arange(len(samples))
cv = KFold(n_splits=NUM_RUNS, shuffle=True, random_state=42)
for run_idx, (trainval_idx, test_idx) in enumerate(cv.split(idxs)):
seed = 42 + run_idx
set_seed(seed)
print(f"\n--- [CV] {pname}: fold {run_idx+1}/{NUM_RUNS} | seed={seed} ---")
trainval = [copy.deepcopy(samples[i]) for i in trainval_idx]
test = [copy.deepcopy(samples[i]) for i in test_idx]
# Split trainval into train/val
tr_idx, va_idx = train_test_split(
np.arange(len(trainval)),
test_size=VAL_SIZE_WITHIN_TRAINVAL,
random_state=seed,
shuffle=True
)
train = [copy.deepcopy(trainval[i]) for i in tr_idx]
val = [copy.deepcopy(trainval[i]) for i in va_idx]
# Standardize target using training fold only (prevents leakage)
sc = StandardScaler()
sc.fit(np.array([s["target_raw"] for s in train]).reshape(-1, 1))
def _apply_scale(lst):
for s in lst:
s["target_scaled"] = float(sc.transform(np.array([[s["target_raw"]]])).ravel()[0])
_apply_scale(train)
_apply_scale(val)
_apply_scale(test)
ds_train = PolymerPropertyDataset(train, tokenizer, max_length=MAX_LEN)
ds_val = PolymerPropertyDataset(val, tokenizer, max_length=MAX_LEN)
ds_test = PolymerPropertyDataset(test, tokenizer, max_length=MAX_LEN)
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
print(f"[SPLIT] train={len(ds_train)} val={len(ds_val)} test={len(ds_test)}")
# Fresh base model per fold to avoid any cross-fold leakage
polyf_base = load_pretrained_multimodal(pretrained_path)
model = PolyFPropertyRegressor(polyf_base, emb_dim=POLYF_EMB_DIM, dropout=POLYF_DROPOUT).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
best_val = float("inf")
best_state = None
no_improve = 0
# Train with early stopping on validation loss
for epoch in range(1, NUM_EPOCHS + 1):
tr_loss = train_one_epoch(model, dl_train, optimizer, DEVICE)
va_loss, _, _ = evaluate(model, dl_val, DEVICE)
va_loss = va_loss if va_loss is not None else float("inf")
scheduler.step()
print(f"[{pname}] fold {run_idx+1}/{NUM_RUNS} epoch {epoch:03d} | train={tr_loss:.6f} | val={va_loss:.6f}")
if va_loss < best_val - 1e-8:
best_val = va_loss
no_improve = 0
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
else:
no_improve += 1
if no_improve >= PATIENCE:
print(f"[{pname}] fold {run_idx+1}: early stopping (patience={PATIENCE}) at epoch {epoch}.")
break
if best_state is None:
print(f"[{pname}][WARN] No best checkpoint captured for fold {run_idx+1}; skipping fold.")
continue
# Restore best state and evaluate on test fold
model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=True)
_, pred_scaled, true_scaled = evaluate(model, dl_test, DEVICE)
if pred_scaled is None:
print(f"[{pname}][WARN] Test evaluation returned empty predictions for fold {run_idx+1}.")
continue
# Convert from scaled space back to original units
pred = sc.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()
true = sc.inverse_transform(true_scaled.reshape(-1, 1)).ravel()
m = compute_metrics(true, pred)
run_metrics.append(m)
print(f"[{pname}] fold {run_idx+1} TEST | r2={m['r2']:.4f} mae={m['mae']:.4f} rmse={m['rmse']:.4f}")
record = {
"property": pname,
"property_col": pcol,
"run": run_idx + 1,
"seed": seed,
"n_train": len(ds_train),
"n_val": len(ds_val),
"n_test": len(ds_test),
"best_val_loss": float(best_val),
"test_metrics": m
}
run_records.append(record)
with open(output_file, "a") as fh:
fh.write(json.dumps(make_json_serializable(record)) + "\n")
# Update best fold bundle (by test R2)
if float(m.get("r2", -1e18)) > float(best_overall_r2):
best_overall_r2 = float(m.get("r2", -1e18))
best_overall_payload = {
"property": pname,
"property_col": pcol,
"best_run": int(run_idx + 1),
"seed": int(seed),
"n_train": int(len(ds_train)),
"n_val": int(len(ds_val)),
"n_test": int(len(ds_test)),
"best_val_loss": float(best_val),
"test_metrics": make_json_serializable(m),
"scaler_mean": make_json_serializable(getattr(sc, "mean_", None)),
"scaler_scale": make_json_serializable(getattr(sc, "scale_", None)),
"scaler_var": make_json_serializable(getattr(sc, "var_", None)),
"scaler_n_samples_seen": make_json_serializable(getattr(sc, "n_samples_seen_", None)),
"model_state_dict": best_state, # CPU tensors
}
# Save best fold weights + metadata per property
if best_overall_payload is not None and "model_state_dict" in best_overall_payload:
os.makedirs(BEST_WEIGHTS_DIR, exist_ok=True)
prop_dir = os.path.join(BEST_WEIGHTS_DIR, _sanitize_name(pname))
os.makedirs(prop_dir, exist_ok=True)
ckpt_bundle = {k: v for k, v in best_overall_payload.items() if k != "test_metrics"}
ckpt_bundle["test_metrics"] = best_overall_payload["test_metrics"]
torch.save(ckpt_bundle, os.path.join(prop_dir, "best_run_checkpoint.pt"))
meta = {k: v for k, v in best_overall_payload.items() if k != "model_state_dict"}
with open(os.path.join(prop_dir, "best_run_metadata.json"), "w") as fh:
fh.write(json.dumps(make_json_serializable(meta), indent=2))
print(f"[BEST] Saved best fold for '{pname}' -> {prop_dir}")
print(f"[BEST] best_run={best_overall_payload['best_run']} best_test_r2={best_overall_payload['test_metrics'].get('r2', None)}")
# Aggregate metrics across folds
if run_metrics:
r2s = [x["r2"] for x in run_metrics]
maes = [x["mae"] for x in run_metrics]
rmses = [x["rmse"] for x in run_metrics]
mses = [x["mse"] for x in run_metrics]
agg = {
"r2": {"mean": float(np.mean(r2s)), "std": float(np.std(r2s, ddof=0))},
"mae": {"mean": float(np.mean(maes)), "std": float(np.std(maes, ddof=0))},
"rmse": {"mean": float(np.mean(rmses)), "std": float(np.std(rmses, ddof=0))},
"mse": {"mean": float(np.mean(mses)), "std": float(np.std(mses, ddof=0))},
}
print(f"[AGG] {pname} | r2={agg['r2']['mean']:.4f}±{agg['r2']['std']:.4f} mae={agg['mae']['mean']:.4f}±{agg['mae']['std']:.4f}")
else:
agg = None
print(f"[AGG][WARN] No successful folds for '{pname}' (no aggregate computed).")
all_results["per_property"][pname] = {
"property_col": pcol,
"n_samples": len(samples),
"runs": run_records,
"agg": agg
}
with open(output_file, "a") as fh:
fh.write("AGG_PROPERTY: " + json.dumps(make_json_serializable({pname: agg})) + "\n")
return all_results
# =============================================================================
# Main
# =============================================================================
def main():
# Start a fresh results file (back up old results if present)
if os.path.exists(OUTPUT_RESULTS):
backup = OUTPUT_RESULTS + ".bak"
shutil.copy(OUTPUT_RESULTS, backup)
print(f"[INIT] Existing results backed up: {backup}")
open(OUTPUT_RESULTS, "w").close()
print(f"[INIT] Writing results to: {OUTPUT_RESULTS}")
# Load PolyInfo
if not os.path.isfile(POLYINFO_PATH):
raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}")
polyinfo_raw = pd.read_csv(POLYINFO_PATH, engine="python")
print(f"[DATA] Loaded PolyInfo: n_rows={len(polyinfo_raw)} n_cols={len(polyinfo_raw.columns)}")
# Map requested properties to dataframe columns
found = find_property_columns(polyinfo_raw.columns)
prop_map = {req: col for req, col in found.items()}
print(f"[COLMAP] Property-to-column map: {prop_map}")
property_list = []
property_cols = []
for req in REQUESTED_PROPERTIES:
col = prop_map.get(req)
if col is None:
print(f"[COLMAP][WARN] Could not find a column for '{req}'; skipping.")
continue
property_list.append(req)
property_cols.append(col)
overall = run_polyf_downstream(property_list, property_cols, polyinfo_raw, PRETRAINED_MULTIMODAL_DIR, OUTPUT_RESULTS)
# Write final summary (aggregated per property)
final_agg = {}
if overall and "per_property" in overall:
for pname, info in overall["per_property"].items():
final_agg[pname] = info.get("agg", None)
with open(OUTPUT_RESULTS, "a") as fh:
fh.write("\nFINAL_SUMMARY\n")
fh.write(json.dumps(make_json_serializable(final_agg), indent=2))
fh.write("\n")
print(f"\n Results appended to: {OUTPUT_RESULTS}")
print(f" Best checkpoints saved under: {BEST_WEIGHTS_DIR}")
if __name__ == "__main__":
main()