#!/usr/bin/env python3 """Validation for the advanced pharmacology-aware DDI model.""" from __future__ import annotations import argparse import json import logging import sys from pathlib import Path import numpy as np import pandas as pd from preprocessing.artifact_manager import manager import torch from sklearn.model_selection import train_test_split ROOT = Path(__file__).resolve().parents[2] if str(ROOT / 'src') not in sys.path: sys.path.insert(0, str(ROOT / 'src')) from training.advanced_feature_engineering import AdvancedBiomedicalFeatureEngineer, AdvancedFeatureConfig, load_metadata_map from training.advanced_ddi_model import AdvancedDDINet, AdvancedModelConfig from training.healthcare_safe_pipeline import compute_publication_metrics, optimize_severe_threshold, save_publication_outputs, set_deterministic_seed logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) BASE_DIR = Path(__file__).resolve().parents[2] MODELS_DIR = BASE_DIR / "models" REPORTS_DIR = MODELS_DIR / "reports" CHECKPOINT_PATH = MODELS_DIR / "advanced_ddi_safe.pt" REPORTS_DIR.mkdir(parents=True, exist_ok=True) def main() -> None: parser = argparse.ArgumentParser(description="Validate the advanced DDI model") parser.add_argument("--seed", type=int, default=2026) parser.add_argument("--metadata-path", type=str, default="") parser.add_argument("--sample-limit", type=int, default=0) args = parser.parse_args() set_deterministic_seed(args.seed) if not CHECKPOINT_PATH.exists(): raise FileNotFoundError(f"Missing checkpoint: {CHECKPOINT_PATH}") metadata = load_metadata_map(args.metadata_path) if args.metadata_path else {} engineer = AdvancedBiomedicalFeatureEngineer(AdvancedFeatureConfig(), metadata=metadata) df = manager.load_artifact('ddinter_combined') if args.sample_limit and args.sample_limit > 0 and args.sample_limit < len(df): df = df.sample(n=args.sample_limit, random_state=args.seed).reset_index(drop=True) X = [] y = [] for _, row in df.iterrows(): feats = engineer.pair_features(row["Drug_A"], row["Drug_B"]) X.append(feats) y.append({"unknown": 0, "minor": 1, "moderate": 2, "major": 3}.get(str(row["Level"]).strip().lower(), 0)) idx = np.arange(len(X)) _, test_idx = train_test_split(idx, test_size=0.2, random_state=args.seed, stratify=np.array(y)) test_y = np.array([y[i] for i in test_idx], dtype=np.int64) test_feats = {key: np.vstack([X[i][key] for i in test_idx]).astype(np.float32) for key in X[0].keys()} payload = torch.load(CHECKPOINT_PATH, map_location="cpu") config = AdvancedModelConfig(**payload["config"]) model = AdvancedDDINet(config) model.load_state_dict(payload["model_state_dict"]) model.eval() with torch.no_grad(): logits, _ = model( torch.from_numpy(test_feats["fingerprint"]), torch.from_numpy(test_feats["semantic"]), torch.from_numpy(test_feats["pharmacology"]), torch.from_numpy(test_feats["pairwise"]), torch.from_numpy(test_feats["molecular_pair"]), ) probs = torch.softmax(logits, dim=1).numpy() threshold = float(payload.get("threshold", 0.5)) threshold_info = optimize_severe_threshold(probs, test_y, precision_floor=0.25) threshold = threshold_info["threshold"] metrics = compute_publication_metrics(test_y, probs, threshold) save_publication_outputs(metrics, REPORTS_DIR, prefix="advanced_ddi_validation") summary = { "checkpoint": str(CHECKPOINT_PATH), "threshold": threshold, "threshold_info": threshold_info, "metrics": metrics, } (REPORTS_DIR / "advanced_ddi_validation_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") logger.info("Validation complete. Severe recall: %.4f", metrics["per_class"]["major"]["recall"]) if __name__ == "__main__": main()