LCVC-DeepFuse / src /modeling.py
vimdhayak's picture
Upload folder using huggingface_hub
f6ab35f verified
Raw
History Blame Contribute Delete
11.9 kB
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms
from .config import (
CLASS_DISPLAY_NAMES,
CLASS_NAMES,
ENSEMBLE_MEMBERS,
IMAGE_SIZE,
MODELS_DIR,
NORMALIZE_MEAN,
NORMALIZE_STD,
SELECTED_ENSEMBLE_PATH,
)
def _env_flag(name: str, default: bool = True) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() not in {"0", "false", "no", "off"}
STRICT_CHECKPOINT_LOADING = _env_flag("STRICT_CHECKPOINT_LOADING", True)
@dataclass
class LoadedMember:
member: str
display_name: str
model_name: str
seed: int
weight: float
checkpoint_file: str
checkpoint_path: Path
model: nn.Module
@dataclass
class PredictionResult:
predicted_class: str
predicted_display: str
confidence: float
probabilities: dict[str, float]
probability_df: pd.DataFrame
member_df: pd.DataFrame
ensemble_logits: torch.Tensor
input_tensor: torch.Tensor
_preprocess = transforms.Compose(
[
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD),
]
)
def preprocess_image(image: Image.Image) -> torch.Tensor:
if image is None:
raise ValueError("Please upload an MRI image first.")
return _preprocess(image.convert("RGB")).unsqueeze(0)
def build_model(model_name: str, num_classes: int = len(CLASS_NAMES)) -> nn.Module:
constructors = {
"efficientnet_b0": models.efficientnet_b0,
"mobilenet_v3_small": models.mobilenet_v3_small,
}
if model_name not in constructors:
raise ValueError(f"Unsupported deployment backbone: {model_name}")
# Do not request torchvision pretrained weights at Space startup. The fine-tuned
# checkpoint is expected to contain the trained weights.
model = constructors[model_name](weights=None)
if model_name in {"efficientnet_b0", "mobilenet_v3_small"}:
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)
else: # Defensive; guarded above.
raise ValueError(f"No classifier replacement rule for {model_name}")
return model
def _torch_load(path: Path) -> Any:
"""Load a PyTorch checkpoint across torch versions.
Newer PyTorch versions may support weights_only. We first try the safer path,
then fall back for older checkpoints that store a richer dictionary.
"""
try:
return torch.load(path, map_location="cpu", weights_only=True)
except TypeError:
return torch.load(path, map_location="cpu")
except Exception:
# Only use this fallback for your own trusted checkpoints.
return torch.load(path, map_location="cpu", weights_only=False)
def clean_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
if isinstance(checkpoint, nn.Module):
checkpoint = checkpoint.state_dict()
if isinstance(checkpoint, dict):
for key in ("model_state_dict", "state_dict", "model", "net", "weights"):
value = checkpoint.get(key)
if isinstance(value, dict):
checkpoint = value
break
if not isinstance(checkpoint, dict):
raise TypeError("Checkpoint does not contain a PyTorch state_dict-like object.")
cleaned: dict[str, torch.Tensor] = {}
for key, value in checkpoint.items():
if not torch.is_tensor(value):
continue
new_key = str(key)
for prefix in ("module.", "model."):
if new_key.startswith(prefix):
new_key = new_key[len(prefix) :]
cleaned[new_key] = value
if not cleaned:
raise ValueError("No tensor weights were found in the checkpoint.")
return cleaned
def expected_checkpoint_paths() -> dict[str, Path]:
return {m["checkpoint_file"]: MODELS_DIR / m["checkpoint_file"] for m in ENSEMBLE_MEMBERS}
def diagnose_checkpoints() -> tuple[bool, pd.DataFrame, str]:
rows = []
all_present = True
for m in ENSEMBLE_MEMBERS:
path = MODELS_DIR / m["checkpoint_file"]
exists = path.exists()
all_present = all_present and exists
rows.append(
{
"member": m["display_name"],
"weight": round(float(m["weight"]), 8),
"expected file": f"models/{m['checkpoint_file']}",
"status": "✅ found" if exists else "❌ missing",
}
)
df = pd.DataFrame(rows)
if all_present:
message = "✅ All required checkpoint files were found in `models/`."
else:
missing = [r["expected file"] for r in rows if r["status"].startswith("❌")]
message = "❌ Missing checkpoint file(s):\n" + "\n".join(f"- `{m}`" for m in missing)
return all_present, df, message
def _load_selected_metadata() -> dict[str, Any]:
if SELECTED_ENSEMBLE_PATH.exists():
return json.loads(SELECTED_ENSEMBLE_PATH.read_text(encoding="utf-8"))
return {}
@lru_cache(maxsize=1)
def load_ensemble() -> tuple[list[LoadedMember], torch.device, dict[str, Any]]:
all_present, _df, message = diagnose_checkpoints()
if not all_present:
raise FileNotFoundError(message)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded: list[LoadedMember] = []
for m in ENSEMBLE_MEMBERS:
checkpoint_path = MODELS_DIR / m["checkpoint_file"]
model = build_model(m["model_name"], len(CLASS_NAMES))
state_dict = clean_state_dict(_torch_load(checkpoint_path))
model.load_state_dict(state_dict, strict=STRICT_CHECKPOINT_LOADING)
model.eval().to(device)
loaded.append(
LoadedMember(
member=m["member"],
display_name=m["display_name"],
model_name=m["model_name"],
seed=int(m["seed"]),
weight=float(m["weight"]),
checkpoint_file=m["checkpoint_file"],
checkpoint_path=checkpoint_path,
model=model,
)
)
return loaded, device, _load_selected_metadata()
def predict(image: Image.Image) -> PredictionResult:
members, device, _metadata = load_ensemble()
x_cpu = preprocess_image(image)
x = x_cpu.to(device)
ensemble_probs = None
rows = []
with torch.inference_mode():
for m in members:
logits = m.model(x)
probs = F.softmax(logits, dim=1)
weighted_probs = probs * m.weight
ensemble_probs = weighted_probs if ensemble_probs is None else ensemble_probs + weighted_probs
probs_np = probs.squeeze(0).detach().cpu().numpy()
idx = int(np.argmax(probs_np))
rows.append(
{
"member": m.display_name,
"weight": round(m.weight, 8),
"member prediction": CLASS_DISPLAY_NAMES[CLASS_NAMES[idx]],
"member confidence": round(float(probs_np[idx]), 6),
}
)
if ensemble_probs is None:
raise RuntimeError("No ensemble members were loaded.")
probs_np = ensemble_probs.squeeze(0).detach().cpu().numpy()
# The weights are normalized from the optimization result, but normalize defensively.
probs_np = probs_np / max(float(probs_np.sum()), 1e-12)
top_idx = int(np.argmax(probs_np))
predicted_class = CLASS_NAMES[top_idx]
prob_rows = []
for label, probability in zip(CLASS_NAMES, probs_np):
prob_rows.append(
{
"class": CLASS_DISPLAY_NAMES[label],
"probability": float(probability),
"percent": f"{100.0 * float(probability):.2f}%",
}
)
prob_df = pd.DataFrame(prob_rows).sort_values("probability", ascending=False).reset_index(drop=True)
return PredictionResult(
predicted_class=predicted_class,
predicted_display=CLASS_DISPLAY_NAMES[predicted_class],
confidence=float(probs_np[top_idx]),
probabilities={label: float(prob) for label, prob in zip(CLASS_NAMES, probs_np)},
probability_df=prob_df,
member_df=pd.DataFrame(rows),
ensemble_logits=torch.from_numpy(np.log(np.maximum(probs_np, 1e-12))).unsqueeze(0),
input_tensor=x_cpu,
)
def get_target_layer(model: nn.Module, model_name: str) -> nn.Module:
# Last convolutional feature block for each deployed torchvision architecture.
if model_name == "efficientnet_b0":
return model.features[-1]
if model_name == "mobilenet_v3_small":
return model.features[-1]
raise ValueError(f"No Grad-CAM layer configured for {model_name}")
def gradcam_for_member(member: LoadedMember, x_cpu: torch.Tensor, target_index: int, output_size: tuple[int, int]) -> np.ndarray:
device = next(member.model.parameters()).device
x = x_cpu.to(device)
activations: list[torch.Tensor] = []
gradients: list[torch.Tensor] = []
target_layer = get_target_layer(member.model, member.model_name)
def forward_hook(_module, _inputs, output):
activations.append(output.detach())
def backward_hook(_module, _grad_input, grad_output):
gradients.append(grad_output[0].detach())
handle_fwd = target_layer.register_forward_hook(forward_hook)
handle_bwd = target_layer.register_full_backward_hook(backward_hook)
try:
member.model.zero_grad(set_to_none=True)
logits = member.model(x)
score = logits[0, target_index]
score.backward()
finally:
handle_fwd.remove()
handle_bwd.remove()
if not activations or not gradients:
raise RuntimeError(f"Could not collect gradients for {member.display_name}.")
acts = activations[-1]
grads = gradients[-1]
weights = grads.mean(dim=(2, 3), keepdim=True)
cam = torch.relu((weights * acts).sum(dim=1, keepdim=True))
cam = F.interpolate(cam, size=output_size, mode="bilinear", align_corners=False)
cam_np = cam.squeeze().detach().cpu().numpy()
cam_np = cam_np - cam_np.min()
denom = cam_np.max()
if denom > 1e-8:
cam_np = cam_np / denom
return cam_np.astype(np.float32)
def weighted_ensemble_cam(image: Image.Image, target_class: str) -> Image.Image:
members, _device, _metadata = load_ensemble()
rgb = image.convert("RGB")
x_cpu = preprocess_image(rgb)
target_index = CLASS_NAMES.index(target_class)
width, height = rgb.size
combined = np.zeros((height, width), dtype=np.float32)
total_weight = 0.0
for member in members:
try:
cam = gradcam_for_member(member, x_cpu, target_index, output_size=(height, width))
combined += cam * float(member.weight)
total_weight += float(member.weight)
except Exception:
# Heatmap is interpretability assistance, not the core prediction. Keep
# going if one hook fails; deployment prediction remains unaffected.
continue
if total_weight <= 0:
raise RuntimeError("Could not generate Grad-CAM for any ensemble member.")
combined = combined / total_weight
combined = combined - combined.min()
if combined.max() > 1e-8:
combined = combined / combined.max()
import matplotlib.cm as cm
base = np.asarray(rgb).astype(np.float32) / 255.0
heat = cm.get_cmap("magma")(combined)[..., :3].astype(np.float32)
overlay = np.clip(0.58 * base + 0.42 * heat, 0, 1)
return Image.fromarray((overlay * 255).astype(np.uint8))