POLYMER-PROPERTY / src /predictor_multitask.py
sobinalosious92's picture
Upload 297 files
930ea3d verified
from __future__ import annotations
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch_geometric.data import Data
from src.data_builder import featurize_smiles, TargetScaler
from src.model import build_model
from src.utils import to_device, apply_inverse_transform
# -------------------------
# Unit correction (ML only)
# -------------------------
POST_SCALE = {
"td": 1e-7,
"dif": 1e-5,
"visc": 1e-3,
}
def _load_scaler_compat(path: Path) -> TargetScaler:
blob = torch.load(path, map_location="cpu")
if "mean" not in blob or "std" not in blob:
raise RuntimeError(f"Unrecognized target_scaler format: {path}")
ts = TargetScaler(
transforms=blob.get("transforms", None),
eps=blob.get("eps", None),
)
ts.load_state_dict({
"mean": blob["mean"].float(),
"std": blob["std"].float(),
"transforms": blob.get("transforms", ts.transforms),
"eps": blob.get("eps", ts.eps),
})
ts.targets = [str(t).lower() for t in blob.get("targets", [])]
return ts
def _infer_seed(path: Path) -> Optional[int]:
m = re.search(r"_([0-9]+)\.pt$", path.name)
return int(m.group(1)) if m else None
def _make_one_graph(smiles: str, T: int, fid_idx: int = 0) -> Data:
x, edge_index, edge_attr = featurize_smiles(smiles)
d = Data(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
y=torch.zeros(1, T),
y_mask=torch.zeros(1, T, dtype=torch.bool),
fid_idx=torch.tensor([fid_idx], dtype=torch.long),
)
d.smiles = smiles
return d
class MultiTaskEnsemblePredictor:
"""
Multi-task ensemble:
models/multitask_models/{task}_model_{seed}.pt
models/multitask_models/{task}_scalar_{seed}.pt
"""
def __init__(self, models_dir: str = "models/multitask_models", device: str = "cpu"):
self.models_dir = Path(models_dir)
self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {}
def available_seeds(self, task: str) -> List[int]:
task = task.strip().lower()
seeds = []
for p in self.models_dir.glob(f"{task}_model_*.pt"):
s = _infer_seed(p)
if s is not None:
seeds.append(s)
return sorted(set(seeds))
def _load_one_meta(self, task: str, seed: int):
task = task.strip().lower()
key = (task, seed)
if key in self._cache:
return self._cache[key]
ckpt_path = self.models_dir / f"{task}_model_{seed}.pt"
scaler_path = self.models_dir / f"{task}_scalar_{seed}.pt"
if not ckpt_path.exists() or not scaler_path.exists():
raise FileNotFoundError(f"Missing model/scaler for task={task} seed={seed}")
ckpt = torch.load(ckpt_path, map_location=self.device)
state_dict = ckpt["model"]
train_args = ckpt.get("args", {})
scaler = _load_scaler_compat(scaler_path)
task_names = list(getattr(scaler, "targets", []))
if not task_names:
raise RuntimeError(f"No targets found in scaler: {scaler_path}")
if "fid_embed.weight" in state_dict:
num_fids = state_dict["fid_embed.weight"].shape[0]
else:
num_fids = 1
meta = {
"train_args": train_args,
"task_names": task_names,
"num_fids": num_fids,
}
self._cache[key] = (None, scaler, meta)
return self._cache[key]
def _build_if_needed(self, task: str, seed: int, in_dim_node: int, in_dim_edge: int):
task = task.strip().lower()
key = (task, seed)
model, scaler, meta = self._cache[key]
if model is not None:
return model, scaler, meta
train_args = meta["train_args"]
task_names = meta["task_names"]
num_fids = meta["num_fids"]
model = build_model(
in_dim_node=in_dim_node,
in_dim_edge=in_dim_edge,
task_names=task_names,
num_fids=num_fids,
gnn_type=train_args.get("gnn_type", "gine"),
gnn_emb_dim=train_args.get("gnn_emb_dim", 256),
gnn_layers=train_args.get("gnn_layers", 5),
gnn_norm=train_args.get("gnn_norm", "batch"),
gnn_readout=train_args.get("gnn_readout", "mean"),
gnn_act=train_args.get("gnn_act", "relu"),
gnn_dropout=train_args.get("gnn_dropout", 0.0),
gnn_residual=train_args.get("gnn_residual", True),
fid_emb_dim=train_args.get("fid_emb_dim", 64),
use_film=train_args.get("use_film", True),
use_task_embed=train_args.get("use_task_embed", True),
task_emb_dim=train_args.get("task_emb_dim", 32),
head_hidden=train_args.get("head_hidden", 512),
head_depth=train_args.get("head_depth", 2),
head_act=train_args.get("head_act", "relu"),
head_dropout=train_args.get("head_dropout", 0.0),
heteroscedastic=train_args.get("heteroscedastic", False),
fid_emb_l2=0.0,
task_emb_l2=0.0,
use_task_uncertainty=train_args.get("task_uncertainty", False),
).to(self.device)
ckpt_path = self.models_dir / f"{task}_model_{seed}.pt"
ckpt = torch.load(ckpt_path, map_location=self.device)
model.load_state_dict(ckpt["model"], strict=True)
model.eval()
self._cache[key] = (model, scaler, meta)
return model, scaler, meta
def predict_mean_std(self, smiles: str, prop_key: str, task: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]:
task = task.strip().lower()
prop_key = prop_key.lower()
seeds = self.available_seeds(task)
if not seeds:
return None, None, {}
self._load_one_meta(task, seeds[0])
_, scaler0, meta0 = self._cache[(task, seeds[0])]
targets = list(meta0["task_names"]) # already lower()
if prop_key not in targets:
return None, None, {}
t_idx = targets.index(prop_key)
T = len(targets)
try:
g = _make_one_graph(smiles, T=T, fid_idx=0)
except Exception:
return None, None, {}
in_dim_node = g.x.shape[1]
in_dim_edge = g.edge_attr.shape[1]
per_seed: Dict[int, float] = {}
with torch.no_grad():
for seed in seeds:
self._load_one_meta(task, seed)
model, scaler, meta = self._build_if_needed(task, seed, in_dim_node, in_dim_edge)
batch = to_device(g, self.device)
out = model(batch)
pred_n = out["pred"] # [1, T]
pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1)
val = float(pred[t_idx])
# unit correction
val *= POST_SCALE.get(prop_key, 1.0)
per_seed[seed] = val
vals = np.array(list(per_seed.values()), dtype=float)
mean = float(vals.mean())
std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0
return mean, std, per_seed