Spaces:
Running
Running
| """Model evaluation and comparison script. | |
| Supports three inference modes: | |
| - Standard: single forward pass per image. | |
| - TTA: average over 4 augmented views (original + h-flip + rotate ±7°). | |
| - Ensemble + TTA: average TTA predictions from all loaded models. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as transforms_functional | |
| from torch.utils.data import DataLoader | |
| from src.data.dataset import PATHOLOGY_LABELS, ChestXrayDataset | |
| from src.models.densenet_transfer import CheXVisionDenseNet | |
| from src.models.scratch_cnn import CheXVisionScratch | |
| from src.training.metrics import compute_binary_metrics, compute_multilabel_metrics | |
| from src.training.trainer import set_seed | |
| logger = logging.getLogger(__name__) | |
| def load_model(checkpoint_path: Path, device: torch.device) -> tuple[torch.nn.Module, dict]: | |
| """Load a trained model from checkpoint.""" | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| config = checkpoint["config"] | |
| model: torch.nn.Module | |
| if config["model"]["type"] == "scratch": | |
| arch = config["model"].get("architecture", {}) | |
| model = CheXVisionScratch( | |
| in_channels=3, | |
| num_classes=14, | |
| block_config=tuple(arch.get("block_config", [2, 2, 2, 2])), | |
| filter_sizes=tuple(arch.get("filter_sizes", [64, 128, 256, 512])), | |
| dropout=arch.get("dropout", 0.5), | |
| ) | |
| else: | |
| arch = config["model"].get("architecture", {}) | |
| model = CheXVisionDenseNet( | |
| num_classes=14, | |
| pretrained=False, | |
| dropout=arch.get("dropout", 0.3), | |
| ) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| return model, config | |
| def predict(model: torch.nn.Module, dataloader: DataLoader, device: torch.device) -> dict[str, np.ndarray]: | |
| """Run standard inference (single forward pass per image).""" | |
| all_ml_probs, all_ml_targets = [], [] | |
| all_bin_probs, all_bin_targets = [], [] | |
| for batch in dataloader: | |
| images = batch["image"].to(device) | |
| outputs = model(images) | |
| all_ml_probs.append(torch.sigmoid(outputs["multilabel_logits"]).cpu().numpy()) | |
| all_ml_targets.append(batch["multilabel_target"].numpy()) | |
| all_bin_probs.append(torch.sigmoid(outputs["binary_logits"]).cpu().numpy()) | |
| all_bin_targets.append(batch["binary_target"].numpy()) | |
| return { | |
| "ml_probs": np.concatenate(all_ml_probs), | |
| "ml_targets": np.concatenate(all_ml_targets), | |
| "bin_probs": np.concatenate(all_bin_probs).squeeze(-1), | |
| "bin_targets": np.concatenate(all_bin_targets).squeeze(-1), | |
| } | |
| def predict_with_tta( | |
| model: torch.nn.Module, | |
| dataloader: DataLoader, | |
| device: torch.device, | |
| ) -> dict[str, np.ndarray]: | |
| """Run inference with Test-Time Augmentation (TTA). | |
| Averages predictions over 4 views of each image: | |
| 1. Original | |
| 2. Horizontal flip (chest X-rays are bilaterally symmetric) | |
| 3. Rotate +7° (simulates slight patient tilt) | |
| 4. Rotate -7° | |
| Reduces prediction variance with zero additional training. | |
| """ | |
| all_ml_probs, all_ml_targets = [], [] | |
| all_bin_probs, all_bin_targets = [], [] | |
| for batch in dataloader: | |
| images = batch["image"].to(device) | |
| augmented = [ | |
| images, | |
| transforms_functional.hflip(images), | |
| transforms_functional.rotate(images, angle=7), | |
| transforms_functional.rotate(images, angle=-7), | |
| ] | |
| ml_sum = torch.zeros(images.size(0), len(PATHOLOGY_LABELS), device=device) | |
| bin_sum = torch.zeros(images.size(0), 1, device=device) | |
| for aug in augmented: | |
| out = model(aug) | |
| ml_sum += torch.sigmoid(out["multilabel_logits"]) | |
| bin_sum += torch.sigmoid(out["binary_logits"]) | |
| all_ml_probs.append((ml_sum / len(augmented)).cpu().numpy()) | |
| all_ml_targets.append(batch["multilabel_target"].numpy()) | |
| all_bin_probs.append((bin_sum / len(augmented)).cpu().numpy()) | |
| all_bin_targets.append(batch["binary_target"].numpy()) | |
| return { | |
| "ml_probs": np.concatenate(all_ml_probs), | |
| "ml_targets": np.concatenate(all_ml_targets), | |
| "bin_probs": np.concatenate(all_bin_probs).squeeze(-1), | |
| "bin_targets": np.concatenate(all_bin_targets).squeeze(-1), | |
| } | |
| def predict_ensemble( | |
| models: list[torch.nn.Module], | |
| dataloader: DataLoader, | |
| device: torch.device, | |
| use_tta: bool = True, | |
| ) -> dict[str, np.ndarray]: | |
| """Average predictions from multiple models (ensemble), optionally with TTA. | |
| Combines the Custom CNN and DenseNet-121 predictions. The two architectures | |
| have different inductive biases and fail on different examples — averaging | |
| reduces variance and typically improves macro AUC. | |
| Args: | |
| models: List of loaded, eval-mode models. | |
| dataloader: DataLoader over the evaluation split. | |
| device: Target device. | |
| use_tta: If True, apply TTA to each model before averaging. | |
| """ | |
| predict_fn = predict_with_tta if use_tta else predict | |
| all_results = [predict_fn(m, dataloader, device) for m in models] | |
| return { | |
| "ml_probs": np.mean([r["ml_probs"] for r in all_results], axis=0), | |
| "ml_targets": all_results[0]["ml_targets"], | |
| "bin_probs": np.mean([r["bin_probs"] for r in all_results], axis=0), | |
| "bin_targets": all_results[0]["bin_targets"], | |
| } | |
| def compare_models(results: dict[str, dict], output_dir: Path) -> None: | |
| """Generate comparison plots and summary.""" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| model_names = list(results.keys()) | |
| # Per-class AUC comparison | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| x = np.arange(len(PATHOLOGY_LABELS)) | |
| width = 0.35 | |
| for i, name in enumerate(model_names): | |
| aucs = [results[name]["ml_metrics"].get(f"auc_{label}", 0) for label in PATHOLOGY_LABELS] | |
| ax.bar(x + i * width, aucs, width, label=name) | |
| ax.set_xlabel("Pathology") | |
| ax.set_ylabel("AUC-ROC") | |
| ax.set_title("Per-Class AUC-ROC Comparison") | |
| ax.set_xticks(x + width / 2) | |
| ax.set_xticklabels(PATHOLOGY_LABELS, rotation=45, ha="right") | |
| ax.legend() | |
| ax.set_ylim(0, 1) | |
| plt.tight_layout() | |
| plt.savefig(output_dir / "auc_comparison.png", dpi=150) | |
| plt.close() | |
| # Summary table | |
| summary = {} | |
| for name in model_names: | |
| summary[name] = { | |
| "macro_auc": results[name]["ml_metrics"]["auc_roc_macro"], | |
| "macro_f1": results[name]["ml_metrics"]["f1_macro"], | |
| "binary_auc": results[name]["bin_metrics"].get("binary_auc_roc", 0), | |
| "binary_f1": results[name]["bin_metrics"]["binary_f1"], | |
| "binary_accuracy": results[name]["bin_metrics"]["binary_accuracy"], | |
| } | |
| with open(output_dir / "comparison_summary.json", "w") as f: | |
| json.dump(summary, f, indent=2) | |
| logger.info("Comparison results saved to %s", output_dir) | |
| for name, metrics in summary.items(): | |
| logger.info(" %s: Macro AUC=%.4f, Binary AUC=%.4f", name, metrics["macro_auc"], metrics["binary_auc"]) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate and compare CheXVision models") | |
| parser.add_argument("--model-dir", type=Path, default=Path("checkpoints"), help="Directory with model checkpoints") | |
| parser.add_argument("--data-dir", type=Path, default=Path("data"), help="Data directory") | |
| parser.add_argument("--output-dir", type=Path, default=Path("results"), help="Output directory for plots") | |
| parser.add_argument("--compare", action="store_true", help="Compare all models in model-dir") | |
| args = parser.parse_args() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| set_seed(42) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load test dataset | |
| test_dataset = ChestXrayDataset(args.data_dir / "images", args.data_dir / "labels.csv", split="test") | |
| test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4) | |
| # Evaluate all checkpoints (standard + TTA) | |
| results: dict[str, dict] = {} | |
| loaded_models: list[torch.nn.Module] = [] | |
| for ckpt_path in sorted(args.model_dir.glob("*_best.pth")): | |
| logger.info("Evaluating %s", ckpt_path.name) | |
| model, config = load_model(ckpt_path, device) | |
| loaded_models.append(model) | |
| name = config["model"].get("name", ckpt_path.stem) | |
| # Standard inference | |
| preds = predict(model, test_loader, device) | |
| ml_metrics = compute_multilabel_metrics(preds["ml_targets"], (preds["ml_probs"] >= 0.5).astype(int), preds["ml_probs"]) | |
| bin_metrics = compute_binary_metrics(preds["bin_targets"], (preds["bin_probs"] >= 0.5).astype(int), preds["bin_probs"]) | |
| results[name] = {"ml_metrics": ml_metrics, "bin_metrics": bin_metrics} | |
| # TTA inference | |
| logger.info(" Running TTA for %s …", name) | |
| preds_tta = predict_with_tta(model, test_loader, device) | |
| ml_tta = compute_multilabel_metrics(preds_tta["ml_targets"], (preds_tta["ml_probs"] >= 0.5).astype(int), preds_tta["ml_probs"]) | |
| bin_tta = compute_binary_metrics(preds_tta["bin_targets"], (preds_tta["bin_probs"] >= 0.5).astype(int), preds_tta["bin_probs"]) | |
| results[f"{name} + TTA"] = {"ml_metrics": ml_tta, "bin_metrics": bin_tta} | |
| logger.info(" %s: AUC %.4f → TTA %.4f", name, ml_metrics["auc_roc_macro"], ml_tta["auc_roc_macro"]) | |
| # Ensemble (only when both models are present) | |
| if len(loaded_models) >= 2: | |
| logger.info("Running ensemble (all models + TTA) …") | |
| preds_ens = predict_ensemble(loaded_models, test_loader, device, use_tta=True) | |
| ml_ens = compute_multilabel_metrics(preds_ens["ml_targets"], (preds_ens["ml_probs"] >= 0.5).astype(int), preds_ens["ml_probs"]) | |
| bin_ens = compute_binary_metrics(preds_ens["bin_targets"], (preds_ens["bin_probs"] >= 0.5).astype(int), preds_ens["bin_probs"]) | |
| results["Ensemble (CNN + DenseNet + TTA)"] = {"ml_metrics": ml_ens, "bin_metrics": bin_ens} | |
| logger.info(" Ensemble: Macro AUC=%.4f, Binary AUC=%.4f", ml_ens["auc_roc_macro"], bin_ens.get("binary_auc_roc", 0)) | |
| if args.compare and len(results) >= 2: | |
| compare_models(results, args.output_dir) | |
| elif results: | |
| for name, r in results.items(): | |
| logger.info("%s — Macro AUC: %.4f, Binary AUC: %.4f", name, r["ml_metrics"]["auc_roc_macro"], r["bin_metrics"].get("binary_auc_roc", 0)) | |
| if __name__ == "__main__": | |
| main() | |