from __future__ import annotations import re from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch from torch_geometric.data import Data from src.data_builder import featurize_smiles, TargetScaler from src.model import build_model from src.utils import to_device, apply_inverse_transform # ------------------------- # Unit correction (ML only) # ------------------------- POST_SCALE = { "td": 1e-7, "dif": 1e-5, "visc": 1e-3, } def _load_scaler_compat(path: Path) -> TargetScaler: blob = torch.load(path, map_location="cpu") if "mean" not in blob or "std" not in blob: raise RuntimeError(f"Unrecognized target_scaler format: {path}") ts = TargetScaler( transforms=blob.get("transforms", None), eps=blob.get("eps", None), ) ts.load_state_dict({ "mean": blob["mean"].float(), "std": blob["std"].float(), "transforms": blob.get("transforms", ts.transforms), "eps": blob.get("eps", ts.eps), }) ts.targets = [str(t).lower() for t in blob.get("targets", [])] return ts def _infer_seed(path: Path) -> Optional[int]: m = re.search(r"_([0-9]+)\.pt$", path.name) return int(m.group(1)) if m else None def _make_one_graph(smiles: str, T: int, fid_idx: int = 0) -> Data: x, edge_index, edge_attr = featurize_smiles(smiles) d = Data( x=x, edge_index=edge_index, edge_attr=edge_attr, y=torch.zeros(1, T), y_mask=torch.zeros(1, T, dtype=torch.bool), fid_idx=torch.tensor([fid_idx], dtype=torch.long), ) d.smiles = smiles return d class MultiTaskEnsemblePredictor: """ Multi-task ensemble: models/multitask_models/{task}_model_{seed}.pt models/multitask_models/{task}_scalar_{seed}.pt """ def __init__(self, models_dir: str = "models/multitask_models", device: str = "cpu"): self.models_dir = Path(models_dir) self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu") self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {} def available_seeds(self, task: str) -> List[int]: task = task.strip().lower() seeds = [] for p in self.models_dir.glob(f"{task}_model_*.pt"): s = _infer_seed(p) if s is not None: seeds.append(s) return sorted(set(seeds)) def _load_one_meta(self, task: str, seed: int): task = task.strip().lower() key = (task, seed) if key in self._cache: return self._cache[key] ckpt_path = self.models_dir / f"{task}_model_{seed}.pt" scaler_path = self.models_dir / f"{task}_scalar_{seed}.pt" if not ckpt_path.exists() or not scaler_path.exists(): raise FileNotFoundError(f"Missing model/scaler for task={task} seed={seed}") ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt["model"] train_args = ckpt.get("args", {}) scaler = _load_scaler_compat(scaler_path) task_names = list(getattr(scaler, "targets", [])) if not task_names: raise RuntimeError(f"No targets found in scaler: {scaler_path}") if "fid_embed.weight" in state_dict: num_fids = state_dict["fid_embed.weight"].shape[0] else: num_fids = 1 meta = { "train_args": train_args, "task_names": task_names, "num_fids": num_fids, } self._cache[key] = (None, scaler, meta) return self._cache[key] def _build_if_needed(self, task: str, seed: int, in_dim_node: int, in_dim_edge: int): task = task.strip().lower() key = (task, seed) model, scaler, meta = self._cache[key] if model is not None: return model, scaler, meta train_args = meta["train_args"] task_names = meta["task_names"] num_fids = meta["num_fids"] model = build_model( in_dim_node=in_dim_node, in_dim_edge=in_dim_edge, task_names=task_names, num_fids=num_fids, gnn_type=train_args.get("gnn_type", "gine"), gnn_emb_dim=train_args.get("gnn_emb_dim", 256), gnn_layers=train_args.get("gnn_layers", 5), gnn_norm=train_args.get("gnn_norm", "batch"), gnn_readout=train_args.get("gnn_readout", "mean"), gnn_act=train_args.get("gnn_act", "relu"), gnn_dropout=train_args.get("gnn_dropout", 0.0), gnn_residual=train_args.get("gnn_residual", True), fid_emb_dim=train_args.get("fid_emb_dim", 64), use_film=train_args.get("use_film", True), use_task_embed=train_args.get("use_task_embed", True), task_emb_dim=train_args.get("task_emb_dim", 32), head_hidden=train_args.get("head_hidden", 512), head_depth=train_args.get("head_depth", 2), head_act=train_args.get("head_act", "relu"), head_dropout=train_args.get("head_dropout", 0.0), heteroscedastic=train_args.get("heteroscedastic", False), fid_emb_l2=0.0, task_emb_l2=0.0, use_task_uncertainty=train_args.get("task_uncertainty", False), ).to(self.device) ckpt_path = self.models_dir / f"{task}_model_{seed}.pt" ckpt = torch.load(ckpt_path, map_location=self.device) model.load_state_dict(ckpt["model"], strict=True) model.eval() self._cache[key] = (model, scaler, meta) return model, scaler, meta def predict_mean_std(self, smiles: str, prop_key: str, task: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]: task = task.strip().lower() prop_key = prop_key.lower() seeds = self.available_seeds(task) if not seeds: return None, None, {} self._load_one_meta(task, seeds[0]) _, scaler0, meta0 = self._cache[(task, seeds[0])] targets = list(meta0["task_names"]) # already lower() if prop_key not in targets: return None, None, {} t_idx = targets.index(prop_key) T = len(targets) try: g = _make_one_graph(smiles, T=T, fid_idx=0) except Exception: return None, None, {} in_dim_node = g.x.shape[1] in_dim_edge = g.edge_attr.shape[1] per_seed: Dict[int, float] = {} with torch.no_grad(): for seed in seeds: self._load_one_meta(task, seed) model, scaler, meta = self._build_if_needed(task, seed, in_dim_node, in_dim_edge) batch = to_device(g, self.device) out = model(batch) pred_n = out["pred"] # [1, T] pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1) val = float(pred[t_idx]) # unit correction val *= POST_SCALE.get(prop_key, 1.0) per_seed[seed] = val vals = np.array(list(per_seed.values()), dtype=float) mean = float(vals.mean()) std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0 return mean, std, per_seed