File size: 11,901 Bytes
f6ab35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
from __future__ import annotations

import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms

from .config import (
    CLASS_DISPLAY_NAMES,
    CLASS_NAMES,
    ENSEMBLE_MEMBERS,
    IMAGE_SIZE,
    MODELS_DIR,
    NORMALIZE_MEAN,
    NORMALIZE_STD,
    SELECTED_ENSEMBLE_PATH,
)


def _env_flag(name: str, default: bool = True) -> bool:
    raw = os.getenv(name)
    if raw is None:
        return default
    return raw.strip().lower() not in {"0", "false", "no", "off"}


STRICT_CHECKPOINT_LOADING = _env_flag("STRICT_CHECKPOINT_LOADING", True)


@dataclass
class LoadedMember:
    member: str
    display_name: str
    model_name: str
    seed: int
    weight: float
    checkpoint_file: str
    checkpoint_path: Path
    model: nn.Module


@dataclass
class PredictionResult:
    predicted_class: str
    predicted_display: str
    confidence: float
    probabilities: dict[str, float]
    probability_df: pd.DataFrame
    member_df: pd.DataFrame
    ensemble_logits: torch.Tensor
    input_tensor: torch.Tensor


_preprocess = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD),
    ]
)


def preprocess_image(image: Image.Image) -> torch.Tensor:
    if image is None:
        raise ValueError("Please upload an MRI image first.")
    return _preprocess(image.convert("RGB")).unsqueeze(0)


def build_model(model_name: str, num_classes: int = len(CLASS_NAMES)) -> nn.Module:
    constructors = {
        "efficientnet_b0": models.efficientnet_b0,
        "mobilenet_v3_small": models.mobilenet_v3_small,
    }
    if model_name not in constructors:
        raise ValueError(f"Unsupported deployment backbone: {model_name}")

    # Do not request torchvision pretrained weights at Space startup. The fine-tuned
    # checkpoint is expected to contain the trained weights.
    model = constructors[model_name](weights=None)

    if model_name in {"efficientnet_b0", "mobilenet_v3_small"}:
        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, num_classes)
    else:  # Defensive; guarded above.
        raise ValueError(f"No classifier replacement rule for {model_name}")
    return model


def _torch_load(path: Path) -> Any:
    """Load a PyTorch checkpoint across torch versions.

    Newer PyTorch versions may support weights_only. We first try the safer path,
    then fall back for older checkpoints that store a richer dictionary.
    """
    try:
        return torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        return torch.load(path, map_location="cpu")
    except Exception:
        # Only use this fallback for your own trusted checkpoints.
        return torch.load(path, map_location="cpu", weights_only=False)


def clean_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
    if isinstance(checkpoint, nn.Module):
        checkpoint = checkpoint.state_dict()

    if isinstance(checkpoint, dict):
        for key in ("model_state_dict", "state_dict", "model", "net", "weights"):
            value = checkpoint.get(key)
            if isinstance(value, dict):
                checkpoint = value
                break

    if not isinstance(checkpoint, dict):
        raise TypeError("Checkpoint does not contain a PyTorch state_dict-like object.")

    cleaned: dict[str, torch.Tensor] = {}
    for key, value in checkpoint.items():
        if not torch.is_tensor(value):
            continue
        new_key = str(key)
        for prefix in ("module.", "model."):
            if new_key.startswith(prefix):
                new_key = new_key[len(prefix) :]
        cleaned[new_key] = value

    if not cleaned:
        raise ValueError("No tensor weights were found in the checkpoint.")
    return cleaned


def expected_checkpoint_paths() -> dict[str, Path]:
    return {m["checkpoint_file"]: MODELS_DIR / m["checkpoint_file"] for m in ENSEMBLE_MEMBERS}


def diagnose_checkpoints() -> tuple[bool, pd.DataFrame, str]:
    rows = []
    all_present = True
    for m in ENSEMBLE_MEMBERS:
        path = MODELS_DIR / m["checkpoint_file"]
        exists = path.exists()
        all_present = all_present and exists
        rows.append(
            {
                "member": m["display_name"],
                "weight": round(float(m["weight"]), 8),
                "expected file": f"models/{m['checkpoint_file']}",
                "status": "✅ found" if exists else "❌ missing",
            }
        )
    df = pd.DataFrame(rows)
    if all_present:
        message = "✅ All required checkpoint files were found in `models/`."
    else:
        missing = [r["expected file"] for r in rows if r["status"].startswith("❌")]
        message = "❌ Missing checkpoint file(s):\n" + "\n".join(f"- `{m}`" for m in missing)
    return all_present, df, message


def _load_selected_metadata() -> dict[str, Any]:
    if SELECTED_ENSEMBLE_PATH.exists():
        return json.loads(SELECTED_ENSEMBLE_PATH.read_text(encoding="utf-8"))
    return {}


@lru_cache(maxsize=1)
def load_ensemble() -> tuple[list[LoadedMember], torch.device, dict[str, Any]]:
    all_present, _df, message = diagnose_checkpoints()
    if not all_present:
        raise FileNotFoundError(message)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loaded: list[LoadedMember] = []
    for m in ENSEMBLE_MEMBERS:
        checkpoint_path = MODELS_DIR / m["checkpoint_file"]
        model = build_model(m["model_name"], len(CLASS_NAMES))
        state_dict = clean_state_dict(_torch_load(checkpoint_path))
        model.load_state_dict(state_dict, strict=STRICT_CHECKPOINT_LOADING)
        model.eval().to(device)
        loaded.append(
            LoadedMember(
                member=m["member"],
                display_name=m["display_name"],
                model_name=m["model_name"],
                seed=int(m["seed"]),
                weight=float(m["weight"]),
                checkpoint_file=m["checkpoint_file"],
                checkpoint_path=checkpoint_path,
                model=model,
            )
        )
    return loaded, device, _load_selected_metadata()


def predict(image: Image.Image) -> PredictionResult:
    members, device, _metadata = load_ensemble()
    x_cpu = preprocess_image(image)
    x = x_cpu.to(device)

    ensemble_probs = None
    rows = []
    with torch.inference_mode():
        for m in members:
            logits = m.model(x)
            probs = F.softmax(logits, dim=1)
            weighted_probs = probs * m.weight
            ensemble_probs = weighted_probs if ensemble_probs is None else ensemble_probs + weighted_probs

            probs_np = probs.squeeze(0).detach().cpu().numpy()
            idx = int(np.argmax(probs_np))
            rows.append(
                {
                    "member": m.display_name,
                    "weight": round(m.weight, 8),
                    "member prediction": CLASS_DISPLAY_NAMES[CLASS_NAMES[idx]],
                    "member confidence": round(float(probs_np[idx]), 6),
                }
            )

    if ensemble_probs is None:
        raise RuntimeError("No ensemble members were loaded.")

    probs_np = ensemble_probs.squeeze(0).detach().cpu().numpy()
    # The weights are normalized from the optimization result, but normalize defensively.
    probs_np = probs_np / max(float(probs_np.sum()), 1e-12)
    top_idx = int(np.argmax(probs_np))
    predicted_class = CLASS_NAMES[top_idx]

    prob_rows = []
    for label, probability in zip(CLASS_NAMES, probs_np):
        prob_rows.append(
            {
                "class": CLASS_DISPLAY_NAMES[label],
                "probability": float(probability),
                "percent": f"{100.0 * float(probability):.2f}%",
            }
        )
    prob_df = pd.DataFrame(prob_rows).sort_values("probability", ascending=False).reset_index(drop=True)

    return PredictionResult(
        predicted_class=predicted_class,
        predicted_display=CLASS_DISPLAY_NAMES[predicted_class],
        confidence=float(probs_np[top_idx]),
        probabilities={label: float(prob) for label, prob in zip(CLASS_NAMES, probs_np)},
        probability_df=prob_df,
        member_df=pd.DataFrame(rows),
        ensemble_logits=torch.from_numpy(np.log(np.maximum(probs_np, 1e-12))).unsqueeze(0),
        input_tensor=x_cpu,
    )


def get_target_layer(model: nn.Module, model_name: str) -> nn.Module:
    # Last convolutional feature block for each deployed torchvision architecture.
    if model_name == "efficientnet_b0":
        return model.features[-1]
    if model_name == "mobilenet_v3_small":
        return model.features[-1]
    raise ValueError(f"No Grad-CAM layer configured for {model_name}")


def gradcam_for_member(member: LoadedMember, x_cpu: torch.Tensor, target_index: int, output_size: tuple[int, int]) -> np.ndarray:
    device = next(member.model.parameters()).device
    x = x_cpu.to(device)
    activations: list[torch.Tensor] = []
    gradients: list[torch.Tensor] = []

    target_layer = get_target_layer(member.model, member.model_name)

    def forward_hook(_module, _inputs, output):
        activations.append(output.detach())

    def backward_hook(_module, _grad_input, grad_output):
        gradients.append(grad_output[0].detach())

    handle_fwd = target_layer.register_forward_hook(forward_hook)
    handle_bwd = target_layer.register_full_backward_hook(backward_hook)
    try:
        member.model.zero_grad(set_to_none=True)
        logits = member.model(x)
        score = logits[0, target_index]
        score.backward()
    finally:
        handle_fwd.remove()
        handle_bwd.remove()

    if not activations or not gradients:
        raise RuntimeError(f"Could not collect gradients for {member.display_name}.")

    acts = activations[-1]
    grads = gradients[-1]
    weights = grads.mean(dim=(2, 3), keepdim=True)
    cam = torch.relu((weights * acts).sum(dim=1, keepdim=True))
    cam = F.interpolate(cam, size=output_size, mode="bilinear", align_corners=False)
    cam_np = cam.squeeze().detach().cpu().numpy()
    cam_np = cam_np - cam_np.min()
    denom = cam_np.max()
    if denom > 1e-8:
        cam_np = cam_np / denom
    return cam_np.astype(np.float32)


def weighted_ensemble_cam(image: Image.Image, target_class: str) -> Image.Image:
    members, _device, _metadata = load_ensemble()
    rgb = image.convert("RGB")
    x_cpu = preprocess_image(rgb)
    target_index = CLASS_NAMES.index(target_class)
    width, height = rgb.size

    combined = np.zeros((height, width), dtype=np.float32)
    total_weight = 0.0
    for member in members:
        try:
            cam = gradcam_for_member(member, x_cpu, target_index, output_size=(height, width))
            combined += cam * float(member.weight)
            total_weight += float(member.weight)
        except Exception:
            # Heatmap is interpretability assistance, not the core prediction. Keep
            # going if one hook fails; deployment prediction remains unaffected.
            continue

    if total_weight <= 0:
        raise RuntimeError("Could not generate Grad-CAM for any ensemble member.")
    combined = combined / total_weight
    combined = combined - combined.min()
    if combined.max() > 1e-8:
        combined = combined / combined.max()

    import matplotlib.cm as cm

    base = np.asarray(rgb).astype(np.float32) / 255.0
    heat = cm.get_cmap("magma")(combined)[..., :3].astype(np.float32)
    overlay = np.clip(0.58 * base + 0.42 * heat, 0, 1)
    return Image.fromarray((overlay * 255).astype(np.uint8))