File size: 4,252 Bytes
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations

import json
from functools import lru_cache
from pathlib import Path
from typing import Dict, List

import joblib
import numpy as np

from src.constants import TARGET_NAMES
from src.features import FingerprintFeaturizer
from src.seed import set_seed

BASE_PREDICTION = 0.5


@lru_cache(maxsize=1)
def _load_manifest() -> Dict:
    manifest_path = Path("./checkpoints/training_manifest.json")
    if not manifest_path.exists():
        raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.")
    with manifest_path.open("r", encoding="utf-8") as f:
        manifest = json.load(f)
    return manifest


@lru_cache(maxsize=2)
def _load_stage_models(stage: str):
    manifest = _load_manifest()
    stage_info = manifest.get(stage, {})
    model_dir = stage_info.get("model_dir")
    if not model_dir:
        return {}
    model_path = Path(model_dir)
    models = {}
    for target in manifest.get("target_names", TARGET_NAMES):
        model_file = model_path / f"{target}.pkl"
        if model_file.exists():
            models[target] = joblib.load(model_file)
    return models


def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray:
    """Return predictions for the valid molecules from stage-1 models."""
    stage1_models = _load_stage_models("stage1")
    if features.shape[0] == 0:
        return np.zeros((0, len(target_names)), dtype=np.float32)

    predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32)
    for idx, target in enumerate(target_names):
        booster = stage1_models.get(target)
        if booster is None:
            continue
        best_iter = getattr(booster, "best_iteration_", None)
        kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
        preds = booster.predict_proba(features, **kwargs)[:, 1]
        predictions[:, idx] = preds
    return predictions


def _compute_stage2_predictions(
    base_features: np.ndarray,
    stage1_preds: np.ndarray,
    target_names: List[str],
) -> np.ndarray:
    stage2_models = _load_stage_models("stage2")
    if not stage2_models:
        return stage1_preds

    n_samples = base_features.shape[0]
    results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32)
    for idx, target in enumerate(target_names):
        model = stage2_models.get(target)
        if model is None:
            results[:, idx] = stage1_preds[:, idx]
            continue
        augmented = np.concatenate(
            [
                base_features,
                np.delete(stage1_preds, idx, axis=1),
            ],
            axis=1,
        )
        best_iter = getattr(model, "best_iteration_", None)
        kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
        preds = model.predict_proba(augmented, **kwargs)[:, 1]
        results[:, idx] = preds
    return results


def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]:
    """
    Predict toxicity targets for a list of SMILES strings.

    Args:
        smiles_list (list[str]): SMILES strings

    Returns:
        dict: {smiles: {target_name: prediction_prob}}
    """
    set_seed(0)
    manifest = _load_manifest()
    target_names = manifest.get("target_names", TARGET_NAMES)
    feature_config = manifest.get("feature_config", {"type": "ecfp"})

    featurizer = FingerprintFeaturizer(feature_config)
    batch, features = featurizer.featurize_smiles(smiles_list)

    stage1_preds = _compute_stage1_predictions(features, target_names)
    stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names)

    predictions: Dict[str, Dict[str, float]] = {}
    valid_idx = 0

    for original_smiles, is_valid in zip(smiles_list, batch.mask):
        if not is_valid:
            predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names}
            continue

        row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION)
        predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)}
        valid_idx += 1

    return predictions