from __future__ import annotations import json import os from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import models, transforms from .config import ( CLASS_DISPLAY_NAMES, CLASS_NAMES, ENSEMBLE_MEMBERS, IMAGE_SIZE, MODELS_DIR, NORMALIZE_MEAN, NORMALIZE_STD, SELECTED_ENSEMBLE_PATH, ) def _env_flag(name: str, default: bool = True) -> bool: raw = os.getenv(name) if raw is None: return default return raw.strip().lower() not in {"0", "false", "no", "off"} STRICT_CHECKPOINT_LOADING = _env_flag("STRICT_CHECKPOINT_LOADING", True) @dataclass class LoadedMember: member: str display_name: str model_name: str seed: int weight: float checkpoint_file: str checkpoint_path: Path model: nn.Module @dataclass class PredictionResult: predicted_class: str predicted_display: str confidence: float probabilities: dict[str, float] probability_df: pd.DataFrame member_df: pd.DataFrame ensemble_logits: torch.Tensor input_tensor: torch.Tensor _preprocess = transforms.Compose( [ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD), ] ) def preprocess_image(image: Image.Image) -> torch.Tensor: if image is None: raise ValueError("Please upload an MRI image first.") return _preprocess(image.convert("RGB")).unsqueeze(0) def build_model(model_name: str, num_classes: int = len(CLASS_NAMES)) -> nn.Module: constructors = { "efficientnet_b0": models.efficientnet_b0, "mobilenet_v3_small": models.mobilenet_v3_small, } if model_name not in constructors: raise ValueError(f"Unsupported deployment backbone: {model_name}") # Do not request torchvision pretrained weights at Space startup. The fine-tuned # checkpoint is expected to contain the trained weights. model = constructors[model_name](weights=None) if model_name in {"efficientnet_b0", "mobilenet_v3_small"}: in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, num_classes) else: # Defensive; guarded above. raise ValueError(f"No classifier replacement rule for {model_name}") return model def _torch_load(path: Path) -> Any: """Load a PyTorch checkpoint across torch versions. Newer PyTorch versions may support weights_only. We first try the safer path, then fall back for older checkpoints that store a richer dictionary. """ try: return torch.load(path, map_location="cpu", weights_only=True) except TypeError: return torch.load(path, map_location="cpu") except Exception: # Only use this fallback for your own trusted checkpoints. return torch.load(path, map_location="cpu", weights_only=False) def clean_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]: if isinstance(checkpoint, nn.Module): checkpoint = checkpoint.state_dict() if isinstance(checkpoint, dict): for key in ("model_state_dict", "state_dict", "model", "net", "weights"): value = checkpoint.get(key) if isinstance(value, dict): checkpoint = value break if not isinstance(checkpoint, dict): raise TypeError("Checkpoint does not contain a PyTorch state_dict-like object.") cleaned: dict[str, torch.Tensor] = {} for key, value in checkpoint.items(): if not torch.is_tensor(value): continue new_key = str(key) for prefix in ("module.", "model."): if new_key.startswith(prefix): new_key = new_key[len(prefix) :] cleaned[new_key] = value if not cleaned: raise ValueError("No tensor weights were found in the checkpoint.") return cleaned def expected_checkpoint_paths() -> dict[str, Path]: return {m["checkpoint_file"]: MODELS_DIR / m["checkpoint_file"] for m in ENSEMBLE_MEMBERS} def diagnose_checkpoints() -> tuple[bool, pd.DataFrame, str]: rows = [] all_present = True for m in ENSEMBLE_MEMBERS: path = MODELS_DIR / m["checkpoint_file"] exists = path.exists() all_present = all_present and exists rows.append( { "member": m["display_name"], "weight": round(float(m["weight"]), 8), "expected file": f"models/{m['checkpoint_file']}", "status": "✅ found" if exists else "❌ missing", } ) df = pd.DataFrame(rows) if all_present: message = "✅ All required checkpoint files were found in `models/`." else: missing = [r["expected file"] for r in rows if r["status"].startswith("❌")] message = "❌ Missing checkpoint file(s):\n" + "\n".join(f"- `{m}`" for m in missing) return all_present, df, message def _load_selected_metadata() -> dict[str, Any]: if SELECTED_ENSEMBLE_PATH.exists(): return json.loads(SELECTED_ENSEMBLE_PATH.read_text(encoding="utf-8")) return {} @lru_cache(maxsize=1) def load_ensemble() -> tuple[list[LoadedMember], torch.device, dict[str, Any]]: all_present, _df, message = diagnose_checkpoints() if not all_present: raise FileNotFoundError(message) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loaded: list[LoadedMember] = [] for m in ENSEMBLE_MEMBERS: checkpoint_path = MODELS_DIR / m["checkpoint_file"] model = build_model(m["model_name"], len(CLASS_NAMES)) state_dict = clean_state_dict(_torch_load(checkpoint_path)) model.load_state_dict(state_dict, strict=STRICT_CHECKPOINT_LOADING) model.eval().to(device) loaded.append( LoadedMember( member=m["member"], display_name=m["display_name"], model_name=m["model_name"], seed=int(m["seed"]), weight=float(m["weight"]), checkpoint_file=m["checkpoint_file"], checkpoint_path=checkpoint_path, model=model, ) ) return loaded, device, _load_selected_metadata() def predict(image: Image.Image) -> PredictionResult: members, device, _metadata = load_ensemble() x_cpu = preprocess_image(image) x = x_cpu.to(device) ensemble_probs = None rows = [] with torch.inference_mode(): for m in members: logits = m.model(x) probs = F.softmax(logits, dim=1) weighted_probs = probs * m.weight ensemble_probs = weighted_probs if ensemble_probs is None else ensemble_probs + weighted_probs probs_np = probs.squeeze(0).detach().cpu().numpy() idx = int(np.argmax(probs_np)) rows.append( { "member": m.display_name, "weight": round(m.weight, 8), "member prediction": CLASS_DISPLAY_NAMES[CLASS_NAMES[idx]], "member confidence": round(float(probs_np[idx]), 6), } ) if ensemble_probs is None: raise RuntimeError("No ensemble members were loaded.") probs_np = ensemble_probs.squeeze(0).detach().cpu().numpy() # The weights are normalized from the optimization result, but normalize defensively. probs_np = probs_np / max(float(probs_np.sum()), 1e-12) top_idx = int(np.argmax(probs_np)) predicted_class = CLASS_NAMES[top_idx] prob_rows = [] for label, probability in zip(CLASS_NAMES, probs_np): prob_rows.append( { "class": CLASS_DISPLAY_NAMES[label], "probability": float(probability), "percent": f"{100.0 * float(probability):.2f}%", } ) prob_df = pd.DataFrame(prob_rows).sort_values("probability", ascending=False).reset_index(drop=True) return PredictionResult( predicted_class=predicted_class, predicted_display=CLASS_DISPLAY_NAMES[predicted_class], confidence=float(probs_np[top_idx]), probabilities={label: float(prob) for label, prob in zip(CLASS_NAMES, probs_np)}, probability_df=prob_df, member_df=pd.DataFrame(rows), ensemble_logits=torch.from_numpy(np.log(np.maximum(probs_np, 1e-12))).unsqueeze(0), input_tensor=x_cpu, ) def get_target_layer(model: nn.Module, model_name: str) -> nn.Module: # Last convolutional feature block for each deployed torchvision architecture. if model_name == "efficientnet_b0": return model.features[-1] if model_name == "mobilenet_v3_small": return model.features[-1] raise ValueError(f"No Grad-CAM layer configured for {model_name}") def gradcam_for_member(member: LoadedMember, x_cpu: torch.Tensor, target_index: int, output_size: tuple[int, int]) -> np.ndarray: device = next(member.model.parameters()).device x = x_cpu.to(device) activations: list[torch.Tensor] = [] gradients: list[torch.Tensor] = [] target_layer = get_target_layer(member.model, member.model_name) def forward_hook(_module, _inputs, output): activations.append(output.detach()) def backward_hook(_module, _grad_input, grad_output): gradients.append(grad_output[0].detach()) handle_fwd = target_layer.register_forward_hook(forward_hook) handle_bwd = target_layer.register_full_backward_hook(backward_hook) try: member.model.zero_grad(set_to_none=True) logits = member.model(x) score = logits[0, target_index] score.backward() finally: handle_fwd.remove() handle_bwd.remove() if not activations or not gradients: raise RuntimeError(f"Could not collect gradients for {member.display_name}.") acts = activations[-1] grads = gradients[-1] weights = grads.mean(dim=(2, 3), keepdim=True) cam = torch.relu((weights * acts).sum(dim=1, keepdim=True)) cam = F.interpolate(cam, size=output_size, mode="bilinear", align_corners=False) cam_np = cam.squeeze().detach().cpu().numpy() cam_np = cam_np - cam_np.min() denom = cam_np.max() if denom > 1e-8: cam_np = cam_np / denom return cam_np.astype(np.float32) def weighted_ensemble_cam(image: Image.Image, target_class: str) -> Image.Image: members, _device, _metadata = load_ensemble() rgb = image.convert("RGB") x_cpu = preprocess_image(rgb) target_index = CLASS_NAMES.index(target_class) width, height = rgb.size combined = np.zeros((height, width), dtype=np.float32) total_weight = 0.0 for member in members: try: cam = gradcam_for_member(member, x_cpu, target_index, output_size=(height, width)) combined += cam * float(member.weight) total_weight += float(member.weight) except Exception: # Heatmap is interpretability assistance, not the core prediction. Keep # going if one hook fails; deployment prediction remains unaffected. continue if total_weight <= 0: raise RuntimeError("Could not generate Grad-CAM for any ensemble member.") combined = combined / total_weight combined = combined - combined.min() if combined.max() > 1e-8: combined = combined / combined.max() import matplotlib.cm as cm base = np.asarray(rgb).astype(np.float32) / 255.0 heat = cm.get_cmap("magma")(combined)[..., :3].astype(np.float32) overlay = np.clip(0.58 * base + 0.42 * heat, 0, 1) return Image.fromarray((overlay * 255).astype(np.uint8))