Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| ROOT = Path(__file__).resolve().parent | |
| CONFIG_PATH = ROOT / "ensemble_config.json" | |
| with open(CONFIG_PATH, "r") as f: | |
| CFG = json.load(f) | |
| CLASS_NAMES = CFG["classes"] | |
| NUM_CLASSES = int(CFG["num_classes"]) | |
| IMAGE_SIZE = int(CFG["image_size"]) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def replace_classifier(model, model_name, num_classes): | |
| if model_name == "vgg16_bn": | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, num_classes) | |
| elif model_name == "densenet121": | |
| in_features = model.classifier.in_features | |
| model.classifier = nn.Linear(in_features, num_classes) | |
| elif model_name == "efficientnet_b0": | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, num_classes) | |
| elif model_name == "mobilenet_v3_small": | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, num_classes) | |
| elif model_name == "convnext_tiny": | |
| in_features = model.classifier[-1].in_features | |
| model.classifier[-1] = nn.Linear(in_features, num_classes) | |
| else: | |
| raise ValueError(f"Unknown model name: {model_name}") | |
| return model | |
| def build_model(model_name, num_classes): | |
| if model_name == "vgg16_bn": | |
| model = models.vgg16_bn(weights=None) | |
| elif model_name == "densenet121": | |
| model = models.densenet121(weights=None) | |
| elif model_name == "efficientnet_b0": | |
| model = models.efficientnet_b0(weights=None) | |
| elif model_name == "mobilenet_v3_small": | |
| model = models.mobilenet_v3_small(weights=None) | |
| elif model_name == "convnext_tiny": | |
| model = models.convnext_tiny(weights=None) | |
| else: | |
| raise ValueError(f"Unknown model name: {model_name}") | |
| return replace_classifier(model, model_name, num_classes) | |
| def load_state_dict_safely(path): | |
| try: | |
| ckpt = torch.load(path, map_location="cpu", weights_only=True) | |
| except TypeError: | |
| ckpt = torch.load(path, map_location="cpu") | |
| if isinstance(ckpt, dict): | |
| for key in ["model_state_dict", "state_dict", "model"]: | |
| if key in ckpt and isinstance(ckpt[key], dict): | |
| ckpt = ckpt[key] | |
| break | |
| cleaned = {} | |
| for k, v in ckpt.items(): | |
| nk = k[7:] if str(k).startswith("module.") else k | |
| cleaned[nk] = v | |
| return cleaned | |
| def load_ensemble(): | |
| loaded = [] | |
| for member in CFG["members"]: | |
| model_name = member["model"] | |
| ckpt_path = ROOT / member["checkpoint_file"] | |
| model = build_model(model_name, NUM_CLASSES) | |
| state = load_state_dict_safely(ckpt_path) | |
| model.load_state_dict(state, strict=True) | |
| model.to(DEVICE) | |
| model.eval() | |
| loaded.append({ | |
| "model": model, | |
| "display_name": member.get("display_name", model_name), | |
| "seed": member.get("seed"), | |
| "weight": float(member["weight"]), | |
| "temperature": max(float(member["temperature"]), 1e-8), | |
| }) | |
| weight_sum = sum(m["weight"] for m in loaded) | |
| if weight_sum <= 0: | |
| for m in loaded: | |
| m["weight"] = 1.0 / len(loaded) | |
| else: | |
| for m in loaded: | |
| m["weight"] /= weight_sum | |
| return loaded | |
| ENSEMBLE = load_ensemble() | |
| PREPROCESS = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=CFG["preprocessing"]["normalization_mean"], | |
| std=CFG["preprocessing"]["normalization_std"], | |
| ), | |
| ]) | |
| def predict(image): | |
| if image is None: | |
| return None, "Please upload an MRI image." | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGB") | |
| x = PREPROCESS(image).unsqueeze(0).to(DEVICE) | |
| final_probs = torch.zeros((1, NUM_CLASSES), dtype=torch.float32, device=DEVICE) | |
| member_lines = [] | |
| for member in ENSEMBLE: | |
| logits = member["model"](x) | |
| probs = F.softmax(logits / member["temperature"], dim=1) | |
| final_probs += member["weight"] * probs | |
| top_prob, top_idx = torch.max(probs, dim=1) | |
| member_lines.append( | |
| f"{member['display_name']} seed {member['seed']}: " | |
| f"{CLASS_NAMES[int(top_idx.item())]} ({float(top_prob.item()):.4f})" | |
| ) | |
| final_probs_np = final_probs.squeeze(0).detach().cpu().numpy() | |
| pred_idx = int(np.argmax(final_probs_np)) | |
| pred_class = CLASS_NAMES[pred_idx] | |
| pred_conf = float(final_probs_np[pred_idx]) | |
| label_scores = { | |
| CLASS_NAMES[i]: float(final_probs_np[i]) | |
| for i in range(NUM_CLASSES) | |
| } | |
| details = ( | |
| f"Predicted class: {pred_class}\n" | |
| f"Calibrated ensemble confidence: {pred_conf:.4f}\n\n" | |
| "Member predictions:\n" | |
| + "\n".join(member_lines) | |
| + "\n\nDisclaimer: This tool is for research and educational demonstration only. " | |
| "It is not a medical device and must not be used for diagnosis or treatment." | |
| ) | |
| return label_scores, details | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload MRI image"), | |
| outputs=[ | |
| gr.Label(num_top_classes=NUM_CLASSES, label="Calibrated ensemble probabilities"), | |
| gr.Textbox(label="Prediction details"), | |
| ], | |
| title="LCVC-Ensemble Brain Tumor MRI Classifier", | |
| description=( | |
| "Leakage-Controlled Validation-Calibrated Multi-Backbone Ensemble. " | |
| "Research demonstration only; not for clinical use." | |
| ), | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |