from __future__ import annotations import re from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch from torch_geometric.data import Data from src.data_builder import featurize_smiles, TargetScaler from src.model import build_model from src.utils import to_device, apply_inverse_transform # ------------------------- # Unit correction (ML only) # ------------------------- POST_SCALE = { "td": 1e-7, "dif": 1e-5, "visc": 1e-3, } def _load_scaler_compat(path: Path) -> TargetScaler: blob = torch.load(path, map_location="cpu") if "mean" not in blob or "std" not in blob: raise RuntimeError(f"Unrecognized target_scaler format: {path}") ts = TargetScaler( transforms=blob.get("transforms", None), eps=blob.get("eps", None), ) ts.load_state_dict({ "mean": blob["mean"].float(), "std": blob["std"].float(), "transforms": blob.get("transforms", ts.transforms), "eps": blob.get("eps", ts.eps), }) ts.targets = [str(t).lower() for t in blob.get("targets", [])] return ts def _infer_seed_from_name(path: Path) -> Optional[int]: m = re.search(r"_([0-9]+)\.pt$", path.name) return int(m.group(1)) if m else None def _make_one_graph(smiles: str) -> Data: x, edge_index, edge_attr = featurize_smiles(smiles) d = Data( x=x, edge_index=edge_index, edge_attr=edge_attr, y=torch.zeros(1, 1), y_mask=torch.zeros(1, 1, dtype=torch.bool), fid_idx=torch.tensor([0], dtype=torch.long), ) d.smiles = smiles return d class SingleTaskEnsemblePredictor: """ Single-task ensemble: models/single_models/{prop}_single_model_{seed}.pt models/single_models/{prop}_single_scalar_{seed}.pt """ def __init__(self, models_dir: str = "models/single_models", device: str = "cpu"): self.models_dir = Path(models_dir) self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu") self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {} def available_seeds(self, prop: str) -> List[int]: prop = prop.lower() seeds = [] for p in self.models_dir.glob(f"{prop}_single_model_*.pt"): s = _infer_seed_from_name(p) if s is not None: seeds.append(s) return sorted(set(seeds)) def _load_one(self, prop: str, seed: int): prop = prop.lower() key = (prop, seed) if key in self._cache: return self._cache[key] ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt" scaler_path = self.models_dir / f"{prop}_single_scalar_{seed}.pt" if not ckpt_path.exists() or not scaler_path.exists(): raise FileNotFoundError(f"Missing model/scaler for {prop} seed {seed}") ckpt = torch.load(ckpt_path, map_location=self.device) train_args = ckpt.get("args", {}) scaler = _load_scaler_compat(scaler_path) task_names = list(getattr(scaler, "targets", [])) or [prop] meta = {"train_args": train_args, "task_names": task_names} self._cache[key] = (None, scaler, meta) return self._cache[key] def _build_model_if_needed(self, prop: str, seed: int, in_dim_node: int, in_dim_edge: int): prop = prop.lower() key = (prop, seed) model, scaler, meta = self._cache[key] if model is not None: return model, scaler, meta train_args = meta["train_args"] task_names = meta["task_names"] ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt" ckpt = torch.load(ckpt_path, map_location=self.device) state_dict = ckpt["model"] # infer num_fids from checkpoint if "fid_embed.weight" in state_dict: num_fids = state_dict["fid_embed.weight"].shape[0] else: num_fids = 1 model = build_model( in_dim_node=in_dim_node, in_dim_edge=in_dim_edge, task_names=task_names, num_fids=num_fids, gnn_type=train_args.get("gnn_type", "gine"), gnn_emb_dim=train_args.get("gnn_emb_dim", 256), gnn_layers=train_args.get("gnn_layers", 5), gnn_norm=train_args.get("gnn_norm", "batch"), gnn_readout=train_args.get("gnn_readout", "mean"), gnn_act=train_args.get("gnn_act", "relu"), gnn_dropout=train_args.get("gnn_dropout", 0.0), gnn_residual=train_args.get("gnn_residual", True), fid_emb_dim=train_args.get("fid_emb_dim", 64), use_film=train_args.get("use_film", True), use_task_embed=train_args.get("use_task_embed", True), task_emb_dim=train_args.get("task_emb_dim", 32), head_hidden=train_args.get("head_hidden", 512), head_depth=train_args.get("head_depth", 2), head_act=train_args.get("head_act", "relu"), head_dropout=train_args.get("head_dropout", 0.0), heteroscedastic=train_args.get("heteroscedastic", False), fid_emb_l2=0.0, task_emb_l2=0.0, use_task_uncertainty=train_args.get("task_uncertainty", False), ).to(self.device) model.load_state_dict(state_dict, strict=True) model.eval() self._cache[key] = (model, scaler, meta) return model, scaler, meta def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]: prop = prop.lower() seeds = self.available_seeds(prop) if not seeds: return None, None, {} try: g = _make_one_graph(smiles) except Exception: return None, None, {} in_dim_node = g.x.shape[1] in_dim_edge = g.edge_attr.shape[1] per_seed: Dict[int, float] = {} with torch.no_grad(): for seed in seeds: self._load_one(prop, seed) model, scaler, meta = self._build_model_if_needed(prop, seed, in_dim_node, in_dim_edge) batch = to_device(g, self.device) out = model(batch) pred_n = out["pred"] # [1, 1] pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1) val = float(pred[0]) # unit correction val *= POST_SCALE.get(prop, 1.0) per_seed[seed] = val vals = np.array(list(per_seed.values()), dtype=float) mean = float(vals.mean()) std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0 return mean, std, per_seed