#!/usr/bin/env python3 """ Streamlit Brain MRI Tumor Detection App (updated to: - load & display the uploaded image(s), - pass the image to the ViT model for inference, - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report, - provide robust logging, error handling, and a download for the generated report. Important: - This app is informational only and not a medical diagnosis. - Set API_KEY in your environment to enable Groq calls. """ import os import logging import traceback import base64 from io import BytesIO from typing import Tuple # ------------------ Safe startup: import torch first and monkeypatch ------------------ try: import torch # Avoid Streamlit file-watcher introspection triggering a PyTorch C++ error if hasattr(torch, "classes"): try: torch.classes.__path__ = [] except Exception: # ignore - best-effort pass except Exception as e: torch = None logging.error("Failed to import torch at startup: %s", e) # ------------------ Now safe to import Streamlit and other libs ------------------ import streamlit as st from PIL import Image, ImageOps import numpy as np import torch.nn.functional as F from transformers import ViTForImageClassification, ViTImageProcessor # Groq client try: from groq import Groq except Exception: Groq = None # ------------------ Logging ------------------ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # ------------------ Page config + CSS ------------------ st.set_page_config(layout="wide", page_title="Brain MRI Tumor Detection") combined_css = """ .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; } .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); } .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; } .stSpinner { color: #4CAF50; } .title { font-size: 3rem; font-weight: bold; display: flex; align-items: center; justify-content: center; } .colorful-text { background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .black-white-text { color: black; } .custom-text { font-size: 1.2rem; color: #feb47b; text-align: center; margin-top: -20px; margin-bottom: 20px; } .disclaimer { color: #ffcc66; font-weight: bold; text-align: center; margin-bottom: 12px; } .small-muted { font-size:0.9rem; color:#cccccc; text-align:center; margin-top:8px; } """ st.markdown(f"", unsafe_allow_html=True) # ------------------ Header + disclaimer ------------------ st.markdown( '
Brain MRI Tumor Detection
', unsafe_allow_html=True ) st.markdown( '
Upload an MRI image to detect a brain tumor and get an informational medical report.
', unsafe_allow_html=True ) st.markdown( "
⚠️ This app is experimental and informational only. It is NOT a medical diagnosis. " "If you have health concerns, consult a licensed medical professional. In emergencies call your local emergency number.
", unsafe_allow_html=True, ) # ------------------ Model loading ------------------ repository_id = "EnDevSols/brainmri-vit-model" model = None feature_extractor = None model_load_error = None try: if torch is None: raise RuntimeError("torch is not available in this environment.") model = ViTForImageClassification.from_pretrained(repository_id) feature_extractor = ViTImageProcessor.from_pretrained(repository_id) logger.info("Model loaded successfully.") except Exception as e: model_load_error = str(e) logger.exception("Failed to load model: %s", e) # ------------------ Prediction helper ------------------ def predict_image(image: Image.Image) -> Tuple[str, float]: """ Run model inference on a PIL image. Returns (label, confidence) where label in {"Yes","No","Unknown"} and confidence 0..1. """ if model is None or feature_extractor is None: raise RuntimeError("Model not loaded.") # Preprocess using the feature extractor inputs = feature_extractor(images=image, return_tensors="pt") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy() pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) label_map = {0: "No", 1: "Yes"} # adjust if your model mapping differs label = label_map.get(pred_idx, "Unknown") return label, confidence # ------------------ Groq LLM helper (defensive) ------------------ def generate_medical_report(diagnosis_label: str, confidence: float, image_info: str, include_image_base64: bool = False, image_b64: str = None) -> str: """ Ask the LLM to create an informational medical report based on the model inference. Returns a string report (informational only). """ safety_sentence = "This information is informational only — seek evaluation from a licensed medical professional." if Groq is None: logger.error("Groq client unavailable.") return "Medical report not available: Groq client library not installed in this environment." api_key = os.getenv("API_KEY") if not api_key: logger.error("API_KEY not set.") return "Medical report not available: API_KEY environment variable not configured." try: client = Groq(api_key=api_key) except Exception as e: logger.exception("Failed to instantiate Groq client: %s", e) return "Medical report temporarily unavailable (client init failed)." # Construct a concise prompt that includes the model's result and image metadata. # Do NOT include patient identifying data; keep it informational. prompt_lines = [ "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.", f"Model diagnosis: {diagnosis_label}", f"Model confidence: {confidence:.2%}", f"Image info: {image_info}", "Do NOT provide definitive medical diagnoses or prescriptive orders. ALWAYS include the sentence:", f"'{safety_sentence}'", "Provide the report sections: (1) Brief summary of findings, (2) Suggested next diagnostic steps for a clinician to consider, (3) Questions a patient can ask their clinician, (4) Immediate red-flag signs requiring emergency care.", "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)." ] if include_image_base64 and image_b64: # Optionally include a tiny thumbnail as base64 (be careful with payload size). prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.") prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}") # only include a prefix to avoid huge payloads prompt = "\n\n".join(prompt_lines) messages = [ {"role": "system", "content": "You are a cautious medical assistant that always advises users to consult licensed clinicians."}, {"role": "user", "content": prompt} ] try: completion = client.chat.completions.create( model="deepseek-r1-distill-llama-70b", messages=messages, temperature=0.3, max_completion_tokens=1024, top_p=0.9, stream=False, stop=None, ) # Extract text robustly try: report_text = completion.choices[0].message.content except Exception: try: report_text = completion.choices[0].text except Exception: report_text = str(completion) # Ensure safety sentence present if safety_sentence not in report_text: report_text = safety_sentence + "\n\n" + report_text return report_text except Exception as e: logger.exception("Groq call failed: %s", e) # Try to pull useful info from exception if it exists resp = None for attr in ("response", "http_response", "raw_response", "resp"): resp = getattr(e, attr, None) if resp: break if resp: try: status = getattr(resp, "status_code", getattr(resp, "status", "unknown")) body_preview = getattr(resp, "text", getattr(resp, "body", str(resp))) logger.error("Groq response: status=%s body_preview=%s", status, str(body_preview)[:500]) except Exception: logger.error("Could not extract response details.") return "Medical report temporarily unavailable due to an error contacting the assistance model. Please consult a clinician." # ------------------ Helpers for image display & base64 ------------------ def pil_to_base64(img: Image.Image, size: Tuple[int, int] = None) -> str: """Return a base64-encoded JPEG for the PIL image. Optionally resize (maintain aspect ratio).""" tmp = img.copy() if size: tmp = ImageOps.contain(tmp, size) buff = BytesIO() tmp.save(buff, format="JPEG") return base64.b64encode(buff.getvalue()).decode() # ------------------ Streamlit UI: upload, display, inference, report ------------------ uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Load image try: pil_image = Image.open(uploaded_file).convert("RGB") except Exception as e: st.error(f"Unable to open the uploaded file as an image: {e}") logger.exception("Open uploaded image failed: %s", e) pil_image = None if pil_image: # Display original and a preprocessed/thumbnail side-by-side col1, col2 = st.columns([1, 1]) with col1: st.markdown("**Original image**") st.image(pil_image, use_column_width=True) # Create a centered thumbnail / processed view (resize for model preview) processed_for_display = ImageOps.contain(pil_image, (512, 512)) with col2: st.markdown("**Processed (for model preview)**") st.image(processed_for_display, use_column_width=True) # Show image metadata img_w, img_h = pil_image.size st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}") # Option to include a small base64 thumbnail in the LLM prompt (default OFF to avoid large payloads) include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False) # Model availability check if model_load_error: st.error("Model failed to load at startup. See Developer info for details.") st.code(model_load_error) else: # Run inference run_infer = st.button("Run inference & generate report") if run_infer: try: with st.spinner("Running model inference..."): label, confidence = predict_image(processed_for_display) st.success("Inference complete") st.markdown("### Model prediction:") st.write(f"**{label}** (confidence {confidence:.2%})") except Exception as e: logger.exception("Inference failed: %s", e) st.error("Inference failed: " + str(e)) label = None confidence = None # If inference ok, call LLM to generate report if label is not None: # Prepare image_info summary image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}" # Optionally produce small base64 thumbnail image_b64 = None if include_thumbnail: try: image_b64 = pil_to_base64(processed_for_display, size=(256, 256)) except Exception as e: logger.exception("Failed to create base64 thumbnail: %s", e) image_b64 = None with st.spinner("Generating informational medical report from LLM..."): report_text = generate_medical_report(label, confidence, image_info, include_image_base64=include_thumbnail, image_b64=image_b64) st.markdown("### Medical Report (informational)") st.write(report_text) # Allow user to download the report as a .txt file try: report_bytes = report_text.encode("utf-8") download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt" st.download_button("Download report", data=report_bytes, file_name=download_name, mime="text/plain") except Exception as e: logger.exception("Failed to prepare report download: %s", e) st.error("Could not prepare download: " + str(e)) # If no file uploaded, show placeholder instructions if uploaded_file is None: st.markdown("
Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.
", unsafe_allow_html=True) # ------------------ Developer troubleshooting expander ------------------ with st.expander("Developer info / Troubleshooting"): st.markdown(f"**Model repository**: `{repository_id}`") st.markdown(f"**Torch available**: {'Yes' if torch is not None else 'No'}") st.markdown(f"**Model loaded**: {'Yes' if model is not None else 'No'}") st.write({ "CUDA available": torch.cuda.is_available() if torch is not None else False, "API_KEY set for Groq": bool(os.getenv("API_KEY")), "Groq installed": Groq is not None }) if model_load_error: st.markdown("**Model load error**:") st.code(model_load_error) st.markdown("---") st.markdown("### Groq quick test (for debugging API errors)") st.markdown("Click the button to run a very small 'ping' to the Groq chat endpoint. This helps capture raw error info without sending large prompts.") if st.button("Run Groq ping"): # small test call def groq_test_ping(max_tokens: int = 8): if Groq is None: return {"ok": False, "result": "Groq client library not available."} api_key = os.getenv("API_KEY") if not api_key: return {"ok": False, "result": "API_KEY not configured."} try: client = Groq(api_key=api_key) res = client.chat.completions.create( model="deepseek-r1-distill-llama-70b", messages=[{"role": "user", "content": "ping"}], max_completion_tokens=max_tokens, ) try: content = res.choices[0].message.content except Exception: try: content = res.choices[0].text except Exception: content = str(res) return {"ok": True, "result": content} except Exception as e: info = {"exception_repr": repr(e)} for attr in ("response", "http_response", "raw_response", "resp"): if hasattr(e, attr): rval = getattr(e, attr) try: info[attr] = { "status": getattr(rval, "status_code", getattr(rval, "status", "unknown")), "body_preview": (getattr(rval, "text", getattr(rval, "body", str(rval)))[:1000] + "...") if getattr(rval, "text", None) or getattr(rval, "body", None) else str(rval), } except Exception: info[attr] = str(rval) logger.exception("Groq test ping failed: %s", e) return {"ok": False, "result": info} ping_result = groq_test_ping() if ping_result.get("ok"): st.success("Groq ping successful") st.text_area("Result (truncated)", str(ping_result.get("result"))[:2000], height=200) else: st.error("Groq ping failed; see details below") st.json(ping_result.get("result")) st.markdown("---") st.markdown("Debugging tips:") st.markdown( "- If Groq returns HTTP 400: check model name, prompt length, and messages shape.\n" "- Use the Groq ping to inspect raw error details.\n" "- Ensure `API_KEY` is set & has permissions for the requested model.\n" "- To avoid the Streamlit <-> PyTorch watcher issue you can also run Streamlit with: " "`streamlit run app.py --server.fileWatcherType none` or set `.streamlit/config.toml`." )