Spaces:
Sleeping
Sleeping
Commit ·
691bb8f
1
Parent(s): 90e8c4e
🚀 Added all models + Gradio app.py + requirements.txt
Browse files
app.py
CHANGED
|
@@ -5,10 +5,10 @@ from torchvision import transforms
|
|
| 5 |
from PIL import Image
|
| 6 |
import google.generativeai as genai
|
| 7 |
|
| 8 |
-
#
|
| 9 |
genai.configure(api_key="AIzaSyDn5yK2_2pYMId3bpFlAf0LkWoJ7dvEcqM")
|
| 10 |
|
| 11 |
-
# ------------------ Model Definitions ------------------
|
| 12 |
|
| 13 |
class AlzheimerCNN(torch.nn.Module):
|
| 14 |
def __init__(self, num_classes=2):
|
|
@@ -18,6 +18,7 @@ class AlzheimerCNN(torch.nn.Module):
|
|
| 18 |
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 19 |
self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
|
| 20 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
|
|
|
| 21 |
def forward(self, x):
|
| 22 |
x = self.pool(F.relu(self.conv1(x)))
|
| 23 |
x = self.pool(F.relu(self.conv2(x)))
|
|
@@ -34,6 +35,7 @@ class BrainHemorrhageCNN(torch.nn.Module):
|
|
| 34 |
self.dropout = torch.nn.Dropout(0.5)
|
| 35 |
self.fc1 = torch.nn.Linear(32 * 16 * 16, 64)
|
| 36 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
|
|
|
| 37 |
def forward(self, x):
|
| 38 |
x = self.pool(F.relu(self.conv1(x)))
|
| 39 |
x = self.pool(F.relu(self.conv2(x)))
|
|
@@ -50,6 +52,7 @@ class StrokeCTCNN(torch.nn.Module):
|
|
| 50 |
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 51 |
self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
|
| 52 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
|
|
|
| 53 |
def forward(self, x):
|
| 54 |
x = self.pool(F.relu(self.conv1(x)))
|
| 55 |
x = self.pool(F.relu(self.conv2(x)))
|
|
@@ -65,6 +68,7 @@ class BrainTumorCNN(torch.nn.Module):
|
|
| 65 |
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 66 |
self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
|
| 67 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
|
|
|
| 68 |
def forward(self, x):
|
| 69 |
x = self.pool(F.relu(self.conv1(x)))
|
| 70 |
x = self.pool(F.relu(self.conv2(x)))
|
|
@@ -72,7 +76,7 @@ class BrainTumorCNN(torch.nn.Module):
|
|
| 72 |
x = F.relu(self.fc1(x))
|
| 73 |
return self.fc2(x)
|
| 74 |
|
| 75 |
-
# ------------------
|
| 76 |
|
| 77 |
transform = transforms.Compose([
|
| 78 |
transforms.Resize((128, 128)),
|
|
@@ -83,34 +87,34 @@ transform = transforms.Compose([
|
|
| 83 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 84 |
|
| 85 |
alz_model = AlzheimerCNN()
|
| 86 |
-
alz_model.load_state_dict(torch.load(
|
| 87 |
alz_model.eval()
|
| 88 |
|
| 89 |
brainh_model = BrainHemorrhageCNN()
|
| 90 |
-
brainh_model.load_state_dict(torch.load(
|
| 91 |
brainh_model.eval()
|
| 92 |
|
| 93 |
brainst_model = StrokeCTCNN()
|
| 94 |
-
brainst_model.load_state_dict(torch.load(
|
| 95 |
brainst_model.eval()
|
| 96 |
|
| 97 |
braint_model = BrainTumorCNN()
|
| 98 |
-
braint_model.load_state_dict(torch.load(
|
| 99 |
braint_model.eval()
|
| 100 |
|
| 101 |
-
# ------------------
|
| 102 |
|
| 103 |
recommendation_dict = {
|
| 104 |
-
"Non Demented": "No signs of Alzheimer's detected. Maintain a healthy brain with mental activities and regular checkups
|
| 105 |
-
"Demented": "Signs of dementia detected. Consult a neurologist for proper diagnosis and treatment options.
|
| 106 |
-
"Normal": "Brain scan appears normal. Stay consistent with health checks.
|
| 107 |
-
"Hemorrhagic": "Hemorrhage detected. Seek immediate medical attention — may require surgery or ICU.
|
| 108 |
-
"Bleeding": "Bleeding stroke identified. Emergency treatment may be necessary.
|
| 109 |
-
"Ischemia": "Ischemic stroke detected. Treatment may include clot-busting medication.
|
| 110 |
-
"Glioma": "Glioma tumor found. Requires MRI follow-up and oncology consultation.
|
| 111 |
-
"Meningioma": "Meningioma detected. Often benign but may need surgical evaluation.
|
| 112 |
-
"No Tumor": "No brain tumor detected. Continue routine monitoring.
|
| 113 |
-
"Pituitary": "Pituitary tumor detected. Hormonal and visual exams recommended.
|
| 114 |
}
|
| 115 |
|
| 116 |
# ------------------ Prediction Function ------------------
|
|
@@ -122,49 +126,50 @@ def predict(disorder, image):
|
|
| 122 |
if disorder == "Alzheimer":
|
| 123 |
outputs = alz_model(img_tensor)
|
| 124 |
class_names = ["Non Demented", "Demented"]
|
| 125 |
-
|
| 126 |
elif disorder == "Brain Hemorrhage":
|
| 127 |
outputs = brainh_model(img_tensor)
|
| 128 |
class_names = ["Normal", "Hemorrhagic"]
|
| 129 |
-
|
| 130 |
elif disorder == "Brain Stroke":
|
| 131 |
outputs = brainst_model(img_tensor)
|
| 132 |
class_names = ["Bleeding", "Ischemia", "Normal"]
|
| 133 |
-
|
| 134 |
elif disorder == "Brain Tumor":
|
| 135 |
outputs = braint_model(img_tensor)
|
| 136 |
class_names = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
|
| 137 |
|
| 138 |
probs = torch.softmax(outputs, dim=1)[0]
|
| 139 |
-
|
| 140 |
-
pred_label = class_names[pred_idx]
|
| 141 |
recommendation = recommendation_dict.get(pred_label, "No recommendation available.")
|
| 142 |
|
| 143 |
-
return f"Prediction: {pred_label}\n\
|
| 144 |
|
| 145 |
-
# ------------------ Gemini Chatbot
|
| 146 |
|
| 147 |
model = genai.GenerativeModel("models/gemini-1.5-flash")
|
| 148 |
|
| 149 |
def chat_with_gemini(user_input, history=[]):
|
| 150 |
-
prompt =
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
convo.send_message(prompt + user_input)
|
| 153 |
return convo.last.text
|
| 154 |
|
| 155 |
-
# ------------------ Gradio
|
| 156 |
|
| 157 |
with gr.Blocks(theme=gr.themes.Base(), title="Brain Neurologist App") as demo:
|
| 158 |
-
with gr.Tab("Brain Scan Predictor"):
|
| 159 |
-
gr.Markdown("##
|
| 160 |
disorder_input = gr.Dropdown(["Alzheimer", "Brain Hemorrhage", "Brain Stroke", "Brain Tumor"], label="Select Disorder")
|
| 161 |
image_input = gr.Image(type="filepath", label="Upload Brain Scan")
|
| 162 |
output_text = gr.Textbox(label="Prediction and Recommendation")
|
| 163 |
submit_btn = gr.Button("Predict")
|
| 164 |
submit_btn.click(fn=predict, inputs=[disorder_input, image_input], outputs=output_text)
|
| 165 |
|
| 166 |
-
with gr.Tab("🧑⚕️ Dr.Neuro"):
|
| 167 |
-
gr.Markdown("##
|
| 168 |
chatbot = gr.ChatInterface(fn=chat_with_gemini)
|
| 169 |
|
| 170 |
-
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
import google.generativeai as genai
|
| 7 |
|
| 8 |
+
# ------------------ Gemini API Key (Hardcoded - as requested) ------------------
|
| 9 |
genai.configure(api_key="AIzaSyDn5yK2_2pYMId3bpFlAf0LkWoJ7dvEcqM")
|
| 10 |
|
| 11 |
+
# ------------------ CNN Model Definitions ------------------
|
| 12 |
|
| 13 |
class AlzheimerCNN(torch.nn.Module):
|
| 14 |
def __init__(self, num_classes=2):
|
|
|
|
| 18 |
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 19 |
self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
|
| 20 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
| 21 |
+
|
| 22 |
def forward(self, x):
|
| 23 |
x = self.pool(F.relu(self.conv1(x)))
|
| 24 |
x = self.pool(F.relu(self.conv2(x)))
|
|
|
|
| 35 |
self.dropout = torch.nn.Dropout(0.5)
|
| 36 |
self.fc1 = torch.nn.Linear(32 * 16 * 16, 64)
|
| 37 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
| 38 |
+
|
| 39 |
def forward(self, x):
|
| 40 |
x = self.pool(F.relu(self.conv1(x)))
|
| 41 |
x = self.pool(F.relu(self.conv2(x)))
|
|
|
|
| 52 |
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 53 |
self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
|
| 54 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
| 55 |
+
|
| 56 |
def forward(self, x):
|
| 57 |
x = self.pool(F.relu(self.conv1(x)))
|
| 58 |
x = self.pool(F.relu(self.conv2(x)))
|
|
|
|
| 68 |
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 69 |
self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
|
| 70 |
self.fc2 = torch.nn.Linear(64, num_classes)
|
| 71 |
+
|
| 72 |
def forward(self, x):
|
| 73 |
x = self.pool(F.relu(self.conv1(x)))
|
| 74 |
x = self.pool(F.relu(self.conv2(x)))
|
|
|
|
| 76 |
x = F.relu(self.fc1(x))
|
| 77 |
return self.fc2(x)
|
| 78 |
|
| 79 |
+
# ------------------ Model Loading ------------------
|
| 80 |
|
| 81 |
transform = transforms.Compose([
|
| 82 |
transforms.Resize((128, 128)),
|
|
|
|
| 87 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 88 |
|
| 89 |
alz_model = AlzheimerCNN()
|
| 90 |
+
alz_model.load_state_dict(torch.load("alzmodel.pth", map_location=device))
|
| 91 |
alz_model.eval()
|
| 92 |
|
| 93 |
brainh_model = BrainHemorrhageCNN()
|
| 94 |
+
brainh_model.load_state_dict(torch.load("brainham.pth", map_location=device))
|
| 95 |
brainh_model.eval()
|
| 96 |
|
| 97 |
brainst_model = StrokeCTCNN()
|
| 98 |
+
brainst_model.load_state_dict(torch.load("brainst_model.pth", map_location=device))
|
| 99 |
brainst_model.eval()
|
| 100 |
|
| 101 |
braint_model = BrainTumorCNN()
|
| 102 |
+
braint_model.load_state_dict(torch.load("braintmodel.pth", map_location=device))
|
| 103 |
braint_model.eval()
|
| 104 |
|
| 105 |
+
# ------------------ Recommendations ------------------
|
| 106 |
|
| 107 |
recommendation_dict = {
|
| 108 |
+
"Non Demented": "No signs of Alzheimer's detected. Maintain a healthy brain with mental activities and regular checkups.",
|
| 109 |
+
"Demented": "Signs of dementia detected. Consult a neurologist for proper diagnosis and treatment options.",
|
| 110 |
+
"Normal": "Brain scan appears normal. Stay consistent with health checks.",
|
| 111 |
+
"Hemorrhagic": "Hemorrhage detected. Seek immediate medical attention — may require surgery or ICU.",
|
| 112 |
+
"Bleeding": "Bleeding stroke identified. Emergency treatment may be necessary.",
|
| 113 |
+
"Ischemia": "Ischemic stroke detected. Treatment may include clot-busting medication.",
|
| 114 |
+
"Glioma": "Glioma tumor found. Requires MRI follow-up and oncology consultation.",
|
| 115 |
+
"Meningioma": "Meningioma detected. Often benign but may need surgical evaluation.",
|
| 116 |
+
"No Tumor": "No brain tumor detected. Continue routine monitoring.",
|
| 117 |
+
"Pituitary": "Pituitary tumor detected. Hormonal and visual exams recommended."
|
| 118 |
}
|
| 119 |
|
| 120 |
# ------------------ Prediction Function ------------------
|
|
|
|
| 126 |
if disorder == "Alzheimer":
|
| 127 |
outputs = alz_model(img_tensor)
|
| 128 |
class_names = ["Non Demented", "Demented"]
|
|
|
|
| 129 |
elif disorder == "Brain Hemorrhage":
|
| 130 |
outputs = brainh_model(img_tensor)
|
| 131 |
class_names = ["Normal", "Hemorrhagic"]
|
|
|
|
| 132 |
elif disorder == "Brain Stroke":
|
| 133 |
outputs = brainst_model(img_tensor)
|
| 134 |
class_names = ["Bleeding", "Ischemia", "Normal"]
|
|
|
|
| 135 |
elif disorder == "Brain Tumor":
|
| 136 |
outputs = braint_model(img_tensor)
|
| 137 |
class_names = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
|
| 138 |
|
| 139 |
probs = torch.softmax(outputs, dim=1)[0]
|
| 140 |
+
pred_label = class_names[torch.argmax(probs).item()]
|
|
|
|
| 141 |
recommendation = recommendation_dict.get(pred_label, "No recommendation available.")
|
| 142 |
|
| 143 |
+
return f"🧠 Prediction: {pred_label}\n\n📌 Recommendation: {recommendation}"
|
| 144 |
|
| 145 |
+
# ------------------ Gemini Chatbot ------------------
|
| 146 |
|
| 147 |
model = genai.GenerativeModel("models/gemini-1.5-flash")
|
| 148 |
|
| 149 |
def chat_with_gemini(user_input, history=[]):
|
| 150 |
+
prompt = (
|
| 151 |
+
"You are a highly experienced and friendly neurologist. "
|
| 152 |
+
"Help answer questions related to brain diseases like stroke, dementia, tumors, Alzheimer’s, etc. "
|
| 153 |
+
"Answer briefly and simply so non-medical users can understand easily.\n\n"
|
| 154 |
+
)
|
| 155 |
+
convo = model.start_chat()
|
| 156 |
convo.send_message(prompt + user_input)
|
| 157 |
return convo.last.text
|
| 158 |
|
| 159 |
+
# ------------------ Gradio UI ------------------
|
| 160 |
|
| 161 |
with gr.Blocks(theme=gr.themes.Base(), title="Brain Neurologist App") as demo:
|
| 162 |
+
with gr.Tab("🧠 Brain Scan Predictor"):
|
| 163 |
+
gr.Markdown("## Upload an MRI/CT scan to get brain disorder prediction and medical advice")
|
| 164 |
disorder_input = gr.Dropdown(["Alzheimer", "Brain Hemorrhage", "Brain Stroke", "Brain Tumor"], label="Select Disorder")
|
| 165 |
image_input = gr.Image(type="filepath", label="Upload Brain Scan")
|
| 166 |
output_text = gr.Textbox(label="Prediction and Recommendation")
|
| 167 |
submit_btn = gr.Button("Predict")
|
| 168 |
submit_btn.click(fn=predict, inputs=[disorder_input, image_input], outputs=output_text)
|
| 169 |
|
| 170 |
+
with gr.Tab("🧑⚕️ Dr.Neuro - Gemini Chatbot"):
|
| 171 |
+
gr.Markdown("## Ask neurological questions powered by Gemini")
|
| 172 |
chatbot = gr.ChatInterface(fn=chat_with_gemini)
|
| 173 |
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
demo.launch()
|