LCVC-Ensemble / app.py
vimdhayak's picture
Upload 7 files
292b6c2 verified
Raw
History Blame Contribute Delete
5.92 kB
import json
from pathlib import Path
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms
ROOT = Path(__file__).resolve().parent
CONFIG_PATH = ROOT / "ensemble_config.json"
with open(CONFIG_PATH, "r") as f:
CFG = json.load(f)
CLASS_NAMES = CFG["classes"]
NUM_CLASSES = int(CFG["num_classes"])
IMAGE_SIZE = int(CFG["image_size"])
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def replace_classifier(model, model_name, num_classes):
if model_name == "vgg16_bn":
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)
elif model_name == "densenet121":
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, num_classes)
elif model_name == "efficientnet_b0":
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)
elif model_name == "mobilenet_v3_small":
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)
elif model_name == "convnext_tiny":
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)
else:
raise ValueError(f"Unknown model name: {model_name}")
return model
def build_model(model_name, num_classes):
if model_name == "vgg16_bn":
model = models.vgg16_bn(weights=None)
elif model_name == "densenet121":
model = models.densenet121(weights=None)
elif model_name == "efficientnet_b0":
model = models.efficientnet_b0(weights=None)
elif model_name == "mobilenet_v3_small":
model = models.mobilenet_v3_small(weights=None)
elif model_name == "convnext_tiny":
model = models.convnext_tiny(weights=None)
else:
raise ValueError(f"Unknown model name: {model_name}")
return replace_classifier(model, model_name, num_classes)
def load_state_dict_safely(path):
try:
ckpt = torch.load(path, map_location="cpu", weights_only=True)
except TypeError:
ckpt = torch.load(path, map_location="cpu")
if isinstance(ckpt, dict):
for key in ["model_state_dict", "state_dict", "model"]:
if key in ckpt and isinstance(ckpt[key], dict):
ckpt = ckpt[key]
break
cleaned = {}
for k, v in ckpt.items():
nk = k[7:] if str(k).startswith("module.") else k
cleaned[nk] = v
return cleaned
def load_ensemble():
loaded = []
for member in CFG["members"]:
model_name = member["model"]
ckpt_path = ROOT / member["checkpoint_file"]
model = build_model(model_name, NUM_CLASSES)
state = load_state_dict_safely(ckpt_path)
model.load_state_dict(state, strict=True)
model.to(DEVICE)
model.eval()
loaded.append({
"model": model,
"display_name": member.get("display_name", model_name),
"seed": member.get("seed"),
"weight": float(member["weight"]),
"temperature": max(float(member["temperature"]), 1e-8),
})
weight_sum = sum(m["weight"] for m in loaded)
if weight_sum <= 0:
for m in loaded:
m["weight"] = 1.0 / len(loaded)
else:
for m in loaded:
m["weight"] /= weight_sum
return loaded
ENSEMBLE = load_ensemble()
PREPROCESS = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=CFG["preprocessing"]["normalization_mean"],
std=CFG["preprocessing"]["normalization_std"],
),
])
@torch.no_grad()
def predict(image):
if image is None:
return None, "Please upload an MRI image."
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
x = PREPROCESS(image).unsqueeze(0).to(DEVICE)
final_probs = torch.zeros((1, NUM_CLASSES), dtype=torch.float32, device=DEVICE)
member_lines = []
for member in ENSEMBLE:
logits = member["model"](x)
probs = F.softmax(logits / member["temperature"], dim=1)
final_probs += member["weight"] * probs
top_prob, top_idx = torch.max(probs, dim=1)
member_lines.append(
f"{member['display_name']} seed {member['seed']}: "
f"{CLASS_NAMES[int(top_idx.item())]} ({float(top_prob.item()):.4f})"
)
final_probs_np = final_probs.squeeze(0).detach().cpu().numpy()
pred_idx = int(np.argmax(final_probs_np))
pred_class = CLASS_NAMES[pred_idx]
pred_conf = float(final_probs_np[pred_idx])
label_scores = {
CLASS_NAMES[i]: float(final_probs_np[i])
for i in range(NUM_CLASSES)
}
details = (
f"Predicted class: {pred_class}\n"
f"Calibrated ensemble confidence: {pred_conf:.4f}\n\n"
"Member predictions:\n"
+ "\n".join(member_lines)
+ "\n\nDisclaimer: This tool is for research and educational demonstration only. "
"It is not a medical device and must not be used for diagnosis or treatment."
)
return label_scores, details
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload MRI image"),
outputs=[
gr.Label(num_top_classes=NUM_CLASSES, label="Calibrated ensemble probabilities"),
gr.Textbox(label="Prediction details"),
],
title="LCVC-Ensemble Brain Tumor MRI Classifier",
description=(
"Leakage-Controlled Validation-Calibrated Multi-Backbone Ensemble. "
"Research demonstration only; not for clinical use."
),
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch()