Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| import math | |
| import re | |
| from collections import Counter, defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| from training.feature_pipeline import transform_pair_features as transform_pair_features_legacy | |
| from training.feature_pipeline_multisource import transform_pair_features as transform_pair_features_multisource | |
| from training.ensemble import EnsemblePredictor | |
| from training.calibration import SafetyCalibrationConfig, risk_adjusted_prediction | |
| from chemistry.smiles_recovery import resolve_drug_name_to_smiles | |
| try: | |
| from chemistry.drug_synonyms import load_drugbank_name_to_smiles_index, lookup_drugbank_smiles, normalize_drug_name | |
| except Exception: # pragma: no cover - optional DrugBank index | |
| load_drugbank_name_to_smiles_index = None # type: ignore | |
| lookup_drugbank_smiles = None # type: ignore | |
| normalize_drug_name = None # type: ignore | |
| logger = logging.getLogger('medcare_ddi.predictor') | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| DATA_PATH = BASE_DIR / 'data' / 'processed' / 'ddinter_combined.parquet' | |
| MODEL_DIR = BASE_DIR / 'models' | |
| MODEL_PATH = MODEL_DIR / 'ddi_mlp_best.pt' | |
| PRODUCTION_MODEL_PATH = MODEL_DIR / 'ddi_mlp_production.pt' # Healthcare-grade production model | |
| FEATURE_PIPELINE_PATH = MODEL_DIR / 'feature_pipeline.pkl' | |
| FEATURE_PIPELINE_MULTISOURCE_PATH = MODEL_DIR / 'feature_pipeline_multisource.pkl' | |
| CALIBRATION_PATH = MODEL_DIR / 'calibration_artifacts.pkl' | |
| PRODUCTION_CALIBRATION_PATH = MODEL_DIR / 'calibration_artifacts_production.pkl' # Production calibration | |
| ENSEMBLE_DIR = MODEL_DIR / 'ensemble' | |
| DRUGBANK_TO_DDINTER = { | |
| 'DB01048': 'Abacavir', | |
| 'DB01097': 'Leflunomide', | |
| 'DB00331': 'Metformin', | |
| 'DB00682': 'Warfarin', | |
| 'DB00945': 'Acetylsalicylic acid', | |
| 'DB01050': 'Ibuprofen', | |
| 'DB00338': 'Omeprazole', | |
| 'DB01076': 'Atorvastatin', | |
| 'DB00722': 'Lisinopril', | |
| 'DB00381': 'Amlodipine', | |
| 'DB00758': 'Clopidogrel', | |
| 'DB00641': 'Simvastatin', | |
| 'DB00537': 'Ciprofloxacin', | |
| 'DB00196': 'Fluconazole', | |
| 'DB01045': 'Rifampicin', | |
| } | |
| LABEL_NAMES = ['unknown', 'minor', 'moderate', 'major'] | |
| LABEL_TO_INDEX = {label: index for index, label in enumerate(LABEL_NAMES)} | |
| INDEX_TO_LABEL = {index: label for label, index in LABEL_TO_INDEX.items()} | |
| SEVERITY_ADVICE = { | |
| 'major': 'Avoid the combination when possible. If there is no alternative, use specialist oversight and close monitoring.', | |
| 'moderate': 'Use with caution. Consider monitoring, dose adjustment, and review of safer alternatives.', | |
| 'minor': 'The combination is generally acceptable with routine clinical monitoring.', | |
| 'unknown': 'The local DDInter table does not provide a clear severity signal for this pair.', | |
| } | |
| def normalize_name(value: str) -> str: | |
| return ' '.join(value.strip().lower().split()) | |
| def resolve_input_name(value: str) -> str: | |
| if not value: | |
| raise ValueError('Drug input is required') | |
| cleaned_value = value.strip() | |
| if re.fullmatch(r'DB\d{5}', cleaned_value, flags=re.IGNORECASE): | |
| return DRUGBANK_TO_DDINTER.get(cleaned_value.upper(), cleaned_value.upper()) | |
| return cleaned_value | |
| def _load_frontend_alias_map() -> dict[str, str]: | |
| """Load a lightweight alias map from the frontend `src/lib/drugAliases.js`. | |
| This provides a fallback mapping for common marketing names and ATC codes | |
| defined in the React UI when the DrugBank index is unavailable or missing | |
| those aliases. | |
| """ | |
| alias_map: dict[str, str] = {} | |
| try: | |
| js_path = BASE_DIR / 'src' / 'lib' / 'drugAliases.js' | |
| if not js_path.exists(): | |
| return alias_map | |
| text = js_path.read_text(encoding='utf-8') | |
| # Find objects inside DRUG_DATABASE array | |
| entries = re.findall(r"\{([^}]+)\}", text[text.find('DRUG_DATABASE'):]) | |
| for entry in entries: | |
| try: | |
| name_m = re.search(r"name\s*:\s*[\"']([^\"']+)[\"']", entry) | |
| atc_m = re.search(r"atc\s*:\s*[\"']([^\"']+)[\"']", entry) | |
| markets_m = re.search(r"marketingNames\s*:\s*\[([^\]]+)\]", entry, flags=re.S) | |
| if not name_m: | |
| continue | |
| canonical = name_m.group(1).strip() | |
| if atc_m: | |
| key = normalize_name(atc_m.group(1)) | |
| alias_map[key] = canonical | |
| if markets_m: | |
| markets = re.findall(r"[\"']([^\"']+)[\"']", markets_m.group(1)) | |
| for mkt in markets: | |
| alias_map[normalize_name(mkt)] = canonical | |
| # also map ingredient name | |
| alias_map[normalize_name(canonical)] = canonical | |
| except Exception: | |
| continue | |
| except Exception: | |
| return {} | |
| return alias_map | |
| def canonical_pair_key(drug_a: str, drug_b: str) -> tuple[str, str]: | |
| return tuple(sorted((normalize_name(drug_a), normalize_name(drug_b)))) | |
| def severity_rank(value: str) -> int: | |
| return {'unknown': 0, 'minor': 1, 'moderate': 2, 'major': 3}.get(value.lower(), 0) | |
| def lookup_confidence(severity: str, support_count: int) -> float: | |
| base = { | |
| 'major': 0.94, | |
| 'moderate': 0.84, | |
| 'minor': 0.72, | |
| 'unknown': 0.58, | |
| }.get(severity, 0.60) | |
| support_bonus = min(0.06, round(math.log2(max(support_count, 1)) * 0.01, 3)) | |
| return round(min(0.99, base + support_bonus), 3) | |
| def confidence_band(confidence: float) -> str: | |
| if confidence < 0.55: | |
| return 'low' | |
| if confidence < 0.75: | |
| return 'medium' | |
| return 'high' | |
| def load_ddinter_lookup() -> tuple[dict[tuple[str, str], list[dict[str, str]]], dict[str, Counter]]: | |
| from preprocessing.artifact_manager import manager | |
| try: | |
| df = manager.load_artifact('ddinter_combined') | |
| except Exception as e: | |
| raise FileNotFoundError(f'Failed to load ddinter_combined artifact: {e}') | |
| pair_index: dict[tuple[str, str], list[dict[str, str]]] = defaultdict(list) | |
| drug_profiles: dict[str, Counter] = defaultdict(Counter) | |
| for _, row in df.iterrows(): | |
| # Ensure row represents dictionary properly (handle pandas types) | |
| row_dict = {k: str(v) for k, v in row.to_dict().items() if v is not None and not pd.isna(v)} | |
| try: | |
| a_can = str(row_dict.get('canonical_drug_a') or row_dict.get('Drug_A', '')).strip().upper() | |
| b_can = str(row_dict.get('canonical_drug_b') or row_dict.get('Drug_B', '')).strip().upper() | |
| except: | |
| a_can = str(row_dict.get('Drug_A', '')).strip().upper() | |
| b_can = str(row_dict.get('Drug_B', '')).strip().upper() | |
| if not a_can or not b_can: | |
| continue | |
| pair_key = tuple(sorted([a_can, b_can])) | |
| pair_index[pair_key].append(row_dict) | |
| try: | |
| level = str(row_dict.get('Level') or row_dict.get('level', 'Unknown')).strip() | |
| except: | |
| level = 'Unknown' | |
| if level in ('Major', 'Moderate', 'Minor'): | |
| drug_profiles[a_can][level] += 1 | |
| drug_profiles[b_can][level] += 1 | |
| return pair_index, drug_profiles | |
| class DDIEmbeddingMLP(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| embedding_dim: int, | |
| hidden_dim: int, | |
| num_classes: int, | |
| dropout: float = 0.2, | |
| ) -> None: | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) | |
| self.network = nn.Sequential( | |
| nn.Linear(embedding_dim * 4, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim // 2, num_classes), | |
| ) | |
| def forward(self, drug_a_ids: torch.Tensor, drug_b_ids: torch.Tensor) -> torch.Tensor: | |
| embedding_a = self.embedding(drug_a_ids) | |
| embedding_b = self.embedding(drug_b_ids) | |
| features = torch.cat( | |
| [embedding_a, embedding_b, torch.abs(embedding_a - embedding_b), embedding_a * embedding_b], | |
| dim=-1, | |
| ) | |
| return self.network(features) | |
| class FeatureMLP(nn.Module): | |
| def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float = 0.2) -> None: | |
| super().__init__() | |
| self.network = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, max(8, hidden_dim // 2)), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(max(8, hidden_dim // 2), num_classes), | |
| ) | |
| def forward(self, features: torch.Tensor) -> torch.Tensor: | |
| return self.network(features) | |
| class InferenceResult: | |
| source: str | |
| confidence: float | |
| severity: str | |
| explanation: str | |
| warning: str | None | |
| model_version: str | |
| mechanism: str | |
| affected_systems: str | |
| smiles_a: str | |
| smiles_b: str | |
| clinical_advice: str | |
| additional_notes: str | |
| drug_a_name: str | |
| drug_b_name: str | |
| drug_a_id: str | |
| drug_b_id: str | |
| source_dataset: str | |
| exact_match: bool | |
| evidence_count: int | None = None | |
| confidence_band: str | None = None | |
| probabilities: dict[str, float] | None = None | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| 'source': self.source, | |
| 'confidence': self.confidence, | |
| 'severity': self.severity, | |
| 'explanation': self.explanation, | |
| 'warning': self.warning, | |
| 'model_version': self.model_version, | |
| 'mechanism': self.mechanism, | |
| 'affected_systems': self.affected_systems, | |
| 'smiles_a': self.smiles_a, | |
| 'smiles_b': self.smiles_b, | |
| 'clinical_advice': self.clinical_advice, | |
| 'additional_notes': self.additional_notes, | |
| 'drug_a_name': self.drug_a_name, | |
| 'drug_b_name': self.drug_b_name, | |
| 'drug_a_id': self.drug_a_id, | |
| 'drug_b_id': self.drug_b_id, | |
| 'source_dataset': self.source_dataset, | |
| 'exact_match': self.exact_match, | |
| 'evidence_count': self.evidence_count, | |
| 'confidence_band': self.confidence_band, | |
| 'probabilities': self.probabilities, | |
| } | |
| class HybridDDIPredictor: | |
| def __init__(self, checkpoint: dict[str, Any], calibration: dict[str, Any] | None = None) -> None: | |
| self.checkpoint = checkpoint | |
| self.model_version = str(checkpoint.get('model_version', 'ddi-mlp-v1')) | |
| self.embedding_dim = int(checkpoint.get('embedding_dim', 64)) | |
| self.hidden_dim = int(checkpoint.get('hidden_dim', 128)) | |
| self.model_type = str(checkpoint.get('model_type', 'ddi_embedding_mlp')) | |
| self.input_dim = checkpoint.get('input_dim') | |
| self.temperature = float(checkpoint.get('temperature', 1.0)) | |
| self.vocab = checkpoint.get('drug_vocab', {}) | |
| self.label_names = list(checkpoint.get('label_names', LABEL_NAMES)) | |
| self.label_to_index = dict(checkpoint.get('label_to_index', LABEL_TO_INDEX)) | |
| self.index_to_label = {int(index): label for index, label in checkpoint.get('index_to_label', {}).items()} | |
| if not self.index_to_label: | |
| self.index_to_label = {index: label for label, index in self.label_to_index.items()} | |
| # Load calibration artifacts (temperature scaling, threshold tuning) | |
| self.calibration = calibration or {} | |
| if self.calibration: | |
| self.temperature = float(self.calibration.get('temperature', self.temperature)) | |
| self.major_threshold = float(self.calibration.get('major_threshold', 0.5)) | |
| logger.info(f'Loaded calibration: temperature={self.temperature:.4f}, major_threshold={self.major_threshold:.4f}') | |
| else: | |
| self.major_threshold = 0.5 | |
| self.safety_calibration_cfg = SafetyCalibrationConfig( | |
| severe_class_index=self.label_to_index.get('major', 3), | |
| low_confidence_threshold=float(self.calibration.get('low_confidence_threshold', 0.55)), | |
| medium_confidence_threshold=float(self.calibration.get('medium_confidence_threshold', 0.75)), | |
| severe_alert_threshold=float(self.calibration.get('severe_alert_threshold', 0.45)), | |
| entropy_high_threshold=float(self.calibration.get('entropy_high_threshold', 1.0)), | |
| top2_margin_low_threshold=float(self.calibration.get('top2_margin_low_threshold', 0.15)), | |
| ) | |
| self.pair_index, self.drug_profiles = load_ddinter_lookup() | |
| self.ensemble: EnsemblePredictor | None = None | |
| self.feature_pipeline = None | |
| if self.model_type in {'feature_mlp', 'feature_mlp_multisource'}: | |
| feature_pipeline_path = FEATURE_PIPELINE_MULTISOURCE_PATH if self.model_type == 'feature_mlp_multisource' else FEATURE_PIPELINE_PATH | |
| if not feature_pipeline_path.exists(): | |
| raise FileNotFoundError(f'Feature pipeline artifact not found at {feature_pipeline_path}') | |
| self.feature_pipeline = joblib.load(feature_pipeline_path) | |
| artifact_dim = int(self.feature_pipeline.get('metadata', {}).get('vector_dim', self.feature_pipeline.get('metadata', {}).get('feature_dim', 0))) | |
| if self.input_dim is not None and int(self.input_dim) != artifact_dim: | |
| raise ValueError( | |
| f'Checkpoint input_dim={self.input_dim} does not match feature pipeline dim={artifact_dim}. ' | |
| 'Refusing to infer with a mismatched schema.' | |
| ) | |
| if self.input_dim is None: | |
| self.input_dim = artifact_dim | |
| self.model = FeatureMLP( | |
| input_dim=int(self.input_dim or self.feature_pipeline['metadata']['vector_dim']), | |
| hidden_dim=self.hidden_dim, | |
| num_classes=len(self.label_names), | |
| dropout=float(checkpoint.get('dropout', 0.2)), | |
| ) | |
| else: | |
| self.model = DDIEmbeddingMLP( | |
| vocab_size=len(self.vocab) + 1, | |
| embedding_dim=self.embedding_dim, | |
| hidden_dim=self.hidden_dim, | |
| num_classes=len(self.label_names), | |
| ) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| # Optional ensemble backend for feature-based models. | |
| if self.model_type in {'feature_mlp', 'feature_mlp_multisource'} and ENSEMBLE_DIR.exists(): | |
| try: | |
| self.ensemble = EnsemblePredictor(ENSEMBLE_DIR) | |
| logger.info('Loaded optional ensemble artifacts from %s', ENSEMBLE_DIR) | |
| except Exception as exc: | |
| logger.warning('Failed to load ensemble artifacts: %s', exc) | |
| def from_default_paths(cls, use_production: bool = True) -> 'HybridDDIPredictor': | |
| """Load predictor from default model paths. | |
| Args: | |
| use_production: If True, try loading production model first; fall back to standard model | |
| Returns: | |
| HybridDDIPredictor instance | |
| """ | |
| # Prefer production model if it exists | |
| model_path = PRODUCTION_MODEL_PATH if (use_production and PRODUCTION_MODEL_PATH.exists()) else MODEL_PATH | |
| if not model_path.exists(): | |
| raise FileNotFoundError( | |
| f'Model checkpoint not found at {model_path}. ' | |
| f'Run src/training/train_healthcare_production.py or src/inference/train_model.py to create it.' | |
| ) | |
| logger.info(f'Loading checkpoint from {model_path}') | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| # Try to load calibration artifacts (optional) | |
| calibration = None | |
| calibration_path = PRODUCTION_CALIBRATION_PATH if (use_production and PRODUCTION_CALIBRATION_PATH.exists()) else CALIBRATION_PATH | |
| if calibration_path.exists(): | |
| logger.info(f'Loading calibration artifacts from {calibration_path}') | |
| try: | |
| calibration = joblib.load(calibration_path) | |
| except Exception as e: | |
| logger.warning(f'Failed to load calibration artifacts: {e}') | |
| return cls(checkpoint, calibration=calibration) | |
| def health(self) -> dict[str, Any]: | |
| return { | |
| 'status': 'healthy', | |
| 'mode': 'hybrid', | |
| 'model_loaded': True, | |
| 'pipeline_loaded': True, | |
| 'calibration_loaded': bool(self.calibration), | |
| 'ensemble_loaded': self.ensemble is not None, | |
| 'model_version': self.model_version, | |
| 'model_type': self.model_type, | |
| 'pairs_loaded': len(self.pair_index), | |
| 'records_loaded': sum(len(records) for records in self.pair_index.values()), | |
| 'vocab_size': len(self.vocab), | |
| 'label_names': self.label_names, | |
| 'temperature': self.temperature, | |
| 'major_threshold': self.major_threshold, | |
| 'feature_schema': { | |
| 'input_dim': self.input_dim, | |
| 'temperature': self.temperature, | |
| 'feature_pipeline_path': str(FEATURE_PIPELINE_MULTISOURCE_PATH if self.model_type == 'feature_mlp_multisource' else FEATURE_PIPELINE_PATH) if self.model_type in {'feature_mlp', 'feature_mlp_multisource'} else None, | |
| 'available': self.feature_pipeline is not None, | |
| 'group_slices': self.feature_pipeline.get('group_slices', {}) if self.feature_pipeline else {}, | |
| }, | |
| } | |
| def _resolve_drug_name(self, drug_value: str) -> str: | |
| """Resolve input which may be a DrugBank ID, active ingredient, marketing name, or ATC code. | |
| Strategy: | |
| 1. Basic cleaning and DB ID mapping via `resolve_input_name`. | |
| 2. If available, consult DrugBank alias index (`load_drugbank_name_to_smiles_index`) for | |
| direct alias -> canonical name, ATC code matches, and canonical lookup. | |
| 3. Use `lookup_drugbank_smiles` as a fuzzy fallback when available. | |
| 4. Fall back to the frontend alias map defined in `src/lib/drugAliases.js`. | |
| 5. Return the cleaned value if no mapping found. | |
| """ | |
| # 1) Basic resolution (handles DBxxxxx -> mapped name) | |
| try: | |
| resolved = resolve_input_name(drug_value) | |
| except Exception: | |
| resolved = str(drug_value or '').strip() | |
| # 2) Consult DrugBank index if available | |
| try: | |
| index = load_drugbank_name_to_smiles_index() if load_drugbank_name_to_smiles_index else None | |
| except Exception: | |
| index = None | |
| if index: | |
| # Precomputed alias map: alias -> canonical name | |
| name_to_record = index.get('name_to_record_key', {}) or {} | |
| canonical_map = index.get('canonical_name_to_record', {}) or {} | |
| # Normalize input for direct alias lookup | |
| norm_key = None | |
| try: | |
| norm_key = normalize_drug_name(drug_value) if normalize_drug_name else normalize_name(str(drug_value or '')) | |
| except Exception: | |
| norm_key = normalize_name(str(drug_value or '')) | |
| if norm_key in name_to_record: | |
| return name_to_record.get(norm_key) or resolved | |
| # ATC code pattern (e.g., C09AA05) | |
| up = str(drug_value or '').strip().upper() | |
| if re.fullmatch(r'[A-Z]\d{2}[A-Z]{2}\d{2}', up): | |
| for record in canonical_map.values(): | |
| atcs = record.get('atc_codes') or [] | |
| for code in atcs: | |
| if code and up == str(code).upper(): | |
| return record.get('canonical_name') or resolved | |
| # Try expanded variants (brands, synonyms) | |
| try: | |
| variants = drug_name_variants(drug_value) if 'drug_name_variants' in globals() or 'drug_name_variants' in locals() else [] | |
| except Exception: | |
| variants = [] | |
| for var in variants: | |
| if var in name_to_record: | |
| return name_to_record.get(var) or resolved | |
| # 3) Fuzzy lookup via `lookup_drugbank_smiles` if available | |
| if lookup_drugbank_smiles: | |
| try: | |
| lookup = lookup_drugbank_smiles(drug_value) | |
| if lookup and lookup.get('matched') and lookup.get('matched_name'): | |
| return lookup.get('matched_name') | |
| except Exception: | |
| pass | |
| # 4) Frontend alias map fallback | |
| try: | |
| frontend_aliases = _load_frontend_alias_map() | |
| key = normalize_name(str(drug_value or '')) | |
| if key in frontend_aliases: | |
| return frontend_aliases[key] | |
| except Exception: | |
| pass | |
| # 5) Final fallback | |
| return resolved | |
| def _find_vocab_id(self, drug_name: str) -> int: | |
| if self.model_type == 'feature_mlp': | |
| return 0 | |
| normalized_name = normalize_name(drug_name) | |
| return int(self.vocab.get(normalized_name, 0)) | |
| def _resolve_smiles(self, drug_name: str) -> str: | |
| try: | |
| resolved = resolve_drug_name_to_smiles(drug_name) | |
| except Exception as exc: | |
| logger.warning('smiles_resolution_failed drug=%s error=%s', drug_name, exc) | |
| resolved = None | |
| return resolved or 'N/A' | |
| def _normalize_clinical_severity(self, probabilities: dict[str, float], fallback_severity: str) -> tuple[str, float]: | |
| clinical_labels = ('minor', 'moderate', 'major') | |
| clinical_probs = {label: float(probabilities.get(label, 0.0) or 0.0) for label in clinical_labels} | |
| total = sum(clinical_probs.values()) | |
| if total > 0: | |
| normalized_probs = {label: value / total for label, value in clinical_probs.items()} | |
| chosen_label = max(normalized_probs, key=normalized_probs.get) | |
| return chosen_label, round(float(normalized_probs[chosen_label]), 3) | |
| fallback = fallback_severity if fallback_severity in clinical_labels else 'minor' | |
| return fallback, round(float(clinical_probs.get(fallback, 0.0)), 3) | |
| def _lookup_exact(self, drug_a_name: str, drug_b_name: str, drug_a_id: str, drug_b_id: str) -> InferenceResult | None: | |
| key = canonical_pair_key(drug_a_name, drug_b_name) | |
| records = self.pair_index.get(key, []) | |
| if not records: | |
| return None | |
| top_record = max(records, key=lambda record: severity_rank(record['severity'])) | |
| top_severity = top_record['severity'] | |
| confidence = lookup_confidence(top_severity, len(records)) | |
| warning = None | |
| if confidence < 0.75: | |
| warning = 'Exact DDInter evidence exists, but the support is limited and confidence is below the high-confidence threshold.' | |
| explanation = ( | |
| f'Exact DDInter evidence match found for {drug_a_name} and {drug_b_name}. ' | |
| f'{len(records)} supporting record(s) were retrieved and the highest observed severity is {top_severity}.' | |
| ) | |
| return InferenceResult( | |
| source='ddinter_lookup', | |
| confidence=confidence, | |
| severity=top_severity, | |
| explanation=explanation, | |
| warning=warning, | |
| model_version='ddinter-evidence-v1', | |
| mechanism='Evidence-backed lookup from the processed DDInter table', | |
| affected_systems='Not available in the local DDInter table', | |
| smiles_a=self._resolve_smiles(drug_a_name), | |
| smiles_b=self._resolve_smiles(drug_b_name), | |
| clinical_advice=SEVERITY_ADVICE[top_severity], | |
| additional_notes=explanation if warning is None else f'{explanation} {warning}', | |
| drug_a_name=drug_a_name, | |
| drug_b_name=drug_b_name, | |
| drug_a_id=drug_a_id, | |
| drug_b_id=drug_b_id, | |
| source_dataset='ddinter_combined.parquet', | |
| exact_match=True, | |
| evidence_count=len(records), | |
| confidence_band=confidence_band(confidence), | |
| probabilities=None, | |
| ) | |
| def _predict_with_model(self, drug_a_name: str, drug_b_name: str, drug_a_id: str, drug_b_id: str) -> InferenceResult: | |
| if self.model_type == 'feature_mlp': | |
| if self.feature_pipeline is None: | |
| raise RuntimeError('Feature pipeline artifacts are not loaded') | |
| features = transform_pair_features_legacy(drug_a_name, drug_b_name, self.feature_pipeline) | |
| input_features = torch.tensor([features], dtype=torch.float32) | |
| with torch.no_grad(): | |
| logits = self.model(input_features) / max(self.temperature, 1e-6) | |
| probabilities = torch.softmax(logits, dim=-1).squeeze(0) | |
| elif self.model_type == 'feature_mlp_multisource': | |
| if self.feature_pipeline is None: | |
| raise RuntimeError('Feature pipeline artifacts are not loaded') | |
| features = transform_pair_features_multisource(drug_a_name, drug_b_name, self.feature_pipeline) | |
| input_features = torch.tensor([features], dtype=torch.float32) | |
| with torch.no_grad(): | |
| logits = self.model(input_features) / max(self.temperature, 1e-6) | |
| probabilities = torch.softmax(logits, dim=-1).squeeze(0) | |
| # Ensemble override when available. | |
| if self.ensemble is not None: | |
| try: | |
| probs = self.ensemble.predict_proba(np.array([features], dtype=np.float32))[0] | |
| probabilities = torch.tensor(probs, dtype=torch.float32) | |
| except Exception as exc: | |
| logger.warning('Ensemble inference failed; falling back to base model: %s', exc) | |
| else: | |
| drug_a_vocab_id = self._find_vocab_id(drug_a_name) | |
| drug_b_vocab_id = self._find_vocab_id(drug_b_name) | |
| input_a = torch.tensor([drug_a_vocab_id], dtype=torch.long) | |
| input_b = torch.tensor([drug_b_vocab_id], dtype=torch.long) | |
| with torch.no_grad(): | |
| logits = self.model(input_a, input_b) | |
| probabilities = torch.softmax(logits, dim=-1).squeeze(0) | |
| probability_values = probabilities.tolist() | |
| probability_array = np.array(probability_values, dtype=np.float32) | |
| safety = risk_adjusted_prediction(probability_array, cfg=self.safety_calibration_cfg) | |
| best_index = int(safety['pred_index']) | |
| predicted_severity = self.index_to_label.get(best_index, self.label_names[best_index]) | |
| confidence = round(float(safety['max_probability']), 3) | |
| confidence_band_value = str(safety['confidence_band']) | |
| probabilities_map = { | |
| self.index_to_label.get(index, self.label_names[index]): round(float(value), 3) | |
| for index, value in enumerate(probability_values) | |
| } | |
| clinical_severity, clinical_confidence = self._normalize_clinical_severity(probabilities_map, predicted_severity) | |
| if clinical_severity != predicted_severity: | |
| predicted_severity = clinical_severity | |
| confidence = clinical_confidence | |
| warning = None | |
| if confidence_band_value == 'low': | |
| warning = 'Low-confidence deep learning prediction. Please review this result carefully and consider expert validation.' | |
| elif confidence_band_value == 'medium': | |
| warning = 'Moderate-confidence prediction. Treat this result as advisory rather than definitive.' | |
| if bool(safety.get('escalated_for_safety', False)): | |
| escalated_warning = 'Prediction was conservatively escalated to major due to uncertainty and elevated severe-class probability.' | |
| warning = escalated_warning if warning is None else f'{warning} {escalated_warning}' | |
| explanation = ( | |
| f'Deep learning model predicted {predicted_severity} interaction severity with {confidence:.3f} confidence. ' | |
| f'The prediction used learned embeddings for {drug_a_name} and {drug_b_name} from the trained MEDCARE-DDI MLP. ' | |
| f'Uncertainty={safety.get("uncertainty", "normal")}, top2_margin={float(safety.get("top2_margin", 0.0)):.3f}, ' | |
| f'severe_probability={float(safety.get("severe_probability", 0.0)):.3f}.' | |
| ) | |
| return InferenceResult( | |
| source='deep_learning_prediction', | |
| confidence=confidence, | |
| severity=predicted_severity, | |
| explanation=explanation, | |
| warning=warning, | |
| model_version=self.model_version, | |
| mechanism='Deep learning inference from the trained MEDCARE-DDI PyTorch model', | |
| affected_systems='Learned from DDInter interaction patterns', | |
| smiles_a=self._resolve_smiles(drug_a_name), | |
| smiles_b=self._resolve_smiles(drug_b_name), | |
| clinical_advice=SEVERITY_ADVICE.get(predicted_severity, SEVERITY_ADVICE['unknown']), | |
| additional_notes=explanation if warning is None else f'{explanation} {warning}', | |
| drug_a_name=drug_a_name, | |
| drug_b_name=drug_b_name, | |
| drug_a_id=drug_a_id, | |
| drug_b_id=drug_b_id, | |
| source_dataset='ddi_mlp_best.pt', | |
| exact_match=False, | |
| evidence_count=None, | |
| confidence_band=confidence_band_value, | |
| probabilities=probabilities_map, | |
| ) | |
| def predict(self, drug_a_value: str, drug_b_value: str) -> dict[str, Any]: | |
| drug_a_name = self._resolve_drug_name(drug_a_value) | |
| drug_b_name = self._resolve_drug_name(drug_b_value) | |
| drug_a_id = drug_a_value.strip().upper() | |
| drug_b_id = drug_b_value.strip().upper() | |
| if not drug_a_name or not drug_b_name: | |
| raise ValueError('Both drugs must be provided') | |
| exact_result = self._lookup_exact(drug_a_name, drug_b_name, drug_a_id, drug_b_id) | |
| if exact_result is not None: | |
| # When DDInter has an exact pair but only "unknown" severity, use the model | |
| # to provide a more actionable estimate while preserving evidence context. | |
| if exact_result.severity == 'unknown': | |
| model_result = self._predict_with_model(drug_a_name, drug_b_name, drug_a_id, drug_b_id) | |
| evidence_note = ( | |
| f'Exact DDInter match exists ({exact_result.evidence_count or 0} record(s)) ' | |
| 'but reported severity is unknown; using ML fallback for a more informative estimate.' | |
| ) | |
| model_result.warning = ( | |
| f'{evidence_note} {model_result.warning}' if model_result.warning else evidence_note | |
| ) | |
| model_result.additional_notes = ( | |
| f'{model_result.additional_notes} {evidence_note}'.strip() | |
| ) | |
| logger.info( | |
| 'lookup_unknown_fallback drug_a=%s drug_b=%s evidence_count=%s model_severity=%s model_confidence=%.3f', | |
| drug_a_name, | |
| drug_b_name, | |
| exact_result.evidence_count, | |
| model_result.severity, | |
| model_result.confidence, | |
| ) | |
| return model_result.to_dict() | |
| logger.info( | |
| 'lookup_hit drug_a=%s drug_b=%s severity=%s confidence=%.3f', | |
| drug_a_name, | |
| drug_b_name, | |
| exact_result.severity, | |
| exact_result.confidence, | |
| ) | |
| return exact_result.to_dict() | |
| model_result = self._predict_with_model(drug_a_name, drug_b_name, drug_a_id, drug_b_id) | |
| logger.info( | |
| 'model_fallback drug_a=%s drug_b=%s severity=%s confidence=%.3f confidence_band=%s', | |
| drug_a_name, | |
| drug_b_name, | |
| model_result.severity, | |
| model_result.confidence, | |
| model_result.confidence_band, | |
| ) | |
| return model_result.to_dict() | |