Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| from .config import ( | |
| CLASS_DISPLAY_NAMES, | |
| CLASS_NAMES, | |
| ENSEMBLE_MEMBERS, | |
| IMAGE_SIZE, | |
| MODELS_DIR, | |
| NORMALIZE_MEAN, | |
| NORMALIZE_STD, | |
| SELECTED_ENSEMBLE_PATH, | |
| ) | |
| def _env_flag(name: str, default: bool = True) -> bool: | |
| raw = os.getenv(name) | |
| if raw is None: | |
| return default | |
| return raw.strip().lower() not in {"0", "false", "no", "off"} | |
| STRICT_CHECKPOINT_LOADING = _env_flag("STRICT_CHECKPOINT_LOADING", True) | |
| class LoadedMember: | |
| member: str | |
| display_name: str | |
| model_name: str | |
| seed: int | |
| weight: float | |
| checkpoint_file: str | |
| checkpoint_path: Path | |
| model: nn.Module | |
| class PredictionResult: | |
| predicted_class: str | |
| predicted_display: str | |
| confidence: float | |
| probabilities: dict[str, float] | |
| probability_df: pd.DataFrame | |
| member_df: pd.DataFrame | |
| ensemble_logits: torch.Tensor | |
| input_tensor: torch.Tensor | |
| _preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD), | |
| ] | |
| ) | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| if image is None: | |
| raise ValueError("Please upload an MRI image first.") | |
| return _preprocess(image.convert("RGB")).unsqueeze(0) | |
| def build_model(model_name: str, num_classes: int = len(CLASS_NAMES)) -> nn.Module: | |
| constructors = { | |
| "efficientnet_b0": models.efficientnet_b0, | |
| "mobilenet_v3_small": models.mobilenet_v3_small, | |
| } | |
| if model_name not in constructors: | |
| raise ValueError(f"Unsupported deployment backbone: {model_name}") | |
| # Do not request torchvision pretrained weights at Space startup. The fine-tuned | |
| # checkpoint is expected to contain the trained weights. | |
| model = constructors[model_name](weights=None) | |
| if model_name in {"efficientnet_b0", "mobilenet_v3_small"}: | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, num_classes) | |
| else: # Defensive; guarded above. | |
| raise ValueError(f"No classifier replacement rule for {model_name}") | |
| return model | |
| def _torch_load(path: Path) -> Any: | |
| """Load a PyTorch checkpoint across torch versions. | |
| Newer PyTorch versions may support weights_only. We first try the safer path, | |
| then fall back for older checkpoints that store a richer dictionary. | |
| """ | |
| try: | |
| return torch.load(path, map_location="cpu", weights_only=True) | |
| except TypeError: | |
| return torch.load(path, map_location="cpu") | |
| except Exception: | |
| # Only use this fallback for your own trusted checkpoints. | |
| return torch.load(path, map_location="cpu", weights_only=False) | |
| def clean_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]: | |
| if isinstance(checkpoint, nn.Module): | |
| checkpoint = checkpoint.state_dict() | |
| if isinstance(checkpoint, dict): | |
| for key in ("model_state_dict", "state_dict", "model", "net", "weights"): | |
| value = checkpoint.get(key) | |
| if isinstance(value, dict): | |
| checkpoint = value | |
| break | |
| if not isinstance(checkpoint, dict): | |
| raise TypeError("Checkpoint does not contain a PyTorch state_dict-like object.") | |
| cleaned: dict[str, torch.Tensor] = {} | |
| for key, value in checkpoint.items(): | |
| if not torch.is_tensor(value): | |
| continue | |
| new_key = str(key) | |
| for prefix in ("module.", "model."): | |
| if new_key.startswith(prefix): | |
| new_key = new_key[len(prefix) :] | |
| cleaned[new_key] = value | |
| if not cleaned: | |
| raise ValueError("No tensor weights were found in the checkpoint.") | |
| return cleaned | |
| def expected_checkpoint_paths() -> dict[str, Path]: | |
| return {m["checkpoint_file"]: MODELS_DIR / m["checkpoint_file"] for m in ENSEMBLE_MEMBERS} | |
| def diagnose_checkpoints() -> tuple[bool, pd.DataFrame, str]: | |
| rows = [] | |
| all_present = True | |
| for m in ENSEMBLE_MEMBERS: | |
| path = MODELS_DIR / m["checkpoint_file"] | |
| exists = path.exists() | |
| all_present = all_present and exists | |
| rows.append( | |
| { | |
| "member": m["display_name"], | |
| "weight": round(float(m["weight"]), 8), | |
| "expected file": f"models/{m['checkpoint_file']}", | |
| "status": "✅ found" if exists else "❌ missing", | |
| } | |
| ) | |
| df = pd.DataFrame(rows) | |
| if all_present: | |
| message = "✅ All required checkpoint files were found in `models/`." | |
| else: | |
| missing = [r["expected file"] for r in rows if r["status"].startswith("❌")] | |
| message = "❌ Missing checkpoint file(s):\n" + "\n".join(f"- `{m}`" for m in missing) | |
| return all_present, df, message | |
| def _load_selected_metadata() -> dict[str, Any]: | |
| if SELECTED_ENSEMBLE_PATH.exists(): | |
| return json.loads(SELECTED_ENSEMBLE_PATH.read_text(encoding="utf-8")) | |
| return {} | |
| def load_ensemble() -> tuple[list[LoadedMember], torch.device, dict[str, Any]]: | |
| all_present, _df, message = diagnose_checkpoints() | |
| if not all_present: | |
| raise FileNotFoundError(message) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| loaded: list[LoadedMember] = [] | |
| for m in ENSEMBLE_MEMBERS: | |
| checkpoint_path = MODELS_DIR / m["checkpoint_file"] | |
| model = build_model(m["model_name"], len(CLASS_NAMES)) | |
| state_dict = clean_state_dict(_torch_load(checkpoint_path)) | |
| model.load_state_dict(state_dict, strict=STRICT_CHECKPOINT_LOADING) | |
| model.eval().to(device) | |
| loaded.append( | |
| LoadedMember( | |
| member=m["member"], | |
| display_name=m["display_name"], | |
| model_name=m["model_name"], | |
| seed=int(m["seed"]), | |
| weight=float(m["weight"]), | |
| checkpoint_file=m["checkpoint_file"], | |
| checkpoint_path=checkpoint_path, | |
| model=model, | |
| ) | |
| ) | |
| return loaded, device, _load_selected_metadata() | |
| def predict(image: Image.Image) -> PredictionResult: | |
| members, device, _metadata = load_ensemble() | |
| x_cpu = preprocess_image(image) | |
| x = x_cpu.to(device) | |
| ensemble_probs = None | |
| rows = [] | |
| with torch.inference_mode(): | |
| for m in members: | |
| logits = m.model(x) | |
| probs = F.softmax(logits, dim=1) | |
| weighted_probs = probs * m.weight | |
| ensemble_probs = weighted_probs if ensemble_probs is None else ensemble_probs + weighted_probs | |
| probs_np = probs.squeeze(0).detach().cpu().numpy() | |
| idx = int(np.argmax(probs_np)) | |
| rows.append( | |
| { | |
| "member": m.display_name, | |
| "weight": round(m.weight, 8), | |
| "member prediction": CLASS_DISPLAY_NAMES[CLASS_NAMES[idx]], | |
| "member confidence": round(float(probs_np[idx]), 6), | |
| } | |
| ) | |
| if ensemble_probs is None: | |
| raise RuntimeError("No ensemble members were loaded.") | |
| probs_np = ensemble_probs.squeeze(0).detach().cpu().numpy() | |
| # The weights are normalized from the optimization result, but normalize defensively. | |
| probs_np = probs_np / max(float(probs_np.sum()), 1e-12) | |
| top_idx = int(np.argmax(probs_np)) | |
| predicted_class = CLASS_NAMES[top_idx] | |
| prob_rows = [] | |
| for label, probability in zip(CLASS_NAMES, probs_np): | |
| prob_rows.append( | |
| { | |
| "class": CLASS_DISPLAY_NAMES[label], | |
| "probability": float(probability), | |
| "percent": f"{100.0 * float(probability):.2f}%", | |
| } | |
| ) | |
| prob_df = pd.DataFrame(prob_rows).sort_values("probability", ascending=False).reset_index(drop=True) | |
| return PredictionResult( | |
| predicted_class=predicted_class, | |
| predicted_display=CLASS_DISPLAY_NAMES[predicted_class], | |
| confidence=float(probs_np[top_idx]), | |
| probabilities={label: float(prob) for label, prob in zip(CLASS_NAMES, probs_np)}, | |
| probability_df=prob_df, | |
| member_df=pd.DataFrame(rows), | |
| ensemble_logits=torch.from_numpy(np.log(np.maximum(probs_np, 1e-12))).unsqueeze(0), | |
| input_tensor=x_cpu, | |
| ) | |
| def get_target_layer(model: nn.Module, model_name: str) -> nn.Module: | |
| # Last convolutional feature block for each deployed torchvision architecture. | |
| if model_name == "efficientnet_b0": | |
| return model.features[-1] | |
| if model_name == "mobilenet_v3_small": | |
| return model.features[-1] | |
| raise ValueError(f"No Grad-CAM layer configured for {model_name}") | |
| def gradcam_for_member(member: LoadedMember, x_cpu: torch.Tensor, target_index: int, output_size: tuple[int, int]) -> np.ndarray: | |
| device = next(member.model.parameters()).device | |
| x = x_cpu.to(device) | |
| activations: list[torch.Tensor] = [] | |
| gradients: list[torch.Tensor] = [] | |
| target_layer = get_target_layer(member.model, member.model_name) | |
| def forward_hook(_module, _inputs, output): | |
| activations.append(output.detach()) | |
| def backward_hook(_module, _grad_input, grad_output): | |
| gradients.append(grad_output[0].detach()) | |
| handle_fwd = target_layer.register_forward_hook(forward_hook) | |
| handle_bwd = target_layer.register_full_backward_hook(backward_hook) | |
| try: | |
| member.model.zero_grad(set_to_none=True) | |
| logits = member.model(x) | |
| score = logits[0, target_index] | |
| score.backward() | |
| finally: | |
| handle_fwd.remove() | |
| handle_bwd.remove() | |
| if not activations or not gradients: | |
| raise RuntimeError(f"Could not collect gradients for {member.display_name}.") | |
| acts = activations[-1] | |
| grads = gradients[-1] | |
| weights = grads.mean(dim=(2, 3), keepdim=True) | |
| cam = torch.relu((weights * acts).sum(dim=1, keepdim=True)) | |
| cam = F.interpolate(cam, size=output_size, mode="bilinear", align_corners=False) | |
| cam_np = cam.squeeze().detach().cpu().numpy() | |
| cam_np = cam_np - cam_np.min() | |
| denom = cam_np.max() | |
| if denom > 1e-8: | |
| cam_np = cam_np / denom | |
| return cam_np.astype(np.float32) | |
| def weighted_ensemble_cam(image: Image.Image, target_class: str) -> Image.Image: | |
| members, _device, _metadata = load_ensemble() | |
| rgb = image.convert("RGB") | |
| x_cpu = preprocess_image(rgb) | |
| target_index = CLASS_NAMES.index(target_class) | |
| width, height = rgb.size | |
| combined = np.zeros((height, width), dtype=np.float32) | |
| total_weight = 0.0 | |
| for member in members: | |
| try: | |
| cam = gradcam_for_member(member, x_cpu, target_index, output_size=(height, width)) | |
| combined += cam * float(member.weight) | |
| total_weight += float(member.weight) | |
| except Exception: | |
| # Heatmap is interpretability assistance, not the core prediction. Keep | |
| # going if one hook fails; deployment prediction remains unaffected. | |
| continue | |
| if total_weight <= 0: | |
| raise RuntimeError("Could not generate Grad-CAM for any ensemble member.") | |
| combined = combined / total_weight | |
| combined = combined - combined.min() | |
| if combined.max() > 1e-8: | |
| combined = combined / combined.max() | |
| import matplotlib.cm as cm | |
| base = np.asarray(rgb).astype(np.float32) / 255.0 | |
| heat = cm.get_cmap("magma")(combined)[..., :3].astype(np.float32) | |
| overlay = np.clip(0.58 * base + 0.42 * heat, 0, 1) | |
| return Image.fromarray((overlay * 255).astype(np.uint8)) | |