| | from __future__ import annotations |
| |
|
| | import re |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from torch_geometric.data import Data |
| |
|
| | from src.data_builder import featurize_smiles, TargetScaler |
| | from src.model import build_model |
| | from src.utils import to_device, apply_inverse_transform |
| |
|
| |
|
| | |
| | |
| | |
| | POST_SCALE = { |
| | "td": 1e-7, |
| | "dif": 1e-5, |
| | "visc": 1e-3, |
| | } |
| |
|
| |
|
| | def _load_scaler_compat(path: Path) -> TargetScaler: |
| | blob = torch.load(path, map_location="cpu") |
| | if "mean" not in blob or "std" not in blob: |
| | raise RuntimeError(f"Unrecognized target_scaler format: {path}") |
| |
|
| | ts = TargetScaler( |
| | transforms=blob.get("transforms", None), |
| | eps=blob.get("eps", None), |
| | ) |
| | ts.load_state_dict({ |
| | "mean": blob["mean"].float(), |
| | "std": blob["std"].float(), |
| | "transforms": blob.get("transforms", ts.transforms), |
| | "eps": blob.get("eps", ts.eps), |
| | }) |
| | ts.targets = [str(t).lower() for t in blob.get("targets", [])] |
| | return ts |
| |
|
| |
|
| | def _infer_seed(path: Path) -> Optional[int]: |
| | m = re.search(r"_([0-9]+)\.pt$", path.name) |
| | return int(m.group(1)) if m else None |
| |
|
| |
|
| | def _make_one_graph(smiles: str, T: int, fid_idx: int = 0) -> Data: |
| | x, edge_index, edge_attr = featurize_smiles(smiles) |
| | d = Data( |
| | x=x, |
| | edge_index=edge_index, |
| | edge_attr=edge_attr, |
| | y=torch.zeros(1, T), |
| | y_mask=torch.zeros(1, T, dtype=torch.bool), |
| | fid_idx=torch.tensor([fid_idx], dtype=torch.long), |
| | ) |
| | d.smiles = smiles |
| | return d |
| |
|
| |
|
| | class MultiTaskEnsemblePredictor: |
| | """ |
| | Multi-task ensemble: |
| | models/multitask_models/{task}_model_{seed}.pt |
| | models/multitask_models/{task}_scalar_{seed}.pt |
| | """ |
| |
|
| | def __init__(self, models_dir: str = "models/multitask_models", device: str = "cpu"): |
| | self.models_dir = Path(models_dir) |
| | self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu") |
| | self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {} |
| |
|
| | def available_seeds(self, task: str) -> List[int]: |
| | task = task.strip().lower() |
| | seeds = [] |
| | for p in self.models_dir.glob(f"{task}_model_*.pt"): |
| | s = _infer_seed(p) |
| | if s is not None: |
| | seeds.append(s) |
| | return sorted(set(seeds)) |
| |
|
| | def _load_one_meta(self, task: str, seed: int): |
| | task = task.strip().lower() |
| | key = (task, seed) |
| | if key in self._cache: |
| | return self._cache[key] |
| |
|
| | ckpt_path = self.models_dir / f"{task}_model_{seed}.pt" |
| | scaler_path = self.models_dir / f"{task}_scalar_{seed}.pt" |
| | if not ckpt_path.exists() or not scaler_path.exists(): |
| | raise FileNotFoundError(f"Missing model/scaler for task={task} seed={seed}") |
| |
|
| | ckpt = torch.load(ckpt_path, map_location=self.device) |
| | state_dict = ckpt["model"] |
| | train_args = ckpt.get("args", {}) |
| |
|
| | scaler = _load_scaler_compat(scaler_path) |
| | task_names = list(getattr(scaler, "targets", [])) |
| | if not task_names: |
| | raise RuntimeError(f"No targets found in scaler: {scaler_path}") |
| |
|
| | if "fid_embed.weight" in state_dict: |
| | num_fids = state_dict["fid_embed.weight"].shape[0] |
| | else: |
| | num_fids = 1 |
| |
|
| | meta = { |
| | "train_args": train_args, |
| | "task_names": task_names, |
| | "num_fids": num_fids, |
| | } |
| | self._cache[key] = (None, scaler, meta) |
| | return self._cache[key] |
| |
|
| | def _build_if_needed(self, task: str, seed: int, in_dim_node: int, in_dim_edge: int): |
| | task = task.strip().lower() |
| | key = (task, seed) |
| | model, scaler, meta = self._cache[key] |
| | if model is not None: |
| | return model, scaler, meta |
| |
|
| | train_args = meta["train_args"] |
| | task_names = meta["task_names"] |
| | num_fids = meta["num_fids"] |
| |
|
| | model = build_model( |
| | in_dim_node=in_dim_node, |
| | in_dim_edge=in_dim_edge, |
| | task_names=task_names, |
| | num_fids=num_fids, |
| | gnn_type=train_args.get("gnn_type", "gine"), |
| | gnn_emb_dim=train_args.get("gnn_emb_dim", 256), |
| | gnn_layers=train_args.get("gnn_layers", 5), |
| | gnn_norm=train_args.get("gnn_norm", "batch"), |
| | gnn_readout=train_args.get("gnn_readout", "mean"), |
| | gnn_act=train_args.get("gnn_act", "relu"), |
| | gnn_dropout=train_args.get("gnn_dropout", 0.0), |
| | gnn_residual=train_args.get("gnn_residual", True), |
| | fid_emb_dim=train_args.get("fid_emb_dim", 64), |
| | use_film=train_args.get("use_film", True), |
| | use_task_embed=train_args.get("use_task_embed", True), |
| | task_emb_dim=train_args.get("task_emb_dim", 32), |
| | head_hidden=train_args.get("head_hidden", 512), |
| | head_depth=train_args.get("head_depth", 2), |
| | head_act=train_args.get("head_act", "relu"), |
| | head_dropout=train_args.get("head_dropout", 0.0), |
| | heteroscedastic=train_args.get("heteroscedastic", False), |
| | fid_emb_l2=0.0, |
| | task_emb_l2=0.0, |
| | use_task_uncertainty=train_args.get("task_uncertainty", False), |
| | ).to(self.device) |
| |
|
| | ckpt_path = self.models_dir / f"{task}_model_{seed}.pt" |
| | ckpt = torch.load(ckpt_path, map_location=self.device) |
| | model.load_state_dict(ckpt["model"], strict=True) |
| | model.eval() |
| |
|
| | self._cache[key] = (model, scaler, meta) |
| | return model, scaler, meta |
| |
|
| | def predict_mean_std(self, smiles: str, prop_key: str, task: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]: |
| | task = task.strip().lower() |
| | prop_key = prop_key.lower() |
| |
|
| | seeds = self.available_seeds(task) |
| | if not seeds: |
| | return None, None, {} |
| |
|
| | self._load_one_meta(task, seeds[0]) |
| | _, scaler0, meta0 = self._cache[(task, seeds[0])] |
| | targets = list(meta0["task_names"]) |
| | if prop_key not in targets: |
| | return None, None, {} |
| |
|
| | t_idx = targets.index(prop_key) |
| | T = len(targets) |
| |
|
| | try: |
| | g = _make_one_graph(smiles, T=T, fid_idx=0) |
| | except Exception: |
| | return None, None, {} |
| |
|
| | in_dim_node = g.x.shape[1] |
| | in_dim_edge = g.edge_attr.shape[1] |
| |
|
| | per_seed: Dict[int, float] = {} |
| | with torch.no_grad(): |
| | for seed in seeds: |
| | self._load_one_meta(task, seed) |
| | model, scaler, meta = self._build_if_needed(task, seed, in_dim_node, in_dim_edge) |
| |
|
| | batch = to_device(g, self.device) |
| | out = model(batch) |
| | pred_n = out["pred"] |
| | pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1) |
| | val = float(pred[t_idx]) |
| |
|
| | |
| | val *= POST_SCALE.get(prop_key, 1.0) |
| |
|
| | per_seed[seed] = val |
| |
|
| | vals = np.array(list(per_seed.values()), dtype=float) |
| | mean = float(vals.mean()) |
| | std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0 |
| | return mean, std, per_seed |
| |
|