chexvision-demo / src /training /evaluate.py
arudaev's picture
feat: add CLAHE, label smoothing, TTA, and ensemble inference
7f1af80
"""Model evaluation and comparison script.
Supports three inference modes:
- Standard: single forward pass per image.
- TTA: average over 4 augmented views (original + h-flip + rotate ±7°).
- Ensemble + TTA: average TTA predictions from all loaded models.
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as transforms_functional
from torch.utils.data import DataLoader
from src.data.dataset import PATHOLOGY_LABELS, ChestXrayDataset
from src.models.densenet_transfer import CheXVisionDenseNet
from src.models.scratch_cnn import CheXVisionScratch
from src.training.metrics import compute_binary_metrics, compute_multilabel_metrics
from src.training.trainer import set_seed
logger = logging.getLogger(__name__)
def load_model(checkpoint_path: Path, device: torch.device) -> tuple[torch.nn.Module, dict]:
"""Load a trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = checkpoint["config"]
model: torch.nn.Module
if config["model"]["type"] == "scratch":
arch = config["model"].get("architecture", {})
model = CheXVisionScratch(
in_channels=3,
num_classes=14,
block_config=tuple(arch.get("block_config", [2, 2, 2, 2])),
filter_sizes=tuple(arch.get("filter_sizes", [64, 128, 256, 512])),
dropout=arch.get("dropout", 0.5),
)
else:
arch = config["model"].get("architecture", {})
model = CheXVisionDenseNet(
num_classes=14,
pretrained=False,
dropout=arch.get("dropout", 0.3),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
return model, config
@torch.no_grad()
def predict(model: torch.nn.Module, dataloader: DataLoader, device: torch.device) -> dict[str, np.ndarray]:
"""Run standard inference (single forward pass per image)."""
all_ml_probs, all_ml_targets = [], []
all_bin_probs, all_bin_targets = [], []
for batch in dataloader:
images = batch["image"].to(device)
outputs = model(images)
all_ml_probs.append(torch.sigmoid(outputs["multilabel_logits"]).cpu().numpy())
all_ml_targets.append(batch["multilabel_target"].numpy())
all_bin_probs.append(torch.sigmoid(outputs["binary_logits"]).cpu().numpy())
all_bin_targets.append(batch["binary_target"].numpy())
return {
"ml_probs": np.concatenate(all_ml_probs),
"ml_targets": np.concatenate(all_ml_targets),
"bin_probs": np.concatenate(all_bin_probs).squeeze(-1),
"bin_targets": np.concatenate(all_bin_targets).squeeze(-1),
}
@torch.no_grad()
def predict_with_tta(
model: torch.nn.Module,
dataloader: DataLoader,
device: torch.device,
) -> dict[str, np.ndarray]:
"""Run inference with Test-Time Augmentation (TTA).
Averages predictions over 4 views of each image:
1. Original
2. Horizontal flip (chest X-rays are bilaterally symmetric)
3. Rotate +7° (simulates slight patient tilt)
4. Rotate -7°
Reduces prediction variance with zero additional training.
"""
all_ml_probs, all_ml_targets = [], []
all_bin_probs, all_bin_targets = [], []
for batch in dataloader:
images = batch["image"].to(device)
augmented = [
images,
transforms_functional.hflip(images),
transforms_functional.rotate(images, angle=7),
transforms_functional.rotate(images, angle=-7),
]
ml_sum = torch.zeros(images.size(0), len(PATHOLOGY_LABELS), device=device)
bin_sum = torch.zeros(images.size(0), 1, device=device)
for aug in augmented:
out = model(aug)
ml_sum += torch.sigmoid(out["multilabel_logits"])
bin_sum += torch.sigmoid(out["binary_logits"])
all_ml_probs.append((ml_sum / len(augmented)).cpu().numpy())
all_ml_targets.append(batch["multilabel_target"].numpy())
all_bin_probs.append((bin_sum / len(augmented)).cpu().numpy())
all_bin_targets.append(batch["binary_target"].numpy())
return {
"ml_probs": np.concatenate(all_ml_probs),
"ml_targets": np.concatenate(all_ml_targets),
"bin_probs": np.concatenate(all_bin_probs).squeeze(-1),
"bin_targets": np.concatenate(all_bin_targets).squeeze(-1),
}
def predict_ensemble(
models: list[torch.nn.Module],
dataloader: DataLoader,
device: torch.device,
use_tta: bool = True,
) -> dict[str, np.ndarray]:
"""Average predictions from multiple models (ensemble), optionally with TTA.
Combines the Custom CNN and DenseNet-121 predictions. The two architectures
have different inductive biases and fail on different examples — averaging
reduces variance and typically improves macro AUC.
Args:
models: List of loaded, eval-mode models.
dataloader: DataLoader over the evaluation split.
device: Target device.
use_tta: If True, apply TTA to each model before averaging.
"""
predict_fn = predict_with_tta if use_tta else predict
all_results = [predict_fn(m, dataloader, device) for m in models]
return {
"ml_probs": np.mean([r["ml_probs"] for r in all_results], axis=0),
"ml_targets": all_results[0]["ml_targets"],
"bin_probs": np.mean([r["bin_probs"] for r in all_results], axis=0),
"bin_targets": all_results[0]["bin_targets"],
}
def compare_models(results: dict[str, dict], output_dir: Path) -> None:
"""Generate comparison plots and summary."""
output_dir.mkdir(parents=True, exist_ok=True)
model_names = list(results.keys())
# Per-class AUC comparison
fig, ax = plt.subplots(figsize=(14, 6))
x = np.arange(len(PATHOLOGY_LABELS))
width = 0.35
for i, name in enumerate(model_names):
aucs = [results[name]["ml_metrics"].get(f"auc_{label}", 0) for label in PATHOLOGY_LABELS]
ax.bar(x + i * width, aucs, width, label=name)
ax.set_xlabel("Pathology")
ax.set_ylabel("AUC-ROC")
ax.set_title("Per-Class AUC-ROC Comparison")
ax.set_xticks(x + width / 2)
ax.set_xticklabels(PATHOLOGY_LABELS, rotation=45, ha="right")
ax.legend()
ax.set_ylim(0, 1)
plt.tight_layout()
plt.savefig(output_dir / "auc_comparison.png", dpi=150)
plt.close()
# Summary table
summary = {}
for name in model_names:
summary[name] = {
"macro_auc": results[name]["ml_metrics"]["auc_roc_macro"],
"macro_f1": results[name]["ml_metrics"]["f1_macro"],
"binary_auc": results[name]["bin_metrics"].get("binary_auc_roc", 0),
"binary_f1": results[name]["bin_metrics"]["binary_f1"],
"binary_accuracy": results[name]["bin_metrics"]["binary_accuracy"],
}
with open(output_dir / "comparison_summary.json", "w") as f:
json.dump(summary, f, indent=2)
logger.info("Comparison results saved to %s", output_dir)
for name, metrics in summary.items():
logger.info(" %s: Macro AUC=%.4f, Binary AUC=%.4f", name, metrics["macro_auc"], metrics["binary_auc"])
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate and compare CheXVision models")
parser.add_argument("--model-dir", type=Path, default=Path("checkpoints"), help="Directory with model checkpoints")
parser.add_argument("--data-dir", type=Path, default=Path("data"), help="Data directory")
parser.add_argument("--output-dir", type=Path, default=Path("results"), help="Output directory for plots")
parser.add_argument("--compare", action="store_true", help="Compare all models in model-dir")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load test dataset
test_dataset = ChestXrayDataset(args.data_dir / "images", args.data_dir / "labels.csv", split="test")
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
# Evaluate all checkpoints (standard + TTA)
results: dict[str, dict] = {}
loaded_models: list[torch.nn.Module] = []
for ckpt_path in sorted(args.model_dir.glob("*_best.pth")):
logger.info("Evaluating %s", ckpt_path.name)
model, config = load_model(ckpt_path, device)
loaded_models.append(model)
name = config["model"].get("name", ckpt_path.stem)
# Standard inference
preds = predict(model, test_loader, device)
ml_metrics = compute_multilabel_metrics(preds["ml_targets"], (preds["ml_probs"] >= 0.5).astype(int), preds["ml_probs"])
bin_metrics = compute_binary_metrics(preds["bin_targets"], (preds["bin_probs"] >= 0.5).astype(int), preds["bin_probs"])
results[name] = {"ml_metrics": ml_metrics, "bin_metrics": bin_metrics}
# TTA inference
logger.info(" Running TTA for %s …", name)
preds_tta = predict_with_tta(model, test_loader, device)
ml_tta = compute_multilabel_metrics(preds_tta["ml_targets"], (preds_tta["ml_probs"] >= 0.5).astype(int), preds_tta["ml_probs"])
bin_tta = compute_binary_metrics(preds_tta["bin_targets"], (preds_tta["bin_probs"] >= 0.5).astype(int), preds_tta["bin_probs"])
results[f"{name} + TTA"] = {"ml_metrics": ml_tta, "bin_metrics": bin_tta}
logger.info(" %s: AUC %.4f → TTA %.4f", name, ml_metrics["auc_roc_macro"], ml_tta["auc_roc_macro"])
# Ensemble (only when both models are present)
if len(loaded_models) >= 2:
logger.info("Running ensemble (all models + TTA) …")
preds_ens = predict_ensemble(loaded_models, test_loader, device, use_tta=True)
ml_ens = compute_multilabel_metrics(preds_ens["ml_targets"], (preds_ens["ml_probs"] >= 0.5).astype(int), preds_ens["ml_probs"])
bin_ens = compute_binary_metrics(preds_ens["bin_targets"], (preds_ens["bin_probs"] >= 0.5).astype(int), preds_ens["bin_probs"])
results["Ensemble (CNN + DenseNet + TTA)"] = {"ml_metrics": ml_ens, "bin_metrics": bin_ens}
logger.info(" Ensemble: Macro AUC=%.4f, Binary AUC=%.4f", ml_ens["auc_roc_macro"], bin_ens.get("binary_auc_roc", 0))
if args.compare and len(results) >= 2:
compare_models(results, args.output_dir)
elif results:
for name, r in results.items():
logger.info("%s — Macro AUC: %.4f, Binary AUC: %.4f", name, r["ml_metrics"]["auc_roc_macro"], r["bin_metrics"].get("binary_auc_roc", 0))
if __name__ == "__main__":
main()