File size: 10,830 Bytes
7f1af80
 
 
 
 
 
 
c1969c3
 
 
 
 
 
 
 
 
 
 
7f1af80
c1969c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91347e8
c1969c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f1af80
c1969c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f1af80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1969c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f1af80
 
 
 
c1969c3
 
 
7f1af80
 
 
c1969c3
7f1af80
 
c1969c3
 
 
 
7f1af80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1969c3
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""Model evaluation and comparison script.

Supports three inference modes:
- Standard: single forward pass per image.
- TTA: average over 4 augmented views (original + h-flip + rotate ±7°).
- Ensemble + TTA: average TTA predictions from all loaded models.
"""

from __future__ import annotations

import argparse
import json
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as transforms_functional
from torch.utils.data import DataLoader

from src.data.dataset import PATHOLOGY_LABELS, ChestXrayDataset
from src.models.densenet_transfer import CheXVisionDenseNet
from src.models.scratch_cnn import CheXVisionScratch
from src.training.metrics import compute_binary_metrics, compute_multilabel_metrics
from src.training.trainer import set_seed

logger = logging.getLogger(__name__)


def load_model(checkpoint_path: Path, device: torch.device) -> tuple[torch.nn.Module, dict]:
    """Load a trained model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    config = checkpoint["config"]

    model: torch.nn.Module
    if config["model"]["type"] == "scratch":
        arch = config["model"].get("architecture", {})
        model = CheXVisionScratch(
            in_channels=3,
            num_classes=14,
            block_config=tuple(arch.get("block_config", [2, 2, 2, 2])),
            filter_sizes=tuple(arch.get("filter_sizes", [64, 128, 256, 512])),
            dropout=arch.get("dropout", 0.5),
        )
    else:
        arch = config["model"].get("architecture", {})
        model = CheXVisionDenseNet(
            num_classes=14,
            pretrained=False,
            dropout=arch.get("dropout", 0.3),
        )

    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    return model, config


@torch.no_grad()
def predict(model: torch.nn.Module, dataloader: DataLoader, device: torch.device) -> dict[str, np.ndarray]:
    """Run standard inference (single forward pass per image)."""
    all_ml_probs, all_ml_targets = [], []
    all_bin_probs, all_bin_targets = [], []

    for batch in dataloader:
        images = batch["image"].to(device)
        outputs = model(images)

        all_ml_probs.append(torch.sigmoid(outputs["multilabel_logits"]).cpu().numpy())
        all_ml_targets.append(batch["multilabel_target"].numpy())
        all_bin_probs.append(torch.sigmoid(outputs["binary_logits"]).cpu().numpy())
        all_bin_targets.append(batch["binary_target"].numpy())

    return {
        "ml_probs": np.concatenate(all_ml_probs),
        "ml_targets": np.concatenate(all_ml_targets),
        "bin_probs": np.concatenate(all_bin_probs).squeeze(-1),
        "bin_targets": np.concatenate(all_bin_targets).squeeze(-1),
    }


@torch.no_grad()
def predict_with_tta(
    model: torch.nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> dict[str, np.ndarray]:
    """Run inference with Test-Time Augmentation (TTA).

    Averages predictions over 4 views of each image:
      1. Original
      2. Horizontal flip  (chest X-rays are bilaterally symmetric)
      3. Rotate +7°       (simulates slight patient tilt)
      4. Rotate -7°

    Reduces prediction variance with zero additional training.
    """
    all_ml_probs, all_ml_targets = [], []
    all_bin_probs, all_bin_targets = [], []

    for batch in dataloader:
        images = batch["image"].to(device)

        augmented = [
            images,
            transforms_functional.hflip(images),
            transforms_functional.rotate(images, angle=7),
            transforms_functional.rotate(images, angle=-7),
        ]

        ml_sum = torch.zeros(images.size(0), len(PATHOLOGY_LABELS), device=device)
        bin_sum = torch.zeros(images.size(0), 1, device=device)

        for aug in augmented:
            out = model(aug)
            ml_sum += torch.sigmoid(out["multilabel_logits"])
            bin_sum += torch.sigmoid(out["binary_logits"])

        all_ml_probs.append((ml_sum / len(augmented)).cpu().numpy())
        all_ml_targets.append(batch["multilabel_target"].numpy())
        all_bin_probs.append((bin_sum / len(augmented)).cpu().numpy())
        all_bin_targets.append(batch["binary_target"].numpy())

    return {
        "ml_probs": np.concatenate(all_ml_probs),
        "ml_targets": np.concatenate(all_ml_targets),
        "bin_probs": np.concatenate(all_bin_probs).squeeze(-1),
        "bin_targets": np.concatenate(all_bin_targets).squeeze(-1),
    }


def predict_ensemble(
    models: list[torch.nn.Module],
    dataloader: DataLoader,
    device: torch.device,
    use_tta: bool = True,
) -> dict[str, np.ndarray]:
    """Average predictions from multiple models (ensemble), optionally with TTA.

    Combines the Custom CNN and DenseNet-121 predictions. The two architectures
    have different inductive biases and fail on different examples — averaging
    reduces variance and typically improves macro AUC.

    Args:
        models: List of loaded, eval-mode models.
        dataloader: DataLoader over the evaluation split.
        device: Target device.
        use_tta: If True, apply TTA to each model before averaging.
    """
    predict_fn = predict_with_tta if use_tta else predict
    all_results = [predict_fn(m, dataloader, device) for m in models]

    return {
        "ml_probs": np.mean([r["ml_probs"] for r in all_results], axis=0),
        "ml_targets": all_results[0]["ml_targets"],
        "bin_probs": np.mean([r["bin_probs"] for r in all_results], axis=0),
        "bin_targets": all_results[0]["bin_targets"],
    }


def compare_models(results: dict[str, dict], output_dir: Path) -> None:
    """Generate comparison plots and summary."""
    output_dir.mkdir(parents=True, exist_ok=True)

    model_names = list(results.keys())

    # Per-class AUC comparison
    fig, ax = plt.subplots(figsize=(14, 6))
    x = np.arange(len(PATHOLOGY_LABELS))
    width = 0.35

    for i, name in enumerate(model_names):
        aucs = [results[name]["ml_metrics"].get(f"auc_{label}", 0) for label in PATHOLOGY_LABELS]
        ax.bar(x + i * width, aucs, width, label=name)

    ax.set_xlabel("Pathology")
    ax.set_ylabel("AUC-ROC")
    ax.set_title("Per-Class AUC-ROC Comparison")
    ax.set_xticks(x + width / 2)
    ax.set_xticklabels(PATHOLOGY_LABELS, rotation=45, ha="right")
    ax.legend()
    ax.set_ylim(0, 1)
    plt.tight_layout()
    plt.savefig(output_dir / "auc_comparison.png", dpi=150)
    plt.close()

    # Summary table
    summary = {}
    for name in model_names:
        summary[name] = {
            "macro_auc": results[name]["ml_metrics"]["auc_roc_macro"],
            "macro_f1": results[name]["ml_metrics"]["f1_macro"],
            "binary_auc": results[name]["bin_metrics"].get("binary_auc_roc", 0),
            "binary_f1": results[name]["bin_metrics"]["binary_f1"],
            "binary_accuracy": results[name]["bin_metrics"]["binary_accuracy"],
        }

    with open(output_dir / "comparison_summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    logger.info("Comparison results saved to %s", output_dir)
    for name, metrics in summary.items():
        logger.info("  %s: Macro AUC=%.4f, Binary AUC=%.4f", name, metrics["macro_auc"], metrics["binary_auc"])


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate and compare CheXVision models")
    parser.add_argument("--model-dir", type=Path, default=Path("checkpoints"), help="Directory with model checkpoints")
    parser.add_argument("--data-dir", type=Path, default=Path("data"), help="Data directory")
    parser.add_argument("--output-dir", type=Path, default=Path("results"), help="Output directory for plots")
    parser.add_argument("--compare", action="store_true", help="Compare all models in model-dir")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    set_seed(42)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load test dataset
    test_dataset = ChestXrayDataset(args.data_dir / "images", args.data_dir / "labels.csv", split="test")
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    # Evaluate all checkpoints (standard + TTA)
    results: dict[str, dict] = {}
    loaded_models: list[torch.nn.Module] = []

    for ckpt_path in sorted(args.model_dir.glob("*_best.pth")):
        logger.info("Evaluating %s", ckpt_path.name)
        model, config = load_model(ckpt_path, device)
        loaded_models.append(model)

        name = config["model"].get("name", ckpt_path.stem)

        # Standard inference
        preds = predict(model, test_loader, device)
        ml_metrics = compute_multilabel_metrics(preds["ml_targets"], (preds["ml_probs"] >= 0.5).astype(int), preds["ml_probs"])
        bin_metrics = compute_binary_metrics(preds["bin_targets"], (preds["bin_probs"] >= 0.5).astype(int), preds["bin_probs"])
        results[name] = {"ml_metrics": ml_metrics, "bin_metrics": bin_metrics}

        # TTA inference
        logger.info("  Running TTA for %s …", name)
        preds_tta = predict_with_tta(model, test_loader, device)
        ml_tta = compute_multilabel_metrics(preds_tta["ml_targets"], (preds_tta["ml_probs"] >= 0.5).astype(int), preds_tta["ml_probs"])
        bin_tta = compute_binary_metrics(preds_tta["bin_targets"], (preds_tta["bin_probs"] >= 0.5).astype(int), preds_tta["bin_probs"])
        results[f"{name} + TTA"] = {"ml_metrics": ml_tta, "bin_metrics": bin_tta}
        logger.info("  %s: AUC %.4f → TTA %.4f", name, ml_metrics["auc_roc_macro"], ml_tta["auc_roc_macro"])

    # Ensemble (only when both models are present)
    if len(loaded_models) >= 2:
        logger.info("Running ensemble (all models + TTA) …")
        preds_ens = predict_ensemble(loaded_models, test_loader, device, use_tta=True)
        ml_ens = compute_multilabel_metrics(preds_ens["ml_targets"], (preds_ens["ml_probs"] >= 0.5).astype(int), preds_ens["ml_probs"])
        bin_ens = compute_binary_metrics(preds_ens["bin_targets"], (preds_ens["bin_probs"] >= 0.5).astype(int), preds_ens["bin_probs"])
        results["Ensemble (CNN + DenseNet + TTA)"] = {"ml_metrics": ml_ens, "bin_metrics": bin_ens}
        logger.info("  Ensemble: Macro AUC=%.4f, Binary AUC=%.4f", ml_ens["auc_roc_macro"], bin_ens.get("binary_auc_roc", 0))

    if args.compare and len(results) >= 2:
        compare_models(results, args.output_dir)
    elif results:
        for name, r in results.items():
            logger.info("%s — Macro AUC: %.4f, Binary AUC: %.4f", name, r["ml_metrics"]["auc_roc_macro"], r["bin_metrics"].get("binary_auc_roc", 0))


if __name__ == "__main__":
    main()