import json from pathlib import Path import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import models, transforms ROOT = Path(__file__).resolve().parent CONFIG_PATH = ROOT / "ensemble_config.json" with open(CONFIG_PATH, "r") as f: CFG = json.load(f) CLASS_NAMES = CFG["classes"] NUM_CLASSES = int(CFG["num_classes"]) IMAGE_SIZE = int(CFG["image_size"]) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def replace_classifier(model, model_name, num_classes): if model_name == "vgg16_bn": in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, num_classes) elif model_name == "densenet121": in_features = model.classifier.in_features model.classifier = nn.Linear(in_features, num_classes) elif model_name == "efficientnet_b0": in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, num_classes) elif model_name == "mobilenet_v3_small": in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, num_classes) elif model_name == "convnext_tiny": in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, num_classes) else: raise ValueError(f"Unknown model name: {model_name}") return model def build_model(model_name, num_classes): if model_name == "vgg16_bn": model = models.vgg16_bn(weights=None) elif model_name == "densenet121": model = models.densenet121(weights=None) elif model_name == "efficientnet_b0": model = models.efficientnet_b0(weights=None) elif model_name == "mobilenet_v3_small": model = models.mobilenet_v3_small(weights=None) elif model_name == "convnext_tiny": model = models.convnext_tiny(weights=None) else: raise ValueError(f"Unknown model name: {model_name}") return replace_classifier(model, model_name, num_classes) def load_state_dict_safely(path): try: ckpt = torch.load(path, map_location="cpu", weights_only=True) except TypeError: ckpt = torch.load(path, map_location="cpu") if isinstance(ckpt, dict): for key in ["model_state_dict", "state_dict", "model"]: if key in ckpt and isinstance(ckpt[key], dict): ckpt = ckpt[key] break cleaned = {} for k, v in ckpt.items(): nk = k[7:] if str(k).startswith("module.") else k cleaned[nk] = v return cleaned def load_ensemble(): loaded = [] for member in CFG["members"]: model_name = member["model"] ckpt_path = ROOT / member["checkpoint_file"] model = build_model(model_name, NUM_CLASSES) state = load_state_dict_safely(ckpt_path) model.load_state_dict(state, strict=True) model.to(DEVICE) model.eval() loaded.append({ "model": model, "display_name": member.get("display_name", model_name), "seed": member.get("seed"), "weight": float(member["weight"]), "temperature": max(float(member["temperature"]), 1e-8), }) weight_sum = sum(m["weight"] for m in loaded) if weight_sum <= 0: for m in loaded: m["weight"] = 1.0 / len(loaded) else: for m in loaded: m["weight"] /= weight_sum return loaded ENSEMBLE = load_ensemble() PREPROCESS = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize( mean=CFG["preprocessing"]["normalization_mean"], std=CFG["preprocessing"]["normalization_std"], ), ]) @torch.no_grad() def predict(image): if image is None: return None, "Please upload an MRI image." if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") x = PREPROCESS(image).unsqueeze(0).to(DEVICE) final_probs = torch.zeros((1, NUM_CLASSES), dtype=torch.float32, device=DEVICE) member_lines = [] for member in ENSEMBLE: logits = member["model"](x) probs = F.softmax(logits / member["temperature"], dim=1) final_probs += member["weight"] * probs top_prob, top_idx = torch.max(probs, dim=1) member_lines.append( f"{member['display_name']} seed {member['seed']}: " f"{CLASS_NAMES[int(top_idx.item())]} ({float(top_prob.item()):.4f})" ) final_probs_np = final_probs.squeeze(0).detach().cpu().numpy() pred_idx = int(np.argmax(final_probs_np)) pred_class = CLASS_NAMES[pred_idx] pred_conf = float(final_probs_np[pred_idx]) label_scores = { CLASS_NAMES[i]: float(final_probs_np[i]) for i in range(NUM_CLASSES) } details = ( f"Predicted class: {pred_class}\n" f"Calibrated ensemble confidence: {pred_conf:.4f}\n\n" "Member predictions:\n" + "\n".join(member_lines) + "\n\nDisclaimer: This tool is for research and educational demonstration only. " "It is not a medical device and must not be used for diagnosis or treatment." ) return label_scores, details demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload MRI image"), outputs=[ gr.Label(num_top_classes=NUM_CLASSES, label="Calibrated ensemble probabilities"), gr.Textbox(label="Prediction details"), ], title="LCVC-Ensemble Brain Tumor MRI Classifier", description=( "Leakage-Controlled Validation-Calibrated Multi-Backbone Ensemble. " "Research demonstration only; not for clinical use." ), flagging_mode="never", ) if __name__ == "__main__": demo.launch()