Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Validation for the advanced pharmacology-aware DDI model.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from preprocessing.artifact_manager import manager | |
| import torch | |
| from sklearn.model_selection import train_test_split | |
| ROOT = Path(__file__).resolve().parents[2] | |
| if str(ROOT / 'src') not in sys.path: | |
| sys.path.insert(0, str(ROOT / 'src')) | |
| from training.advanced_feature_engineering import AdvancedBiomedicalFeatureEngineer, AdvancedFeatureConfig, load_metadata_map | |
| from training.advanced_ddi_model import AdvancedDDINet, AdvancedModelConfig | |
| from training.healthcare_safe_pipeline import compute_publication_metrics, optimize_severe_threshold, save_publication_outputs, set_deterministic_seed | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| MODELS_DIR = BASE_DIR / "models" | |
| REPORTS_DIR = MODELS_DIR / "reports" | |
| CHECKPOINT_PATH = MODELS_DIR / "advanced_ddi_safe.pt" | |
| REPORTS_DIR.mkdir(parents=True, exist_ok=True) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Validate the advanced DDI model") | |
| parser.add_argument("--seed", type=int, default=2026) | |
| parser.add_argument("--metadata-path", type=str, default="") | |
| parser.add_argument("--sample-limit", type=int, default=0) | |
| args = parser.parse_args() | |
| set_deterministic_seed(args.seed) | |
| if not CHECKPOINT_PATH.exists(): | |
| raise FileNotFoundError(f"Missing checkpoint: {CHECKPOINT_PATH}") | |
| metadata = load_metadata_map(args.metadata_path) if args.metadata_path else {} | |
| engineer = AdvancedBiomedicalFeatureEngineer(AdvancedFeatureConfig(), metadata=metadata) | |
| df = manager.load_artifact('ddinter_combined') | |
| if args.sample_limit and args.sample_limit > 0 and args.sample_limit < len(df): | |
| df = df.sample(n=args.sample_limit, random_state=args.seed).reset_index(drop=True) | |
| X = [] | |
| y = [] | |
| for _, row in df.iterrows(): | |
| feats = engineer.pair_features(row["Drug_A"], row["Drug_B"]) | |
| X.append(feats) | |
| y.append({"unknown": 0, "minor": 1, "moderate": 2, "major": 3}.get(str(row["Level"]).strip().lower(), 0)) | |
| idx = np.arange(len(X)) | |
| _, test_idx = train_test_split(idx, test_size=0.2, random_state=args.seed, stratify=np.array(y)) | |
| test_y = np.array([y[i] for i in test_idx], dtype=np.int64) | |
| test_feats = {key: np.vstack([X[i][key] for i in test_idx]).astype(np.float32) for key in X[0].keys()} | |
| payload = torch.load(CHECKPOINT_PATH, map_location="cpu") | |
| config = AdvancedModelConfig(**payload["config"]) | |
| model = AdvancedDDINet(config) | |
| model.load_state_dict(payload["model_state_dict"]) | |
| model.eval() | |
| with torch.no_grad(): | |
| logits, _ = model( | |
| torch.from_numpy(test_feats["fingerprint"]), | |
| torch.from_numpy(test_feats["semantic"]), | |
| torch.from_numpy(test_feats["pharmacology"]), | |
| torch.from_numpy(test_feats["pairwise"]), | |
| torch.from_numpy(test_feats["molecular_pair"]), | |
| ) | |
| probs = torch.softmax(logits, dim=1).numpy() | |
| threshold = float(payload.get("threshold", 0.5)) | |
| threshold_info = optimize_severe_threshold(probs, test_y, precision_floor=0.25) | |
| threshold = threshold_info["threshold"] | |
| metrics = compute_publication_metrics(test_y, probs, threshold) | |
| save_publication_outputs(metrics, REPORTS_DIR, prefix="advanced_ddi_validation") | |
| summary = { | |
| "checkpoint": str(CHECKPOINT_PATH), | |
| "threshold": threshold, | |
| "threshold_info": threshold_info, | |
| "metrics": metrics, | |
| } | |
| (REPORTS_DIR / "advanced_ddi_validation_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| logger.info("Validation complete. Severe recall: %.4f", metrics["per_class"]["major"]["recall"]) | |
| if __name__ == "__main__": | |
| main() | |