Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from datetime import datetime, timezone | |
| from io import BytesIO | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from reportlab.lib.colors import Color, HexColor | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.lib.units import mm | |
| from reportlab.lib.utils import ImageReader | |
| from reportlab.pdfgen import canvas | |
| LABELS = ["No Tumor", "Pituitary", "Glioma", "Meningioma"] | |
| NUM_DISPLAY_CLASSES = len(LABELS) | |
| def is_likely_mri(pil_image: Image.Image) -> tuple[bool, dict]: | |
| """Heuristic check: brain MRIs are near-grayscale with dark backgrounds. | |
| Returns (is_mri_like, diagnostics). | |
| """ | |
| arr = np.asarray(pil_image.convert("RGB"), dtype=np.float32) | |
| if arr.ndim != 3 or arr.shape[-1] != 3: | |
| return False, {"reason": "not RGB"} | |
| # Per-pixel saturation: how far the max channel is from the min channel. | |
| chroma = arr.max(axis=-1) - arr.min(axis=-1) | |
| mean_chroma = float(chroma.mean()) | |
| # Background darkness: real MRIs have a lot of near-black pixels. | |
| luma = arr.mean(axis=-1) | |
| dark_fraction = float((luma < 25).mean()) | |
| diagnostics = { | |
| "mean_chroma": mean_chroma, | |
| "dark_fraction": dark_fraction, | |
| } | |
| # Thresholds tuned empirically: real grayscale MRIs sit at chroma < 8, | |
| # color logos / photos easily exceed 25. Background should cover >5% of image. | |
| is_mri = mean_chroma < 15 and dark_fraction > 0.05 | |
| return is_mri, diagnostics | |
| def predict(model, image, device): | |
| model.eval() | |
| image = image.to(device) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| _, predicted = torch.max(outputs, 1) | |
| return predicted.item() | |
| def predict_with_probs(model, image_tensor, device): | |
| model.eval() | |
| with torch.no_grad(): | |
| logits = model(image_tensor.to(device)) | |
| probs = torch.softmax(logits[:, :NUM_DISPLAY_CLASSES], dim=1)[0].cpu().numpy() | |
| top_idx = int(probs.argmax()) | |
| return top_idx, probs | |
| def _pil_to_normalized_array(pil_image: Image.Image, size: int = 224) -> np.ndarray: | |
| resized = pil_image.convert("RGB").resize((size, size)) | |
| return np.asarray(resized, dtype=np.float32) / 255.0 | |
| def compute_gradcam(model, image_tensor, target_layer, predicted_class: int, original_pil: Image.Image) -> Image.Image: | |
| cam = GradCAM(model=model, target_layers=[target_layer]) | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| grayscale_cam = cam( | |
| input_tensor=image_tensor, | |
| targets=[ClassifierOutputTarget(predicted_class)], | |
| )[0] | |
| rgb_image = _pil_to_normalized_array(original_pil) | |
| overlay = show_cam_on_image(rgb_image, grayscale_cam, use_rgb=True) | |
| return Image.fromarray(overlay) | |
| # Brand palette mirrors app.py CSS variables. | |
| _ACCENT_1 = HexColor("#7c3aed") # purple | |
| _ACCENT_2 = HexColor("#06b6d4") # cyan | |
| _INK = HexColor("#0f172a") | |
| _INK_SOFT = HexColor("#475569") | |
| _INK_FAINT = HexColor("#94a3b8") | |
| _BG_SOFT = HexColor("#f1f5f9") | |
| _BG_CARD = HexColor("#faf5ff") | |
| _BORDER = HexColor("#e2e8f0") | |
| _GOOD = HexColor("#10b981") | |
| _WARN = HexColor("#f59e0b") | |
| _BAD = HexColor("#ef4444") | |
| def _confidence_color(p: float) -> Color: | |
| if p >= 0.85: | |
| return _GOOD | |
| if p >= 0.6: | |
| return _WARN | |
| return _BAD | |
| def _confidence_label(p: float) -> str: | |
| if p >= 0.85: | |
| return "High confidence" | |
| if p >= 0.6: | |
| return "Medium confidence" | |
| return "Low confidence" | |
| def _draw_gradient_band(c: canvas.Canvas, x: float, y: float, w: float, h: float, | |
| c1: Color, c2: Color, steps: int = 60) -> None: | |
| """Horizontal gradient as thin vertical strips.""" | |
| strip = w / steps | |
| for i in range(steps): | |
| t = i / (steps - 1) | |
| r = c1.red + (c2.red - c1.red) * t | |
| g = c1.green + (c2.green - c1.green) * t | |
| b = c1.blue + (c2.blue - c1.blue) * t | |
| c.setFillColorRGB(r, g, b) | |
| c.rect(x + i * strip, y, strip + 0.5, h, stroke=0, fill=1) | |
| def _draw_image_card(c: canvas.Canvas, img: Image.Image, x: float, y: float, | |
| w: float, h: float, label: str) -> None: | |
| # Card label (above image) | |
| c.setFont("Helvetica-Bold", 9) | |
| c.setFillColor(_INK_SOFT) | |
| c.drawString(x, y + h + 3 * mm, label.upper()) | |
| # Card border | |
| c.setStrokeColor(_BORDER) | |
| c.setLineWidth(0.7) | |
| c.roundRect(x - 1.5, y - 1.5, w + 3, h + 3, 3 * mm, stroke=1, fill=0) | |
| # Image | |
| reader = ImageReader(img.convert("RGB")) | |
| c.drawImage(reader, x, y, width=w, height=h, preserveAspectRatio=True, anchor="c", mask="auto") | |
| def make_pdf_report( | |
| *, | |
| original_image: Image.Image, | |
| heatmap_image: Image.Image, | |
| predicted_label: str, | |
| probabilities: np.ndarray, | |
| model_version: str = "model_38", | |
| ) -> bytes: | |
| buf = BytesIO() | |
| c = canvas.Canvas(buf, pagesize=A4) | |
| width, height = A4 | |
| margin = 18 * mm | |
| top_idx = int(np.argmax(probabilities)) | |
| top_prob = float(probabilities[top_idx]) | |
| # ---- Header band ------------------------------------------------------- | |
| band_h = 26 * mm | |
| band_y = height - band_h | |
| _draw_gradient_band(c, 0, band_y, width, band_h, _ACCENT_1, _ACCENT_2) | |
| c.setFillColorRGB(1, 1, 1) | |
| c.setFont("Helvetica-Bold", 18) | |
| c.drawString(margin, band_y + band_h - 12 * mm, "Brain MRI Tumor Classification") | |
| c.setFont("Helvetica", 10) | |
| c.drawString(margin, band_y + band_h - 18 * mm, "Grad-CAM-explained PyTorch CNN report") | |
| # Right-aligned meta | |
| timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC") | |
| c.setFont("Helvetica", 9) | |
| meta_x = width - margin | |
| c.drawRightString(meta_x, band_y + band_h - 12 * mm, f"Generated {timestamp}") | |
| c.drawRightString(meta_x, band_y + band_h - 17 * mm, f"Model {model_version}") | |
| # ---- Image cards ------------------------------------------------------- | |
| img_y = band_y - 8 * mm - 70 * mm | |
| img_w = (width - 2 * margin - 8 * mm) / 2 | |
| img_h = 70 * mm | |
| _draw_image_card(c, original_image, margin, img_y, img_w, img_h, "Uploaded MRI") | |
| _draw_image_card(c, heatmap_image, margin + img_w + 8 * mm, img_y, img_w, img_h, | |
| "Grad-CAM heatmap") | |
| # ---- Prediction card --------------------------------------------------- | |
| card_y = img_y - 12 * mm - 38 * mm | |
| card_h = 38 * mm | |
| card_w = width - 2 * margin | |
| # Card background | |
| c.setFillColor(_BG_CARD) | |
| c.setStrokeColor(_ACCENT_1) | |
| c.setLineWidth(0.7) | |
| c.roundRect(margin, card_y, card_w, card_h, 4 * mm, stroke=1, fill=1) | |
| # Left accent stripe | |
| c.setFillColor(_ACCENT_1) | |
| c.rect(margin, card_y, 2 * mm, card_h, stroke=0, fill=1) | |
| inner_x = margin + 8 * mm | |
| inner_top = card_y + card_h - 8 * mm | |
| # "PREDICTION" pill | |
| pill_w, pill_h = 26 * mm, 5 * mm | |
| _draw_gradient_band(c, inner_x, inner_top - pill_h + 1 * mm, pill_w, pill_h, | |
| _ACCENT_1, _ACCENT_2, steps=30) | |
| c.setFillColorRGB(1, 1, 1) | |
| c.setFont("Helvetica-Bold", 8) | |
| c.drawCentredString(inner_x + pill_w / 2, inner_top - pill_h + 2.6 * mm, "PREDICTION") | |
| # Label | |
| c.setFillColor(_INK) | |
| c.setFont("Helvetica-Bold", 22) | |
| c.drawString(inner_x, inner_top - 16 * mm, predicted_label) | |
| # Confidence (right side) | |
| conf_color = _confidence_color(top_prob) | |
| c.setFillColor(_INK_SOFT) | |
| c.setFont("Helvetica", 9) | |
| c.drawRightString(margin + card_w - 8 * mm, inner_top - 4 * mm, | |
| _confidence_label(top_prob).upper()) | |
| c.setFillColor(conf_color) | |
| c.setFont("Helvetica-Bold", 26) | |
| c.drawRightString(margin + card_w - 8 * mm, inner_top - 14 * mm, | |
| f"{top_prob * 100:.1f}%") | |
| # Class description | |
| desc_map = { | |
| "No Tumor": "Healthy brain tissue. No tumor detected in the scan.", | |
| "Pituitary": "Tumor in the pituitary gland (base of the brain). Often benign.", | |
| "Glioma": "Tumor arising from glial cells. Can be aggressive; needs urgent review.", | |
| "Meningioma": "Tumor of the meninges (brain's protective layers). Usually slow-growing.", | |
| } | |
| c.setFillColor(_INK_SOFT) | |
| c.setFont("Helvetica-Oblique", 9) | |
| c.drawString(inner_x, card_y + 6 * mm, desc_map.get(predicted_label, "")) | |
| # ---- Per-class probability bars ---------------------------------------- | |
| bars_y = card_y - 14 * mm | |
| c.setFillColor(_INK) | |
| c.setFont("Helvetica-Bold", 11) | |
| c.drawString(margin, bars_y, "Per-class probability") | |
| bar_top = bars_y - 6 * mm | |
| row_h = 9 * mm | |
| name_w = 30 * mm | |
| pct_w = 18 * mm | |
| bar_x = margin + name_w | |
| bar_w = card_w - name_w - pct_w | |
| bar_h = 5 * mm | |
| for i, label in enumerate(LABELS): | |
| row_y = bar_top - (i + 1) * row_h + (row_h - bar_h) / 2 | |
| text_y = row_y + (bar_h - 8) / 2 + 0.5 | |
| prob = float(probabilities[i]) | |
| # name | |
| c.setFillColor(_INK) | |
| c.setFont("Helvetica" if i != top_idx else "Helvetica-Bold", 10) | |
| c.drawString(margin, text_y, label) | |
| # bg track | |
| c.setFillColor(_BG_SOFT) | |
| c.roundRect(bar_x, row_y, bar_w, bar_h, bar_h / 2, stroke=0, fill=1) | |
| # filled portion | |
| fill_w = max(prob * bar_w, 0.5) | |
| if i == top_idx: | |
| _draw_gradient_band(c, bar_x, row_y, fill_w, bar_h, _ACCENT_1, _ACCENT_2, | |
| steps=max(int(fill_w / 2), 4)) | |
| else: | |
| c.setFillColor(_INK_FAINT) | |
| c.roundRect(bar_x, row_y, fill_w, bar_h, bar_h / 2, stroke=0, fill=1) | |
| # pct | |
| c.setFillColor(_INK) | |
| c.setFont("Helvetica" if i != top_idx else "Helvetica-Bold", 10) | |
| c.drawRightString(margin + card_w, text_y, f"{prob * 100:.2f}%") | |
| # ---- Footer disclaimer ------------------------------------------------- | |
| footer_h = 18 * mm | |
| c.setFillColor(_BG_SOFT) | |
| c.rect(0, 0, width, footer_h, stroke=0, fill=1) | |
| c.setStrokeColor(_BORDER) | |
| c.setLineWidth(0.5) | |
| c.line(0, footer_h, width, footer_h) | |
| c.setFillColor(_WARN) | |
| c.setFont("Helvetica-Bold", 9) | |
| c.drawString(margin, footer_h - 7 * mm, "DISCLAIMER") | |
| c.setFillColor(_INK_SOFT) | |
| c.setFont("Helvetica", 9) | |
| c.drawString( | |
| margin + 22 * mm, | |
| footer_h - 7 * mm, | |
| "For educational and research use only. Not a medical device.", | |
| ) | |
| c.setFillColor(_INK_FAINT) | |
| c.setFont("Helvetica", 8) | |
| c.drawString( | |
| margin, | |
| footer_h - 12 * mm, | |
| "Do not use these predictions for diagnosis or treatment decisions. " | |
| "Consult a qualified clinician for any medical concerns.", | |
| ) | |
| c.showPage() | |
| c.save() | |
| return buf.getvalue() | |