Hammad712 commited on
Commit
dade807
·
verified ·
1 Parent(s): 23b9c8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -6
app.py CHANGED
@@ -1,11 +1,14 @@
1
  #!/usr/bin/env python3
2
  """
3
- Streamlit Brain MRI Tumor Detection App (updated):
4
  - load & display the uploaded image(s),
5
  - pass the image to the ViT model for inference,
6
  - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report,
7
- - removed Groq "ping" and debugging tips,
8
- - replaced deprecated use_column_width with use_container_width.
 
 
 
9
  """
10
 
11
  import os
@@ -23,6 +26,7 @@ try:
23
  try:
24
  torch.classes.__path__ = []
25
  except Exception:
 
26
  pass
27
  except Exception as e:
28
  torch = None
@@ -118,6 +122,7 @@ def predict_image(image: Image.Image) -> Tuple[str, float]:
118
  """
119
  if model is None or feature_extractor is None:
120
  raise RuntimeError("Model not loaded.")
 
121
  inputs = feature_extractor(images=image, return_tensors="pt")
122
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
  model.to(device)
@@ -154,6 +159,8 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
154
  logger.exception("Failed to instantiate Groq client: %s", e)
155
  return "Medical report temporarily unavailable (client init failed)."
156
 
 
 
157
  prompt_lines = [
158
  "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.",
159
  f"Model diagnosis: {diagnosis_label}",
@@ -165,8 +172,9 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
165
  "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)."
166
  ]
167
  if include_image_base64 and image_b64:
 
168
  prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.")
169
- prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}")
170
 
171
  prompt = "\n\n".join(prompt_lines)
172
 
@@ -185,6 +193,7 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
185
  stream=False,
186
  stop=None,
187
  )
 
188
  try:
189
  report_text = completion.choices[0].message.content
190
  except Exception:
@@ -192,11 +201,13 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
192
  report_text = completion.choices[0].text
193
  except Exception:
194
  report_text = str(completion)
 
195
  if safety_sentence not in report_text:
196
  report_text = safety_sentence + "\n\n" + report_text
197
  return report_text
198
  except Exception as e:
199
  logger.exception("Groq call failed: %s", e)
 
200
  resp = None
201
  for attr in ("response", "http_response", "raw_response", "resp"):
202
  resp = getattr(e, attr, None)
@@ -225,6 +236,7 @@ def pil_to_base64(img: Image.Image, size: Tuple[int, int] = None) -> str:
225
  uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])
226
 
227
  if uploaded_file is not None:
 
228
  try:
229
  pil_image = Image.open(uploaded_file).convert("RGB")
230
  except Exception as e:
@@ -233,24 +245,30 @@ if uploaded_file is not None:
233
  pil_image = None
234
 
235
  if pil_image:
 
236
  col1, col2 = st.columns([1, 1])
237
  with col1:
238
  st.markdown("**Original image**")
239
- st.image(pil_image, use_container_width=True)
 
240
  processed_for_display = ImageOps.contain(pil_image, (512, 512))
241
  with col2:
242
  st.markdown("**Processed (for model preview)**")
243
- st.image(processed_for_display, use_container_width=True)
244
 
 
245
  img_w, img_h = pil_image.size
246
  st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}")
247
 
 
248
  include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False)
249
 
 
250
  if model_load_error:
251
  st.error("Model failed to load at startup. See Developer info for details.")
252
  st.code(model_load_error)
253
  else:
 
254
  run_infer = st.button("Run inference & generate report")
255
  if run_infer:
256
  try:
@@ -265,8 +283,11 @@ if uploaded_file is not None:
265
  label = None
266
  confidence = None
267
 
 
268
  if label is not None:
 
269
  image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}"
 
270
  image_b64 = None
271
  if include_thumbnail:
272
  try:
@@ -280,6 +301,7 @@ if uploaded_file is not None:
280
  st.markdown("### Medical Report (informational)")
281
  st.write(report_text)
282
 
 
283
  try:
284
  report_bytes = report_text.encode("utf-8")
285
  download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt"
@@ -291,3 +313,76 @@ if uploaded_file is not None:
291
  # If no file uploaded, show placeholder instructions
292
  if uploaded_file is None:
293
  st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Streamlit Brain MRI Tumor Detection App (updated to:
4
  - load & display the uploaded image(s),
5
  - pass the image to the ViT model for inference,
6
  - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report,
7
+ - provide robust logging, error handling, and a download for the generated report.
8
+
9
+ Important:
10
+ - This app is informational only and not a medical diagnosis.
11
+ - Set API_KEY in your environment to enable Groq calls.
12
  """
13
 
14
  import os
 
26
  try:
27
  torch.classes.__path__ = []
28
  except Exception:
29
+ # ignore - best-effort
30
  pass
31
  except Exception as e:
32
  torch = None
 
122
  """
123
  if model is None or feature_extractor is None:
124
  raise RuntimeError("Model not loaded.")
125
+ # Preprocess using the feature extractor
126
  inputs = feature_extractor(images=image, return_tensors="pt")
127
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
  model.to(device)
 
159
  logger.exception("Failed to instantiate Groq client: %s", e)
160
  return "Medical report temporarily unavailable (client init failed)."
161
 
162
+ # Construct a concise prompt that includes the model's result and image metadata.
163
+ # Do NOT include patient identifying data; keep it informational.
164
  prompt_lines = [
165
  "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.",
166
  f"Model diagnosis: {diagnosis_label}",
 
172
  "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)."
173
  ]
174
  if include_image_base64 and image_b64:
175
+ # Optionally include a tiny thumbnail as base64 (be careful with payload size).
176
  prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.")
177
+ prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}") # only include a prefix to avoid huge payloads
178
 
179
  prompt = "\n\n".join(prompt_lines)
180
 
 
193
  stream=False,
194
  stop=None,
195
  )
196
+ # Extract text robustly
197
  try:
198
  report_text = completion.choices[0].message.content
199
  except Exception:
 
201
  report_text = completion.choices[0].text
202
  except Exception:
203
  report_text = str(completion)
204
+ # Ensure safety sentence present
205
  if safety_sentence not in report_text:
206
  report_text = safety_sentence + "\n\n" + report_text
207
  return report_text
208
  except Exception as e:
209
  logger.exception("Groq call failed: %s", e)
210
+ # Try to pull useful info from exception if it exists
211
  resp = None
212
  for attr in ("response", "http_response", "raw_response", "resp"):
213
  resp = getattr(e, attr, None)
 
236
  uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])
237
 
238
  if uploaded_file is not None:
239
+ # Load image
240
  try:
241
  pil_image = Image.open(uploaded_file).convert("RGB")
242
  except Exception as e:
 
245
  pil_image = None
246
 
247
  if pil_image:
248
+ # Display original and a preprocessed/thumbnail side-by-side
249
  col1, col2 = st.columns([1, 1])
250
  with col1:
251
  st.markdown("**Original image**")
252
+ st.image(pil_image, use_column_width=True)
253
+ # Create a centered thumbnail / processed view (resize for model preview)
254
  processed_for_display = ImageOps.contain(pil_image, (512, 512))
255
  with col2:
256
  st.markdown("**Processed (for model preview)**")
257
+ st.image(processed_for_display, use_column_width=True)
258
 
259
+ # Show image metadata
260
  img_w, img_h = pil_image.size
261
  st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}")
262
 
263
+ # Option to include a small base64 thumbnail in the LLM prompt (default OFF to avoid large payloads)
264
  include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False)
265
 
266
+ # Model availability check
267
  if model_load_error:
268
  st.error("Model failed to load at startup. See Developer info for details.")
269
  st.code(model_load_error)
270
  else:
271
+ # Run inference
272
  run_infer = st.button("Run inference & generate report")
273
  if run_infer:
274
  try:
 
283
  label = None
284
  confidence = None
285
 
286
+ # If inference ok, call LLM to generate report
287
  if label is not None:
288
+ # Prepare image_info summary
289
  image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}"
290
+ # Optionally produce small base64 thumbnail
291
  image_b64 = None
292
  if include_thumbnail:
293
  try:
 
301
  st.markdown("### Medical Report (informational)")
302
  st.write(report_text)
303
 
304
+ # Allow user to download the report as a .txt file
305
  try:
306
  report_bytes = report_text.encode("utf-8")
307
  download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt"
 
313
  # If no file uploaded, show placeholder instructions
314
  if uploaded_file is None:
315
  st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True)
316
+
317
+ # ------------------ Developer troubleshooting expander ------------------
318
+ with st.expander("Developer info / Troubleshooting"):
319
+ st.markdown(f"**Model repository**: `{repository_id}`")
320
+ st.markdown(f"**Torch available**: {'Yes' if torch is not None else 'No'}")
321
+ st.markdown(f"**Model loaded**: {'Yes' if model is not None else 'No'}")
322
+ st.write({
323
+ "CUDA available": torch.cuda.is_available() if torch is not None else False,
324
+ "API_KEY set for Groq": bool(os.getenv("API_KEY")),
325
+ "Groq installed": Groq is not None
326
+ })
327
+ if model_load_error:
328
+ st.markdown("**Model load error**:")
329
+ st.code(model_load_error)
330
+
331
+ st.markdown("---")
332
+ st.markdown("### Groq quick test (for debugging API errors)")
333
+ st.markdown("Click the button to run a very small 'ping' to the Groq chat endpoint. This helps capture raw error info without sending large prompts.")
334
+ if st.button("Run Groq ping"):
335
+ # small test call
336
+ def groq_test_ping(max_tokens: int = 8):
337
+ if Groq is None:
338
+ return {"ok": False, "result": "Groq client library not available."}
339
+ api_key = os.getenv("API_KEY")
340
+ if not api_key:
341
+ return {"ok": False, "result": "API_KEY not configured."}
342
+ try:
343
+ client = Groq(api_key=api_key)
344
+ res = client.chat.completions.create(
345
+ model="deepseek-r1-distill-llama-70b",
346
+ messages=[{"role": "user", "content": "ping"}],
347
+ max_completion_tokens=max_tokens,
348
+ )
349
+ try:
350
+ content = res.choices[0].message.content
351
+ except Exception:
352
+ try:
353
+ content = res.choices[0].text
354
+ except Exception:
355
+ content = str(res)
356
+ return {"ok": True, "result": content}
357
+ except Exception as e:
358
+ info = {"exception_repr": repr(e)}
359
+ for attr in ("response", "http_response", "raw_response", "resp"):
360
+ if hasattr(e, attr):
361
+ rval = getattr(e, attr)
362
+ try:
363
+ info[attr] = {
364
+ "status": getattr(rval, "status_code", getattr(rval, "status", "unknown")),
365
+ "body_preview": (getattr(rval, "text", getattr(rval, "body", str(rval)))[:1000] + "...") if getattr(rval, "text", None) or getattr(rval, "body", None) else str(rval),
366
+ }
367
+ except Exception:
368
+ info[attr] = str(rval)
369
+ logger.exception("Groq test ping failed: %s", e)
370
+ return {"ok": False, "result": info}
371
+
372
+ ping_result = groq_test_ping()
373
+ if ping_result.get("ok"):
374
+ st.success("Groq ping successful")
375
+ st.text_area("Result (truncated)", str(ping_result.get("result"))[:2000], height=200)
376
+ else:
377
+ st.error("Groq ping failed; see details below")
378
+ st.json(ping_result.get("result"))
379
+
380
+ st.markdown("---")
381
+ st.markdown("Debugging tips:")
382
+ st.markdown(
383
+ "- If Groq returns HTTP 400: check model name, prompt length, and messages shape.\n"
384
+ "- Use the Groq ping to inspect raw error details.\n"
385
+ "- Ensure `API_KEY` is set & has permissions for the requested model.\n"
386
+ "- To avoid the Streamlit <-> PyTorch watcher issue you can also run Streamlit with: "
387
+ "`streamlit run app.py --server.fileWatcherType none` or set `.streamlit/config.toml`."
388
+ )