Tumor-Detection / app.py
Hammad712's picture
Update app.py
dade807 verified
#!/usr/bin/env python3
"""
Streamlit Brain MRI Tumor Detection App (updated to:
- load & display the uploaded image(s),
- pass the image to the ViT model for inference,
- pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report,
- provide robust logging, error handling, and a download for the generated report.
Important:
- This app is informational only and not a medical diagnosis.
- Set API_KEY in your environment to enable Groq calls.
"""
import os
import logging
import traceback
import base64
from io import BytesIO
from typing import Tuple
# ------------------ Safe startup: import torch first and monkeypatch ------------------
try:
import torch
# Avoid Streamlit file-watcher introspection triggering a PyTorch C++ error
if hasattr(torch, "classes"):
try:
torch.classes.__path__ = []
except Exception:
# ignore - best-effort
pass
except Exception as e:
torch = None
logging.error("Failed to import torch at startup: %s", e)
# ------------------ Now safe to import Streamlit and other libs ------------------
import streamlit as st
from PIL import Image, ImageOps
import numpy as np
import torch.nn.functional as F
from transformers import ViTForImageClassification, ViTImageProcessor
# Groq client
try:
from groq import Groq
except Exception:
Groq = None
# ------------------ Logging ------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# ------------------ Page config + CSS ------------------
st.set_page_config(layout="wide", page_title="Brain MRI Tumor Detection")
combined_css = """
.main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
.block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
.stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
.stSpinner { color: #4CAF50; }
.title {
font-size: 3rem;
font-weight: bold;
display: flex;
align-items: center;
justify-content: center;
}
.colorful-text {
background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.black-white-text {
color: black;
}
.custom-text {
font-size: 1.2rem;
color: #feb47b;
text-align: center;
margin-top: -20px;
margin-bottom: 20px;
}
.disclaimer { color: #ffcc66; font-weight: bold; text-align: center; margin-bottom: 12px; }
.small-muted { font-size:0.9rem; color:#cccccc; text-align:center; margin-top:8px; }
"""
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
# ------------------ Header + disclaimer ------------------
st.markdown(
'<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
unsafe_allow_html=True
)
st.markdown(
'<div class="custom-text">Upload an MRI image to detect a brain tumor and get an informational medical report.</div>',
unsafe_allow_html=True
)
st.markdown(
"<div class='disclaimer'>⚠️ This app is experimental and informational only. It is NOT a medical diagnosis. "
"If you have health concerns, consult a licensed medical professional. In emergencies call your local emergency number.</div>",
unsafe_allow_html=True,
)
# ------------------ Model loading ------------------
repository_id = "EnDevSols/brainmri-vit-model"
model = None
feature_extractor = None
model_load_error = None
try:
if torch is None:
raise RuntimeError("torch is not available in this environment.")
model = ViTForImageClassification.from_pretrained(repository_id)
feature_extractor = ViTImageProcessor.from_pretrained(repository_id)
logger.info("Model loaded successfully.")
except Exception as e:
model_load_error = str(e)
logger.exception("Failed to load model: %s", e)
# ------------------ Prediction helper ------------------
def predict_image(image: Image.Image) -> Tuple[str, float]:
"""
Run model inference on a PIL image.
Returns (label, confidence) where label in {"Yes","No","Unknown"} and confidence 0..1.
"""
if model is None or feature_extractor is None:
raise RuntimeError("Model not loaded.")
# Preprocess using the feature extractor
inputs = feature_extractor(images=image, return_tensors="pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx])
label_map = {0: "No", 1: "Yes"} # adjust if your model mapping differs
label = label_map.get(pred_idx, "Unknown")
return label, confidence
# ------------------ Groq LLM helper (defensive) ------------------
def generate_medical_report(diagnosis_label: str, confidence: float, image_info: str, include_image_base64: bool = False, image_b64: str = None) -> str:
"""
Ask the LLM to create an informational medical report based on the model inference.
Returns a string report (informational only).
"""
safety_sentence = "This information is informational only — seek evaluation from a licensed medical professional."
if Groq is None:
logger.error("Groq client unavailable.")
return "Medical report not available: Groq client library not installed in this environment."
api_key = os.getenv("API_KEY")
if not api_key:
logger.error("API_KEY not set.")
return "Medical report not available: API_KEY environment variable not configured."
try:
client = Groq(api_key=api_key)
except Exception as e:
logger.exception("Failed to instantiate Groq client: %s", e)
return "Medical report temporarily unavailable (client init failed)."
# Construct a concise prompt that includes the model's result and image metadata.
# Do NOT include patient identifying data; keep it informational.
prompt_lines = [
"You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.",
f"Model diagnosis: {diagnosis_label}",
f"Model confidence: {confidence:.2%}",
f"Image info: {image_info}",
"Do NOT provide definitive medical diagnoses or prescriptive orders. ALWAYS include the sentence:",
f"'{safety_sentence}'",
"Provide the report sections: (1) Brief summary of findings, (2) Suggested next diagnostic steps for a clinician to consider, (3) Questions a patient can ask their clinician, (4) Immediate red-flag signs requiring emergency care.",
"Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)."
]
if include_image_base64 and image_b64:
# Optionally include a tiny thumbnail as base64 (be careful with payload size).
prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.")
prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}") # only include a prefix to avoid huge payloads
prompt = "\n\n".join(prompt_lines)
messages = [
{"role": "system", "content": "You are a cautious medical assistant that always advises users to consult licensed clinicians."},
{"role": "user", "content": prompt}
]
try:
completion = client.chat.completions.create(
model="deepseek-r1-distill-llama-70b",
messages=messages,
temperature=0.3,
max_completion_tokens=1024,
top_p=0.9,
stream=False,
stop=None,
)
# Extract text robustly
try:
report_text = completion.choices[0].message.content
except Exception:
try:
report_text = completion.choices[0].text
except Exception:
report_text = str(completion)
# Ensure safety sentence present
if safety_sentence not in report_text:
report_text = safety_sentence + "\n\n" + report_text
return report_text
except Exception as e:
logger.exception("Groq call failed: %s", e)
# Try to pull useful info from exception if it exists
resp = None
for attr in ("response", "http_response", "raw_response", "resp"):
resp = getattr(e, attr, None)
if resp:
break
if resp:
try:
status = getattr(resp, "status_code", getattr(resp, "status", "unknown"))
body_preview = getattr(resp, "text", getattr(resp, "body", str(resp)))
logger.error("Groq response: status=%s body_preview=%s", status, str(body_preview)[:500])
except Exception:
logger.error("Could not extract response details.")
return "Medical report temporarily unavailable due to an error contacting the assistance model. Please consult a clinician."
# ------------------ Helpers for image display & base64 ------------------
def pil_to_base64(img: Image.Image, size: Tuple[int, int] = None) -> str:
"""Return a base64-encoded JPEG for the PIL image. Optionally resize (maintain aspect ratio)."""
tmp = img.copy()
if size:
tmp = ImageOps.contain(tmp, size)
buff = BytesIO()
tmp.save(buff, format="JPEG")
return base64.b64encode(buff.getvalue()).decode()
# ------------------ Streamlit UI: upload, display, inference, report ------------------
uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Load image
try:
pil_image = Image.open(uploaded_file).convert("RGB")
except Exception as e:
st.error(f"Unable to open the uploaded file as an image: {e}")
logger.exception("Open uploaded image failed: %s", e)
pil_image = None
if pil_image:
# Display original and a preprocessed/thumbnail side-by-side
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("**Original image**")
st.image(pil_image, use_column_width=True)
# Create a centered thumbnail / processed view (resize for model preview)
processed_for_display = ImageOps.contain(pil_image, (512, 512))
with col2:
st.markdown("**Processed (for model preview)**")
st.image(processed_for_display, use_column_width=True)
# Show image metadata
img_w, img_h = pil_image.size
st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}")
# Option to include a small base64 thumbnail in the LLM prompt (default OFF to avoid large payloads)
include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False)
# Model availability check
if model_load_error:
st.error("Model failed to load at startup. See Developer info for details.")
st.code(model_load_error)
else:
# Run inference
run_infer = st.button("Run inference & generate report")
if run_infer:
try:
with st.spinner("Running model inference..."):
label, confidence = predict_image(processed_for_display)
st.success("Inference complete")
st.markdown("### Model prediction:")
st.write(f"**{label}** (confidence {confidence:.2%})")
except Exception as e:
logger.exception("Inference failed: %s", e)
st.error("Inference failed: " + str(e))
label = None
confidence = None
# If inference ok, call LLM to generate report
if label is not None:
# Prepare image_info summary
image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}"
# Optionally produce small base64 thumbnail
image_b64 = None
if include_thumbnail:
try:
image_b64 = pil_to_base64(processed_for_display, size=(256, 256))
except Exception as e:
logger.exception("Failed to create base64 thumbnail: %s", e)
image_b64 = None
with st.spinner("Generating informational medical report from LLM..."):
report_text = generate_medical_report(label, confidence, image_info, include_image_base64=include_thumbnail, image_b64=image_b64)
st.markdown("### Medical Report (informational)")
st.write(report_text)
# Allow user to download the report as a .txt file
try:
report_bytes = report_text.encode("utf-8")
download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt"
st.download_button("Download report", data=report_bytes, file_name=download_name, mime="text/plain")
except Exception as e:
logger.exception("Failed to prepare report download: %s", e)
st.error("Could not prepare download: " + str(e))
# If no file uploaded, show placeholder instructions
if uploaded_file is None:
st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True)
# ------------------ Developer troubleshooting expander ------------------
with st.expander("Developer info / Troubleshooting"):
st.markdown(f"**Model repository**: `{repository_id}`")
st.markdown(f"**Torch available**: {'Yes' if torch is not None else 'No'}")
st.markdown(f"**Model loaded**: {'Yes' if model is not None else 'No'}")
st.write({
"CUDA available": torch.cuda.is_available() if torch is not None else False,
"API_KEY set for Groq": bool(os.getenv("API_KEY")),
"Groq installed": Groq is not None
})
if model_load_error:
st.markdown("**Model load error**:")
st.code(model_load_error)
st.markdown("---")
st.markdown("### Groq quick test (for debugging API errors)")
st.markdown("Click the button to run a very small 'ping' to the Groq chat endpoint. This helps capture raw error info without sending large prompts.")
if st.button("Run Groq ping"):
# small test call
def groq_test_ping(max_tokens: int = 8):
if Groq is None:
return {"ok": False, "result": "Groq client library not available."}
api_key = os.getenv("API_KEY")
if not api_key:
return {"ok": False, "result": "API_KEY not configured."}
try:
client = Groq(api_key=api_key)
res = client.chat.completions.create(
model="deepseek-r1-distill-llama-70b",
messages=[{"role": "user", "content": "ping"}],
max_completion_tokens=max_tokens,
)
try:
content = res.choices[0].message.content
except Exception:
try:
content = res.choices[0].text
except Exception:
content = str(res)
return {"ok": True, "result": content}
except Exception as e:
info = {"exception_repr": repr(e)}
for attr in ("response", "http_response", "raw_response", "resp"):
if hasattr(e, attr):
rval = getattr(e, attr)
try:
info[attr] = {
"status": getattr(rval, "status_code", getattr(rval, "status", "unknown")),
"body_preview": (getattr(rval, "text", getattr(rval, "body", str(rval)))[:1000] + "...") if getattr(rval, "text", None) or getattr(rval, "body", None) else str(rval),
}
except Exception:
info[attr] = str(rval)
logger.exception("Groq test ping failed: %s", e)
return {"ok": False, "result": info}
ping_result = groq_test_ping()
if ping_result.get("ok"):
st.success("Groq ping successful")
st.text_area("Result (truncated)", str(ping_result.get("result"))[:2000], height=200)
else:
st.error("Groq ping failed; see details below")
st.json(ping_result.get("result"))
st.markdown("---")
st.markdown("Debugging tips:")
st.markdown(
"- If Groq returns HTTP 400: check model name, prompt length, and messages shape.\n"
"- Use the Groq ping to inspect raw error details.\n"
"- Ensure `API_KEY` is set & has permissions for the requested model.\n"
"- To avoid the Streamlit <-> PyTorch watcher issue you can also run Streamlit with: "
"`streamlit run app.py --server.fileWatcherType none` or set `.streamlit/config.toml`."
)