POLYMER-PROPERTY / src /predictor_router.py
sobinalosious92's picture
Upload 297 files
930ea3d verified
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, Optional, Tuple
from src.predictor import SingleTaskEnsemblePredictor
from src.predictor_multitask import MultiTaskEnsemblePredictor
class RouterPredictor:
"""
Routes each property to either:
- single-task ensemble (models/single_models)
- multitask ensemble (models/multitask_models/{task}_*)
based on models/best_model_map.json
"""
def __init__(
self,
map_path: str = "models/best_model_map.json",
single_dir: str = "models/single_models",
multitask_dir: str = "models/multitask_models",
device: str = "cpu",
):
self.map_path = Path(map_path)
self.map: Dict[str, dict] = json.load(open(self.map_path))
self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device)
self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device)
def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]:
prop = prop.lower()
cfg = self.map.get(prop, {"family": "single"})
fam = cfg.get("family", "single").lower()
if fam == "multitask":
task = str(cfg.get("task", "all")).lower()
mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task)
label = f"multitask:{task}"
return mean, std, per_seed, label
# default: single
mean, std, per_seed = self.single.predict_mean_std(smiles, prop)
label = "single"
return mean, std, per_seed, label