File size: 1,655 Bytes
930ea3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Optional, Tuple

from src.predictor import SingleTaskEnsemblePredictor
from src.predictor_multitask import MultiTaskEnsemblePredictor


class RouterPredictor:
    """
    Routes each property to either:
      - single-task ensemble (models/single_models)
      - multitask ensemble (models/multitask_models/{task}_*)
    based on models/best_model_map.json
    """

    def __init__(
        self,
        map_path: str = "models/best_model_map.json",
        single_dir: str = "models/single_models",
        multitask_dir: str = "models/multitask_models",
        device: str = "cpu",
    ):
        self.map_path = Path(map_path)
        self.map: Dict[str, dict] = json.load(open(self.map_path))
        self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device)
        self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device)

    def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]:
        prop = prop.lower()
        cfg = self.map.get(prop, {"family": "single"})

        fam = cfg.get("family", "single").lower()
        if fam == "multitask":
            task = str(cfg.get("task", "all")).lower()
            mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task)
            label = f"multitask:{task}"
            return mean, std, per_seed, label

        # default: single
        mean, std, per_seed = self.single.predict_mean_std(smiles, prop)
        label = "single"
        return mean, std, per_seed, label