Spaces:
Running
Running
| """Embedding model benchmark and comparison. | |
| Compares: | |
| - BioBERT | |
| - PubMedBERT | |
| - SapBERT | |
| - ChemBERTa | |
| Output: | |
| - embedding_benchmark_results.csv | |
| - embedding_ablation_report.md | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from preprocessing.artifact_manager import manager | |
| import torch | |
| from sklearn.metrics import accuracy_score, f1_score, recall_score, roc_auc_score | |
| from sklearn.model_selection import train_test_split | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
| ) | |
| logger = logging.getLogger('medcare_ddi.embedding_bench') | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| DATA_DIR = BASE_DIR / 'data' | |
| PROCESSED_DIR = DATA_DIR / 'processed' | |
| MODEL_DIR = BASE_DIR / 'models' | |
| REPORTS_DIR = MODEL_DIR / 'reports' | |
| REPORTS_DIR.mkdir(parents=True, exist_ok=True) | |
| EMBEDDING_MODELS = { | |
| 'biobert': 'dmis-lab/biobert-base-cased-v1.1', | |
| 'pubmedbert': 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext', | |
| 'sapbert': 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext', | |
| 'chemberta': 'seyonec/ChemBERTa-zinc-base-v1', | |
| } | |
| LABEL_NAMES = ['unknown', 'minor', 'moderate', 'major'] | |
| LABEL_TO_INDEX = {label: idx for idx, label in enumerate(LABEL_NAMES)} | |
| def _normalize_text(v: str) -> str: | |
| return ' '.join(str(v).strip().lower().split()) | |
| def load_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Load preprocessed features and labels.""" | |
| feature_pipeline_path = MODEL_DIR / 'feature_pipeline_multisource.pkl' | |
| if not feature_pipeline_path.exists(): | |
| raise FileNotFoundError(f'Feature pipeline not found: {feature_pipeline_path}') | |
| feature_pipeline = joblib.load(feature_pipeline_path) | |
| ddinter_path = PROCESSED_DIR / 'ddinter_combined.parquet' | |
| if not ddinter_path.exists(): | |
| raise FileNotFoundError(f'DDInter not found: {ddinter_path}') | |
| df = manager.load_artifact('ddinter_combined') | |
| logger.info(f'Loaded {len(df)} DDInter records') | |
| y = np.array([LABEL_TO_INDEX.get(str(lbl).lower(), 0) for lbl in df['Level']], dtype=np.int64) | |
| # Create drug name pairs | |
| drug_names = list(df['Drug_A'].astype(str)) + list(df['Drug_B'].astype(str)) | |
| return np.array(df['Drug_A'].astype(str)), np.array(df['Drug_B'].astype(str)), y | |
| def benchmark_embedding_model( | |
| model_name: str, | |
| model_id: str, | |
| drug_a_names: np.ndarray, | |
| drug_b_names: np.ndarray, | |
| y_true: np.ndarray, | |
| seed: int = 2026, | |
| ) -> Dict[str, Any]: | |
| """Benchmark a single embedding model.""" | |
| logger.info(f'Benchmarking {model_name} ({model_id})') | |
| try: | |
| from training.embeddings import EmbeddingService | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| svc = EmbeddingService(device=device) | |
| # Extract embeddings | |
| embs_a = svc.get_text_embeddings(drug_a_names.tolist(), model_name=model_name, batch_size=32) | |
| embs_b = svc.get_text_embeddings(drug_b_names.tolist(), model_name=model_name, batch_size=32) | |
| # Concatenate embeddings | |
| X = np.hstack([embs_a, embs_b]).astype(np.float32) | |
| logger.info(f'{model_name}: feature shape {X.shape}') | |
| # Train-test split | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y_true, test_size=0.2, random_state=seed, stratify=y_true | |
| ) | |
| # Train ensemble on embeddings | |
| from training.ensemble import train_base_models | |
| ensemble_dir = REPORTS_DIR / f'embedding_{model_name}_ensemble' | |
| train_base_models(X_train, y_train, ensemble_dir, random_state=seed) | |
| # Load and evaluate | |
| from training.ensemble import EnsemblePredictor | |
| predictor = EnsemblePredictor(ensemble_dir) | |
| probs = predictor.predict_proba(X_test) | |
| preds = np.argmax(probs, axis=1) | |
| # Compute metrics | |
| accuracy = float(accuracy_score(y_test, preds)) | |
| macro_f1 = float(f1_score(y_test, preds, average='macro', zero_division=0)) | |
| severe_idx = LABEL_TO_INDEX['major'] | |
| severe_recall = float(recall_score(y_test, preds, labels=[severe_idx], average='macro', zero_division=0)) | |
| try: | |
| y_test_ovr = np.eye(len(LABEL_NAMES))[y_test] | |
| auroc = float(roc_auc_score(y_test_ovr, probs, average='macro', multi_class='ovr')) | |
| except Exception as e: | |
| logger.warning(f'AUROC calculation failed: {e}') | |
| auroc = 0.0 | |
| return { | |
| 'model_name': model_name, | |
| 'model_id': model_id, | |
| 'accuracy': accuracy, | |
| 'macro_f1': macro_f1, | |
| 'severe_recall': severe_recall, | |
| 'auroc': auroc, | |
| 'embedding_dim': int(embs_a.shape[1]), | |
| 'test_samples': len(y_test), | |
| 'status': 'success', | |
| } | |
| except Exception as e: | |
| logger.error(f'Benchmark failed for {model_name}: {e}', exc_info=True) | |
| return { | |
| 'model_name': model_name, | |
| 'model_id': model_id, | |
| 'status': 'failed', | |
| 'error': str(e), | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description='Benchmark embedding models') | |
| parser.add_argument('--seed', type=int, default=2026) | |
| parser.add_argument('--output-csv', type=str, default=str(REPORTS_DIR / 'embedding_benchmark_results.csv')) | |
| parser.add_argument('--output-md', type=str, default=str(REPORTS_DIR / 'embedding_ablation_report.md')) | |
| args = parser.parse_args() | |
| logger.info('Loading data...') | |
| drug_a_names, drug_b_names, y_true = load_data() | |
| logger.info(f'Loaded {len(y_true)} samples') | |
| # Benchmark each model | |
| results = [] | |
| for model_name, model_id in EMBEDDING_MODELS.items(): | |
| result = benchmark_embedding_model( | |
| model_name=model_name, | |
| model_id=model_id, | |
| drug_a_names=drug_a_names, | |
| drug_b_names=drug_b_names, | |
| y_true=y_true, | |
| seed=args.seed, | |
| ) | |
| results.append(result) | |
| # Save CSV results | |
| csv_path = Path(args.output_csv) | |
| csv_path.parent.mkdir(parents=True, exist_ok=True) | |
| with csv_path.open('w', newline='') as f: | |
| fieldnames = ['model_name', 'accuracy', 'macro_f1', 'severe_recall', 'auroc', 'embedding_dim', 'status'] | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| for r in results: | |
| if r.get('status') == 'success': | |
| writer.writerow({k: r.get(k) for k in fieldnames}) | |
| logger.info(f'Saved benchmark results to {csv_path}') | |
| # Generate markdown report | |
| md_path = Path(args.output_md) | |
| with md_path.open('w') as f: | |
| f.write('# Embedding Model Benchmark\n\n') | |
| f.write('## Summary\n\n') | |
| successful = [r for r in results if r.get('status') == 'success'] | |
| if successful: | |
| best = max(successful, key=lambda r: r.get('severe_recall', 0)) | |
| f.write(f'**Best model (by severe recall): {best["model_name"]}**\n\n') | |
| f.write(f'- Severe Recall: {best.get("severe_recall", 0):.4f}\n') | |
| f.write(f'- Accuracy: {best.get("accuracy", 0):.4f}\n') | |
| f.write(f'- Macro F1: {best.get("macro_f1", 0):.4f}\n') | |
| f.write(f'- AUROC: {best.get("auroc", 0):.4f}\n\n') | |
| f.write('## Results\n\n') | |
| f.write('| Model | Accuracy | Macro F1 | Severe Recall | AUROC | Dim |\n') | |
| f.write('|-------|----------|----------|---------------|-------|-----|\n') | |
| for r in successful: | |
| f.write( | |
| f"| {r['model_name']} | " | |
| f"{r.get('accuracy', 0):.4f} | " | |
| f"{r.get('macro_f1', 0):.4f} | " | |
| f"{r.get('severe_recall', 0):.4f} | " | |
| f"{r.get('auroc', 0):.4f} | " | |
| f"{r.get('embedding_dim', 0)} |\n" | |
| ) | |
| failed = [r for r in results if r.get('status') == 'failed'] | |
| if failed: | |
| f.write('\n## Failed Benchmarks\n\n') | |
| for r in failed: | |
| f.write(f"- {r['model_name']}: {r.get('error', 'unknown error')}\n") | |
| logger.info(f'Saved markdown report to {md_path}') | |
| logger.info('✓ Embedding benchmark complete') | |
| if __name__ == '__main__': | |
| main() | |