Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Streamlit Brain MRI Tumor Detection App (updated to: | |
| - load & display the uploaded image(s), | |
| - pass the image to the ViT model for inference, | |
| - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report, | |
| - provide robust logging, error handling, and a download for the generated report. | |
| Important: | |
| - This app is informational only and not a medical diagnosis. | |
| - Set API_KEY in your environment to enable Groq calls. | |
| """ | |
| import os | |
| import logging | |
| import traceback | |
| import base64 | |
| from io import BytesIO | |
| from typing import Tuple | |
| # ------------------ Safe startup: import torch first and monkeypatch ------------------ | |
| try: | |
| import torch | |
| # Avoid Streamlit file-watcher introspection triggering a PyTorch C++ error | |
| if hasattr(torch, "classes"): | |
| try: | |
| torch.classes.__path__ = [] | |
| except Exception: | |
| # ignore - best-effort | |
| pass | |
| except Exception as e: | |
| torch = None | |
| logging.error("Failed to import torch at startup: %s", e) | |
| # ------------------ Now safe to import Streamlit and other libs ------------------ | |
| import streamlit as st | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| # Groq client | |
| try: | |
| from groq import Groq | |
| except Exception: | |
| Groq = None | |
| # ------------------ Logging ------------------ | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ------------------ Page config + CSS ------------------ | |
| st.set_page_config(layout="wide", page_title="Brain MRI Tumor Detection") | |
| combined_css = """ | |
| .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; } | |
| .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); } | |
| .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; } | |
| .stSpinner { color: #4CAF50; } | |
| .title { | |
| font-size: 3rem; | |
| font-weight: bold; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| .colorful-text { | |
| background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .black-white-text { | |
| color: black; | |
| } | |
| .custom-text { | |
| font-size: 1.2rem; | |
| color: #feb47b; | |
| text-align: center; | |
| margin-top: -20px; | |
| margin-bottom: 20px; | |
| } | |
| .disclaimer { color: #ffcc66; font-weight: bold; text-align: center; margin-bottom: 12px; } | |
| .small-muted { font-size:0.9rem; color:#cccccc; text-align:center; margin-top:8px; } | |
| """ | |
| st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True) | |
| # ------------------ Header + disclaimer ------------------ | |
| st.markdown( | |
| '<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>', | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| '<div class="custom-text">Upload an MRI image to detect a brain tumor and get an informational medical report.</div>', | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| "<div class='disclaimer'>⚠️ This app is experimental and informational only. It is NOT a medical diagnosis. " | |
| "If you have health concerns, consult a licensed medical professional. In emergencies call your local emergency number.</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # ------------------ Model loading ------------------ | |
| repository_id = "EnDevSols/brainmri-vit-model" | |
| model = None | |
| feature_extractor = None | |
| model_load_error = None | |
| try: | |
| if torch is None: | |
| raise RuntimeError("torch is not available in this environment.") | |
| model = ViTForImageClassification.from_pretrained(repository_id) | |
| feature_extractor = ViTImageProcessor.from_pretrained(repository_id) | |
| logger.info("Model loaded successfully.") | |
| except Exception as e: | |
| model_load_error = str(e) | |
| logger.exception("Failed to load model: %s", e) | |
| # ------------------ Prediction helper ------------------ | |
| def predict_image(image: Image.Image) -> Tuple[str, float]: | |
| """ | |
| Run model inference on a PIL image. | |
| Returns (label, confidence) where label in {"Yes","No","Unknown"} and confidence 0..1. | |
| """ | |
| if model is None or feature_extractor is None: | |
| raise RuntimeError("Model not loaded.") | |
| # Preprocess using the feature extractor | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy() | |
| pred_idx = int(np.argmax(probs)) | |
| confidence = float(probs[pred_idx]) | |
| label_map = {0: "No", 1: "Yes"} # adjust if your model mapping differs | |
| label = label_map.get(pred_idx, "Unknown") | |
| return label, confidence | |
| # ------------------ Groq LLM helper (defensive) ------------------ | |
| def generate_medical_report(diagnosis_label: str, confidence: float, image_info: str, include_image_base64: bool = False, image_b64: str = None) -> str: | |
| """ | |
| Ask the LLM to create an informational medical report based on the model inference. | |
| Returns a string report (informational only). | |
| """ | |
| safety_sentence = "This information is informational only — seek evaluation from a licensed medical professional." | |
| if Groq is None: | |
| logger.error("Groq client unavailable.") | |
| return "Medical report not available: Groq client library not installed in this environment." | |
| api_key = os.getenv("API_KEY") | |
| if not api_key: | |
| logger.error("API_KEY not set.") | |
| return "Medical report not available: API_KEY environment variable not configured." | |
| try: | |
| client = Groq(api_key=api_key) | |
| except Exception as e: | |
| logger.exception("Failed to instantiate Groq client: %s", e) | |
| return "Medical report temporarily unavailable (client init failed)." | |
| # Construct a concise prompt that includes the model's result and image metadata. | |
| # Do NOT include patient identifying data; keep it informational. | |
| prompt_lines = [ | |
| "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.", | |
| f"Model diagnosis: {diagnosis_label}", | |
| f"Model confidence: {confidence:.2%}", | |
| f"Image info: {image_info}", | |
| "Do NOT provide definitive medical diagnoses or prescriptive orders. ALWAYS include the sentence:", | |
| f"'{safety_sentence}'", | |
| "Provide the report sections: (1) Brief summary of findings, (2) Suggested next diagnostic steps for a clinician to consider, (3) Questions a patient can ask their clinician, (4) Immediate red-flag signs requiring emergency care.", | |
| "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)." | |
| ] | |
| if include_image_base64 and image_b64: | |
| # Optionally include a tiny thumbnail as base64 (be careful with payload size). | |
| prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.") | |
| prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}") # only include a prefix to avoid huge payloads | |
| prompt = "\n\n".join(prompt_lines) | |
| messages = [ | |
| {"role": "system", "content": "You are a cautious medical assistant that always advises users to consult licensed clinicians."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| try: | |
| completion = client.chat.completions.create( | |
| model="deepseek-r1-distill-llama-70b", | |
| messages=messages, | |
| temperature=0.3, | |
| max_completion_tokens=1024, | |
| top_p=0.9, | |
| stream=False, | |
| stop=None, | |
| ) | |
| # Extract text robustly | |
| try: | |
| report_text = completion.choices[0].message.content | |
| except Exception: | |
| try: | |
| report_text = completion.choices[0].text | |
| except Exception: | |
| report_text = str(completion) | |
| # Ensure safety sentence present | |
| if safety_sentence not in report_text: | |
| report_text = safety_sentence + "\n\n" + report_text | |
| return report_text | |
| except Exception as e: | |
| logger.exception("Groq call failed: %s", e) | |
| # Try to pull useful info from exception if it exists | |
| resp = None | |
| for attr in ("response", "http_response", "raw_response", "resp"): | |
| resp = getattr(e, attr, None) | |
| if resp: | |
| break | |
| if resp: | |
| try: | |
| status = getattr(resp, "status_code", getattr(resp, "status", "unknown")) | |
| body_preview = getattr(resp, "text", getattr(resp, "body", str(resp))) | |
| logger.error("Groq response: status=%s body_preview=%s", status, str(body_preview)[:500]) | |
| except Exception: | |
| logger.error("Could not extract response details.") | |
| return "Medical report temporarily unavailable due to an error contacting the assistance model. Please consult a clinician." | |
| # ------------------ Helpers for image display & base64 ------------------ | |
| def pil_to_base64(img: Image.Image, size: Tuple[int, int] = None) -> str: | |
| """Return a base64-encoded JPEG for the PIL image. Optionally resize (maintain aspect ratio).""" | |
| tmp = img.copy() | |
| if size: | |
| tmp = ImageOps.contain(tmp, size) | |
| buff = BytesIO() | |
| tmp.save(buff, format="JPEG") | |
| return base64.b64encode(buff.getvalue()).decode() | |
| # ------------------ Streamlit UI: upload, display, inference, report ------------------ | |
| uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Load image | |
| try: | |
| pil_image = Image.open(uploaded_file).convert("RGB") | |
| except Exception as e: | |
| st.error(f"Unable to open the uploaded file as an image: {e}") | |
| logger.exception("Open uploaded image failed: %s", e) | |
| pil_image = None | |
| if pil_image: | |
| # Display original and a preprocessed/thumbnail side-by-side | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.markdown("**Original image**") | |
| st.image(pil_image, use_column_width=True) | |
| # Create a centered thumbnail / processed view (resize for model preview) | |
| processed_for_display = ImageOps.contain(pil_image, (512, 512)) | |
| with col2: | |
| st.markdown("**Processed (for model preview)**") | |
| st.image(processed_for_display, use_column_width=True) | |
| # Show image metadata | |
| img_w, img_h = pil_image.size | |
| st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}") | |
| # Option to include a small base64 thumbnail in the LLM prompt (default OFF to avoid large payloads) | |
| include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False) | |
| # Model availability check | |
| if model_load_error: | |
| st.error("Model failed to load at startup. See Developer info for details.") | |
| st.code(model_load_error) | |
| else: | |
| # Run inference | |
| run_infer = st.button("Run inference & generate report") | |
| if run_infer: | |
| try: | |
| with st.spinner("Running model inference..."): | |
| label, confidence = predict_image(processed_for_display) | |
| st.success("Inference complete") | |
| st.markdown("### Model prediction:") | |
| st.write(f"**{label}** (confidence {confidence:.2%})") | |
| except Exception as e: | |
| logger.exception("Inference failed: %s", e) | |
| st.error("Inference failed: " + str(e)) | |
| label = None | |
| confidence = None | |
| # If inference ok, call LLM to generate report | |
| if label is not None: | |
| # Prepare image_info summary | |
| image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}" | |
| # Optionally produce small base64 thumbnail | |
| image_b64 = None | |
| if include_thumbnail: | |
| try: | |
| image_b64 = pil_to_base64(processed_for_display, size=(256, 256)) | |
| except Exception as e: | |
| logger.exception("Failed to create base64 thumbnail: %s", e) | |
| image_b64 = None | |
| with st.spinner("Generating informational medical report from LLM..."): | |
| report_text = generate_medical_report(label, confidence, image_info, include_image_base64=include_thumbnail, image_b64=image_b64) | |
| st.markdown("### Medical Report (informational)") | |
| st.write(report_text) | |
| # Allow user to download the report as a .txt file | |
| try: | |
| report_bytes = report_text.encode("utf-8") | |
| download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt" | |
| st.download_button("Download report", data=report_bytes, file_name=download_name, mime="text/plain") | |
| except Exception as e: | |
| logger.exception("Failed to prepare report download: %s", e) | |
| st.error("Could not prepare download: " + str(e)) | |
| # If no file uploaded, show placeholder instructions | |
| if uploaded_file is None: | |
| st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True) | |
| # ------------------ Developer troubleshooting expander ------------------ | |
| with st.expander("Developer info / Troubleshooting"): | |
| st.markdown(f"**Model repository**: `{repository_id}`") | |
| st.markdown(f"**Torch available**: {'Yes' if torch is not None else 'No'}") | |
| st.markdown(f"**Model loaded**: {'Yes' if model is not None else 'No'}") | |
| st.write({ | |
| "CUDA available": torch.cuda.is_available() if torch is not None else False, | |
| "API_KEY set for Groq": bool(os.getenv("API_KEY")), | |
| "Groq installed": Groq is not None | |
| }) | |
| if model_load_error: | |
| st.markdown("**Model load error**:") | |
| st.code(model_load_error) | |
| st.markdown("---") | |
| st.markdown("### Groq quick test (for debugging API errors)") | |
| st.markdown("Click the button to run a very small 'ping' to the Groq chat endpoint. This helps capture raw error info without sending large prompts.") | |
| if st.button("Run Groq ping"): | |
| # small test call | |
| def groq_test_ping(max_tokens: int = 8): | |
| if Groq is None: | |
| return {"ok": False, "result": "Groq client library not available."} | |
| api_key = os.getenv("API_KEY") | |
| if not api_key: | |
| return {"ok": False, "result": "API_KEY not configured."} | |
| try: | |
| client = Groq(api_key=api_key) | |
| res = client.chat.completions.create( | |
| model="deepseek-r1-distill-llama-70b", | |
| messages=[{"role": "user", "content": "ping"}], | |
| max_completion_tokens=max_tokens, | |
| ) | |
| try: | |
| content = res.choices[0].message.content | |
| except Exception: | |
| try: | |
| content = res.choices[0].text | |
| except Exception: | |
| content = str(res) | |
| return {"ok": True, "result": content} | |
| except Exception as e: | |
| info = {"exception_repr": repr(e)} | |
| for attr in ("response", "http_response", "raw_response", "resp"): | |
| if hasattr(e, attr): | |
| rval = getattr(e, attr) | |
| try: | |
| info[attr] = { | |
| "status": getattr(rval, "status_code", getattr(rval, "status", "unknown")), | |
| "body_preview": (getattr(rval, "text", getattr(rval, "body", str(rval)))[:1000] + "...") if getattr(rval, "text", None) or getattr(rval, "body", None) else str(rval), | |
| } | |
| except Exception: | |
| info[attr] = str(rval) | |
| logger.exception("Groq test ping failed: %s", e) | |
| return {"ok": False, "result": info} | |
| ping_result = groq_test_ping() | |
| if ping_result.get("ok"): | |
| st.success("Groq ping successful") | |
| st.text_area("Result (truncated)", str(ping_result.get("result"))[:2000], height=200) | |
| else: | |
| st.error("Groq ping failed; see details below") | |
| st.json(ping_result.get("result")) | |
| st.markdown("---") | |
| st.markdown("Debugging tips:") | |
| st.markdown( | |
| "- If Groq returns HTTP 400: check model name, prompt length, and messages shape.\n" | |
| "- Use the Groq ping to inspect raw error details.\n" | |
| "- Ensure `API_KEY` is set & has permissions for the requested model.\n" | |
| "- To avoid the Streamlit <-> PyTorch watcher issue you can also run Streamlit with: " | |
| "`streamlit run app.py --server.fileWatcherType none` or set `.streamlit/config.toml`." | |
| ) | |