katelynhur commited on
Commit
86d3ae4
·
verified ·
1 Parent(s): e7dc2ab

Upload 3 files

Browse files
Files changed (3) hide show
  1. ResNet152_DenseNet201_best.pt +3 -0
  2. app.py +44 -0
  3. requirements.txt +4 -0
ResNet152_DenseNet201_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf755020938d8ae9bb3e61ab2acf960977a78474b6c3a76148bf8e701d5acb44
3
+ size 323570735
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # --------- Load Model ----------
8
+ MODEL_PATH = "ResNet152_DenseNet201_best.pt"
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Load your trained model
12
+ model = torch.load(MODEL_PATH, map_location=device)
13
+ model.eval()
14
+
15
+ # --------- Define Preprocessing ----------
16
+ transform = transforms.Compose([
17
+ transforms.Resize((224, 224)), # match training size
18
+ transforms.ToTensor(),
19
+ transforms.Normalize([0.485, 0.456, 0.406],
20
+ [0.229, 0.224, 0.225]) # standard ImageNet normalization
21
+ ])
22
+
23
+ # Class labels (adjust if yours are different)
24
+ class_names = ["No Alzheimer’s", "Very Mild", "Mild", "Moderate"]
25
+
26
+ # --------- Prediction Function ----------
27
+ def predict(image):
28
+ image = transform(image).unsqueeze(0).to(device)
29
+ with torch.no_grad():
30
+ outputs = model(image)
31
+ _, predicted = torch.max(outputs, 1)
32
+ return class_names[predicted.item()]
33
+
34
+ # --------- Gradio Interface ----------
35
+ iface = gr.Interface(
36
+ fn=predict,
37
+ inputs=gr.Image(type="pil", label="Upload MRI Scan"),
38
+ outputs=gr.Label(num_top_classes=4, label="Predicted Alzheimer’s Stage"),
39
+ title="Alzheimer’s MRI Classifier",
40
+ description="Upload an MRI brain scan to classify into one of four stages of Alzheimer's disease."
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow