import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import gradio as gr class HybridModel(nn.Module): def __init__(self, backbones, num_classes): super(HybridModel, self).__init__() self.models = nn.ModuleList() self.out_features = 0 for name in backbones: if name == "ResNet152": model = models.resnet152(weights=None) in_features = model.fc.in_features model.fc = nn.Identity() elif name == "DenseNet201": model = models.densenet201(weights=None) in_features = model.classifier.in_features model.classifier = nn.Identity() else: raise ValueError(f"Backbone {name} not supported.") self.models.append(model) self.out_features += in_features # ✅ Match checkpoint: 3968 → 1024 → 4 self.classifier = nn.Sequential( nn.Linear(self.out_features, 1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, num_classes) ) def forward(self, x): features = [m(x) for m in self.models] combined = torch.cat(features, dim=1) return self.classifier(combined) # --------- Load Model ---------- MODEL_PATH = "ResNet152_DenseNet201_best.pt" device = torch.device("cpu") # force CPU num_classes = 4 backbones = ["ResNet152", "DenseNet201"] model = HybridModel(backbones, num_classes) state_dict = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() # --------- Define Preprocessing ---------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Class labels class_names = ['No Impairment', 'Very Mild Impairment', 'Moderate Impairment', 'Mild Impairment'] # --------- Prediction Function ---------- def predict(image): image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) return class_names[predicted.item()] # --------- Gradio Interface ---------- iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload MRI Scan"), outputs=gr.Label(num_top_classes=4, label="Predicted Alzheimer’s Stage"), title="Alzheimer’s MRI Classifier", description="Upload an MRI brain scan to classify into one of four stages of Alzheimer's disease." ) if __name__ == "__main__": iface.launch()