from __future__ import annotations import json from pathlib import Path from typing import Dict, Optional, Tuple from src.predictor import SingleTaskEnsemblePredictor from src.predictor_multitask import MultiTaskEnsemblePredictor class RouterPredictor: """ Routes each property to either: - single-task ensemble (models/single_models) - multitask ensemble (models/multitask_models/{task}_*) based on models/best_model_map.json """ def __init__( self, map_path: str = "models/best_model_map.json", single_dir: str = "models/single_models", multitask_dir: str = "models/multitask_models", device: str = "cpu", ): self.map_path = Path(map_path) self.map: Dict[str, dict] = json.load(open(self.map_path)) self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device) self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device) def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]: prop = prop.lower() cfg = self.map.get(prop, {"family": "single"}) fam = cfg.get("family", "single").lower() if fam == "multitask": task = str(cfg.get("task", "all")).lower() mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task) label = f"multitask:{task}" return mean, std, per_seed, label # default: single mean, std, per_seed = self.single.predict_mean_std(smiles, prop) label = "single" return mean, std, per_seed, label