Spaces:
Running
Running
manpreet88
commited on
Commit
·
a5954f7
1
Parent(s):
231c5c8
Update Property_Prediction.py
Browse files- Downstream Tasks/Property_Prediction.py +331 -181
Downstream Tasks/Property_Prediction.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
import time
|
|
@@ -20,66 +19,77 @@ import csv
|
|
| 20 |
import copy
|
| 21 |
from typing import List, Dict, Optional, Tuple, Any
|
| 22 |
|
|
|
|
| 23 |
csv.field_size_limit(sys.maxsize)
|
| 24 |
|
| 25 |
-
#
|
|
|
|
|
|
|
| 26 |
from PolyFusion.GINE import GineEncoder, match_edge_attr_to_index, safe_get
|
| 27 |
from PolyFusion.SchNet import NodeSchNetWrapper
|
| 28 |
from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder
|
| 29 |
from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
|
| 30 |
|
| 31 |
-
#
|
| 32 |
# Configuration
|
| 33 |
-
#
|
| 34 |
BASE_DIR = "/path/to/Polymer_Foundational_Model"
|
| 35 |
POLYINFO_PATH = "/path/to/polyinfo_with_modalities.csv"
|
| 36 |
|
| 37 |
-
# Pretrained
|
| 38 |
PRETRAINED_MULTIMODAL_DIR = "/path/to/multimodal_output/best"
|
| 39 |
BEST_GINE_DIR = "/path/to/gin_output/best"
|
| 40 |
BEST_SCHNET_DIR = "/path/to/schnet_output/best"
|
| 41 |
BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
|
| 42 |
BEST_PSMILES_DIR = "/path/to/polybert_output/best"
|
| 43 |
|
|
|
|
| 44 |
OUTPUT_RESULTS = "/path/to/multimodal_downstream_results.txt"
|
| 45 |
|
| 46 |
-
#
|
| 47 |
BEST_WEIGHTS_DIR = "/path/to/multimodal_downstream_bestweights"
|
| 48 |
|
| 49 |
-
#
|
|
|
|
|
|
|
| 50 |
MAX_ATOMIC_Z = 85
|
| 51 |
MASK_ATOM_ID = MAX_ATOMIC_Z + 1
|
| 52 |
|
| 53 |
-
# GINE
|
| 54 |
NODE_EMB_DIM = 300
|
| 55 |
EDGE_EMB_DIM = 300
|
| 56 |
NUM_GNN_LAYERS = 5
|
| 57 |
|
| 58 |
-
# SchNet
|
| 59 |
SCHNET_NUM_GAUSSIANS = 50
|
| 60 |
SCHNET_NUM_INTERACTIONS = 6
|
| 61 |
SCHNET_CUTOFF = 10.0
|
| 62 |
SCHNET_MAX_NEIGHBORS = 64
|
| 63 |
SCHNET_HIDDEN = 600
|
| 64 |
|
| 65 |
-
#
|
| 66 |
FP_LENGTH = 2048
|
| 67 |
MASK_TOKEN_ID_FP = 2
|
| 68 |
VOCAB_SIZE_FP = 3
|
| 69 |
|
|
|
|
| 70 |
CL_EMB_DIM = 600
|
| 71 |
|
| 72 |
-
# PSMILES/
|
| 73 |
DEBERTA_HIDDEN = 600
|
| 74 |
PSMILES_MAX_LEN = 128
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
MAX_LEN = 128
|
| 84 |
BATCH_SIZE = 32
|
| 85 |
NUM_EPOCHS = 100
|
|
@@ -89,6 +99,7 @@ WEIGHT_DECAY = 0.0
|
|
| 89 |
|
| 90 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 91 |
|
|
|
|
| 92 |
REQUESTED_PROPERTIES = [
|
| 93 |
"density",
|
| 94 |
"glass transition",
|
|
@@ -97,28 +108,31 @@ REQUESTED_PROPERTIES = [
|
|
| 97 |
"thermal decomposition"
|
| 98 |
]
|
| 99 |
|
| 100 |
-
#
|
| 101 |
NUM_RUNS = 5
|
| 102 |
TEST_SIZE = 0.10
|
| 103 |
-
VAL_SIZE_WITHIN_TRAINVAL = 0.10 # of trainval
|
| 104 |
|
| 105 |
-
#
|
| 106 |
AGG_KEYS_PREFERENCE = ["polymer_id", "PolymerID", "poly_id", "psmiles", "smiles", "canonical_smiles"]
|
| 107 |
|
| 108 |
-
#
|
| 109 |
# Utilities
|
| 110 |
-
#
|
| 111 |
def set_seed(seed: int):
|
|
|
|
| 112 |
random.seed(seed)
|
| 113 |
np.random.seed(seed)
|
| 114 |
torch.manual_seed(seed)
|
| 115 |
if torch.cuda.is_available():
|
| 116 |
torch.cuda.manual_seed_all(seed)
|
|
|
|
| 117 |
torch.backends.cudnn.deterministic = True
|
| 118 |
torch.backends.cudnn.benchmark = False
|
| 119 |
|
| 120 |
|
| 121 |
def make_json_serializable(obj):
|
|
|
|
| 122 |
if isinstance(obj, dict):
|
| 123 |
return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()}
|
| 124 |
if isinstance(obj, (list, tuple, set)):
|
|
@@ -144,7 +158,15 @@ def make_json_serializable(obj):
|
|
| 144 |
pass
|
| 145 |
return obj
|
| 146 |
|
|
|
|
| 147 |
def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_state: dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
n_ckpt = len(full_state)
|
| 149 |
n_model = len(model_state)
|
| 150 |
n_loaded = len(filtered_state)
|
|
@@ -160,13 +182,13 @@ def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_stat
|
|
| 160 |
print(f" ckpt keys: {n_ckpt}")
|
| 161 |
print(f" model keys: {n_model}")
|
| 162 |
print(f" loaded keys: {n_loaded}")
|
| 163 |
-
print(f" skipped (not in model):
|
| 164 |
-
print(f" skipped (shape mismatch):
|
| 165 |
|
| 166 |
if missing_in_model:
|
| 167 |
-
print(" examples
|
| 168 |
if shape_mismatch:
|
| 169 |
-
print(" examples
|
| 170 |
for k in shape_mismatch[:10]:
|
| 171 |
print(f" {k}: ckpt={tuple(full_state[k].shape)} model={tuple(model_state[k].shape)}")
|
| 172 |
print("")
|
|
@@ -174,10 +196,10 @@ def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_stat
|
|
| 174 |
|
| 175 |
def find_property_columns(columns):
|
| 176 |
"""
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
"""
|
| 182 |
lowered = {c.lower(): c for c in columns}
|
| 183 |
found = {}
|
|
@@ -186,6 +208,7 @@ def find_property_columns(columns):
|
|
| 186 |
req_low = req.lower().strip()
|
| 187 |
exact = None
|
| 188 |
|
|
|
|
| 189 |
for c_low, c_orig in lowered.items():
|
| 190 |
tokens = set(c_low.replace('_', ' ').split())
|
| 191 |
if req_low in tokens or c_low == req_low:
|
|
@@ -198,6 +221,7 @@ def find_property_columns(columns):
|
|
| 198 |
found[req] = exact
|
| 199 |
continue
|
| 200 |
|
|
|
|
| 201 |
candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low]
|
| 202 |
if req_low == "density":
|
| 203 |
candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()]
|
|
@@ -207,14 +231,16 @@ def find_property_columns(columns):
|
|
| 207 |
else:
|
| 208 |
chosen = candidates[0] if candidates else None
|
| 209 |
found[req] = chosen
|
| 210 |
-
print(f"[
|
| 211 |
if candidates:
|
| 212 |
-
print(f"[
|
| 213 |
else:
|
| 214 |
-
print(f"[WARN] No candidates found for '{req}' using substring search.")
|
| 215 |
return found
|
| 216 |
|
|
|
|
| 217 |
def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
|
|
|
|
| 218 |
for k in AGG_KEYS_PREFERENCE:
|
| 219 |
if k in df.columns:
|
| 220 |
return k
|
|
@@ -223,38 +249,38 @@ def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
|
|
| 223 |
|
| 224 |
def aggregate_polyinfo_duplicates(df: pd.DataFrame, modality_cols: List[str], property_cols: List[str]) -> pd.DataFrame:
|
| 225 |
"""
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
"""
|
| 229 |
key = choose_aggregation_key(df)
|
| 230 |
if key is None:
|
| 231 |
-
print("[
|
| 232 |
return df
|
| 233 |
|
| 234 |
-
# Keep only rows with a non-empty key
|
| 235 |
df2 = df.copy()
|
| 236 |
df2[key] = df2[key].astype(str)
|
| 237 |
df2 = df2[df2[key].str.strip() != ""].copy()
|
| 238 |
if len(df2) == 0:
|
| 239 |
-
print("[
|
| 240 |
return df
|
| 241 |
|
| 242 |
agg_dict = {}
|
| 243 |
-
# modalities: take first non-null string (they should be identical per polymer id in your dataset)
|
| 244 |
for mc in modality_cols:
|
| 245 |
if mc in df2.columns:
|
| 246 |
agg_dict[mc] = "first"
|
| 247 |
-
# properties: mean
|
| 248 |
for pc in property_cols:
|
| 249 |
if pc in df2.columns:
|
| 250 |
agg_dict[pc] = "mean"
|
| 251 |
|
| 252 |
grouped = df2.groupby(key, as_index=False).agg(agg_dict)
|
| 253 |
-
print(f"[
|
| 254 |
return grouped
|
| 255 |
|
| 256 |
|
| 257 |
def _sanitize_name(s: str) -> str:
|
|
|
|
| 258 |
s2 = str(s).strip().lower()
|
| 259 |
keep = []
|
| 260 |
for ch in s2:
|
|
@@ -270,20 +296,29 @@ def _sanitize_name(s: str) -> str:
|
|
| 270 |
out = out.strip("_")
|
| 271 |
return out or "property"
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
class MultimodalContrastiveModel(nn.Module):
|
| 274 |
"""
|
| 275 |
-
Multimodal encoder wrapper
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
"""
|
| 288 |
|
| 289 |
def __init__(
|
|
@@ -310,7 +345,7 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 310 |
self.out_dim = self.emb_dim
|
| 311 |
self.dropout = nn.Dropout(float(dropout))
|
| 312 |
|
| 313 |
-
#
|
| 314 |
if modalities is None:
|
| 315 |
mods = []
|
| 316 |
if self.gine is not None:
|
|
@@ -325,12 +360,12 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 325 |
else:
|
| 326 |
self.modalities = [m for m in modalities if m in ("gine", "schnet", "fp", "psmiles")]
|
| 327 |
|
| 328 |
-
# Projection heads
|
| 329 |
self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None
|
| 330 |
self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None
|
| 331 |
self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None
|
| 332 |
|
| 333 |
-
# PSMILES
|
| 334 |
psm_in = None
|
| 335 |
if self.psmiles is not None:
|
| 336 |
if hasattr(self.psmiles, "out_dim"):
|
|
@@ -349,7 +384,7 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 349 |
self.proj_psmiles = nn.Linear(psm_in, self.emb_dim) if (self.psmiles is not None) else None
|
| 350 |
|
| 351 |
def freeze_cl_encoders(self):
|
| 352 |
-
"""
|
| 353 |
for enc in (self.gine, self.schnet, self.fp, self.psmiles):
|
| 354 |
if enc is None:
|
| 355 |
continue
|
|
@@ -359,9 +394,11 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 359 |
|
| 360 |
def _masked_mean_combine(self, zs: List[torch.Tensor], masks: List[torch.Tensor]) -> torch.Tensor:
|
| 361 |
"""
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
| 365 |
"""
|
| 366 |
if not zs:
|
| 367 |
device = next(self.parameters()).device
|
|
@@ -384,22 +421,21 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 384 |
def forward(self, batch_mods: dict, modality_mask: Optional[dict] = None) -> torch.Tensor:
|
| 385 |
"""
|
| 386 |
batch_mods keys: 'gine', 'schnet', 'fp', 'psmiles'
|
| 387 |
-
modality_mask: dict
|
| 388 |
"""
|
| 389 |
device = next(self.parameters()).device
|
| 390 |
|
| 391 |
zs = []
|
| 392 |
ms = []
|
| 393 |
|
| 394 |
-
#
|
| 395 |
B = None
|
| 396 |
if modality_mask is not None:
|
| 397 |
-
for
|
| 398 |
if isinstance(v, torch.Tensor) and v.numel() > 0:
|
| 399 |
B = int(v.size(0))
|
| 400 |
break
|
| 401 |
|
| 402 |
-
# Fallback: infer B from fp/psmiles tensors
|
| 403 |
if B is None:
|
| 404 |
if "fp" in batch_mods and batch_mods["fp"] is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor):
|
| 405 |
B = int(batch_mods["fp"]["input_ids"].size(0))
|
|
@@ -409,13 +445,14 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 409 |
if B is None:
|
| 410 |
return torch.zeros((1, self.emb_dim), device=device)
|
| 411 |
|
| 412 |
-
# Helper to get modality present mask
|
| 413 |
def _get_mask(name: str) -> torch.Tensor:
|
| 414 |
if modality_mask is not None and name in modality_mask and isinstance(modality_mask[name], torch.Tensor):
|
| 415 |
return modality_mask[name].to(device).bool()
|
| 416 |
return torch.ones((B,), device=device, dtype=torch.bool)
|
| 417 |
|
| 418 |
-
#
|
|
|
|
|
|
|
| 419 |
if "gine" in self.modalities and self.gine is not None and batch_mods.get("gine", None) is not None:
|
| 420 |
g = batch_mods["gine"]
|
| 421 |
if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0:
|
|
@@ -433,7 +470,9 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 433 |
zs.append(z)
|
| 434 |
ms.append(_get_mask("gine"))
|
| 435 |
|
| 436 |
-
#
|
|
|
|
|
|
|
| 437 |
if "schnet" in self.modalities and self.schnet is not None and batch_mods.get("schnet", None) is not None:
|
| 438 |
s = batch_mods["schnet"]
|
| 439 |
if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0:
|
|
@@ -448,7 +487,9 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 448 |
zs.append(z)
|
| 449 |
ms.append(_get_mask("schnet"))
|
| 450 |
|
| 451 |
-
#
|
|
|
|
|
|
|
| 452 |
if "fp" in self.modalities and self.fp is not None and batch_mods.get("fp", None) is not None:
|
| 453 |
f = batch_mods["fp"]
|
| 454 |
if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0:
|
|
@@ -462,7 +503,9 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 462 |
zs.append(z)
|
| 463 |
ms.append(_get_mask("fp"))
|
| 464 |
|
| 465 |
-
#
|
|
|
|
|
|
|
| 466 |
if "psmiles" in self.modalities and self.psmiles is not None and batch_mods.get("psmiles", None) is not None:
|
| 467 |
p = batch_mods["psmiles"]
|
| 468 |
if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0:
|
|
@@ -476,7 +519,7 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 476 |
zs.append(z)
|
| 477 |
ms.append(_get_mask("psmiles"))
|
| 478 |
|
| 479 |
-
#
|
| 480 |
if not zs:
|
| 481 |
return torch.zeros((B, self.emb_dim), device=device)
|
| 482 |
|
|
@@ -484,7 +527,6 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 484 |
z = F.normalize(z, dim=-1)
|
| 485 |
return z
|
| 486 |
|
| 487 |
-
|
| 488 |
@torch.no_grad()
|
| 489 |
def encode_psmiles(
|
| 490 |
self,
|
|
@@ -494,7 +536,7 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 494 |
device: str = DEVICE
|
| 495 |
) -> np.ndarray:
|
| 496 |
"""
|
| 497 |
-
PSMILES-only embeddings (used for
|
| 498 |
"""
|
| 499 |
self.eval()
|
| 500 |
if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None:
|
|
@@ -519,9 +561,9 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 519 |
device: str = DEVICE
|
| 520 |
) -> np.ndarray:
|
| 521 |
"""
|
| 522 |
-
|
| 523 |
-
graph, geometry, fingerprints, psmiles
|
| 524 |
-
Missing modalities are
|
| 525 |
"""
|
| 526 |
self.eval()
|
| 527 |
dev = torch.device(device)
|
|
@@ -531,14 +573,13 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 531 |
for i in range(0, len(records), batch_size):
|
| 532 |
chunk = records[i:i + batch_size]
|
| 533 |
|
| 534 |
-
# Build per-sample modality tensors; do minimal batching (stack where possible)
|
| 535 |
# PSMILES batch
|
| 536 |
psmiles_texts = [str(r.get("psmiles", "")) for r in chunk]
|
| 537 |
p_enc = None
|
| 538 |
if self.psm_tok is not None:
|
| 539 |
p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt")
|
| 540 |
|
| 541 |
-
# FP batch
|
| 542 |
fp_ids, fp_attn = [], []
|
| 543 |
for r in chunk:
|
| 544 |
f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH)
|
|
@@ -547,8 +588,7 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 547 |
fp_ids = torch.stack(fp_ids, dim=0)
|
| 548 |
fp_attn = torch.stack(fp_attn, dim=0)
|
| 549 |
|
| 550 |
-
# GINE
|
| 551 |
-
# (mean pooling per sample already handled in encoder via batch vectors)
|
| 552 |
gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []}
|
| 553 |
node_offset = 0
|
| 554 |
for bi, r in enumerate(chunk):
|
|
@@ -606,21 +646,32 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 606 |
"psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None
|
| 607 |
}
|
| 608 |
|
| 609 |
-
|
|
|
|
| 610 |
outs.append(z.detach().cpu().numpy())
|
| 611 |
|
| 612 |
return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
|
| 613 |
|
| 614 |
-
|
|
|
|
| 615 |
# Tokenizer setup
|
| 616 |
-
#
|
| 617 |
SPM_MODEL = "/path/to/spm.model"
|
| 618 |
tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
| 619 |
|
| 620 |
-
#
|
| 621 |
-
# Dataset
|
| 622 |
-
#
|
| 623 |
class PolymerPropertyDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
def __init__(self, data_list, tokenizer, max_length=128):
|
| 625 |
self.data_list = data_list
|
| 626 |
self.tokenizer = tokenizer
|
|
@@ -632,7 +683,9 @@ class PolymerPropertyDataset(Dataset):
|
|
| 632 |
def __getitem__(self, idx):
|
| 633 |
data = self.data_list[idx]
|
| 634 |
|
| 635 |
-
#
|
|
|
|
|
|
|
| 636 |
gine_data = None
|
| 637 |
if 'graph' in data and data['graph']:
|
| 638 |
try:
|
|
@@ -661,6 +714,7 @@ class PolymerPropertyDataset(Dataset):
|
|
| 661 |
edge_index = None
|
| 662 |
edge_attr = None
|
| 663 |
|
|
|
|
| 664 |
if edge_indices_raw is None:
|
| 665 |
adj_mat = safe_get(graph_field, "adjacency_matrix", None)
|
| 666 |
if adj_mat:
|
|
@@ -676,6 +730,7 @@ class PolymerPropertyDataset(Dataset):
|
|
| 676 |
E = len(srcs)
|
| 677 |
edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
|
| 678 |
else:
|
|
|
|
| 679 |
srcs, dsts = [], []
|
| 680 |
if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0:
|
| 681 |
if isinstance(edge_indices_raw[0], list):
|
|
@@ -702,6 +757,7 @@ class PolymerPropertyDataset(Dataset):
|
|
| 702 |
if len(srcs) > 0:
|
| 703 |
edge_index = [srcs, dsts]
|
| 704 |
|
|
|
|
| 705 |
if edge_features_raw and isinstance(edge_features_raw, list):
|
| 706 |
bond_types = []
|
| 707 |
stereos = []
|
|
@@ -729,7 +785,9 @@ class PolymerPropertyDataset(Dataset):
|
|
| 729 |
except Exception:
|
| 730 |
gine_data = None
|
| 731 |
|
| 732 |
-
#
|
|
|
|
|
|
|
| 733 |
schnet_data = None
|
| 734 |
if 'geometry' in data and data['geometry']:
|
| 735 |
try:
|
|
@@ -746,7 +804,9 @@ class PolymerPropertyDataset(Dataset):
|
|
| 746 |
except Exception:
|
| 747 |
schnet_data = None
|
| 748 |
|
| 749 |
-
#
|
|
|
|
|
|
|
| 750 |
fp_data = None
|
| 751 |
if 'fingerprints' in data and data['fingerprints']:
|
| 752 |
try:
|
|
@@ -797,7 +857,9 @@ class PolymerPropertyDataset(Dataset):
|
|
| 797 |
except Exception:
|
| 798 |
fp_data = None
|
| 799 |
|
| 800 |
-
#
|
|
|
|
|
|
|
| 801 |
psmiles_data = None
|
| 802 |
if 'psmiles' in data and data['psmiles'] and self.tokenizer is not None:
|
| 803 |
try:
|
|
@@ -815,7 +877,9 @@ class PolymerPropertyDataset(Dataset):
|
|
| 815 |
except Exception:
|
| 816 |
psmiles_data = None
|
| 817 |
|
| 818 |
-
#
|
|
|
|
|
|
|
| 819 |
if gine_data is None:
|
| 820 |
gine_data = {
|
| 821 |
'z': torch.tensor([], dtype=torch.long),
|
|
@@ -843,7 +907,7 @@ class PolymerPropertyDataset(Dataset):
|
|
| 843 |
'attention_mask': torch.zeros(PSMILES_MAX_LEN, dtype=torch.bool)
|
| 844 |
}
|
| 845 |
|
| 846 |
-
# Single-task target (scaled)
|
| 847 |
target_scaled = float(data.get("target_scaled", 0.0))
|
| 848 |
|
| 849 |
return {
|
|
@@ -855,10 +919,23 @@ class PolymerPropertyDataset(Dataset):
|
|
| 855 |
}
|
| 856 |
|
| 857 |
|
|
|
|
|
|
|
|
|
|
| 858 |
def multimodal_collate_fn(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
B = len(batch)
|
| 860 |
|
| 861 |
-
#
|
|
|
|
|
|
|
| 862 |
all_z = []
|
| 863 |
all_ch = []
|
| 864 |
all_fc = []
|
|
@@ -905,7 +982,9 @@ def multimodal_collate_fn(batch):
|
|
| 905 |
edge_index_batched = torch.empty((2, 0), dtype=torch.long)
|
| 906 |
edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
|
| 907 |
|
| 908 |
-
#
|
|
|
|
|
|
|
| 909 |
all_sz = []
|
| 910 |
all_pos = []
|
| 911 |
schnet_batch = []
|
|
@@ -930,12 +1009,16 @@ def multimodal_collate_fn(batch):
|
|
| 930 |
s_pos_batch = torch.cat(all_pos, dim=0)
|
| 931 |
s_batch_batch = torch.cat(schnet_batch, dim=0)
|
| 932 |
|
| 933 |
-
#
|
|
|
|
|
|
|
| 934 |
fp_ids = torch.stack([item["fp"]["input_ids"] for item in batch], dim=0)
|
| 935 |
fp_attn = torch.stack([item["fp"]["attention_mask"] for item in batch], dim=0)
|
| 936 |
fp_present = (fp_attn.sum(dim=1) > 0).cpu().numpy().tolist()
|
| 937 |
|
| 938 |
-
#
|
|
|
|
|
|
|
| 939 |
p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch], dim=0)
|
| 940 |
p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch], dim=0)
|
| 941 |
psmiles_present = (p_attn.sum(dim=1) > 0).cpu().numpy().tolist()
|
|
@@ -943,6 +1026,7 @@ def multimodal_collate_fn(batch):
|
|
| 943 |
# Target
|
| 944 |
target = torch.stack([item["target"] for item in batch], dim=0) # (B,)
|
| 945 |
|
|
|
|
| 946 |
modality_mask = {
|
| 947 |
"gine": torch.tensor(gine_present, dtype=torch.bool),
|
| 948 |
"schnet": torch.tensor(schnet_present, dtype=torch.bool),
|
|
@@ -977,13 +1061,17 @@ def multimodal_collate_fn(batch):
|
|
| 977 |
}
|
| 978 |
|
| 979 |
|
| 980 |
-
#
|
| 981 |
-
#
|
| 982 |
-
#
|
| 983 |
-
class
|
| 984 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
super().__init__()
|
| 986 |
-
self.
|
| 987 |
self.head = nn.Sequential(
|
| 988 |
nn.Linear(emb_dim, emb_dim // 2),
|
| 989 |
nn.ReLU(),
|
|
@@ -992,15 +1080,16 @@ class UniPolyPropertyRegressor(nn.Module):
|
|
| 992 |
)
|
| 993 |
|
| 994 |
def forward(self, batch_mods, modality_mask=None):
|
| 995 |
-
emb = self.
|
| 996 |
y = self.head(emb).squeeze(-1) # (B,)
|
| 997 |
return y
|
| 998 |
|
| 999 |
|
| 1000 |
-
#
|
| 1001 |
-
# Training / evaluation
|
| 1002 |
-
#
|
| 1003 |
def compute_metrics(y_true, y_pred):
|
|
|
|
| 1004 |
mse = mean_squared_error(y_true, y_pred)
|
| 1005 |
rmse = math.sqrt(mse)
|
| 1006 |
mae = mean_absolute_error(y_true, y_pred)
|
|
@@ -1009,11 +1098,13 @@ def compute_metrics(y_true, y_pred):
|
|
| 1009 |
|
| 1010 |
|
| 1011 |
def train_one_epoch(model, dataloader, optimizer, device):
|
|
|
|
| 1012 |
model.train()
|
| 1013 |
total_loss = 0.0
|
| 1014 |
total_n = 0
|
| 1015 |
|
| 1016 |
for batch in dataloader:
|
|
|
|
| 1017 |
for k in batch:
|
| 1018 |
if k == "target":
|
| 1019 |
batch[k] = batch[k].to(device)
|
|
@@ -1046,6 +1137,10 @@ def train_one_epoch(model, dataloader, optimizer, device):
|
|
| 1046 |
|
| 1047 |
@torch.no_grad()
|
| 1048 |
def evaluate(model, dataloader, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1049 |
model.eval()
|
| 1050 |
preds = []
|
| 1051 |
trues = []
|
|
@@ -1053,6 +1148,7 @@ def evaluate(model, dataloader, device):
|
|
| 1053 |
total_n = 0
|
| 1054 |
|
| 1055 |
for batch in dataloader:
|
|
|
|
| 1056 |
for k in batch:
|
| 1057 |
if k == "target":
|
| 1058 |
batch[k] = batch[k].to(device)
|
|
@@ -1088,26 +1184,37 @@ def evaluate(model, dataloader, device):
|
|
| 1088 |
return float(avg_loss), preds, trues
|
| 1089 |
|
| 1090 |
|
| 1091 |
-
#
|
| 1092 |
-
# Pretrained loading
|
| 1093 |
-
#
|
| 1094 |
def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1095 |
gine_encoder = GineEncoder(
|
| 1096 |
node_emb_dim=NODE_EMB_DIM,
|
| 1097 |
edge_emb_dim=EDGE_EMB_DIM,
|
| 1098 |
num_layers=NUM_GNN_LAYERS,
|
| 1099 |
max_atomic_z=MAX_ATOMIC_Z
|
| 1100 |
)
|
| 1101 |
-
|
|
|
|
| 1102 |
try:
|
| 1103 |
-
gine_encoder.load_state_dict(
|
| 1104 |
-
|
| 1105 |
-
strict=False
|
| 1106 |
-
)
|
| 1107 |
-
print(f"Loaded GINE weights from {BEST_GINE_DIR}")
|
| 1108 |
except Exception as e:
|
| 1109 |
-
print(f"
|
| 1110 |
|
|
|
|
|
|
|
|
|
|
| 1111 |
schnet_encoder = NodeSchNetWrapper(
|
| 1112 |
hidden_channels=SCHNET_HIDDEN,
|
| 1113 |
num_interactions=SCHNET_NUM_INTERACTIONS,
|
|
@@ -1115,16 +1222,17 @@ def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveMod
|
|
| 1115 |
cutoff=SCHNET_CUTOFF,
|
| 1116 |
max_num_neighbors=SCHNET_MAX_NEIGHBORS
|
| 1117 |
)
|
| 1118 |
-
|
|
|
|
| 1119 |
try:
|
| 1120 |
-
schnet_encoder.load_state_dict(
|
| 1121 |
-
|
| 1122 |
-
strict=False
|
| 1123 |
-
)
|
| 1124 |
-
print(f"Loaded SchNet weights from {BEST_SCHNET_DIR}")
|
| 1125 |
except Exception as e:
|
| 1126 |
-
print(f"
|
| 1127 |
|
|
|
|
|
|
|
|
|
|
| 1128 |
fp_encoder = FingerprintEncoder(
|
| 1129 |
vocab_size=VOCAB_SIZE_FP,
|
| 1130 |
hidden_dim=256,
|
|
@@ -1134,42 +1242,49 @@ def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveMod
|
|
| 1134 |
dim_feedforward=1024,
|
| 1135 |
dropout=0.1
|
| 1136 |
)
|
| 1137 |
-
|
|
|
|
| 1138 |
try:
|
| 1139 |
-
fp_encoder.load_state_dict(
|
| 1140 |
-
|
| 1141 |
-
strict=False
|
| 1142 |
-
)
|
| 1143 |
-
print(f"Loaded fingerprint encoder weights from {BEST_FP_DIR}")
|
| 1144 |
except Exception as e:
|
| 1145 |
-
print(f"
|
| 1146 |
|
|
|
|
|
|
|
|
|
|
| 1147 |
psmiles_encoder = None
|
| 1148 |
if os.path.isdir(BEST_PSMILES_DIR):
|
| 1149 |
try:
|
| 1150 |
psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
|
| 1151 |
-
print(f"
|
| 1152 |
except Exception as e:
|
| 1153 |
-
print(f"
|
| 1154 |
|
|
|
|
| 1155 |
if psmiles_encoder is None:
|
| 1156 |
try:
|
| 1157 |
psmiles_encoder = PSMILESDebertaEncoder(
|
| 1158 |
model_dir_or_name=None,
|
| 1159 |
vocab_fallback=int(getattr(tokenizer, "vocab_size", 300))
|
| 1160 |
)
|
|
|
|
| 1161 |
except Exception as e:
|
| 1162 |
-
print(f"
|
| 1163 |
|
|
|
|
| 1164 |
multimodal_model = MultimodalContrastiveModel(
|
| 1165 |
gine_encoder,
|
| 1166 |
schnet_encoder,
|
| 1167 |
fp_encoder,
|
| 1168 |
psmiles_encoder,
|
| 1169 |
-
emb_dim=
|
| 1170 |
modalities=["gine", "schnet", "fp", "psmiles"]
|
| 1171 |
)
|
| 1172 |
|
|
|
|
|
|
|
|
|
|
| 1173 |
ckpt_path = os.path.join(pretrained_path, "pytorch_model.bin")
|
| 1174 |
if os.path.isfile(ckpt_path):
|
| 1175 |
try:
|
|
@@ -1185,26 +1300,34 @@ def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveMod
|
|
| 1185 |
|
| 1186 |
summarize_state_dict_load(state, model_state, filtered_state)
|
| 1187 |
missing, unexpected = multimodal_model.load_state_dict(filtered_state, strict=False)
|
| 1188 |
-
print(f"[
|
| 1189 |
-
print(f"[
|
| 1190 |
if missing:
|
| 1191 |
-
print("[
|
| 1192 |
if unexpected:
|
| 1193 |
-
print("[
|
| 1194 |
|
| 1195 |
except Exception as e:
|
| 1196 |
-
print(f"
|
|
|
|
|
|
|
| 1197 |
|
| 1198 |
return multimodal_model
|
| 1199 |
|
| 1200 |
|
| 1201 |
-
#
|
| 1202 |
-
# Downstream
|
| 1203 |
-
#
|
| 1204 |
def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1205 |
samples = []
|
| 1206 |
for _, row in df.iterrows():
|
| 1207 |
-
# Require at least one modality
|
| 1208 |
has_modality = False
|
| 1209 |
for col in ['graph', 'geometry', 'fingerprints', 'psmiles']:
|
| 1210 |
if col in row and row[col] and str(row[col]).strip() != "":
|
|
@@ -1232,32 +1355,47 @@ def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
|
|
| 1232 |
return samples
|
| 1233 |
|
| 1234 |
|
| 1235 |
-
def
|
| 1236 |
pretrained_path: str, output_file: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1237 |
os.makedirs(pretrained_path, exist_ok=True)
|
| 1238 |
|
| 1239 |
-
#
|
| 1240 |
modality_cols = ["graph", "geometry", "fingerprints", "psmiles"]
|
| 1241 |
df_proc = aggregate_polyinfo_duplicates(df_raw, modality_cols=modality_cols, property_cols=property_cols)
|
| 1242 |
|
| 1243 |
-
all_results = {"per_property": {}, "mode": "
|
| 1244 |
|
| 1245 |
-
# Load base pretrained model once per property-run (we reload inside each run to avoid leakage across runs)
|
| 1246 |
for pname, pcol in zip(property_list, property_cols):
|
| 1247 |
-
print(f"\n=== Uni-Poly matched fine-tuning: {pname} (col='{pcol}') ===")
|
| 1248 |
samples = build_samples_for_property(df_proc, pcol)
|
|
|
|
|
|
|
| 1249 |
if len(samples) < 200:
|
| 1250 |
-
print(f"[WARN]
|
| 1251 |
if len(samples) < 50:
|
| 1252 |
-
print(f"[WARN] Skipping '{pname}'
|
| 1253 |
continue
|
| 1254 |
|
| 1255 |
run_metrics = []
|
| 1256 |
run_records = []
|
| 1257 |
|
| 1258 |
-
#
|
| 1259 |
best_overall_r2 = -1e18
|
| 1260 |
-
best_overall_payload = None
|
| 1261 |
|
| 1262 |
idxs = np.arange(len(samples))
|
| 1263 |
cv = KFold(n_splits=NUM_RUNS, shuffle=True, random_state=42)
|
|
@@ -1266,10 +1404,12 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1266 |
seed = 42 + run_idx
|
| 1267 |
set_seed(seed)
|
| 1268 |
|
|
|
|
|
|
|
| 1269 |
trainval = [copy.deepcopy(samples[i]) for i in trainval_idx]
|
| 1270 |
test = [copy.deepcopy(samples[i]) for i in test_idx]
|
| 1271 |
|
| 1272 |
-
# train/val
|
| 1273 |
tr_idx, va_idx = train_test_split(
|
| 1274 |
np.arange(len(trainval)),
|
| 1275 |
test_size=VAL_SIZE_WITHIN_TRAINVAL,
|
|
@@ -1279,7 +1419,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1279 |
train = [copy.deepcopy(trainval[i]) for i in tr_idx]
|
| 1280 |
val = [copy.deepcopy(trainval[i]) for i in va_idx]
|
| 1281 |
|
| 1282 |
-
#
|
| 1283 |
sc = StandardScaler()
|
| 1284 |
sc.fit(np.array([s["target_raw"] for s in train]).reshape(-1, 1))
|
| 1285 |
|
|
@@ -1299,9 +1439,11 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1299 |
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
|
| 1300 |
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
|
| 1301 |
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
model
|
|
|
|
|
|
|
| 1305 |
|
| 1306 |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 1307 |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
|
|
@@ -1310,6 +1452,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1310 |
best_state = None
|
| 1311 |
no_improve = 0
|
| 1312 |
|
|
|
|
| 1313 |
for epoch in range(1, NUM_EPOCHS + 1):
|
| 1314 |
tr_loss = train_one_epoch(model, dl_train, optimizer, DEVICE)
|
| 1315 |
va_loss, _, _ = evaluate(model, dl_val, DEVICE)
|
|
@@ -1317,7 +1460,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1317 |
|
| 1318 |
scheduler.step()
|
| 1319 |
|
| 1320 |
-
print(f"[{pname}] fold {run_idx+1}/{NUM_RUNS} epoch {epoch:03d}
|
| 1321 |
|
| 1322 |
if va_loss < best_val - 1e-8:
|
| 1323 |
best_val = va_loss
|
|
@@ -1326,25 +1469,29 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1326 |
else:
|
| 1327 |
no_improve += 1
|
| 1328 |
if no_improve >= PATIENCE:
|
| 1329 |
-
print(f"[{pname}]
|
| 1330 |
break
|
| 1331 |
|
| 1332 |
if best_state is None:
|
| 1333 |
-
print(f"[WARN] No best
|
| 1334 |
continue
|
| 1335 |
|
|
|
|
| 1336 |
model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=True)
|
| 1337 |
_, pred_scaled, true_scaled = evaluate(model, dl_test, DEVICE)
|
| 1338 |
if pred_scaled is None:
|
|
|
|
| 1339 |
continue
|
| 1340 |
|
| 1341 |
-
#
|
| 1342 |
pred = sc.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()
|
| 1343 |
true = sc.inverse_transform(true_scaled.reshape(-1, 1)).ravel()
|
| 1344 |
|
| 1345 |
m = compute_metrics(true, pred)
|
| 1346 |
run_metrics.append(m)
|
| 1347 |
|
|
|
|
|
|
|
| 1348 |
record = {
|
| 1349 |
"property": pname,
|
| 1350 |
"property_col": pcol,
|
|
@@ -1361,7 +1508,7 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1361 |
with open(output_file, "a") as fh:
|
| 1362 |
fh.write(json.dumps(make_json_serializable(record)) + "\n")
|
| 1363 |
|
| 1364 |
-
#
|
| 1365 |
if float(m.get("r2", -1e18)) > float(best_overall_r2):
|
| 1366 |
best_overall_r2 = float(m.get("r2", -1e18))
|
| 1367 |
best_overall_payload = {
|
|
@@ -1378,20 +1525,16 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1378 |
"scaler_scale": make_json_serializable(getattr(sc, "scale_", None)),
|
| 1379 |
"scaler_var": make_json_serializable(getattr(sc, "var_", None)),
|
| 1380 |
"scaler_n_samples_seen": make_json_serializable(getattr(sc, "n_samples_seen_", None)),
|
| 1381 |
-
"model_state_dict": best_state, #
|
| 1382 |
}
|
| 1383 |
|
| 1384 |
-
#
|
| 1385 |
if best_overall_payload is not None and "model_state_dict" in best_overall_payload:
|
| 1386 |
os.makedirs(BEST_WEIGHTS_DIR, exist_ok=True)
|
| 1387 |
prop_dir = os.path.join(BEST_WEIGHTS_DIR, _sanitize_name(pname))
|
| 1388 |
os.makedirs(prop_dir, exist_ok=True)
|
| 1389 |
|
| 1390 |
-
|
| 1391 |
-
ckpt_bundle = {
|
| 1392 |
-
k: v for k, v in best_overall_payload.items()
|
| 1393 |
-
if k != "test_metrics"
|
| 1394 |
-
}
|
| 1395 |
ckpt_bundle["test_metrics"] = best_overall_payload["test_metrics"]
|
| 1396 |
|
| 1397 |
torch.save(ckpt_bundle, os.path.join(prop_dir, "best_run_checkpoint.pt"))
|
|
@@ -1400,10 +1543,10 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1400 |
with open(os.path.join(prop_dir, "best_run_metadata.json"), "w") as fh:
|
| 1401 |
fh.write(json.dumps(make_json_serializable(meta), indent=2))
|
| 1402 |
|
| 1403 |
-
print(f"[
|
| 1404 |
-
print(f"[
|
| 1405 |
|
| 1406 |
-
# Aggregate across
|
| 1407 |
if run_metrics:
|
| 1408 |
r2s = [x["r2"] for x in run_metrics]
|
| 1409 |
maes = [x["mae"] for x in run_metrics]
|
|
@@ -1415,8 +1558,10 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1415 |
"rmse": {"mean": float(np.mean(rmses)), "std": float(np.std(rmses, ddof=0))},
|
| 1416 |
"mse": {"mean": float(np.mean(mses)), "std": float(np.std(mses, ddof=0))},
|
| 1417 |
}
|
|
|
|
| 1418 |
else:
|
| 1419 |
agg = None
|
|
|
|
| 1420 |
|
| 1421 |
all_results["per_property"][pname] = {
|
| 1422 |
"property_col": pcol,
|
|
@@ -1431,38 +1576,42 @@ def run_unipoly_downstream(property_list: List[str], property_cols: List[str], d
|
|
| 1431 |
return all_results
|
| 1432 |
|
| 1433 |
|
| 1434 |
-
#
|
| 1435 |
# Main
|
| 1436 |
-
#
|
| 1437 |
def main():
|
|
|
|
| 1438 |
if os.path.exists(OUTPUT_RESULTS):
|
| 1439 |
backup = OUTPUT_RESULTS + ".bak"
|
| 1440 |
shutil.copy(OUTPUT_RESULTS, backup)
|
| 1441 |
-
print(f"[
|
| 1442 |
open(OUTPUT_RESULTS, "w").close()
|
|
|
|
| 1443 |
|
|
|
|
| 1444 |
if not os.path.isfile(POLYINFO_PATH):
|
| 1445 |
raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}")
|
| 1446 |
polyinfo_raw = pd.read_csv(POLYINFO_PATH, engine="python")
|
|
|
|
| 1447 |
|
|
|
|
| 1448 |
found = find_property_columns(polyinfo_raw.columns)
|
| 1449 |
prop_map = {req: col for req, col in found.items()}
|
| 1450 |
-
print(f"[
|
| 1451 |
|
| 1452 |
property_list = []
|
| 1453 |
property_cols = []
|
| 1454 |
for req in REQUESTED_PROPERTIES:
|
| 1455 |
col = prop_map.get(req)
|
| 1456 |
if col is None:
|
| 1457 |
-
print(f"[WARN] Could not find a column for
|
| 1458 |
continue
|
| 1459 |
property_list.append(req)
|
| 1460 |
property_cols.append(col)
|
| 1461 |
|
| 1462 |
-
|
| 1463 |
-
overall = run_unipoly_downstream(property_list, property_cols, polyinfo_raw, PRETRAINED_MULTIMODAL_DIR, OUTPUT_RESULTS)
|
| 1464 |
|
| 1465 |
-
# Write final summary (aggregated per property)
|
| 1466 |
final_agg = {}
|
| 1467 |
if overall and "per_property" in overall:
|
| 1468 |
for pname, info in overall["per_property"].items():
|
|
@@ -1473,7 +1622,8 @@ def main():
|
|
| 1473 |
fh.write(json.dumps(make_json_serializable(final_agg), indent=2))
|
| 1474 |
fh.write("\n")
|
| 1475 |
|
| 1476 |
-
print(f"\n
|
|
|
|
| 1477 |
|
| 1478 |
|
| 1479 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
import time
|
|
|
|
| 19 |
import copy
|
| 20 |
from typing import List, Dict, Optional, Tuple, Any
|
| 21 |
|
| 22 |
+
# Increase CSV field size limit (PolyInfo modality JSON fields can be large).
|
| 23 |
csv.field_size_limit(sys.maxsize)
|
| 24 |
|
| 25 |
+
# =============================================================================
|
| 26 |
+
# Imports: Shared encoders/helpers from PolyFusion
|
| 27 |
+
# =============================================================================
|
| 28 |
from PolyFusion.GINE import GineEncoder, match_edge_attr_to_index, safe_get
|
| 29 |
from PolyFusion.SchNet import NodeSchNetWrapper
|
| 30 |
from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder
|
| 31 |
from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
|
| 32 |
|
| 33 |
+
# =============================================================================
|
| 34 |
# Configuration
|
| 35 |
+
# =============================================================================
|
| 36 |
BASE_DIR = "/path/to/Polymer_Foundational_Model"
|
| 37 |
POLYINFO_PATH = "/path/to/polyinfo_with_modalities.csv"
|
| 38 |
|
| 39 |
+
# Pretrained encoder directories (update these placeholders)
|
| 40 |
PRETRAINED_MULTIMODAL_DIR = "/path/to/multimodal_output/best"
|
| 41 |
BEST_GINE_DIR = "/path/to/gin_output/best"
|
| 42 |
BEST_SCHNET_DIR = "/path/to/schnet_output/best"
|
| 43 |
BEST_FP_DIR = "/path/to/fingerprint_mlm_output/best"
|
| 44 |
BEST_PSMILES_DIR = "/path/to/polybert_output/best"
|
| 45 |
|
| 46 |
+
# Output log file (per-run json lines + per-property aggregated summary)
|
| 47 |
OUTPUT_RESULTS = "/path/to/multimodal_downstream_results.txt"
|
| 48 |
|
| 49 |
+
# Directory to save best-performing checkpoint bundle per property (best CV run)
|
| 50 |
BEST_WEIGHTS_DIR = "/path/to/multimodal_downstream_bestweights"
|
| 51 |
|
| 52 |
+
# -----------------------------------------------------------------------------
|
| 53 |
+
# Model sizes / dims (must match your pretrained multimodal encoder settings)
|
| 54 |
+
# -----------------------------------------------------------------------------
|
| 55 |
MAX_ATOMIC_Z = 85
|
| 56 |
MASK_ATOM_ID = MAX_ATOMIC_Z + 1
|
| 57 |
|
| 58 |
+
# GINE
|
| 59 |
NODE_EMB_DIM = 300
|
| 60 |
EDGE_EMB_DIM = 300
|
| 61 |
NUM_GNN_LAYERS = 5
|
| 62 |
|
| 63 |
+
# SchNet
|
| 64 |
SCHNET_NUM_GAUSSIANS = 50
|
| 65 |
SCHNET_NUM_INTERACTIONS = 6
|
| 66 |
SCHNET_CUTOFF = 10.0
|
| 67 |
SCHNET_MAX_NEIGHBORS = 64
|
| 68 |
SCHNET_HIDDEN = 600
|
| 69 |
|
| 70 |
+
# Fingerprints
|
| 71 |
FP_LENGTH = 2048
|
| 72 |
MASK_TOKEN_ID_FP = 2
|
| 73 |
VOCAB_SIZE_FP = 3
|
| 74 |
|
| 75 |
+
# Contrastive embedding dim
|
| 76 |
CL_EMB_DIM = 600
|
| 77 |
|
| 78 |
+
# PSMILES/DeBERTa
|
| 79 |
DEBERTA_HIDDEN = 600
|
| 80 |
PSMILES_MAX_LEN = 128
|
| 81 |
|
| 82 |
+
# -----------------------------------------------------------------------------
|
| 83 |
+
# Uni-Poly style fusion + regression head hyperparameters
|
| 84 |
+
# -----------------------------------------------------------------------------
|
| 85 |
+
POLYF_EMB_DIM = 600
|
| 86 |
+
POLYF_ATTN_HEADS = 8
|
| 87 |
+
POLYF_DROPOUT = 0.1
|
| 88 |
+
POLYF_FF_MULT = 4 # FFN hidden = 4*d (common transformer practice)
|
| 89 |
+
|
| 90 |
+
# -----------------------------------------------------------------------------
|
| 91 |
+
# Fine-tuning parameters (single-task per property)
|
| 92 |
+
# -----------------------------------------------------------------------------
|
| 93 |
MAX_LEN = 128
|
| 94 |
BATCH_SIZE = 32
|
| 95 |
NUM_EPOCHS = 100
|
|
|
|
| 99 |
|
| 100 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
|
| 102 |
+
# Properties to evaluate (order preserved)
|
| 103 |
REQUESTED_PROPERTIES = [
|
| 104 |
"density",
|
| 105 |
"glass transition",
|
|
|
|
| 108 |
"thermal decomposition"
|
| 109 |
]
|
| 110 |
|
| 111 |
+
# True K-fold evaluation to match "fivefold per property"
|
| 112 |
NUM_RUNS = 5
|
| 113 |
TEST_SIZE = 0.10
|
| 114 |
+
VAL_SIZE_WITHIN_TRAINVAL = 0.10 # fraction of trainval reserved for val split
|
| 115 |
|
| 116 |
+
# Duplicate aggregation (noise reduction) key preference order
|
| 117 |
AGG_KEYS_PREFERENCE = ["polymer_id", "PolymerID", "poly_id", "psmiles", "smiles", "canonical_smiles"]
|
| 118 |
|
| 119 |
+
# =============================================================================
|
| 120 |
# Utilities
|
| 121 |
+
# =============================================================================
|
| 122 |
def set_seed(seed: int):
|
| 123 |
+
"""Set all relevant RNG seeds for reproducible folds."""
|
| 124 |
random.seed(seed)
|
| 125 |
np.random.seed(seed)
|
| 126 |
torch.manual_seed(seed)
|
| 127 |
if torch.cuda.is_available():
|
| 128 |
torch.cuda.manual_seed_all(seed)
|
| 129 |
+
# Deterministic settings: reproducible but may reduce throughput.
|
| 130 |
torch.backends.cudnn.deterministic = True
|
| 131 |
torch.backends.cudnn.benchmark = False
|
| 132 |
|
| 133 |
|
| 134 |
def make_json_serializable(obj):
|
| 135 |
+
"""Convert common numpy/torch/pandas objects into JSON-safe Python types."""
|
| 136 |
if isinstance(obj, dict):
|
| 137 |
return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()}
|
| 138 |
if isinstance(obj, (list, tuple, set)):
|
|
|
|
| 158 |
pass
|
| 159 |
return obj
|
| 160 |
|
| 161 |
+
|
| 162 |
def summarize_state_dict_load(full_state: dict, model_state: dict, filtered_state: dict):
|
| 163 |
+
"""
|
| 164 |
+
Print a concise load report:
|
| 165 |
+
- how many checkpoint keys exist
|
| 166 |
+
- how many model keys exist
|
| 167 |
+
- how many keys will be loaded (intersection with matching shapes)
|
| 168 |
+
- common reasons for skipped keys
|
| 169 |
+
"""
|
| 170 |
n_ckpt = len(full_state)
|
| 171 |
n_model = len(model_state)
|
| 172 |
n_loaded = len(filtered_state)
|
|
|
|
| 182 |
print(f" ckpt keys: {n_ckpt}")
|
| 183 |
print(f" model keys: {n_model}")
|
| 184 |
print(f" loaded keys: {n_loaded}")
|
| 185 |
+
print(f" skipped (not in model): {len(missing_in_model)}")
|
| 186 |
+
print(f" skipped (shape mismatch): {len(shape_mismatch)}")
|
| 187 |
|
| 188 |
if missing_in_model:
|
| 189 |
+
print(" examples skipped (not in model):", missing_in_model[:10])
|
| 190 |
if shape_mismatch:
|
| 191 |
+
print(" examples skipped (shape mismatch):")
|
| 192 |
for k in shape_mismatch[:10]:
|
| 193 |
print(f" {k}: ckpt={tuple(full_state[k].shape)} model={tuple(model_state[k].shape)}")
|
| 194 |
print("")
|
|
|
|
| 196 |
|
| 197 |
def find_property_columns(columns):
|
| 198 |
"""
|
| 199 |
+
Robust property column matching with guardrails:
|
| 200 |
+
- Prefer word-level (token) matches over substring matches.
|
| 201 |
+
- For 'density', avoid confusing with 'cohesive energy density' columns.
|
| 202 |
+
- Log chosen column and competing candidates when ambiguous.
|
| 203 |
"""
|
| 204 |
lowered = {c.lower(): c for c in columns}
|
| 205 |
found = {}
|
|
|
|
| 208 |
req_low = req.lower().strip()
|
| 209 |
exact = None
|
| 210 |
|
| 211 |
+
# Pass 1: token-level exactness (safer than substring match)
|
| 212 |
for c_low, c_orig in lowered.items():
|
| 213 |
tokens = set(c_low.replace('_', ' ').split())
|
| 214 |
if req_low in tokens or c_low == req_low:
|
|
|
|
| 221 |
found[req] = exact
|
| 222 |
continue
|
| 223 |
|
| 224 |
+
# Pass 2: substring match as fallback
|
| 225 |
candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low]
|
| 226 |
if req_low == "density":
|
| 227 |
candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()]
|
|
|
|
| 231 |
else:
|
| 232 |
chosen = candidates[0] if candidates else None
|
| 233 |
found[req] = chosen
|
| 234 |
+
print(f"[COLMAP] Requested '{req}' -> chosen column: {chosen}")
|
| 235 |
if candidates:
|
| 236 |
+
print(f"[COLMAP] Candidates for '{req}': {candidates}")
|
| 237 |
else:
|
| 238 |
+
print(f"[COLMAP][WARN] No candidates found for '{req}' using substring search.")
|
| 239 |
return found
|
| 240 |
|
| 241 |
+
|
| 242 |
def choose_aggregation_key(df: pd.DataFrame) -> Optional[str]:
|
| 243 |
+
"""Pick the most stable identifier available for duplicate aggregation."""
|
| 244 |
for k in AGG_KEYS_PREFERENCE:
|
| 245 |
if k in df.columns:
|
| 246 |
return k
|
|
|
|
| 249 |
|
| 250 |
def aggregate_polyinfo_duplicates(df: pd.DataFrame, modality_cols: List[str], property_cols: List[str]) -> pd.DataFrame:
|
| 251 |
"""
|
| 252 |
+
Optional noise reduction: group duplicate polymer entries and average properties.
|
| 253 |
+
|
| 254 |
+
- Modalities are taken as "first" (they should be consistent per polymer key).
|
| 255 |
+
- Properties are averaged (mean).
|
| 256 |
"""
|
| 257 |
key = choose_aggregation_key(df)
|
| 258 |
if key is None:
|
| 259 |
+
print("[AGG] No aggregation key found; skipping duplicate aggregation.")
|
| 260 |
return df
|
| 261 |
|
|
|
|
| 262 |
df2 = df.copy()
|
| 263 |
df2[key] = df2[key].astype(str)
|
| 264 |
df2 = df2[df2[key].str.strip() != ""].copy()
|
| 265 |
if len(df2) == 0:
|
| 266 |
+
print("[AGG] Aggregation key exists but is empty; skipping duplicate aggregation.")
|
| 267 |
return df
|
| 268 |
|
| 269 |
agg_dict = {}
|
|
|
|
| 270 |
for mc in modality_cols:
|
| 271 |
if mc in df2.columns:
|
| 272 |
agg_dict[mc] = "first"
|
|
|
|
| 273 |
for pc in property_cols:
|
| 274 |
if pc in df2.columns:
|
| 275 |
agg_dict[pc] = "mean"
|
| 276 |
|
| 277 |
grouped = df2.groupby(key, as_index=False).agg(agg_dict)
|
| 278 |
+
print(f"[AGG] Grouped by '{key}': {len(df)} rows -> {len(grouped)} unique keys")
|
| 279 |
return grouped
|
| 280 |
|
| 281 |
|
| 282 |
def _sanitize_name(s: str) -> str:
|
| 283 |
+
"""Create a filesystem-safe name for property directories."""
|
| 284 |
s2 = str(s).strip().lower()
|
| 285 |
keep = []
|
| 286 |
for ch in s2:
|
|
|
|
| 296 |
out = out.strip("_")
|
| 297 |
return out or "property"
|
| 298 |
|
| 299 |
+
|
| 300 |
+
# =============================================================================
|
| 301 |
+
# Multimodal backbone: encode + project + modality-aware fusion
|
| 302 |
+
# =============================================================================
|
| 303 |
class MultimodalContrastiveModel(nn.Module):
|
| 304 |
"""
|
| 305 |
+
Multimodal encoder wrapper:
|
| 306 |
+
|
| 307 |
+
1) Runs each available modality encoder:
|
| 308 |
+
- GINE (graph)
|
| 309 |
+
- SchNet (3D geometry)
|
| 310 |
+
- Transformer FP encoder (Morgan bit sequence)
|
| 311 |
+
- DeBERTa-based PSMILES encoder (sequence)
|
| 312 |
+
|
| 313 |
+
2) Projects each modality embedding to a shared dim (emb_dim).
|
| 314 |
+
|
| 315 |
+
3) Normalizes each modality embedding (L2), drops out, then fuses via
|
| 316 |
+
a masked mean across modalities that are present for each sample.
|
| 317 |
+
|
| 318 |
+
4) Normalizes the final fused embedding (L2).
|
| 319 |
+
|
| 320 |
+
Expected downstream usage:
|
| 321 |
+
z = model(batch_mods, modality_mask=modality_mask) # (B, emb_dim)
|
| 322 |
"""
|
| 323 |
|
| 324 |
def __init__(
|
|
|
|
| 345 |
self.out_dim = self.emb_dim
|
| 346 |
self.dropout = nn.Dropout(float(dropout))
|
| 347 |
|
| 348 |
+
# Determine which modalities are enabled
|
| 349 |
if modalities is None:
|
| 350 |
mods = []
|
| 351 |
if self.gine is not None:
|
|
|
|
| 360 |
else:
|
| 361 |
self.modalities = [m for m in modalities if m in ("gine", "schnet", "fp", "psmiles")]
|
| 362 |
|
| 363 |
+
# Projection heads into shared embedding space
|
| 364 |
self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None
|
| 365 |
self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None
|
| 366 |
self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None
|
| 367 |
|
| 368 |
+
# Infer PSMILES hidden size if possible; fallback to DEBERTA_HIDDEN
|
| 369 |
psm_in = None
|
| 370 |
if self.psmiles is not None:
|
| 371 |
if hasattr(self.psmiles, "out_dim"):
|
|
|
|
| 384 |
self.proj_psmiles = nn.Linear(psm_in, self.emb_dim) if (self.psmiles is not None) else None
|
| 385 |
|
| 386 |
def freeze_cl_encoders(self):
|
| 387 |
+
"""Freeze all modality encoders (optional for evaluation-only usage)."""
|
| 388 |
for enc in (self.gine, self.schnet, self.fp, self.psmiles):
|
| 389 |
if enc is None:
|
| 390 |
continue
|
|
|
|
| 394 |
|
| 395 |
def _masked_mean_combine(self, zs: List[torch.Tensor], masks: List[torch.Tensor]) -> torch.Tensor:
|
| 396 |
"""
|
| 397 |
+
Compute sample-wise mean over available modalities.
|
| 398 |
+
|
| 399 |
+
zs: list of modality embeddings, each (B,D)
|
| 400 |
+
masks: list of modality presence masks, each (B,) bool
|
| 401 |
+
returns: (B,D)
|
| 402 |
"""
|
| 403 |
if not zs:
|
| 404 |
device = next(self.parameters()).device
|
|
|
|
| 421 |
def forward(self, batch_mods: dict, modality_mask: Optional[dict] = None) -> torch.Tensor:
|
| 422 |
"""
|
| 423 |
batch_mods keys: 'gine', 'schnet', 'fp', 'psmiles'
|
| 424 |
+
modality_mask: dict {modality_name: (B,) bool} describing presence.
|
| 425 |
"""
|
| 426 |
device = next(self.parameters()).device
|
| 427 |
|
| 428 |
zs = []
|
| 429 |
ms = []
|
| 430 |
|
| 431 |
+
# Infer batch size B
|
| 432 |
B = None
|
| 433 |
if modality_mask is not None:
|
| 434 |
+
for _, v in modality_mask.items():
|
| 435 |
if isinstance(v, torch.Tensor) and v.numel() > 0:
|
| 436 |
B = int(v.size(0))
|
| 437 |
break
|
| 438 |
|
|
|
|
| 439 |
if B is None:
|
| 440 |
if "fp" in batch_mods and batch_mods["fp"] is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor):
|
| 441 |
B = int(batch_mods["fp"]["input_ids"].size(0))
|
|
|
|
| 445 |
if B is None:
|
| 446 |
return torch.zeros((1, self.emb_dim), device=device)
|
| 447 |
|
|
|
|
| 448 |
def _get_mask(name: str) -> torch.Tensor:
|
| 449 |
if modality_mask is not None and name in modality_mask and isinstance(modality_mask[name], torch.Tensor):
|
| 450 |
return modality_mask[name].to(device).bool()
|
| 451 |
return torch.ones((B,), device=device, dtype=torch.bool)
|
| 452 |
|
| 453 |
+
# -------------------------
|
| 454 |
+
# GINE (graph modality)
|
| 455 |
+
# -------------------------
|
| 456 |
if "gine" in self.modalities and self.gine is not None and batch_mods.get("gine", None) is not None:
|
| 457 |
g = batch_mods["gine"]
|
| 458 |
if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0:
|
|
|
|
| 470 |
zs.append(z)
|
| 471 |
ms.append(_get_mask("gine"))
|
| 472 |
|
| 473 |
+
# -------------------------
|
| 474 |
+
# SchNet (3D geometry)
|
| 475 |
+
# -------------------------
|
| 476 |
if "schnet" in self.modalities and self.schnet is not None and batch_mods.get("schnet", None) is not None:
|
| 477 |
s = batch_mods["schnet"]
|
| 478 |
if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0:
|
|
|
|
| 487 |
zs.append(z)
|
| 488 |
ms.append(_get_mask("schnet"))
|
| 489 |
|
| 490 |
+
# -------------------------
|
| 491 |
+
# Fingerprint modality
|
| 492 |
+
# -------------------------
|
| 493 |
if "fp" in self.modalities and self.fp is not None and batch_mods.get("fp", None) is not None:
|
| 494 |
f = batch_mods["fp"]
|
| 495 |
if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0:
|
|
|
|
| 503 |
zs.append(z)
|
| 504 |
ms.append(_get_mask("fp"))
|
| 505 |
|
| 506 |
+
# -------------------------
|
| 507 |
+
# PSMILES text modality
|
| 508 |
+
# -------------------------
|
| 509 |
if "psmiles" in self.modalities and self.psmiles is not None and batch_mods.get("psmiles", None) is not None:
|
| 510 |
p = batch_mods["psmiles"]
|
| 511 |
if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0:
|
|
|
|
| 519 |
zs.append(z)
|
| 520 |
ms.append(_get_mask("psmiles"))
|
| 521 |
|
| 522 |
+
# Fuse and normalize
|
| 523 |
if not zs:
|
| 524 |
return torch.zeros((B, self.emb_dim), device=device)
|
| 525 |
|
|
|
|
| 527 |
z = F.normalize(z, dim=-1)
|
| 528 |
return z
|
| 529 |
|
|
|
|
| 530 |
@torch.no_grad()
|
| 531 |
def encode_psmiles(
|
| 532 |
self,
|
|
|
|
| 536 |
device: str = DEVICE
|
| 537 |
) -> np.ndarray:
|
| 538 |
"""
|
| 539 |
+
Convenience: PSMILES-only embeddings (used for fast bulk encoding tasks).
|
| 540 |
"""
|
| 541 |
self.eval()
|
| 542 |
if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None:
|
|
|
|
| 561 |
device: str = DEVICE
|
| 562 |
) -> np.ndarray:
|
| 563 |
"""
|
| 564 |
+
Convenience: multimodal embedding for records carrying:
|
| 565 |
+
- graph, geometry, fingerprints, psmiles
|
| 566 |
+
Missing modalities are handled sample-wise via modality masking.
|
| 567 |
"""
|
| 568 |
self.eval()
|
| 569 |
dev = torch.device(device)
|
|
|
|
| 573 |
for i in range(0, len(records), batch_size):
|
| 574 |
chunk = records[i:i + batch_size]
|
| 575 |
|
|
|
|
| 576 |
# PSMILES batch
|
| 577 |
psmiles_texts = [str(r.get("psmiles", "")) for r in chunk]
|
| 578 |
p_enc = None
|
| 579 |
if self.psm_tok is not None:
|
| 580 |
p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt")
|
| 581 |
|
| 582 |
+
# FP batch (always stack; missing handled by attention_mask downstream)
|
| 583 |
fp_ids, fp_attn = [], []
|
| 584 |
for r in chunk:
|
| 585 |
f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH)
|
|
|
|
| 588 |
fp_ids = torch.stack(fp_ids, dim=0)
|
| 589 |
fp_attn = torch.stack(fp_attn, dim=0)
|
| 590 |
|
| 591 |
+
# GINE + SchNet packed batching
|
|
|
|
| 592 |
gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []}
|
| 593 |
node_offset = 0
|
| 594 |
for bi, r in enumerate(chunk):
|
|
|
|
| 646 |
"psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None
|
| 647 |
}
|
| 648 |
|
| 649 |
+
# NOTE: This script uses forward() as the encoder entry point.
|
| 650 |
+
z = self.forward(batch_mods, modality_mask=None)
|
| 651 |
outs.append(z.detach().cpu().numpy())
|
| 652 |
|
| 653 |
return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
|
| 654 |
|
| 655 |
+
|
| 656 |
+
# =============================================================================
|
| 657 |
# Tokenizer setup
|
| 658 |
+
# =============================================================================
|
| 659 |
SPM_MODEL = "/path/to/spm.model"
|
| 660 |
tokenizer = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
| 661 |
|
| 662 |
+
# =============================================================================
|
| 663 |
+
# Dataset: single-task property prediction (with modality parsing)
|
| 664 |
+
# =============================================================================
|
| 665 |
class PolymerPropertyDataset(Dataset):
|
| 666 |
+
"""
|
| 667 |
+
Dataset that prepares one sample with up to four modalities:
|
| 668 |
+
- graph (for GINE)
|
| 669 |
+
- geometry (for SchNet)
|
| 670 |
+
- fingerprints (for FP transformer)
|
| 671 |
+
- psmiles text (for DeBERTa encoder)
|
| 672 |
+
|
| 673 |
+
Target is a single scalar per sample (already scaled externally).
|
| 674 |
+
"""
|
| 675 |
def __init__(self, data_list, tokenizer, max_length=128):
|
| 676 |
self.data_list = data_list
|
| 677 |
self.tokenizer = tokenizer
|
|
|
|
| 683 |
def __getitem__(self, idx):
|
| 684 |
data = self.data_list[idx]
|
| 685 |
|
| 686 |
+
# ---------------------------------------------------------------------
|
| 687 |
+
# Graph -> GINE tensors (robust parsing of stored JSON fields)
|
| 688 |
+
# ---------------------------------------------------------------------
|
| 689 |
gine_data = None
|
| 690 |
if 'graph' in data and data['graph']:
|
| 691 |
try:
|
|
|
|
| 714 |
edge_index = None
|
| 715 |
edge_attr = None
|
| 716 |
|
| 717 |
+
# Fallback: adjacency matrix if edge_indices missing
|
| 718 |
if edge_indices_raw is None:
|
| 719 |
adj_mat = safe_get(graph_field, "adjacency_matrix", None)
|
| 720 |
if adj_mat:
|
|
|
|
| 730 |
E = len(srcs)
|
| 731 |
edge_attr = [[0.0, 0.0, 0.0] for _ in range(E)]
|
| 732 |
else:
|
| 733 |
+
# edge_indices can be [[srcs],[dsts]] or list of pairs
|
| 734 |
srcs, dsts = [], []
|
| 735 |
if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0:
|
| 736 |
if isinstance(edge_indices_raw[0], list):
|
|
|
|
| 757 |
if len(srcs) > 0:
|
| 758 |
edge_index = [srcs, dsts]
|
| 759 |
|
| 760 |
+
# edge_features: attempt to map known fields; otherwise zeros
|
| 761 |
if edge_features_raw and isinstance(edge_features_raw, list):
|
| 762 |
bond_types = []
|
| 763 |
stereos = []
|
|
|
|
| 785 |
except Exception:
|
| 786 |
gine_data = None
|
| 787 |
|
| 788 |
+
# ---------------------------------------------------------------------
|
| 789 |
+
# Geometry -> SchNet tensors (best conformer)
|
| 790 |
+
# ---------------------------------------------------------------------
|
| 791 |
schnet_data = None
|
| 792 |
if 'geometry' in data and data['geometry']:
|
| 793 |
try:
|
|
|
|
| 804 |
except Exception:
|
| 805 |
schnet_data = None
|
| 806 |
|
| 807 |
+
# ---------------------------------------------------------------------
|
| 808 |
+
# Fingerprints -> FP transformer inputs (bit sequence)
|
| 809 |
+
# ---------------------------------------------------------------------
|
| 810 |
fp_data = None
|
| 811 |
if 'fingerprints' in data and data['fingerprints']:
|
| 812 |
try:
|
|
|
|
| 857 |
except Exception:
|
| 858 |
fp_data = None
|
| 859 |
|
| 860 |
+
# ---------------------------------------------------------------------
|
| 861 |
+
# PSMILES -> DeBERTa tokenizer inputs
|
| 862 |
+
# ---------------------------------------------------------------------
|
| 863 |
psmiles_data = None
|
| 864 |
if 'psmiles' in data and data['psmiles'] and self.tokenizer is not None:
|
| 865 |
try:
|
|
|
|
| 877 |
except Exception:
|
| 878 |
psmiles_data = None
|
| 879 |
|
| 880 |
+
# ---------------------------------------------------------------------
|
| 881 |
+
# Fill defaults for missing modalities (keeps collate simpler)
|
| 882 |
+
# ---------------------------------------------------------------------
|
| 883 |
if gine_data is None:
|
| 884 |
gine_data = {
|
| 885 |
'z': torch.tensor([], dtype=torch.long),
|
|
|
|
| 907 |
'attention_mask': torch.zeros(PSMILES_MAX_LEN, dtype=torch.bool)
|
| 908 |
}
|
| 909 |
|
| 910 |
+
# Single-task regression target (already scaled)
|
| 911 |
target_scaled = float(data.get("target_scaled", 0.0))
|
| 912 |
|
| 913 |
return {
|
|
|
|
| 919 |
}
|
| 920 |
|
| 921 |
|
| 922 |
+
# =============================================================================
|
| 923 |
+
# Collate: pack variable-sized graph/3D into batch tensors + modality masks
|
| 924 |
+
# =============================================================================
|
| 925 |
def multimodal_collate_fn(batch):
|
| 926 |
+
"""
|
| 927 |
+
Collate samples into a single minibatch.
|
| 928 |
+
|
| 929 |
+
- GINE: concatenate nodes across samples and build a `batch` vector.
|
| 930 |
+
- SchNet: concatenate atoms/coords across samples and build a `batch` vector.
|
| 931 |
+
- FP/PSMILES: stack to (B, L).
|
| 932 |
+
- modality_mask: per-sample boolean flags indicating availability.
|
| 933 |
+
"""
|
| 934 |
B = len(batch)
|
| 935 |
|
| 936 |
+
# -------------------------
|
| 937 |
+
# GINE packing
|
| 938 |
+
# -------------------------
|
| 939 |
all_z = []
|
| 940 |
all_ch = []
|
| 941 |
all_fc = []
|
|
|
|
| 982 |
edge_index_batched = torch.empty((2, 0), dtype=torch.long)
|
| 983 |
edge_attr_batched = torch.zeros((0, 3), dtype=torch.float)
|
| 984 |
|
| 985 |
+
# -------------------------
|
| 986 |
+
# SchNet packing
|
| 987 |
+
# -------------------------
|
| 988 |
all_sz = []
|
| 989 |
all_pos = []
|
| 990 |
schnet_batch = []
|
|
|
|
| 1009 |
s_pos_batch = torch.cat(all_pos, dim=0)
|
| 1010 |
s_batch_batch = torch.cat(schnet_batch, dim=0)
|
| 1011 |
|
| 1012 |
+
# -------------------------
|
| 1013 |
+
# FP stacking
|
| 1014 |
+
# -------------------------
|
| 1015 |
fp_ids = torch.stack([item["fp"]["input_ids"] for item in batch], dim=0)
|
| 1016 |
fp_attn = torch.stack([item["fp"]["attention_mask"] for item in batch], dim=0)
|
| 1017 |
fp_present = (fp_attn.sum(dim=1) > 0).cpu().numpy().tolist()
|
| 1018 |
|
| 1019 |
+
# -------------------------
|
| 1020 |
+
# PSMILES stacking
|
| 1021 |
+
# -------------------------
|
| 1022 |
p_ids = torch.stack([item["psmiles"]["input_ids"] for item in batch], dim=0)
|
| 1023 |
p_attn = torch.stack([item["psmiles"]["attention_mask"] for item in batch], dim=0)
|
| 1024 |
psmiles_present = (p_attn.sum(dim=1) > 0).cpu().numpy().tolist()
|
|
|
|
| 1026 |
# Target
|
| 1027 |
target = torch.stack([item["target"] for item in batch], dim=0) # (B,)
|
| 1028 |
|
| 1029 |
+
# Presence mask for fusion (per-sample modality availability)
|
| 1030 |
modality_mask = {
|
| 1031 |
"gine": torch.tensor(gine_present, dtype=torch.bool),
|
| 1032 |
"schnet": torch.tensor(schnet_present, dtype=torch.bool),
|
|
|
|
| 1061 |
}
|
| 1062 |
|
| 1063 |
|
| 1064 |
+
# =============================================================================
|
| 1065 |
+
# Single-task regressor head (Uni-Poly-style fine-tuning)
|
| 1066 |
+
# =============================================================================
|
| 1067 |
+
class PolyFPropertyRegressor(nn.Module):
|
| 1068 |
+
"""
|
| 1069 |
+
Simple MLP head on top of the multimodal fused embedding.
|
| 1070 |
+
Predicts one scalar (scaled target) per sample.
|
| 1071 |
+
"""
|
| 1072 |
+
def __init__(self, polyf_model: MultimodalContrastiveModel, emb_dim: int = POLYF_EMB_DIM, dropout: float = 0.1):
|
| 1073 |
super().__init__()
|
| 1074 |
+
self.polyf = polyf_model
|
| 1075 |
self.head = nn.Sequential(
|
| 1076 |
nn.Linear(emb_dim, emb_dim // 2),
|
| 1077 |
nn.ReLU(),
|
|
|
|
| 1080 |
)
|
| 1081 |
|
| 1082 |
def forward(self, batch_mods, modality_mask=None):
|
| 1083 |
+
emb = self.polyf(batch_mods, modality_mask=modality_mask) # (B,d)
|
| 1084 |
y = self.head(emb).squeeze(-1) # (B,)
|
| 1085 |
return y
|
| 1086 |
|
| 1087 |
|
| 1088 |
+
# =============================================================================
|
| 1089 |
+
# Training / evaluation helpers
|
| 1090 |
+
# =============================================================================
|
| 1091 |
def compute_metrics(y_true, y_pred):
|
| 1092 |
+
"""Compute standard regression metrics in original units."""
|
| 1093 |
mse = mean_squared_error(y_true, y_pred)
|
| 1094 |
rmse = math.sqrt(mse)
|
| 1095 |
mae = mean_absolute_error(y_true, y_pred)
|
|
|
|
| 1098 |
|
| 1099 |
|
| 1100 |
def train_one_epoch(model, dataloader, optimizer, device):
|
| 1101 |
+
"""One epoch of supervised regression training (MSE loss)."""
|
| 1102 |
model.train()
|
| 1103 |
total_loss = 0.0
|
| 1104 |
total_n = 0
|
| 1105 |
|
| 1106 |
for batch in dataloader:
|
| 1107 |
+
# Move nested batch dict to device
|
| 1108 |
for k in batch:
|
| 1109 |
if k == "target":
|
| 1110 |
batch[k] = batch[k].to(device)
|
|
|
|
| 1137 |
|
| 1138 |
@torch.no_grad()
|
| 1139 |
def evaluate(model, dataloader, device):
|
| 1140 |
+
"""
|
| 1141 |
+
Evaluate on a dataloader:
|
| 1142 |
+
- returns avg loss, predicted scaled values, true scaled values
|
| 1143 |
+
"""
|
| 1144 |
model.eval()
|
| 1145 |
preds = []
|
| 1146 |
trues = []
|
|
|
|
| 1148 |
total_n = 0
|
| 1149 |
|
| 1150 |
for batch in dataloader:
|
| 1151 |
+
# Move nested batch dict to device
|
| 1152 |
for k in batch:
|
| 1153 |
if k == "target":
|
| 1154 |
batch[k] = batch[k].to(device)
|
|
|
|
| 1184 |
return float(avg_loss), preds, trues
|
| 1185 |
|
| 1186 |
|
| 1187 |
+
# =============================================================================
|
| 1188 |
+
# Pretrained loading helpers
|
| 1189 |
+
# =============================================================================
|
| 1190 |
def load_pretrained_multimodal(pretrained_path: str) -> MultimodalContrastiveModel:
|
| 1191 |
+
"""
|
| 1192 |
+
Construct modality encoders and load any available pretrained weights:
|
| 1193 |
+
- modality-specific checkpoints (BEST_*_DIR)
|
| 1194 |
+
- full multimodal checkpoint from `pretrained_path/pytorch_model.bin`
|
| 1195 |
+
|
| 1196 |
+
Returns a ready-to-fine-tune MultimodalContrastiveModel.
|
| 1197 |
+
"""
|
| 1198 |
+
# -------------------------
|
| 1199 |
+
# GINE encoder
|
| 1200 |
+
# -------------------------
|
| 1201 |
gine_encoder = GineEncoder(
|
| 1202 |
node_emb_dim=NODE_EMB_DIM,
|
| 1203 |
edge_emb_dim=EDGE_EMB_DIM,
|
| 1204 |
num_layers=NUM_GNN_LAYERS,
|
| 1205 |
max_atomic_z=MAX_ATOMIC_Z
|
| 1206 |
)
|
| 1207 |
+
gine_ckpt = os.path.join(BEST_GINE_DIR, "pytorch_model.bin")
|
| 1208 |
+
if os.path.exists(gine_ckpt):
|
| 1209 |
try:
|
| 1210 |
+
gine_encoder.load_state_dict(torch.load(gine_ckpt, map_location="cpu"), strict=False)
|
| 1211 |
+
print(f"[LOAD] GINE weights: {gine_ckpt}")
|
|
|
|
|
|
|
|
|
|
| 1212 |
except Exception as e:
|
| 1213 |
+
print(f"[LOAD][WARN] Could not load GINE weights: {e}")
|
| 1214 |
|
| 1215 |
+
# -------------------------
|
| 1216 |
+
# SchNet encoder
|
| 1217 |
+
# -------------------------
|
| 1218 |
schnet_encoder = NodeSchNetWrapper(
|
| 1219 |
hidden_channels=SCHNET_HIDDEN,
|
| 1220 |
num_interactions=SCHNET_NUM_INTERACTIONS,
|
|
|
|
| 1222 |
cutoff=SCHNET_CUTOFF,
|
| 1223 |
max_num_neighbors=SCHNET_MAX_NEIGHBORS
|
| 1224 |
)
|
| 1225 |
+
sch_ckpt = os.path.join(BEST_SCHNET_DIR, "pytorch_model.bin")
|
| 1226 |
+
if os.path.exists(sch_ckpt):
|
| 1227 |
try:
|
| 1228 |
+
schnet_encoder.load_state_dict(torch.load(sch_ckpt, map_location="cpu"), strict=False)
|
| 1229 |
+
print(f"[LOAD] SchNet weights: {sch_ckpt}")
|
|
|
|
|
|
|
|
|
|
| 1230 |
except Exception as e:
|
| 1231 |
+
print(f"[LOAD][WARN] Could not load SchNet weights: {e}")
|
| 1232 |
|
| 1233 |
+
# -------------------------
|
| 1234 |
+
# Fingerprint encoder
|
| 1235 |
+
# -------------------------
|
| 1236 |
fp_encoder = FingerprintEncoder(
|
| 1237 |
vocab_size=VOCAB_SIZE_FP,
|
| 1238 |
hidden_dim=256,
|
|
|
|
| 1242 |
dim_feedforward=1024,
|
| 1243 |
dropout=0.1
|
| 1244 |
)
|
| 1245 |
+
fp_ckpt = os.path.join(BEST_FP_DIR, "pytorch_model.bin")
|
| 1246 |
+
if os.path.exists(fp_ckpt):
|
| 1247 |
try:
|
| 1248 |
+
fp_encoder.load_state_dict(torch.load(fp_ckpt, map_location="cpu"), strict=False)
|
| 1249 |
+
print(f"[LOAD] FP encoder weights: {fp_ckpt}")
|
|
|
|
|
|
|
|
|
|
| 1250 |
except Exception as e:
|
| 1251 |
+
print(f"[LOAD][WARN] Could not load fingerprint weights: {e}")
|
| 1252 |
|
| 1253 |
+
# -------------------------
|
| 1254 |
+
# PSMILES encoder
|
| 1255 |
+
# -------------------------
|
| 1256 |
psmiles_encoder = None
|
| 1257 |
if os.path.isdir(BEST_PSMILES_DIR):
|
| 1258 |
try:
|
| 1259 |
psmiles_encoder = PSMILESDebertaEncoder(model_dir_or_name=BEST_PSMILES_DIR)
|
| 1260 |
+
print(f"[LOAD] PSMILES encoder: {BEST_PSMILES_DIR}")
|
| 1261 |
except Exception as e:
|
| 1262 |
+
print(f"[LOAD][WARN] Could not load PSMILES encoder from dir: {e}")
|
| 1263 |
|
| 1264 |
+
# Fallback: initialize with vocab fallback (still functional, but not your finetuned weights)
|
| 1265 |
if psmiles_encoder is None:
|
| 1266 |
try:
|
| 1267 |
psmiles_encoder = PSMILESDebertaEncoder(
|
| 1268 |
model_dir_or_name=None,
|
| 1269 |
vocab_fallback=int(getattr(tokenizer, "vocab_size", 300))
|
| 1270 |
)
|
| 1271 |
+
print("[LOAD] PSMILES encoder: initialized fallback (no pretrained dir).")
|
| 1272 |
except Exception as e:
|
| 1273 |
+
print(f"[LOAD][WARN] Could not initialize PSMILES encoder: {e}")
|
| 1274 |
|
| 1275 |
+
# Build multimodal wrapper
|
| 1276 |
multimodal_model = MultimodalContrastiveModel(
|
| 1277 |
gine_encoder,
|
| 1278 |
schnet_encoder,
|
| 1279 |
fp_encoder,
|
| 1280 |
psmiles_encoder,
|
| 1281 |
+
emb_dim=POLYF_EMB_DIM,
|
| 1282 |
modalities=["gine", "schnet", "fp", "psmiles"]
|
| 1283 |
)
|
| 1284 |
|
| 1285 |
+
# -------------------------
|
| 1286 |
+
# Optional: load full multimodal checkpoint (projects/fusion, etc.)
|
| 1287 |
+
# -------------------------
|
| 1288 |
ckpt_path = os.path.join(pretrained_path, "pytorch_model.bin")
|
| 1289 |
if os.path.isfile(ckpt_path):
|
| 1290 |
try:
|
|
|
|
| 1300 |
|
| 1301 |
summarize_state_dict_load(state, model_state, filtered_state)
|
| 1302 |
missing, unexpected = multimodal_model.load_state_dict(filtered_state, strict=False)
|
| 1303 |
+
print(f"[LOAD] Multimodal checkpoint: {ckpt_path}")
|
| 1304 |
+
print(f"[LOAD] load_state_dict -> missing={len(missing)} unexpected={len(unexpected)}")
|
| 1305 |
if missing:
|
| 1306 |
+
print("[LOAD] Missing keys (sample):", missing[:50])
|
| 1307 |
if unexpected:
|
| 1308 |
+
print("[LOAD] Unexpected keys (sample):", unexpected[:50])
|
| 1309 |
|
| 1310 |
except Exception as e:
|
| 1311 |
+
print(f"[LOAD][WARN] Failed to load multimodal pretrained weights: {e}")
|
| 1312 |
+
else:
|
| 1313 |
+
print(f"[LOAD] No multimodal checkpoint found at: {ckpt_path}")
|
| 1314 |
|
| 1315 |
return multimodal_model
|
| 1316 |
|
| 1317 |
|
| 1318 |
+
# =============================================================================
|
| 1319 |
+
# Downstream: sample construction + CV training loop
|
| 1320 |
+
# =============================================================================
|
| 1321 |
def build_samples_for_property(df: pd.DataFrame, prop_col: str) -> List[dict]:
|
| 1322 |
+
"""
|
| 1323 |
+
Construct training samples for a single property:
|
| 1324 |
+
- Keep rows that have at least one modality present.
|
| 1325 |
+
- Keep rows with a finite property value in `prop_col`.
|
| 1326 |
+
- Store raw target (will be scaled per fold).
|
| 1327 |
+
"""
|
| 1328 |
samples = []
|
| 1329 |
for _, row in df.iterrows():
|
| 1330 |
+
# Require at least one modality present
|
| 1331 |
has_modality = False
|
| 1332 |
for col in ['graph', 'geometry', 'fingerprints', 'psmiles']:
|
| 1333 |
if col in row and row[col] and str(row[col]).strip() != "":
|
|
|
|
| 1355 |
return samples
|
| 1356 |
|
| 1357 |
|
| 1358 |
+
def run_polyf_downstream(property_list: List[str], property_cols: List[str], df_raw: pd.DataFrame,
|
| 1359 |
pretrained_path: str, output_file: str):
|
| 1360 |
+
"""
|
| 1361 |
+
Uni-Poly-matched downstream evaluation:
|
| 1362 |
+
|
| 1363 |
+
For each property:
|
| 1364 |
+
- Build samples from PolyInfo
|
| 1365 |
+
- 5-fold CV:
|
| 1366 |
+
- Split into trainval/test (by KFold)
|
| 1367 |
+
- Split trainval into train/val
|
| 1368 |
+
- Fit StandardScaler on train targets
|
| 1369 |
+
- Fine-tune encoder+head end-to-end with early stopping by val loss
|
| 1370 |
+
- Evaluate on held-out test fold in original units
|
| 1371 |
+
- Save per-fold results and per-property mean±std
|
| 1372 |
+
- Save best fold checkpoint bundle (by test R2) for later reuse
|
| 1373 |
+
"""
|
| 1374 |
os.makedirs(pretrained_path, exist_ok=True)
|
| 1375 |
|
| 1376 |
+
# Optional duplicate aggregation (noise reduction)
|
| 1377 |
modality_cols = ["graph", "geometry", "fingerprints", "psmiles"]
|
| 1378 |
df_proc = aggregate_polyinfo_duplicates(df_raw, modality_cols=modality_cols, property_cols=property_cols)
|
| 1379 |
|
| 1380 |
+
all_results = {"per_property": {}, "mode": "POLYF_MATCHED_SINGLE_TASK"}
|
| 1381 |
|
|
|
|
| 1382 |
for pname, pcol in zip(property_list, property_cols):
|
| 1383 |
+
print(f"\n=== [DOWNSTREAM] Uni-Poly matched fine-tuning: {pname} (col='{pcol}') ===")
|
| 1384 |
samples = build_samples_for_property(df_proc, pcol)
|
| 1385 |
+
|
| 1386 |
+
print(f"[DATA] {pname}: n_samples={len(samples)}")
|
| 1387 |
if len(samples) < 200:
|
| 1388 |
+
print(f"[DATA][WARN] '{pname}' has <200 samples; results may be noisy.")
|
| 1389 |
if len(samples) < 50:
|
| 1390 |
+
print(f"[DATA][WARN] Skipping '{pname}' (insufficient samples).")
|
| 1391 |
continue
|
| 1392 |
|
| 1393 |
run_metrics = []
|
| 1394 |
run_records = []
|
| 1395 |
|
| 1396 |
+
# Track best-performing fold for this property (by test R2)
|
| 1397 |
best_overall_r2 = -1e18
|
| 1398 |
+
best_overall_payload = None
|
| 1399 |
|
| 1400 |
idxs = np.arange(len(samples))
|
| 1401 |
cv = KFold(n_splits=NUM_RUNS, shuffle=True, random_state=42)
|
|
|
|
| 1404 |
seed = 42 + run_idx
|
| 1405 |
set_seed(seed)
|
| 1406 |
|
| 1407 |
+
print(f"\n--- [CV] {pname}: fold {run_idx+1}/{NUM_RUNS} | seed={seed} ---")
|
| 1408 |
+
|
| 1409 |
trainval = [copy.deepcopy(samples[i]) for i in trainval_idx]
|
| 1410 |
test = [copy.deepcopy(samples[i]) for i in test_idx]
|
| 1411 |
|
| 1412 |
+
# Split trainval into train/val
|
| 1413 |
tr_idx, va_idx = train_test_split(
|
| 1414 |
np.arange(len(trainval)),
|
| 1415 |
test_size=VAL_SIZE_WITHIN_TRAINVAL,
|
|
|
|
| 1419 |
train = [copy.deepcopy(trainval[i]) for i in tr_idx]
|
| 1420 |
val = [copy.deepcopy(trainval[i]) for i in va_idx]
|
| 1421 |
|
| 1422 |
+
# Standardize target using training fold only (prevents leakage)
|
| 1423 |
sc = StandardScaler()
|
| 1424 |
sc.fit(np.array([s["target_raw"] for s in train]).reshape(-1, 1))
|
| 1425 |
|
|
|
|
| 1439 |
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
|
| 1440 |
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=multimodal_collate_fn)
|
| 1441 |
|
| 1442 |
+
print(f"[SPLIT] train={len(ds_train)} val={len(ds_val)} test={len(ds_test)}")
|
| 1443 |
+
|
| 1444 |
+
# Fresh base model per fold to avoid any cross-fold leakage
|
| 1445 |
+
polyf_base = load_pretrained_multimodal(pretrained_path)
|
| 1446 |
+
model = PolyFPropertyRegressor(polyf_base, emb_dim=POLYF_EMB_DIM, dropout=POLYF_DROPOUT).to(DEVICE)
|
| 1447 |
|
| 1448 |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
| 1449 |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
|
|
|
|
| 1452 |
best_state = None
|
| 1453 |
no_improve = 0
|
| 1454 |
|
| 1455 |
+
# Train with early stopping on validation loss
|
| 1456 |
for epoch in range(1, NUM_EPOCHS + 1):
|
| 1457 |
tr_loss = train_one_epoch(model, dl_train, optimizer, DEVICE)
|
| 1458 |
va_loss, _, _ = evaluate(model, dl_val, DEVICE)
|
|
|
|
| 1460 |
|
| 1461 |
scheduler.step()
|
| 1462 |
|
| 1463 |
+
print(f"[{pname}] fold {run_idx+1}/{NUM_RUNS} epoch {epoch:03d} | train={tr_loss:.6f} | val={va_loss:.6f}")
|
| 1464 |
|
| 1465 |
if va_loss < best_val - 1e-8:
|
| 1466 |
best_val = va_loss
|
|
|
|
| 1469 |
else:
|
| 1470 |
no_improve += 1
|
| 1471 |
if no_improve >= PATIENCE:
|
| 1472 |
+
print(f"[{pname}] fold {run_idx+1}: early stopping (patience={PATIENCE}) at epoch {epoch}.")
|
| 1473 |
break
|
| 1474 |
|
| 1475 |
if best_state is None:
|
| 1476 |
+
print(f"[{pname}][WARN] No best checkpoint captured for fold {run_idx+1}; skipping fold.")
|
| 1477 |
continue
|
| 1478 |
|
| 1479 |
+
# Restore best state and evaluate on test fold
|
| 1480 |
model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=True)
|
| 1481 |
_, pred_scaled, true_scaled = evaluate(model, dl_test, DEVICE)
|
| 1482 |
if pred_scaled is None:
|
| 1483 |
+
print(f"[{pname}][WARN] Test evaluation returned empty predictions for fold {run_idx+1}.")
|
| 1484 |
continue
|
| 1485 |
|
| 1486 |
+
# Convert from scaled space back to original units
|
| 1487 |
pred = sc.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()
|
| 1488 |
true = sc.inverse_transform(true_scaled.reshape(-1, 1)).ravel()
|
| 1489 |
|
| 1490 |
m = compute_metrics(true, pred)
|
| 1491 |
run_metrics.append(m)
|
| 1492 |
|
| 1493 |
+
print(f"[{pname}] fold {run_idx+1} TEST | r2={m['r2']:.4f} mae={m['mae']:.4f} rmse={m['rmse']:.4f}")
|
| 1494 |
+
|
| 1495 |
record = {
|
| 1496 |
"property": pname,
|
| 1497 |
"property_col": pcol,
|
|
|
|
| 1508 |
with open(output_file, "a") as fh:
|
| 1509 |
fh.write(json.dumps(make_json_serializable(record)) + "\n")
|
| 1510 |
|
| 1511 |
+
# Update best fold bundle (by test R2)
|
| 1512 |
if float(m.get("r2", -1e18)) > float(best_overall_r2):
|
| 1513 |
best_overall_r2 = float(m.get("r2", -1e18))
|
| 1514 |
best_overall_payload = {
|
|
|
|
| 1525 |
"scaler_scale": make_json_serializable(getattr(sc, "scale_", None)),
|
| 1526 |
"scaler_var": make_json_serializable(getattr(sc, "var_", None)),
|
| 1527 |
"scaler_n_samples_seen": make_json_serializable(getattr(sc, "n_samples_seen_", None)),
|
| 1528 |
+
"model_state_dict": best_state, # CPU tensors
|
| 1529 |
}
|
| 1530 |
|
| 1531 |
+
# Save best fold weights + metadata per property
|
| 1532 |
if best_overall_payload is not None and "model_state_dict" in best_overall_payload:
|
| 1533 |
os.makedirs(BEST_WEIGHTS_DIR, exist_ok=True)
|
| 1534 |
prop_dir = os.path.join(BEST_WEIGHTS_DIR, _sanitize_name(pname))
|
| 1535 |
os.makedirs(prop_dir, exist_ok=True)
|
| 1536 |
|
| 1537 |
+
ckpt_bundle = {k: v for k, v in best_overall_payload.items() if k != "test_metrics"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1538 |
ckpt_bundle["test_metrics"] = best_overall_payload["test_metrics"]
|
| 1539 |
|
| 1540 |
torch.save(ckpt_bundle, os.path.join(prop_dir, "best_run_checkpoint.pt"))
|
|
|
|
| 1543 |
with open(os.path.join(prop_dir, "best_run_metadata.json"), "w") as fh:
|
| 1544 |
fh.write(json.dumps(make_json_serializable(meta), indent=2))
|
| 1545 |
|
| 1546 |
+
print(f"[BEST] Saved best fold for '{pname}' -> {prop_dir}")
|
| 1547 |
+
print(f"[BEST] best_run={best_overall_payload['best_run']} best_test_r2={best_overall_payload['test_metrics'].get('r2', None)}")
|
| 1548 |
|
| 1549 |
+
# Aggregate metrics across folds
|
| 1550 |
if run_metrics:
|
| 1551 |
r2s = [x["r2"] for x in run_metrics]
|
| 1552 |
maes = [x["mae"] for x in run_metrics]
|
|
|
|
| 1558 |
"rmse": {"mean": float(np.mean(rmses)), "std": float(np.std(rmses, ddof=0))},
|
| 1559 |
"mse": {"mean": float(np.mean(mses)), "std": float(np.std(mses, ddof=0))},
|
| 1560 |
}
|
| 1561 |
+
print(f"[AGG] {pname} | r2={agg['r2']['mean']:.4f}±{agg['r2']['std']:.4f} mae={agg['mae']['mean']:.4f}±{agg['mae']['std']:.4f}")
|
| 1562 |
else:
|
| 1563 |
agg = None
|
| 1564 |
+
print(f"[AGG][WARN] No successful folds for '{pname}' (no aggregate computed).")
|
| 1565 |
|
| 1566 |
all_results["per_property"][pname] = {
|
| 1567 |
"property_col": pcol,
|
|
|
|
| 1576 |
return all_results
|
| 1577 |
|
| 1578 |
|
| 1579 |
+
# =============================================================================
|
| 1580 |
# Main
|
| 1581 |
+
# =============================================================================
|
| 1582 |
def main():
|
| 1583 |
+
# Start a fresh results file (back up old results if present)
|
| 1584 |
if os.path.exists(OUTPUT_RESULTS):
|
| 1585 |
backup = OUTPUT_RESULTS + ".bak"
|
| 1586 |
shutil.copy(OUTPUT_RESULTS, backup)
|
| 1587 |
+
print(f"[INIT] Existing results backed up: {backup}")
|
| 1588 |
open(OUTPUT_RESULTS, "w").close()
|
| 1589 |
+
print(f"[INIT] Writing results to: {OUTPUT_RESULTS}")
|
| 1590 |
|
| 1591 |
+
# Load PolyInfo
|
| 1592 |
if not os.path.isfile(POLYINFO_PATH):
|
| 1593 |
raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}")
|
| 1594 |
polyinfo_raw = pd.read_csv(POLYINFO_PATH, engine="python")
|
| 1595 |
+
print(f"[DATA] Loaded PolyInfo: n_rows={len(polyinfo_raw)} n_cols={len(polyinfo_raw.columns)}")
|
| 1596 |
|
| 1597 |
+
# Map requested properties to dataframe columns
|
| 1598 |
found = find_property_columns(polyinfo_raw.columns)
|
| 1599 |
prop_map = {req: col for req, col in found.items()}
|
| 1600 |
+
print(f"[COLMAP] Property-to-column map: {prop_map}")
|
| 1601 |
|
| 1602 |
property_list = []
|
| 1603 |
property_cols = []
|
| 1604 |
for req in REQUESTED_PROPERTIES:
|
| 1605 |
col = prop_map.get(req)
|
| 1606 |
if col is None:
|
| 1607 |
+
print(f"[COLMAP][WARN] Could not find a column for '{req}'; skipping.")
|
| 1608 |
continue
|
| 1609 |
property_list.append(req)
|
| 1610 |
property_cols.append(col)
|
| 1611 |
|
| 1612 |
+
overall = run_polyf_downstream(property_list, property_cols, polyinfo_raw, PRETRAINED_MULTIMODAL_DIR, OUTPUT_RESULTS)
|
|
|
|
| 1613 |
|
| 1614 |
+
# Write final summary (aggregated per property)
|
| 1615 |
final_agg = {}
|
| 1616 |
if overall and "per_property" in overall:
|
| 1617 |
for pname, info in overall["per_property"].items():
|
|
|
|
| 1622 |
fh.write(json.dumps(make_json_serializable(final_agg), indent=2))
|
| 1623 |
fh.write("\n")
|
| 1624 |
|
| 1625 |
+
print(f"\n Results appended to: {OUTPUT_RESULTS}")
|
| 1626 |
+
print(f" Best checkpoints saved under: {BEST_WEIGHTS_DIR}")
|
| 1627 |
|
| 1628 |
|
| 1629 |
if __name__ == "__main__":
|