Spaces:
Running
Running
| """Evaluate the locally generated checkpoint against the processed DDInter dataset. | |
| Produces: | |
| - MEDCARE-DDI-AI/models/eval/metrics_summary.json | |
| - MEDCARE-DDI-AI/models/eval/confusion_matrix.png | |
| - MEDCARE-DDI-AI/models/eval/inference_validation_report.md | |
| The script re-uses preprocessing logic from predictor.py to ensure consistency. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Any | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from preprocessing.artifact_manager import manager | |
| import torch | |
| from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, recall_score | |
| from predictor import ( | |
| BASE_DIR, | |
| DATA_PATH, | |
| MODEL_PATH, | |
| LABEL_NAMES, | |
| LABEL_TO_INDEX, | |
| INDEX_TO_LABEL, | |
| normalize_name, | |
| canonical_pair_key, | |
| HybridDDIPredictor, | |
| ) | |
| OUT_DIR = MODEL_PATH.parent / 'eval' | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| def build_pairs_from_csv(df: pd.DataFrame) -> pd.DataFrame: | |
| # For evaluation: collapse multiple evidence rows into one canonical pair | |
| pairs = {} | |
| for _, row in df.iterrows(): | |
| a = str(row['Drug_A']).strip() | |
| b = str(row['Drug_B']).strip() | |
| level = str(row['Level']).strip().lower() | |
| key = canonical_pair_key(a, b) | |
| if key not in pairs: | |
| pairs[key] = Counter() | |
| pairs[key][level] += 1 | |
| records = [] | |
| for (a, b), counter in pairs.items(): | |
| # majority label | |
| label, _ = counter.most_common(1)[0] | |
| records.append({'drug_a': a, 'drug_b': b, 'label': label, 'support': sum(counter.values())}) | |
| return pd.DataFrame(records) | |
| def evaluate(predictor: HybridDDIPredictor, eval_df: pd.DataFrame) -> dict[str, Any]: | |
| y_true = [] | |
| y_pred = [] | |
| oov_count = 0 | |
| for _, row in eval_df.iterrows(): | |
| a = row['drug_a'] | |
| b = row['drug_b'] | |
| label = row['label'] | |
| y_true.append(LABEL_TO_INDEX.get(label, 0)) | |
| # Use predictor internals to produce logits and probabilities | |
| a_id = predictor._find_vocab_id(a) | |
| b_id = predictor._find_vocab_id(b) | |
| if a_id == 0 or b_id == 0: | |
| oov_count += 1 | |
| with torch.no_grad(): | |
| logits = predictor.model(torch.tensor([a_id], dtype=torch.long), torch.tensor([b_id], dtype=torch.long)) | |
| probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy() | |
| pred_idx = int(np.argmax(probs).item()) | |
| y_pred.append(pred_idx) | |
| y_true = np.array(y_true) | |
| y_pred = np.array(y_pred) | |
| acc = float(accuracy_score(y_true, y_pred)) | |
| macro_f1 = float(f1_score(y_true, y_pred, average='macro', zero_division=0)) | |
| # severe recall corresponds to 'major' label | |
| if 'major' in LABEL_TO_INDEX: | |
| major_idx = LABEL_TO_INDEX['major'] | |
| severe_recall = float(recall_score(y_true, y_pred, labels=[major_idx], average='macro', zero_division=0)) | |
| else: | |
| severe_recall = 0.0 | |
| cm = confusion_matrix(y_true, y_pred, labels=list(range(len(predictor.label_names)))) | |
| metrics = { | |
| 'accuracy': round(acc, 4), | |
| 'macro_f1': round(macro_f1, 4), | |
| 'severe_recall': round(severe_recall, 4), | |
| 'num_examples': int(len(eval_df)), | |
| 'oov_count': int(oov_count), | |
| 'oov_rate': round(float(oov_count) / max(1, len(eval_df)), 4), | |
| } | |
| return metrics, cm, y_true, y_pred | |
| def save_confusion_matrix(cm: np.ndarray, labels: list[str], out_path: Path) -> None: | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
| ax.figure.colorbar(im, ax=ax) | |
| ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), xticklabels=labels, yticklabels=labels, ylabel='True label', xlabel='Predicted label', title='Confusion Matrix') | |
| plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor') | |
| thresh = cm.max() / 2.0 | |
| for i in range(cm.shape[0]): | |
| for j in range(cm.shape[1]): | |
| ax.text(j, i, format(int(cm[i, j]), 'd'), ha='center', va='center', color='white' if cm[i, j] > thresh else 'black') | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=150) | |
| plt.close(fig) | |
| def main() -> None: | |
| print('Loading checkpoint and predictor...') | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f'Checkpoint not found at {MODEL_PATH}') | |
| predictor = HybridDDIPredictor.from_default_paths() | |
| print('Loading processed dataset...') | |
| df = manager.load_artifact('ddinter_combined') | |
| eval_df = build_pairs_from_csv(df) | |
| print(f'Prepared {len(eval_df)} canonical pairs for evaluation') | |
| metrics, cm, y_true, y_pred = evaluate(predictor, eval_df) | |
| # Additional checks: preprocessing consistency | |
| metadata = { | |
| 'model_version': predictor.model_version, | |
| 'vocab_size_checkpoint': len(predictor.vocab), | |
| 'vocab_size_used_by_model': predictor.model.embedding.num_embeddings - 1, | |
| 'embedding_dim_checkpoint': predictor.embedding_dim, | |
| 'model_embedding_dim': predictor.model.embedding.embedding_dim, | |
| 'label_names': predictor.label_names, | |
| 'index_to_label': predictor.index_to_label, | |
| 'num_eval_pairs': len(eval_df), | |
| } | |
| metrics.update(metadata) | |
| # Save metrics JSON | |
| metrics_path = OUT_DIR / 'metrics_summary.json' | |
| with metrics_path.open('w', encoding='utf-8') as fh: | |
| json.dump(metrics, fh, indent=2) | |
| # Save confusion matrix PNG | |
| cm_path = OUT_DIR / 'confusion_matrix.png' | |
| save_confusion_matrix(cm, predictor.label_names, cm_path) | |
| # Generate simple report | |
| report_lines = [] | |
| report_lines.append('# Inference Validation Report') | |
| report_lines.append('') | |
| report_lines.append(f'- Model version: {predictor.model_version}') | |
| report_lines.append(f'- Eval pairs: {len(eval_df)}') | |
| report_lines.append(f"- Vocab size (checkpoint): {metadata['vocab_size_checkpoint']}") | |
| report_lines.append(f"- Vocab size (model): {metadata['vocab_size_used_by_model']}") | |
| report_lines.append(f"- Embedding dim (checkpoint): {metadata['embedding_dim_checkpoint']}") | |
| report_lines.append(f"- Embedding dim (model): {metadata['model_embedding_dim']}") | |
| report_lines.append('') | |
| report_lines.append('## Metrics') | |
| report_lines.append('') | |
| report_lines.append(f"- Accuracy: {metrics['accuracy']}") | |
| report_lines.append(f"- Macro F1: {metrics['macro_f1']}") | |
| report_lines.append(f"- Severe (major) recall: {metrics['severe_recall']}") | |
| report_lines.append(f"- OOV count: {metrics['oov_count']} (rate: {metrics['oov_rate']})") | |
| report_lines.append('') | |
| report_lines.append('## Confusion matrix') | |
| report_lines.append(f'Confusion matrix saved to `{cm_path}`') | |
| report_lines.append('') | |
| report_lines.append('## Preprocessing & Consistency Checks') | |
| report_lines.append('- Label ordering (checkpoint): ' + ', '.join(predictor.label_names)) | |
| report_lines.append('- Index to label mapping:') | |
| report_lines.append('') | |
| for idx, label in predictor.index_to_label.items(): | |
| report_lines.append(f'- {idx} -> {label}') | |
| report_lines.append('') | |
| report_lines.append('## Observations & drift checks') | |
| if metrics['oov_rate'] > 0.05: | |
| report_lines.append('- Warning: OOV rate exceeds 5% — incoming drug names differ from training vocabulary.') | |
| else: | |
| report_lines.append('- OOV rate within expected bounds.') | |
| # Healthcare grade quick pass/fail on severe recall | |
| threshold_severe_recall = 0.90 | |
| if metrics['severe_recall'] >= threshold_severe_recall: | |
| report_lines.append(f'- Severe recall >= {threshold_severe_recall} (PASS)') | |
| else: | |
| report_lines.append(f'- Severe recall < {threshold_severe_recall} (FAIL) — consider retraining or calibration for higher sensitivity on critical events') | |
| report_path = OUT_DIR / 'inference_validation_report.md' | |
| report_path.write_text('\n'.join(report_lines), encoding='utf-8') | |
| print('Evaluation complete.') | |
| print(f'Metrics JSON: {metrics_path}') | |
| print(f'Confusion matrix: {cm_path}') | |
| print(f'Report: {report_path}') | |
| if __name__ == '__main__': | |
| main() | |