sharjeel357 commited on
Commit
691bb8f
·
1 Parent(s): 90e8c4e

🚀 Added all models + Gradio app.py + requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +38 -33
app.py CHANGED
@@ -5,10 +5,10 @@ from torchvision import transforms
5
  from PIL import Image
6
  import google.generativeai as genai
7
 
8
- # Inject API key
9
  genai.configure(api_key="AIzaSyDn5yK2_2pYMId3bpFlAf0LkWoJ7dvEcqM")
10
 
11
- # ------------------ Model Definitions ------------------
12
 
13
  class AlzheimerCNN(torch.nn.Module):
14
  def __init__(self, num_classes=2):
@@ -18,6 +18,7 @@ class AlzheimerCNN(torch.nn.Module):
18
  self.pool = torch.nn.MaxPool2d(2, 2)
19
  self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
20
  self.fc2 = torch.nn.Linear(64, num_classes)
 
21
  def forward(self, x):
22
  x = self.pool(F.relu(self.conv1(x)))
23
  x = self.pool(F.relu(self.conv2(x)))
@@ -34,6 +35,7 @@ class BrainHemorrhageCNN(torch.nn.Module):
34
  self.dropout = torch.nn.Dropout(0.5)
35
  self.fc1 = torch.nn.Linear(32 * 16 * 16, 64)
36
  self.fc2 = torch.nn.Linear(64, num_classes)
 
37
  def forward(self, x):
38
  x = self.pool(F.relu(self.conv1(x)))
39
  x = self.pool(F.relu(self.conv2(x)))
@@ -50,6 +52,7 @@ class StrokeCTCNN(torch.nn.Module):
50
  self.pool = torch.nn.MaxPool2d(2, 2)
51
  self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
52
  self.fc2 = torch.nn.Linear(64, num_classes)
 
53
  def forward(self, x):
54
  x = self.pool(F.relu(self.conv1(x)))
55
  x = self.pool(F.relu(self.conv2(x)))
@@ -65,6 +68,7 @@ class BrainTumorCNN(torch.nn.Module):
65
  self.pool = torch.nn.MaxPool2d(2, 2)
66
  self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
67
  self.fc2 = torch.nn.Linear(64, num_classes)
 
68
  def forward(self, x):
69
  x = self.pool(F.relu(self.conv1(x)))
70
  x = self.pool(F.relu(self.conv2(x)))
@@ -72,7 +76,7 @@ class BrainTumorCNN(torch.nn.Module):
72
  x = F.relu(self.fc1(x))
73
  return self.fc2(x)
74
 
75
- # ------------------ Load Models ------------------
76
 
77
  transform = transforms.Compose([
78
  transforms.Resize((128, 128)),
@@ -83,34 +87,34 @@ transform = transforms.Compose([
83
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
 
85
  alz_model = AlzheimerCNN()
86
- alz_model.load_state_dict(torch.load(r"C:\Users\HP\OneDrive\Desktop\trained model\alzmodel.pth", map_location=device))
87
  alz_model.eval()
88
 
89
  brainh_model = BrainHemorrhageCNN()
90
- brainh_model.load_state_dict(torch.load(r"C:\Users\HP\OneDrive\Desktop\trained model\brainham.pth", map_location=device))
91
  brainh_model.eval()
92
 
93
  brainst_model = StrokeCTCNN()
94
- brainst_model.load_state_dict(torch.load(r"C:\Users\HP\OneDrive\Desktop\trained model\brainst_model.pth", map_location=device))
95
  brainst_model.eval()
96
 
97
  braint_model = BrainTumorCNN()
98
- braint_model.load_state_dict(torch.load(r"C:\Users\HP\OneDrive\Desktop\trained model\braintmodel.pth", map_location=device))
99
  braint_model.eval()
100
 
101
- # ------------------ Recommendation Mapping ------------------
102
 
103
  recommendation_dict = {
104
- "Non Demented": "No signs of Alzheimer's detected. Maintain a healthy brain with mental activities and regular checkups(recomend you to use the gemini chatbot).",
105
- "Demented": "Signs of dementia detected. Consult a neurologist for proper diagnosis and treatment options.(recomend you to use the gemini chatbot)",
106
- "Normal": "Brain scan appears normal. Stay consistent with health checks.(recomend you to use the gemini chatbot)",
107
- "Hemorrhagic": "Hemorrhage detected. Seek immediate medical attention — may require surgery or ICU.(recomend you to use the gemini chatbot)",
108
- "Bleeding": "Bleeding stroke identified. Emergency treatment may be necessary.(recomend you to use the gemini chatbot)",
109
- "Ischemia": "Ischemic stroke detected. Treatment may include clot-busting medication.(recomend you to use the gemini chatbot)",
110
- "Glioma": "Glioma tumor found. Requires MRI follow-up and oncology consultation.(recomend you to use the gemini chatbot)",
111
- "Meningioma": "Meningioma detected. Often benign but may need surgical evaluation.(recomend you to use the gemini chatbot)",
112
- "No Tumor": "No brain tumor detected. Continue routine monitoring.(recomend you to use the gemini chatbot)",
113
- "Pituitary": "Pituitary tumor detected. Hormonal and visual exams recommended.(recomend you to use the gemini chatbot)"
114
  }
115
 
116
  # ------------------ Prediction Function ------------------
@@ -122,49 +126,50 @@ def predict(disorder, image):
122
  if disorder == "Alzheimer":
123
  outputs = alz_model(img_tensor)
124
  class_names = ["Non Demented", "Demented"]
125
-
126
  elif disorder == "Brain Hemorrhage":
127
  outputs = brainh_model(img_tensor)
128
  class_names = ["Normal", "Hemorrhagic"]
129
-
130
  elif disorder == "Brain Stroke":
131
  outputs = brainst_model(img_tensor)
132
  class_names = ["Bleeding", "Ischemia", "Normal"]
133
-
134
  elif disorder == "Brain Tumor":
135
  outputs = braint_model(img_tensor)
136
  class_names = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
137
 
138
  probs = torch.softmax(outputs, dim=1)[0]
139
- pred_idx = torch.argmax(probs).item()
140
- pred_label = class_names[pred_idx]
141
  recommendation = recommendation_dict.get(pred_label, "No recommendation available.")
142
 
143
- return f"Prediction: {pred_label}\n\nRecommendation: {recommendation}"
144
 
145
- # ------------------ Gemini Chatbot Function ------------------
146
 
147
  model = genai.GenerativeModel("models/gemini-1.5-flash")
148
 
149
  def chat_with_gemini(user_input, history=[]):
150
- prompt = "You are a highly experienced neurologist. Help answer questions related to brain diseases like stroke, dementia, tumors, Alzheimer’s, etc. and you answer questions short and simple to make a non medical field based person understand it better and easily\n"
151
- convo = model.start_chat(history=[])
 
 
 
 
152
  convo.send_message(prompt + user_input)
153
  return convo.last.text
154
 
155
- # ------------------ Gradio Interface ------------------
156
 
157
  with gr.Blocks(theme=gr.themes.Base(), title="Brain Neurologist App") as demo:
158
- with gr.Tab("Brain Scan Predictor"):
159
- gr.Markdown("## 🧠 MRI/CT Brain Disorder Detection")
160
  disorder_input = gr.Dropdown(["Alzheimer", "Brain Hemorrhage", "Brain Stroke", "Brain Tumor"], label="Select Disorder")
161
  image_input = gr.Image(type="filepath", label="Upload Brain Scan")
162
  output_text = gr.Textbox(label="Prediction and Recommendation")
163
  submit_btn = gr.Button("Predict")
164
  submit_btn.click(fn=predict, inputs=[disorder_input, image_input], outputs=output_text)
165
 
166
- with gr.Tab("🧑‍⚕️ Dr.Neuro"):
167
- gr.Markdown("## Gemini powered Neurologist Chatbot")
168
  chatbot = gr.ChatInterface(fn=chat_with_gemini)
169
 
170
- demo.launch(share=True)
 
 
5
  from PIL import Image
6
  import google.generativeai as genai
7
 
8
+ # ------------------ Gemini API Key (Hardcoded - as requested) ------------------
9
  genai.configure(api_key="AIzaSyDn5yK2_2pYMId3bpFlAf0LkWoJ7dvEcqM")
10
 
11
+ # ------------------ CNN Model Definitions ------------------
12
 
13
  class AlzheimerCNN(torch.nn.Module):
14
  def __init__(self, num_classes=2):
 
18
  self.pool = torch.nn.MaxPool2d(2, 2)
19
  self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
20
  self.fc2 = torch.nn.Linear(64, num_classes)
21
+
22
  def forward(self, x):
23
  x = self.pool(F.relu(self.conv1(x)))
24
  x = self.pool(F.relu(self.conv2(x)))
 
35
  self.dropout = torch.nn.Dropout(0.5)
36
  self.fc1 = torch.nn.Linear(32 * 16 * 16, 64)
37
  self.fc2 = torch.nn.Linear(64, num_classes)
38
+
39
  def forward(self, x):
40
  x = self.pool(F.relu(self.conv1(x)))
41
  x = self.pool(F.relu(self.conv2(x)))
 
52
  self.pool = torch.nn.MaxPool2d(2, 2)
53
  self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
54
  self.fc2 = torch.nn.Linear(64, num_classes)
55
+
56
  def forward(self, x):
57
  x = self.pool(F.relu(self.conv1(x)))
58
  x = self.pool(F.relu(self.conv2(x)))
 
68
  self.pool = torch.nn.MaxPool2d(2, 2)
69
  self.fc1 = torch.nn.Linear(32 * 32 * 32, 64)
70
  self.fc2 = torch.nn.Linear(64, num_classes)
71
+
72
  def forward(self, x):
73
  x = self.pool(F.relu(self.conv1(x)))
74
  x = self.pool(F.relu(self.conv2(x)))
 
76
  x = F.relu(self.fc1(x))
77
  return self.fc2(x)
78
 
79
+ # ------------------ Model Loading ------------------
80
 
81
  transform = transforms.Compose([
82
  transforms.Resize((128, 128)),
 
87
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
 
89
  alz_model = AlzheimerCNN()
90
+ alz_model.load_state_dict(torch.load("alzmodel.pth", map_location=device))
91
  alz_model.eval()
92
 
93
  brainh_model = BrainHemorrhageCNN()
94
+ brainh_model.load_state_dict(torch.load("brainham.pth", map_location=device))
95
  brainh_model.eval()
96
 
97
  brainst_model = StrokeCTCNN()
98
+ brainst_model.load_state_dict(torch.load("brainst_model.pth", map_location=device))
99
  brainst_model.eval()
100
 
101
  braint_model = BrainTumorCNN()
102
+ braint_model.load_state_dict(torch.load("braintmodel.pth", map_location=device))
103
  braint_model.eval()
104
 
105
+ # ------------------ Recommendations ------------------
106
 
107
  recommendation_dict = {
108
+ "Non Demented": "No signs of Alzheimer's detected. Maintain a healthy brain with mental activities and regular checkups.",
109
+ "Demented": "Signs of dementia detected. Consult a neurologist for proper diagnosis and treatment options.",
110
+ "Normal": "Brain scan appears normal. Stay consistent with health checks.",
111
+ "Hemorrhagic": "Hemorrhage detected. Seek immediate medical attention — may require surgery or ICU.",
112
+ "Bleeding": "Bleeding stroke identified. Emergency treatment may be necessary.",
113
+ "Ischemia": "Ischemic stroke detected. Treatment may include clot-busting medication.",
114
+ "Glioma": "Glioma tumor found. Requires MRI follow-up and oncology consultation.",
115
+ "Meningioma": "Meningioma detected. Often benign but may need surgical evaluation.",
116
+ "No Tumor": "No brain tumor detected. Continue routine monitoring.",
117
+ "Pituitary": "Pituitary tumor detected. Hormonal and visual exams recommended."
118
  }
119
 
120
  # ------------------ Prediction Function ------------------
 
126
  if disorder == "Alzheimer":
127
  outputs = alz_model(img_tensor)
128
  class_names = ["Non Demented", "Demented"]
 
129
  elif disorder == "Brain Hemorrhage":
130
  outputs = brainh_model(img_tensor)
131
  class_names = ["Normal", "Hemorrhagic"]
 
132
  elif disorder == "Brain Stroke":
133
  outputs = brainst_model(img_tensor)
134
  class_names = ["Bleeding", "Ischemia", "Normal"]
 
135
  elif disorder == "Brain Tumor":
136
  outputs = braint_model(img_tensor)
137
  class_names = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
138
 
139
  probs = torch.softmax(outputs, dim=1)[0]
140
+ pred_label = class_names[torch.argmax(probs).item()]
 
141
  recommendation = recommendation_dict.get(pred_label, "No recommendation available.")
142
 
143
+ return f"🧠 Prediction: {pred_label}\n\n📌 Recommendation: {recommendation}"
144
 
145
+ # ------------------ Gemini Chatbot ------------------
146
 
147
  model = genai.GenerativeModel("models/gemini-1.5-flash")
148
 
149
  def chat_with_gemini(user_input, history=[]):
150
+ prompt = (
151
+ "You are a highly experienced and friendly neurologist. "
152
+ "Help answer questions related to brain diseases like stroke, dementia, tumors, Alzheimer’s, etc. "
153
+ "Answer briefly and simply so non-medical users can understand easily.\n\n"
154
+ )
155
+ convo = model.start_chat()
156
  convo.send_message(prompt + user_input)
157
  return convo.last.text
158
 
159
+ # ------------------ Gradio UI ------------------
160
 
161
  with gr.Blocks(theme=gr.themes.Base(), title="Brain Neurologist App") as demo:
162
+ with gr.Tab("🧠 Brain Scan Predictor"):
163
+ gr.Markdown("## Upload an MRI/CT scan to get brain disorder prediction and medical advice")
164
  disorder_input = gr.Dropdown(["Alzheimer", "Brain Hemorrhage", "Brain Stroke", "Brain Tumor"], label="Select Disorder")
165
  image_input = gr.Image(type="filepath", label="Upload Brain Scan")
166
  output_text = gr.Textbox(label="Prediction and Recommendation")
167
  submit_btn = gr.Button("Predict")
168
  submit_btn.click(fn=predict, inputs=[disorder_input, image_input], outputs=output_text)
169
 
170
+ with gr.Tab("🧑‍⚕️ Dr.Neuro - Gemini Chatbot"):
171
+ gr.Markdown("## Ask neurological questions powered by Gemini")
172
  chatbot = gr.ChatInterface(fn=chat_with_gemini)
173
 
174
+ if __name__ == "__main__":
175
+ demo.launch()