AD-Stage-Net / app.py
jungukhur's picture
Update app.py
7f28c1e verified
raw
history blame
2.72 kB
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
class HybridModel(nn.Module):
def __init__(self, backbones, num_classes):
super(HybridModel, self).__init__()
self.models = nn.ModuleList()
self.out_features = 0
for name in backbones:
if name == "ResNet152":
model = models.resnet152(weights=None)
in_features = model.fc.in_features
model.fc = nn.Identity()
elif name == "DenseNet201":
model = models.densenet201(weights=None)
in_features = model.classifier.in_features
model.classifier = nn.Identity()
else:
raise ValueError(f"Backbone {name} not supported.")
self.models.append(model)
self.out_features += in_features
# ✅ Match checkpoint: 3968 → 1024 → 4
self.classifier = nn.Sequential(
nn.Linear(self.out_features, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, num_classes)
)
def forward(self, x):
features = [m(x) for m in self.models]
combined = torch.cat(features, dim=1)
return self.classifier(combined)
# --------- Load Model ----------
MODEL_PATH = "ResNet152_DenseNet201_best.pt"
device = torch.device("cpu") # force CPU
num_classes = 4
backbones = ["ResNet152", "DenseNet201"]
model = HybridModel(backbones, num_classes)
state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# --------- Define Preprocessing ----------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Class labels
class_names = ['No Impairment', 'Very Mild Impairment', 'Moderate Impairment', 'Mild Impairment']
# --------- Prediction Function ----------
def predict(image):
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return class_names[predicted.item()]
# --------- Gradio Interface ----------
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload MRI Scan"),
outputs=gr.Label(num_top_classes=4, label="Predicted Alzheimer’s Stage"),
title="Alzheimer’s MRI Classifier",
description="Upload an MRI brain scan to classify into one of four stages of Alzheimer's disease."
)
if __name__ == "__main__":
iface.launch()