Hammad712 commited on
Commit
23b9c8e
·
verified ·
1 Parent(s): 266a048

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -101
app.py CHANGED
@@ -1,14 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- Streamlit Brain MRI Tumor Detection App (updated to:
4
  - load & display the uploaded image(s),
5
  - pass the image to the ViT model for inference,
6
  - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report,
7
- - provide robust logging, error handling, and a download for the generated report.
8
-
9
- Important:
10
- - This app is informational only and not a medical diagnosis.
11
- - Set API_KEY in your environment to enable Groq calls.
12
  """
13
 
14
  import os
@@ -26,7 +23,6 @@ try:
26
  try:
27
  torch.classes.__path__ = []
28
  except Exception:
29
- # ignore - best-effort
30
  pass
31
  except Exception as e:
32
  torch = None
@@ -122,7 +118,6 @@ def predict_image(image: Image.Image) -> Tuple[str, float]:
122
  """
123
  if model is None or feature_extractor is None:
124
  raise RuntimeError("Model not loaded.")
125
- # Preprocess using the feature extractor
126
  inputs = feature_extractor(images=image, return_tensors="pt")
127
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
  model.to(device)
@@ -159,8 +154,6 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
159
  logger.exception("Failed to instantiate Groq client: %s", e)
160
  return "Medical report temporarily unavailable (client init failed)."
161
 
162
- # Construct a concise prompt that includes the model's result and image metadata.
163
- # Do NOT include patient identifying data; keep it informational.
164
  prompt_lines = [
165
  "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.",
166
  f"Model diagnosis: {diagnosis_label}",
@@ -172,9 +165,8 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
172
  "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)."
173
  ]
174
  if include_image_base64 and image_b64:
175
- # Optionally include a tiny thumbnail as base64 (be careful with payload size).
176
  prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.")
177
- prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}") # only include a prefix to avoid huge payloads
178
 
179
  prompt = "\n\n".join(prompt_lines)
180
 
@@ -193,7 +185,6 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
193
  stream=False,
194
  stop=None,
195
  )
196
- # Extract text robustly
197
  try:
198
  report_text = completion.choices[0].message.content
199
  except Exception:
@@ -201,13 +192,11 @@ def generate_medical_report(diagnosis_label: str, confidence: float, image_info:
201
  report_text = completion.choices[0].text
202
  except Exception:
203
  report_text = str(completion)
204
- # Ensure safety sentence present
205
  if safety_sentence not in report_text:
206
  report_text = safety_sentence + "\n\n" + report_text
207
  return report_text
208
  except Exception as e:
209
  logger.exception("Groq call failed: %s", e)
210
- # Try to pull useful info from exception if it exists
211
  resp = None
212
  for attr in ("response", "http_response", "raw_response", "resp"):
213
  resp = getattr(e, attr, None)
@@ -236,7 +225,6 @@ def pil_to_base64(img: Image.Image, size: Tuple[int, int] = None) -> str:
236
  uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])
237
 
238
  if uploaded_file is not None:
239
- # Load image
240
  try:
241
  pil_image = Image.open(uploaded_file).convert("RGB")
242
  except Exception as e:
@@ -245,30 +233,24 @@ if uploaded_file is not None:
245
  pil_image = None
246
 
247
  if pil_image:
248
- # Display original and a preprocessed/thumbnail side-by-side
249
  col1, col2 = st.columns([1, 1])
250
  with col1:
251
  st.markdown("**Original image**")
252
- st.image(pil_image, use_column_width=True)
253
- # Create a centered thumbnail / processed view (resize for model preview)
254
  processed_for_display = ImageOps.contain(pil_image, (512, 512))
255
  with col2:
256
  st.markdown("**Processed (for model preview)**")
257
- st.image(processed_for_display, use_column_width=True)
258
 
259
- # Show image metadata
260
  img_w, img_h = pil_image.size
261
  st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}")
262
 
263
- # Option to include a small base64 thumbnail in the LLM prompt (default OFF to avoid large payloads)
264
  include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False)
265
 
266
- # Model availability check
267
  if model_load_error:
268
  st.error("Model failed to load at startup. See Developer info for details.")
269
  st.code(model_load_error)
270
  else:
271
- # Run inference
272
  run_infer = st.button("Run inference & generate report")
273
  if run_infer:
274
  try:
@@ -283,11 +265,8 @@ if uploaded_file is not None:
283
  label = None
284
  confidence = None
285
 
286
- # If inference ok, call LLM to generate report
287
  if label is not None:
288
- # Prepare image_info summary
289
  image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}"
290
- # Optionally produce small base64 thumbnail
291
  image_b64 = None
292
  if include_thumbnail:
293
  try:
@@ -301,7 +280,6 @@ if uploaded_file is not None:
301
  st.markdown("### Medical Report (informational)")
302
  st.write(report_text)
303
 
304
- # Allow user to download the report as a .txt file
305
  try:
306
  report_bytes = report_text.encode("utf-8")
307
  download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt"
@@ -313,76 +291,3 @@ if uploaded_file is not None:
313
  # If no file uploaded, show placeholder instructions
314
  if uploaded_file is None:
315
  st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True)
316
-
317
- # ------------------ Developer troubleshooting expander ------------------
318
- with st.expander("Developer info / Troubleshooting"):
319
- st.markdown(f"**Model repository**: `{repository_id}`")
320
- st.markdown(f"**Torch available**: {'Yes' if torch is not None else 'No'}")
321
- st.markdown(f"**Model loaded**: {'Yes' if model is not None else 'No'}")
322
- st.write({
323
- "CUDA available": torch.cuda.is_available() if torch is not None else False,
324
- "API_KEY set for Groq": bool(os.getenv("API_KEY")),
325
- "Groq installed": Groq is not None
326
- })
327
- if model_load_error:
328
- st.markdown("**Model load error**:")
329
- st.code(model_load_error)
330
-
331
- st.markdown("---")
332
- st.markdown("### Groq quick test (for debugging API errors)")
333
- st.markdown("Click the button to run a very small 'ping' to the Groq chat endpoint. This helps capture raw error info without sending large prompts.")
334
- if st.button("Run Groq ping"):
335
- # small test call
336
- def groq_test_ping(max_tokens: int = 8):
337
- if Groq is None:
338
- return {"ok": False, "result": "Groq client library not available."}
339
- api_key = os.getenv("API_KEY")
340
- if not api_key:
341
- return {"ok": False, "result": "API_KEY not configured."}
342
- try:
343
- client = Groq(api_key=api_key)
344
- res = client.chat.completions.create(
345
- model="deepseek-r1-distill-llama-70b",
346
- messages=[{"role": "user", "content": "ping"}],
347
- max_completion_tokens=max_tokens,
348
- )
349
- try:
350
- content = res.choices[0].message.content
351
- except Exception:
352
- try:
353
- content = res.choices[0].text
354
- except Exception:
355
- content = str(res)
356
- return {"ok": True, "result": content}
357
- except Exception as e:
358
- info = {"exception_repr": repr(e)}
359
- for attr in ("response", "http_response", "raw_response", "resp"):
360
- if hasattr(e, attr):
361
- rval = getattr(e, attr)
362
- try:
363
- info[attr] = {
364
- "status": getattr(rval, "status_code", getattr(rval, "status", "unknown")),
365
- "body_preview": (getattr(rval, "text", getattr(rval, "body", str(rval)))[:1000] + "...") if getattr(rval, "text", None) or getattr(rval, "body", None) else str(rval),
366
- }
367
- except Exception:
368
- info[attr] = str(rval)
369
- logger.exception("Groq test ping failed: %s", e)
370
- return {"ok": False, "result": info}
371
-
372
- ping_result = groq_test_ping()
373
- if ping_result.get("ok"):
374
- st.success("Groq ping successful")
375
- st.text_area("Result (truncated)", str(ping_result.get("result"))[:2000], height=200)
376
- else:
377
- st.error("Groq ping failed; see details below")
378
- st.json(ping_result.get("result"))
379
-
380
- st.markdown("---")
381
- st.markdown("Debugging tips:")
382
- st.markdown(
383
- "- If Groq returns HTTP 400: check model name, prompt length, and messages shape.\n"
384
- "- Use the Groq ping to inspect raw error details.\n"
385
- "- Ensure `API_KEY` is set & has permissions for the requested model.\n"
386
- "- To avoid the Streamlit <-> PyTorch watcher issue you can also run Streamlit with: "
387
- "`streamlit run app.py --server.fileWatcherType none` or set `.streamlit/config.toml`."
388
- )
 
1
  #!/usr/bin/env python3
2
  """
3
+ Streamlit Brain MRI Tumor Detection App (updated):
4
  - load & display the uploaded image(s),
5
  - pass the image to the ViT model for inference,
6
  - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report,
7
+ - removed Groq "ping" and debugging tips,
8
+ - replaced deprecated use_column_width with use_container_width.
 
 
 
9
  """
10
 
11
  import os
 
23
  try:
24
  torch.classes.__path__ = []
25
  except Exception:
 
26
  pass
27
  except Exception as e:
28
  torch = None
 
118
  """
119
  if model is None or feature_extractor is None:
120
  raise RuntimeError("Model not loaded.")
 
121
  inputs = feature_extractor(images=image, return_tensors="pt")
122
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
  model.to(device)
 
154
  logger.exception("Failed to instantiate Groq client: %s", e)
155
  return "Medical report temporarily unavailable (client init failed)."
156
 
 
 
157
  prompt_lines = [
158
  "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.",
159
  f"Model diagnosis: {diagnosis_label}",
 
165
  "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)."
166
  ]
167
  if include_image_base64 and image_b64:
 
168
  prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.")
169
+ prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}")
170
 
171
  prompt = "\n\n".join(prompt_lines)
172
 
 
185
  stream=False,
186
  stop=None,
187
  )
 
188
  try:
189
  report_text = completion.choices[0].message.content
190
  except Exception:
 
192
  report_text = completion.choices[0].text
193
  except Exception:
194
  report_text = str(completion)
 
195
  if safety_sentence not in report_text:
196
  report_text = safety_sentence + "\n\n" + report_text
197
  return report_text
198
  except Exception as e:
199
  logger.exception("Groq call failed: %s", e)
 
200
  resp = None
201
  for attr in ("response", "http_response", "raw_response", "resp"):
202
  resp = getattr(e, attr, None)
 
225
  uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])
226
 
227
  if uploaded_file is not None:
 
228
  try:
229
  pil_image = Image.open(uploaded_file).convert("RGB")
230
  except Exception as e:
 
233
  pil_image = None
234
 
235
  if pil_image:
 
236
  col1, col2 = st.columns([1, 1])
237
  with col1:
238
  st.markdown("**Original image**")
239
+ st.image(pil_image, use_container_width=True)
 
240
  processed_for_display = ImageOps.contain(pil_image, (512, 512))
241
  with col2:
242
  st.markdown("**Processed (for model preview)**")
243
+ st.image(processed_for_display, use_container_width=True)
244
 
 
245
  img_w, img_h = pil_image.size
246
  st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}")
247
 
 
248
  include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False)
249
 
 
250
  if model_load_error:
251
  st.error("Model failed to load at startup. See Developer info for details.")
252
  st.code(model_load_error)
253
  else:
 
254
  run_infer = st.button("Run inference & generate report")
255
  if run_infer:
256
  try:
 
265
  label = None
266
  confidence = None
267
 
 
268
  if label is not None:
 
269
  image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}"
 
270
  image_b64 = None
271
  if include_thumbnail:
272
  try:
 
280
  st.markdown("### Medical Report (informational)")
281
  st.write(report_text)
282
 
 
283
  try:
284
  report_bytes = report_text.encode("utf-8")
285
  download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt"
 
291
  # If no file uploaded, show placeholder instructions
292
  if uploaded_file is None:
293
  st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True)