Spaces:
Sleeping
Sleeping
File size: 4,252 Bytes
94b1553 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
from typing import Dict, List
import joblib
import numpy as np
from src.constants import TARGET_NAMES
from src.features import FingerprintFeaturizer
from src.seed import set_seed
BASE_PREDICTION = 0.5
@lru_cache(maxsize=1)
def _load_manifest() -> Dict:
manifest_path = Path("./checkpoints/training_manifest.json")
if not manifest_path.exists():
raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.")
with manifest_path.open("r", encoding="utf-8") as f:
manifest = json.load(f)
return manifest
@lru_cache(maxsize=2)
def _load_stage_models(stage: str):
manifest = _load_manifest()
stage_info = manifest.get(stage, {})
model_dir = stage_info.get("model_dir")
if not model_dir:
return {}
model_path = Path(model_dir)
models = {}
for target in manifest.get("target_names", TARGET_NAMES):
model_file = model_path / f"{target}.pkl"
if model_file.exists():
models[target] = joblib.load(model_file)
return models
def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray:
"""Return predictions for the valid molecules from stage-1 models."""
stage1_models = _load_stage_models("stage1")
if features.shape[0] == 0:
return np.zeros((0, len(target_names)), dtype=np.float32)
predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32)
for idx, target in enumerate(target_names):
booster = stage1_models.get(target)
if booster is None:
continue
best_iter = getattr(booster, "best_iteration_", None)
kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
preds = booster.predict_proba(features, **kwargs)[:, 1]
predictions[:, idx] = preds
return predictions
def _compute_stage2_predictions(
base_features: np.ndarray,
stage1_preds: np.ndarray,
target_names: List[str],
) -> np.ndarray:
stage2_models = _load_stage_models("stage2")
if not stage2_models:
return stage1_preds
n_samples = base_features.shape[0]
results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32)
for idx, target in enumerate(target_names):
model = stage2_models.get(target)
if model is None:
results[:, idx] = stage1_preds[:, idx]
continue
augmented = np.concatenate(
[
base_features,
np.delete(stage1_preds, idx, axis=1),
],
axis=1,
)
best_iter = getattr(model, "best_iteration_", None)
kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
preds = model.predict_proba(augmented, **kwargs)[:, 1]
results[:, idx] = preds
return results
def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]:
"""
Predict toxicity targets for a list of SMILES strings.
Args:
smiles_list (list[str]): SMILES strings
Returns:
dict: {smiles: {target_name: prediction_prob}}
"""
set_seed(0)
manifest = _load_manifest()
target_names = manifest.get("target_names", TARGET_NAMES)
feature_config = manifest.get("feature_config", {"type": "ecfp"})
featurizer = FingerprintFeaturizer(feature_config)
batch, features = featurizer.featurize_smiles(smiles_list)
stage1_preds = _compute_stage1_predictions(features, target_names)
stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names)
predictions: Dict[str, Dict[str, float]] = {}
valid_idx = 0
for original_smiles, is_valid in zip(smiles_list, batch.mask):
if not is_valid:
predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names}
continue
row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION)
predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)}
valid_idx += 1
return predictions
|