from __future__ import annotations import json from functools import lru_cache from pathlib import Path from typing import Dict, List import joblib import numpy as np from src.constants import TARGET_NAMES from src.features import FingerprintFeaturizer from src.seed import set_seed BASE_PREDICTION = 0.5 @lru_cache(maxsize=1) def _load_manifest() -> Dict: manifest_path = Path("./checkpoints/training_manifest.json") if not manifest_path.exists(): raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.") with manifest_path.open("r", encoding="utf-8") as f: manifest = json.load(f) return manifest @lru_cache(maxsize=2) def _load_stage_models(stage: str): manifest = _load_manifest() stage_info = manifest.get(stage, {}) model_dir = stage_info.get("model_dir") if not model_dir: return {} model_path = Path(model_dir) models = {} for target in manifest.get("target_names", TARGET_NAMES): model_file = model_path / f"{target}.pkl" if model_file.exists(): models[target] = joblib.load(model_file) return models def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray: """Return predictions for the valid molecules from stage-1 models.""" stage1_models = _load_stage_models("stage1") if features.shape[0] == 0: return np.zeros((0, len(target_names)), dtype=np.float32) predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32) for idx, target in enumerate(target_names): booster = stage1_models.get(target) if booster is None: continue best_iter = getattr(booster, "best_iteration_", None) kwargs = {"num_iteration": best_iter} if best_iter is not None else {} preds = booster.predict_proba(features, **kwargs)[:, 1] predictions[:, idx] = preds return predictions def _compute_stage2_predictions( base_features: np.ndarray, stage1_preds: np.ndarray, target_names: List[str], ) -> np.ndarray: stage2_models = _load_stage_models("stage2") if not stage2_models: return stage1_preds n_samples = base_features.shape[0] results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32) for idx, target in enumerate(target_names): model = stage2_models.get(target) if model is None: results[:, idx] = stage1_preds[:, idx] continue augmented = np.concatenate( [ base_features, np.delete(stage1_preds, idx, axis=1), ], axis=1, ) best_iter = getattr(model, "best_iteration_", None) kwargs = {"num_iteration": best_iter} if best_iter is not None else {} preds = model.predict_proba(augmented, **kwargs)[:, 1] results[:, idx] = preds return results def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]: """ Predict toxicity targets for a list of SMILES strings. Args: smiles_list (list[str]): SMILES strings Returns: dict: {smiles: {target_name: prediction_prob}} """ set_seed(0) manifest = _load_manifest() target_names = manifest.get("target_names", TARGET_NAMES) feature_config = manifest.get("feature_config", {"type": "ecfp"}) featurizer = FingerprintFeaturizer(feature_config) batch, features = featurizer.featurize_smiles(smiles_list) stage1_preds = _compute_stage1_predictions(features, target_names) stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names) predictions: Dict[str, Dict[str, float]] = {} valid_idx = 0 for original_smiles, is_valid in zip(smiles_list, batch.mask): if not is_valid: predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names} continue row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION) predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)} valid_idx += 1 return predictions