File size: 18,107 Bytes
b9d1861
 
dade807
266a048
 
 
dade807
 
 
 
 
b9d1861
 
 
 
 
fdbb1bf
266a048
 
b9d1861
 
 
266a048
 
 
 
b9d1861
266a048
dade807
266a048
b9d1861
266a048
 
b9d1861
fdbb1bf
266a048
 
 
 
 
fdbb1bf
266a048
fdbb1bf
266a048
fdbb1bf
 
d7767a0
b9d1861
fdbb1bf
 
d26f794
266a048
b9d1861
d7767a0
 
 
 
 
 
 
 
266a048
 
d7767a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9d1861
 
d7767a0
 
 
266a048
f01141e
 
266a048
f01141e
 
266a048
 
f01141e
b9d1861
 
 
fdbb1bf
b9d1861
 
266a048
b9d1861
 
 
 
 
 
 
266a048
b9d1861
 
266a048
b9d1861
 
266a048
b9d1861
266a048
 
b9d1861
266a048
 
b9d1861
 
266a048
dade807
b9d1861
 
 
 
 
 
 
 
266a048
 
 
 
 
 
 
 
b9d1861
266a048
 
b9d1861
fdbb1bf
 
266a048
 
fdbb1bf
b9d1861
 
266a048
 
fdbb1bf
b9d1861
 
 
fdbb1bf
266a048
 
dade807
 
266a048
 
 
 
 
 
 
 
 
 
 
dade807
266a048
dade807
266a048
 
b9d1861
 
266a048
 
b9d1861
 
 
 
 
 
266a048
 
 
b9d1861
 
 
dade807
b9d1861
266a048
b9d1861
 
266a048
b9d1861
266a048
dade807
266a048
 
 
b9d1861
266a048
dade807
266a048
fdbb1bf
266a048
 
fdbb1bf
266a048
fdbb1bf
266a048
 
 
fdbb1bf
266a048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7767a0
 
dade807
b9d1861
266a048
b9d1861
266a048
 
 
 
 
dade807
266a048
 
 
dade807
 
266a048
 
 
dade807
266a048
dade807
266a048
 
 
dade807
266a048
 
dade807
b9d1861
266a048
b9d1861
 
dade807
266a048
 
fdbb1bf
266a048
 
 
 
 
fdbb1bf
266a048
 
 
fdbb1bf
b9d1861
dade807
266a048
dade807
266a048
dade807
266a048
 
 
 
 
 
 
 
 
 
 
 
 
dade807
266a048
 
 
 
 
 
 
 
 
b9d1861
266a048
dade807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
#!/usr/bin/env python3
"""
Streamlit Brain MRI Tumor Detection App (updated to:
 - load & display the uploaded image(s),
 - pass the image to the ViT model for inference,
 - pass the model inference to the Groq Deepseek R1 LLM to generate an informational medical report,
 - provide robust logging, error handling, and a download for the generated report.

Important:
 - This app is informational only and not a medical diagnosis.
 - Set API_KEY in your environment to enable Groq calls.
"""

import os
import logging
import traceback
import base64
from io import BytesIO
from typing import Tuple

# ------------------ Safe startup: import torch first and monkeypatch ------------------
try:
    import torch
    # Avoid Streamlit file-watcher introspection triggering a PyTorch C++ error
    if hasattr(torch, "classes"):
        try:
            torch.classes.__path__ = []
        except Exception:
            # ignore - best-effort
            pass
except Exception as e:
    torch = None
    logging.error("Failed to import torch at startup: %s", e)

# ------------------ Now safe to import Streamlit and other libs ------------------
import streamlit as st
from PIL import Image, ImageOps
import numpy as np
import torch.nn.functional as F
from transformers import ViTForImageClassification, ViTImageProcessor

# Groq client
try:
    from groq import Groq
except Exception:
    Groq = None

# ------------------ Logging ------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ------------------ Page config + CSS ------------------
st.set_page_config(layout="wide", page_title="Brain MRI Tumor Detection")
combined_css = """
    .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
    .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
    .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
    .stSpinner { color: #4CAF50; }
    .title {
        font-size: 3rem;
        font-weight: bold;
        display: flex; 
        align-items: center; 
        justify-content: center;
    }
    .colorful-text {
        background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
    }
    .black-white-text {
        color: black;
    }
    .custom-text {
        font-size: 1.2rem;
        color: #feb47b;
        text-align: center;
        margin-top: -20px;
        margin-bottom: 20px;
    }
    .disclaimer { color: #ffcc66; font-weight: bold; text-align: center; margin-bottom: 12px; }
    .small-muted { font-size:0.9rem; color:#cccccc; text-align:center; margin-top:8px; }
"""
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)

# ------------------ Header + disclaimer ------------------
st.markdown(
    '<div class="title"><span class="colorful-text">Brain MRI</span> <span class="black-white-text">Tumor Detection</span></div>',
    unsafe_allow_html=True
)
st.markdown(
    '<div class="custom-text">Upload an MRI image to detect a brain tumor and get an informational medical report.</div>',
    unsafe_allow_html=True
)
st.markdown(
    "<div class='disclaimer'>⚠️ This app is experimental and informational only. It is NOT a medical diagnosis. "
    "If you have health concerns, consult a licensed medical professional. In emergencies call your local emergency number.</div>",
    unsafe_allow_html=True,
)

# ------------------ Model loading ------------------
repository_id = "EnDevSols/brainmri-vit-model"
model = None
feature_extractor = None
model_load_error = None

try:
    if torch is None:
        raise RuntimeError("torch is not available in this environment.")
    model = ViTForImageClassification.from_pretrained(repository_id)
    feature_extractor = ViTImageProcessor.from_pretrained(repository_id)
    logger.info("Model loaded successfully.")
except Exception as e:
    model_load_error = str(e)
    logger.exception("Failed to load model: %s", e)

# ------------------ Prediction helper ------------------
def predict_image(image: Image.Image) -> Tuple[str, float]:
    """
    Run model inference on a PIL image.
    Returns (label, confidence) where label in {"Yes","No","Unknown"} and confidence 0..1.
    """
    if model is None or feature_extractor is None:
        raise RuntimeError("Model not loaded.")
    # Preprocess using the feature extractor
    inputs = feature_extractor(images=image, return_tensors="pt")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
    pred_idx = int(np.argmax(probs))
    confidence = float(probs[pred_idx])
    label_map = {0: "No", 1: "Yes"}  # adjust if your model mapping differs
    label = label_map.get(pred_idx, "Unknown")
    return label, confidence

# ------------------ Groq LLM helper (defensive) ------------------
def generate_medical_report(diagnosis_label: str, confidence: float, image_info: str, include_image_base64: bool = False, image_b64: str = None) -> str:
    """
    Ask the LLM to create an informational medical report based on the model inference.
    Returns a string report (informational only).
    """
    safety_sentence = "This information is informational only — seek evaluation from a licensed medical professional."
    if Groq is None:
        logger.error("Groq client unavailable.")
        return "Medical report not available: Groq client library not installed in this environment."

    api_key = os.getenv("API_KEY")
    if not api_key:
        logger.error("API_KEY not set.")
        return "Medical report not available: API_KEY environment variable not configured."

    try:
        client = Groq(api_key=api_key)
    except Exception as e:
        logger.exception("Failed to instantiate Groq client: %s", e)
        return "Medical report temporarily unavailable (client init failed)."

    # Construct a concise prompt that includes the model's result and image metadata.
    # Do NOT include patient identifying data; keep it informational.
    prompt_lines = [
        "You are a careful medical assistant creating an informational medical report for a patient based on an automated image analysis result.",
        f"Model diagnosis: {diagnosis_label}",
        f"Model confidence: {confidence:.2%}",
        f"Image info: {image_info}",
        "Do NOT provide definitive medical diagnoses or prescriptive orders. ALWAYS include the sentence:",
        f"'{safety_sentence}'",
        "Provide the report sections: (1) Brief summary of findings, (2) Suggested next diagnostic steps for a clinician to consider, (3) Questions a patient can ask their clinician, (4) Immediate red-flag signs requiring emergency care.",
        "Keep language clear and non-technical where possible, and keep it concise (about 3-6 short paragraphs)."
    ]
    if include_image_base64 and image_b64:
        # Optionally include a tiny thumbnail as base64 (be careful with payload size).
        prompt_lines.append("Note: a small thumbnail was provided (base64), though you should not rely on it for clinical decision-making.")
        prompt_lines.append(f"Thumbnail (base64, trimmed): {image_b64[:800]}")  # only include a prefix to avoid huge payloads

    prompt = "\n\n".join(prompt_lines)

    messages = [
        {"role": "system", "content": "You are a cautious medical assistant that always advises users to consult licensed clinicians."},
        {"role": "user", "content": prompt}
    ]

    try:
        completion = client.chat.completions.create(
            model="deepseek-r1-distill-llama-70b",
            messages=messages,
            temperature=0.3,
            max_completion_tokens=1024,
            top_p=0.9,
            stream=False,
            stop=None,
        )
        # Extract text robustly
        try:
            report_text = completion.choices[0].message.content
        except Exception:
            try:
                report_text = completion.choices[0].text
            except Exception:
                report_text = str(completion)
        # Ensure safety sentence present
        if safety_sentence not in report_text:
            report_text = safety_sentence + "\n\n" + report_text
        return report_text
    except Exception as e:
        logger.exception("Groq call failed: %s", e)
        # Try to pull useful info from exception if it exists
        resp = None
        for attr in ("response", "http_response", "raw_response", "resp"):
            resp = getattr(e, attr, None)
            if resp:
                break
        if resp:
            try:
                status = getattr(resp, "status_code", getattr(resp, "status", "unknown"))
                body_preview = getattr(resp, "text", getattr(resp, "body", str(resp)))
                logger.error("Groq response: status=%s body_preview=%s", status, str(body_preview)[:500])
            except Exception:
                logger.error("Could not extract response details.")
        return "Medical report temporarily unavailable due to an error contacting the assistance model. Please consult a clinician."

# ------------------ Helpers for image display & base64 ------------------
def pil_to_base64(img: Image.Image, size: Tuple[int, int] = None) -> str:
    """Return a base64-encoded JPEG for the PIL image. Optionally resize (maintain aspect ratio)."""
    tmp = img.copy()
    if size:
        tmp = ImageOps.contain(tmp, size)
    buff = BytesIO()
    tmp.save(buff, format="JPEG")
    return base64.b64encode(buff.getvalue()).decode()

# ------------------ Streamlit UI: upload, display, inference, report ------------------
uploaded_file = st.file_uploader("Choose an MRI image (jpg, jpeg, png)", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    # Load image
    try:
        pil_image = Image.open(uploaded_file).convert("RGB")
    except Exception as e:
        st.error(f"Unable to open the uploaded file as an image: {e}")
        logger.exception("Open uploaded image failed: %s", e)
        pil_image = None

    if pil_image:
        # Display original and a preprocessed/thumbnail side-by-side
        col1, col2 = st.columns([1, 1])
        with col1:
            st.markdown("**Original image**")
            st.image(pil_image, use_column_width=True)
        # Create a centered thumbnail / processed view (resize for model preview)
        processed_for_display = ImageOps.contain(pil_image, (512, 512))
        with col2:
            st.markdown("**Processed (for model preview)**")
            st.image(processed_for_display, use_column_width=True)

        # Show image metadata
        img_w, img_h = pil_image.size
        st.markdown(f"**Image metadata:** dimensions = {img_w} x {img_h}, mode = {pil_image.mode}")

        # Option to include a small base64 thumbnail in the LLM prompt (default OFF to avoid large payloads)
        include_thumbnail = st.checkbox("Include small thumbnail preview in the generated report prompt (may increase request size)", value=False)

        # Model availability check
        if model_load_error:
            st.error("Model failed to load at startup. See Developer info for details.")
            st.code(model_load_error)
        else:
            # Run inference
            run_infer = st.button("Run inference & generate report")
            if run_infer:
                try:
                    with st.spinner("Running model inference..."):
                        label, confidence = predict_image(processed_for_display)
                    st.success("Inference complete")
                    st.markdown("### Model prediction:")
                    st.write(f"**{label}** (confidence {confidence:.2%})")
                except Exception as e:
                    logger.exception("Inference failed: %s", e)
                    st.error("Inference failed: " + str(e))
                    label = None
                    confidence = None

                # If inference ok, call LLM to generate report
                if label is not None:
                    # Prepare image_info summary
                    image_info = f"dimensions={img_w}x{img_h}; mode={pil_image.mode}; filename_provided={hasattr(uploaded_file, 'name') and bool(getattr(uploaded_file, 'name', None))}"
                    # Optionally produce small base64 thumbnail
                    image_b64 = None
                    if include_thumbnail:
                        try:
                            image_b64 = pil_to_base64(processed_for_display, size=(256, 256))
                        except Exception as e:
                            logger.exception("Failed to create base64 thumbnail: %s", e)
                            image_b64 = None

                    with st.spinner("Generating informational medical report from LLM..."):
                        report_text = generate_medical_report(label, confidence, image_info, include_image_base64=include_thumbnail, image_b64=image_b64)
                    st.markdown("### Medical Report (informational)")
                    st.write(report_text)

                    # Allow user to download the report as a .txt file
                    try:
                        report_bytes = report_text.encode("utf-8")
                        download_name = f"medical_report_{label}_{int(confidence*100)}pct.txt"
                        st.download_button("Download report", data=report_bytes, file_name=download_name, mime="text/plain")
                    except Exception as e:
                        logger.exception("Failed to prepare report download: %s", e)
                        st.error("Could not prepare download: " + str(e))

# If no file uploaded, show placeholder instructions
if uploaded_file is None:
    st.markdown("<div class='small-muted'>Upload a brain MRI image (jpg/png) to get a model prediction and an informational medical report.</div>", unsafe_allow_html=True)

# ------------------ Developer troubleshooting expander ------------------
with st.expander("Developer info / Troubleshooting"):
    st.markdown(f"**Model repository**: `{repository_id}`")
    st.markdown(f"**Torch available**: {'Yes' if torch is not None else 'No'}")
    st.markdown(f"**Model loaded**: {'Yes' if model is not None else 'No'}")
    st.write({
        "CUDA available": torch.cuda.is_available() if torch is not None else False,
        "API_KEY set for Groq": bool(os.getenv("API_KEY")),
        "Groq installed": Groq is not None
    })
    if model_load_error:
        st.markdown("**Model load error**:")
        st.code(model_load_error)

    st.markdown("---")
    st.markdown("### Groq quick test (for debugging API errors)")
    st.markdown("Click the button to run a very small 'ping' to the Groq chat endpoint. This helps capture raw error info without sending large prompts.")
    if st.button("Run Groq ping"):
        # small test call
        def groq_test_ping(max_tokens: int = 8):
            if Groq is None:
                return {"ok": False, "result": "Groq client library not available."}
            api_key = os.getenv("API_KEY")
            if not api_key:
                return {"ok": False, "result": "API_KEY not configured."}
            try:
                client = Groq(api_key=api_key)
                res = client.chat.completions.create(
                    model="deepseek-r1-distill-llama-70b",
                    messages=[{"role": "user", "content": "ping"}],
                    max_completion_tokens=max_tokens,
                )
                try:
                    content = res.choices[0].message.content
                except Exception:
                    try:
                        content = res.choices[0].text
                    except Exception:
                        content = str(res)
                return {"ok": True, "result": content}
            except Exception as e:
                info = {"exception_repr": repr(e)}
                for attr in ("response", "http_response", "raw_response", "resp"):
                    if hasattr(e, attr):
                        rval = getattr(e, attr)
                        try:
                            info[attr] = {
                                "status": getattr(rval, "status_code", getattr(rval, "status", "unknown")),
                                "body_preview": (getattr(rval, "text", getattr(rval, "body", str(rval)))[:1000] + "...") if getattr(rval, "text", None) or getattr(rval, "body", None) else str(rval),
                            }
                        except Exception:
                            info[attr] = str(rval)
                logger.exception("Groq test ping failed: %s", e)
                return {"ok": False, "result": info}

        ping_result = groq_test_ping()
        if ping_result.get("ok"):
            st.success("Groq ping successful")
            st.text_area("Result (truncated)", str(ping_result.get("result"))[:2000], height=200)
        else:
            st.error("Groq ping failed; see details below")
            st.json(ping_result.get("result"))

    st.markdown("---")
    st.markdown("Debugging tips:")
    st.markdown(
        "- If Groq returns HTTP 400: check model name, prompt length, and messages shape.\n"
        "- Use the Groq ping to inspect raw error details.\n"
        "- Ensure `API_KEY` is set & has permissions for the requested model.\n"
        "- To avoid the Streamlit <-> PyTorch watcher issue you can also run Streamlit with: "
        "`streamlit run app.py --server.fileWatcherType none` or set `.streamlit/config.toml`."
    )