Hammad712 commited on
Commit
b9d1861
·
verified ·
1 Parent(s): d26f794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -124
app.py CHANGED
@@ -1,97 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
- import torch
4
- from transformers import ViTForImageClassification, ViTImageProcessor
5
- import logging
6
- import base64
7
  from io import BytesIO
8
- from groq import Groq # Import the Groq client for Deepseek R1 API
9
- import os
10
- # ------------------ Setup Logging ------------------
11
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
-
13
- # ------------------ Load the ViT Model ------------------
14
- repository_id = "EnDevSols/brainmri-vit-model"
15
- model = ViTForImageClassification.from_pretrained(repository_id)
16
- feature_extractor = ViTImageProcessor.from_pretrained(repository_id)
17
 
18
- # ------------------ ViT Inference Function ------------------
19
- def predict(image):
20
- """
21
- Given an image, perform inference using the ViT model to detect brain tumor.
22
- Returns a human-readable diagnosis string.
23
- """
24
- # Convert to RGB and preprocess the image
25
- image = image.convert("RGB")
26
- inputs = feature_extractor(images=image, return_tensors="pt")
27
-
28
- # Set the device (GPU if available)
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- model.to(device)
31
- inputs = {k: v.to(device) for k, v in inputs.items()}
32
-
33
- # Perform inference without gradient computation
34
- with torch.no_grad():
35
- outputs = model(**inputs)
36
-
37
- # Get the predicted label and map to a diagnosis
38
- logits = outputs.logits
39
- predicted_label = logits.argmax(-1).item()
40
- label_map = {0: "No", 1: "Yes"}
41
- diagnosis = label_map[predicted_label]
42
-
43
- if diagnosis == "Yes":
44
- return "The diagnosis indicates that you have a brain tumor."
45
- else:
46
- return "The diagnosis indicates that you do not have a brain tumor."
47
-
48
- # ------------------ Deepseek R1 Assistance Function ------------------
49
- def get_assistance_from_deepseek(diagnosis_text):
50
- """
51
- Given the diagnosis from the ViT model, call the Deepseek R1 model via the Groq API
52
- to get additional recommendations and next steps.
53
- """
54
 
55
- api_key=os.getenv("API_KEY")
56
- # Instantiate the Groq client with the provided API key
57
- client = Groq(api_key=api_key)
58
-
59
- # Construct a prompt that includes the diagnosis and asks for detailed guidance
60
- prompt = (
61
- f"Based on the following diagnosis: '{diagnosis_text}', please provide next steps and "
62
- "recommendations for the patient. Include whether to consult a specialist, if further tests "
63
- "are needed, and any other immediate actions or lifestyle recommendations."
64
- )
65
-
66
- messages = [
67
- {
68
- "role": "system",
69
- "content": "You are a helpful medical assistant providing guidance after a brain tumor diagnosis."
70
- },
71
- {"role": "user", "content": prompt}
72
- ]
73
-
74
- # Create the completion using the Deepseek R1 model (non-streaming for simplicity)
75
- completion = client.chat.completions.create(
76
- model="deepseek-r1-distill-llama-70b",
77
- messages=messages,
78
- temperature=0.6,
79
- max_completion_tokens=4096,
80
- top_p=0.95,
81
- stream=False,
82
- stop=None,
83
- )
84
-
85
- # Extract the response text. (Depending on the API response format, adjust as needed.)
86
- try:
87
- assistance_text = completion.choices[0].message.content
88
- except AttributeError:
89
- # Fallback in case the structure is different
90
- assistance_text = completion.choices[0].text
91
-
92
- return assistance_text
93
-
94
- # ------------------ Custom CSS for Styling ------------------
95
  combined_css = """
96
  .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
97
  .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
@@ -119,13 +73,12 @@ combined_css = """
119
  margin-top: -20px;
120
  margin-bottom: 20px;
121
  }
 
 
122
  """
123
-
124
- # ------------------ Streamlit App Configuration ------------------
125
- st.set_page_config(layout="wide")
126
  st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
127
 
128
- # App Title and Description
129
  st.markdown(
130
  '<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
131
  unsafe_allow_html=True
@@ -135,37 +88,184 @@ st.markdown(
135
  unsafe_allow_html=True
136
  )
137
 
138
- # ------------------ Image Upload Section ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
140
 
141
  if uploaded_file is not None:
142
- image = Image.open(uploaded_file)
143
-
144
- # Resize image for display purposes
145
- resized_image = image.resize((150, 150))
146
-
147
- # Convert image to base64 for HTML display
148
- buffered = BytesIO()
149
- resized_image.save(buffered, format="JPEG")
150
- img_str = base64.b64encode(buffered.getvalue()).decode()
151
-
152
- # Display the uploaded image in the center
153
- st.markdown(
154
- f"<div style='text-align: center;'><img src='data:image/jpeg;base64,{img_str}' alt='Uploaded Image' width='300'></div>",
155
- unsafe_allow_html=True
156
- )
157
-
158
- st.write("")
159
- st.write("Processing the image...")
160
-
161
- # ------------------ Step 1: Get Diagnosis from the ViT Model ------------------
162
- diagnosis = predict(image)
163
- st.markdown("### Diagnosis:")
164
- st.write(diagnosis)
165
-
166
- # ------------------ Step 2: Get Further Assistance from Deepseek R1 ------------------
167
- with st.spinner("Fetching additional guidance based on your diagnosis..."):
168
- assistance = get_assistance_from_deepseek(diagnosis)
169
-
170
- st.markdown("### Next Steps and Recommendations:")
171
- st.write(assistance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streamlit Brain MRI Tumor Detection App (updated, safe startup + LLM safety)
4
+ - Monkeypatches torch.classes.__path__ before importing streamlit to avoid a Streamlit <-> PyTorch watcher crash.
5
+ - Returns probabilistic model output (label + confidence).
6
+ - Adds a visible medical disclaimer.
7
+ - Adds robust error handling around the Groq (Deepseek R1) call and ensures the LLM output contains a safety sentence.
8
+ - Keeps your original UI/CSS with small UX improvements.
9
+ """
10
+
11
+ import os
12
+ import logging
13
+ import traceback
14
+
15
+ # ------------------ Safe startup: import torch first and monkeypatch ------------------
16
+ # This avoids Streamlit's file-watcher triggering PyTorch C++ registry introspection errors.
17
+ try:
18
+ import torch
19
+ # Force a benign __path__ so Streamlit's watcher won't attempt unsafe introspection.
20
+ try:
21
+ # If torch.classes exists, ensure __path__ is present and is a harmless list.
22
+ if hasattr(torch, "classes"):
23
+ # Some torch builds may already have __path__; overwrite safely.
24
+ torch.classes.__path__ = []
25
+ except Exception as _e:
26
+ # If something goes wrong, don't crash the app at module import time.
27
+ logging.warning("Failed to set torch.classes.__path__: %s", _e)
28
+ except Exception as e:
29
+ # If torch can't be imported at all, we still continue so Streamlit can display an error to the user.
30
+ # Log it; later we'll surface a friendly message in the UI.
31
+ logging.error("Unable to import torch at startup: %s\n%s", e, traceback.format_exc())
32
+ torch = None
33
+
34
+ # ------------------ Now safe to import Streamlit and other packages ------------------
35
  import streamlit as st
36
  from PIL import Image
 
 
 
 
37
  from io import BytesIO
38
+ import base64
39
+ from transformers import ViTForImageClassification, ViTImageProcessor
40
+ from groq import Groq
41
+ import numpy as np
42
+ import torch.nn.functional as F
 
 
 
 
43
 
44
+ # ------------------ Logging ------------------
45
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # ------------------ Page config + CSS ------------------
48
+ st.set_page_config(layout="wide", page_title="Brain MRI Tumor Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  combined_css = """
50
  .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
51
  .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
 
73
  margin-top: -20px;
74
  margin-bottom: 20px;
75
  }
76
+ .disclaimer { color: #ffcc66; font-weight: bold; text-align: center; margin-bottom: 12px; }
77
+ .small-muted { font-size:0.9rem; color:#cccccc; text-align:center; margin-top:8px; }
78
  """
 
 
 
79
  st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
80
 
81
+ # ------------------ App header ------------------
82
  st.markdown(
83
  '<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
84
  unsafe_allow_html=True
 
88
  unsafe_allow_html=True
89
  )
90
 
91
+ # Medical disclaimer (visible)
92
+ st.markdown(
93
+ "<div class='disclaimer'>⚠️ This app is experimental and informational only. It is NOT a medical diagnosis. "
94
+ "If you have health concerns, consult a licensed medical professional. In emergencies call your local emergency number.</div>",
95
+ unsafe_allow_html=True
96
+ )
97
+
98
+ # ------------------ Model loading with graceful errors ------------------
99
+ repository_id = "EnDevSols/brainmri-vit-model"
100
+
101
+ model = None
102
+ feature_extractor = None
103
+ model_load_error = None
104
+
105
+ try:
106
+ # Only attempt to load model if torch was imported successfully
107
+ if torch is None:
108
+ raise RuntimeError("Torch is not available in this environment.")
109
+ # Model loading can be slow; catch errors and show a friendly message later.
110
+ model = ViTForImageClassification.from_pretrained(repository_id)
111
+ feature_extractor = ViTImageProcessor.from_pretrained(repository_id)
112
+ logging.info("Model and feature extractor loaded successfully.")
113
+ except Exception as e:
114
+ model_load_error = str(e)
115
+ logging.error("Failed to load model or feature extractor: %s\n%s", e, traceback.format_exc())
116
+
117
+ # ------------------ Prediction function (returns label + confidence) ------------------
118
+ def predict(image):
119
+ """
120
+ Given a PIL image, returns (diagnosis_label, confidence_float).
121
+ 'diagnosis_label' is "Yes" (tumor) or "No" (no tumor).
122
+ 'confidence_float' is between 0 and 1.
123
+ """
124
+ if model is None or feature_extractor is None:
125
+ raise RuntimeError("Model is not loaded.")
126
+ image = image.convert("RGB")
127
+ inputs = feature_extractor(images=image, return_tensors="pt")
128
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
+ model.to(device)
130
+ inputs = {k: v.to(device) for k, v in inputs.items()}
131
+ with torch.no_grad():
132
+ outputs = model(**inputs)
133
+ logits = outputs.logits
134
+ probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
135
+ predicted_idx = int(np.argmax(probs))
136
+ confidence = float(probs[predicted_idx])
137
+ label_map = {0: "No", 1: "Yes"}
138
+ diagnosis = label_map.get(predicted_idx, "Unknown")
139
+ return diagnosis, confidence
140
+
141
+ # ------------------ Deepseek (Groq) helper ------------------
142
+ def get_assistance_from_deepseek(diagnosis_text):
143
+ """
144
+ Calls Groq Deepseek R1 with a safety-first prompt.
145
+ Returns a string. On error, returns a conservative fallback message.
146
+ """
147
+ api_key = os.getenv("API_KEY")
148
+ if not api_key:
149
+ logging.error("API_KEY environment variable not found for Groq client.")
150
+ return ("No assistance available because the Deepseek API key is not configured. "
151
+ "Please set the API_KEY environment variable.")
152
+ try:
153
+ client = Groq(api_key=api_key)
154
+ except Exception as e:
155
+ logging.error("Failed to instantiate Groq client: %s\n%s", e, traceback.format_exc())
156
+ return ("Assistance temporarily unavailable (failed to initialize model client). "
157
+ "Please try again later or contact support.")
158
+
159
+ # Safer prompt: require the assistant to include a clinician referral sentence
160
+ safety_sentence = "This information is informational only — seek evaluation from a licensed medical professional."
161
+ prompt = (
162
+ f"You are a cautious, safety-first medical assistant. Given the model-diagnosis below, "
163
+ "provide general, non-prescriptive information a patient could use to understand options. "
164
+ "Do NOT provide definitive medical diagnosis or treatment plans. ALWAYS include the sentence: "
165
+ f"'{safety_sentence}'\n\n"
166
+ f"Diagnosis text: '{diagnosis_text}'\n\n"
167
+ "Please list: (1) suggested questions a patient might ask a clinician, (2) common next diagnostic tests a clinician might consider (non-exhaustive), "
168
+ "and (3) immediate safety signs that should prompt emergency care. Keep the language simple and avoid prescriptive medical directives."
169
+ )
170
+
171
+ messages = [
172
+ {"role": "system", "content": "You are a careful medical assistant always advising a user to consult a clinician."},
173
+ {"role": "user", "content": prompt}
174
+ ]
175
+
176
+ try:
177
+ completion = client.chat.completions.create(
178
+ model="deepseek-r1-distill-llama-70b",
179
+ messages=messages,
180
+ temperature=0.6,
181
+ max_completion_tokens=1024,
182
+ top_p=0.95,
183
+ stream=False,
184
+ stop=None,
185
+ )
186
+ # Try different response shapes safely
187
+ assistance_text = ""
188
+ try:
189
+ assistance_text = completion.choices[0].message.content
190
+ except Exception:
191
+ try:
192
+ assistance_text = completion.choices[0].text
193
+ except Exception:
194
+ assistance_text = str(completion)
195
+
196
+ # Ensure the required safety sentence is present
197
+ if safety_sentence not in assistance_text:
198
+ assistance_text = safety_sentence + "\n\n" + assistance_text
199
+
200
+ return assistance_text
201
+ except Exception as e:
202
+ logging.error("Deepseek Groq call failed: %s\n%s", e, traceback.format_exc())
203
+ return ("Assistance is temporarily unavailable due to an error contacting the assistance model. "
204
+ "Please consult a licensed medical professional for evaluation. If you are experiencing severe or life-threatening symptoms, seek emergency care immediately.")
205
+
206
+ # ------------------ Streamlit UI: image upload + inference flow ------------------
207
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
208
 
209
  if uploaded_file is not None:
210
+ try:
211
+ image = Image.open(uploaded_file)
212
+ except Exception as e:
213
+ st.error(f"Failed to open the uploaded file as an image: {e}")
214
+ image = None
215
+
216
+ if image is not None:
217
+ # Display resized thumbnail
218
+ resized_image = image.resize((150, 150))
219
+ buffered = BytesIO()
220
+ resized_image.save(buffered, format="JPEG")
221
+ img_str = base64.b64encode(buffered.getvalue()).decode()
222
+ st.markdown(
223
+ f"<div style='text-align: center;'><img src='data:image/jpeg;base64,{img_str}' alt='Uploaded Image' width='300'></div>",
224
+ unsafe_allow_html=True
225
+ )
226
+
227
+ # Check model loaded
228
+ if model_load_error:
229
+ st.error("The model failed to load at startup. See logs for details.")
230
+ st.code(model_load_error)
231
+ else:
232
+ st.write("")
233
+ st.write("Processing the image...")
234
+
235
+ # Run prediction with try/except
236
+ try:
237
+ diagnosis, confidence = predict(image)
238
+ st.markdown("### Diagnosis (model prediction):")
239
+ st.write(f"**{diagnosis}** (confidence: **{confidence:.2%}**)")
240
+ st.markdown("_Model output is probabilistic and not a clinical diagnosis._", unsafe_allow_html=True)
241
+ except Exception as e:
242
+ st.error("Prediction failed: " + str(e))
243
+ logging.error("Prediction error: %s\n%s", e, traceback.format_exc())
244
+ diagnosis = None
245
+ confidence = None
246
+
247
+ # If we have a diagnosis, call Deepseek for additional guidance (with spinner)
248
+ if diagnosis is not None:
249
+ with st.spinner("Fetching additional guidance based on your diagnosis..."):
250
+ assistance = get_assistance_from_deepseek(f"Diagnosis: {diagnosis} (confidence {confidence:.2%})")
251
+ st.markdown("### Next Steps and Recommendations:")
252
+ # Use st.write which keeps newlines and formatting reasonable.
253
+ st.write(assistance)
254
+
255
+ # ------------------ If no file uploaded, show sample placeholder / instructions ------------------
256
+ if uploaded_file is None:
257
+ st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and informational next steps. </div>", unsafe_allow_html=True)
258
+
259
+ # ------------------ Helpful debug / info for developers (hidden by default) ------------------
260
+ with st.expander("Developer info / Troubleshooting"):
261
+ st.markdown("**Model repository**: " + repository_id)
262
+ st.markdown("**Torch available**: " + ("Yes" if torch is not None else "No"))
263
+ st.markdown("**Model loaded**: " + ("Yes" if model is not None else "No"))
264
+ if model_load_error:
265
+ st.code(model_load_error)
266
+ st.markdown("**Environment**:")
267
+ st.write({
268
+ "CUDA available": torch.cuda.is_available() if torch is not None else False,
269
+ "API_KEY set for Groq": bool(os.getenv("API_KEY"))
270
+ })
271
+ st.markdown("**Notes:**\n- This app is for informational use only. Do not use it as a replacement for professional medical advice.")