Halemo's picture
Redesign PDF report
22d888b verified
from __future__ import annotations
from datetime import datetime, timezone
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from reportlab.lib.colors import Color, HexColor
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import mm
from reportlab.lib.utils import ImageReader
from reportlab.pdfgen import canvas
LABELS = ["No Tumor", "Pituitary", "Glioma", "Meningioma"]
NUM_DISPLAY_CLASSES = len(LABELS)
def is_likely_mri(pil_image: Image.Image) -> tuple[bool, dict]:
"""Heuristic check: brain MRIs are near-grayscale with dark backgrounds.
Returns (is_mri_like, diagnostics).
"""
arr = np.asarray(pil_image.convert("RGB"), dtype=np.float32)
if arr.ndim != 3 or arr.shape[-1] != 3:
return False, {"reason": "not RGB"}
# Per-pixel saturation: how far the max channel is from the min channel.
chroma = arr.max(axis=-1) - arr.min(axis=-1)
mean_chroma = float(chroma.mean())
# Background darkness: real MRIs have a lot of near-black pixels.
luma = arr.mean(axis=-1)
dark_fraction = float((luma < 25).mean())
diagnostics = {
"mean_chroma": mean_chroma,
"dark_fraction": dark_fraction,
}
# Thresholds tuned empirically: real grayscale MRIs sit at chroma < 8,
# color logos / photos easily exceed 25. Background should cover >5% of image.
is_mri = mean_chroma < 15 and dark_fraction > 0.05
return is_mri, diagnostics
def predict(model, image, device):
model.eval()
image = image.to(device)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return predicted.item()
def predict_with_probs(model, image_tensor, device):
model.eval()
with torch.no_grad():
logits = model(image_tensor.to(device))
probs = torch.softmax(logits[:, :NUM_DISPLAY_CLASSES], dim=1)[0].cpu().numpy()
top_idx = int(probs.argmax())
return top_idx, probs
def _pil_to_normalized_array(pil_image: Image.Image, size: int = 224) -> np.ndarray:
resized = pil_image.convert("RGB").resize((size, size))
return np.asarray(resized, dtype=np.float32) / 255.0
def compute_gradcam(model, image_tensor, target_layer, predicted_class: int, original_pil: Image.Image) -> Image.Image:
cam = GradCAM(model=model, target_layers=[target_layer])
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
grayscale_cam = cam(
input_tensor=image_tensor,
targets=[ClassifierOutputTarget(predicted_class)],
)[0]
rgb_image = _pil_to_normalized_array(original_pil)
overlay = show_cam_on_image(rgb_image, grayscale_cam, use_rgb=True)
return Image.fromarray(overlay)
# Brand palette mirrors app.py CSS variables.
_ACCENT_1 = HexColor("#7c3aed") # purple
_ACCENT_2 = HexColor("#06b6d4") # cyan
_INK = HexColor("#0f172a")
_INK_SOFT = HexColor("#475569")
_INK_FAINT = HexColor("#94a3b8")
_BG_SOFT = HexColor("#f1f5f9")
_BG_CARD = HexColor("#faf5ff")
_BORDER = HexColor("#e2e8f0")
_GOOD = HexColor("#10b981")
_WARN = HexColor("#f59e0b")
_BAD = HexColor("#ef4444")
def _confidence_color(p: float) -> Color:
if p >= 0.85:
return _GOOD
if p >= 0.6:
return _WARN
return _BAD
def _confidence_label(p: float) -> str:
if p >= 0.85:
return "High confidence"
if p >= 0.6:
return "Medium confidence"
return "Low confidence"
def _draw_gradient_band(c: canvas.Canvas, x: float, y: float, w: float, h: float,
c1: Color, c2: Color, steps: int = 60) -> None:
"""Horizontal gradient as thin vertical strips."""
strip = w / steps
for i in range(steps):
t = i / (steps - 1)
r = c1.red + (c2.red - c1.red) * t
g = c1.green + (c2.green - c1.green) * t
b = c1.blue + (c2.blue - c1.blue) * t
c.setFillColorRGB(r, g, b)
c.rect(x + i * strip, y, strip + 0.5, h, stroke=0, fill=1)
def _draw_image_card(c: canvas.Canvas, img: Image.Image, x: float, y: float,
w: float, h: float, label: str) -> None:
# Card label (above image)
c.setFont("Helvetica-Bold", 9)
c.setFillColor(_INK_SOFT)
c.drawString(x, y + h + 3 * mm, label.upper())
# Card border
c.setStrokeColor(_BORDER)
c.setLineWidth(0.7)
c.roundRect(x - 1.5, y - 1.5, w + 3, h + 3, 3 * mm, stroke=1, fill=0)
# Image
reader = ImageReader(img.convert("RGB"))
c.drawImage(reader, x, y, width=w, height=h, preserveAspectRatio=True, anchor="c", mask="auto")
def make_pdf_report(
*,
original_image: Image.Image,
heatmap_image: Image.Image,
predicted_label: str,
probabilities: np.ndarray,
model_version: str = "model_38",
) -> bytes:
buf = BytesIO()
c = canvas.Canvas(buf, pagesize=A4)
width, height = A4
margin = 18 * mm
top_idx = int(np.argmax(probabilities))
top_prob = float(probabilities[top_idx])
# ---- Header band -------------------------------------------------------
band_h = 26 * mm
band_y = height - band_h
_draw_gradient_band(c, 0, band_y, width, band_h, _ACCENT_1, _ACCENT_2)
c.setFillColorRGB(1, 1, 1)
c.setFont("Helvetica-Bold", 18)
c.drawString(margin, band_y + band_h - 12 * mm, "Brain MRI Tumor Classification")
c.setFont("Helvetica", 10)
c.drawString(margin, band_y + band_h - 18 * mm, "Grad-CAM-explained PyTorch CNN report")
# Right-aligned meta
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
c.setFont("Helvetica", 9)
meta_x = width - margin
c.drawRightString(meta_x, band_y + band_h - 12 * mm, f"Generated {timestamp}")
c.drawRightString(meta_x, band_y + band_h - 17 * mm, f"Model {model_version}")
# ---- Image cards -------------------------------------------------------
img_y = band_y - 8 * mm - 70 * mm
img_w = (width - 2 * margin - 8 * mm) / 2
img_h = 70 * mm
_draw_image_card(c, original_image, margin, img_y, img_w, img_h, "Uploaded MRI")
_draw_image_card(c, heatmap_image, margin + img_w + 8 * mm, img_y, img_w, img_h,
"Grad-CAM heatmap")
# ---- Prediction card ---------------------------------------------------
card_y = img_y - 12 * mm - 38 * mm
card_h = 38 * mm
card_w = width - 2 * margin
# Card background
c.setFillColor(_BG_CARD)
c.setStrokeColor(_ACCENT_1)
c.setLineWidth(0.7)
c.roundRect(margin, card_y, card_w, card_h, 4 * mm, stroke=1, fill=1)
# Left accent stripe
c.setFillColor(_ACCENT_1)
c.rect(margin, card_y, 2 * mm, card_h, stroke=0, fill=1)
inner_x = margin + 8 * mm
inner_top = card_y + card_h - 8 * mm
# "PREDICTION" pill
pill_w, pill_h = 26 * mm, 5 * mm
_draw_gradient_band(c, inner_x, inner_top - pill_h + 1 * mm, pill_w, pill_h,
_ACCENT_1, _ACCENT_2, steps=30)
c.setFillColorRGB(1, 1, 1)
c.setFont("Helvetica-Bold", 8)
c.drawCentredString(inner_x + pill_w / 2, inner_top - pill_h + 2.6 * mm, "PREDICTION")
# Label
c.setFillColor(_INK)
c.setFont("Helvetica-Bold", 22)
c.drawString(inner_x, inner_top - 16 * mm, predicted_label)
# Confidence (right side)
conf_color = _confidence_color(top_prob)
c.setFillColor(_INK_SOFT)
c.setFont("Helvetica", 9)
c.drawRightString(margin + card_w - 8 * mm, inner_top - 4 * mm,
_confidence_label(top_prob).upper())
c.setFillColor(conf_color)
c.setFont("Helvetica-Bold", 26)
c.drawRightString(margin + card_w - 8 * mm, inner_top - 14 * mm,
f"{top_prob * 100:.1f}%")
# Class description
desc_map = {
"No Tumor": "Healthy brain tissue. No tumor detected in the scan.",
"Pituitary": "Tumor in the pituitary gland (base of the brain). Often benign.",
"Glioma": "Tumor arising from glial cells. Can be aggressive; needs urgent review.",
"Meningioma": "Tumor of the meninges (brain's protective layers). Usually slow-growing.",
}
c.setFillColor(_INK_SOFT)
c.setFont("Helvetica-Oblique", 9)
c.drawString(inner_x, card_y + 6 * mm, desc_map.get(predicted_label, ""))
# ---- Per-class probability bars ----------------------------------------
bars_y = card_y - 14 * mm
c.setFillColor(_INK)
c.setFont("Helvetica-Bold", 11)
c.drawString(margin, bars_y, "Per-class probability")
bar_top = bars_y - 6 * mm
row_h = 9 * mm
name_w = 30 * mm
pct_w = 18 * mm
bar_x = margin + name_w
bar_w = card_w - name_w - pct_w
bar_h = 5 * mm
for i, label in enumerate(LABELS):
row_y = bar_top - (i + 1) * row_h + (row_h - bar_h) / 2
text_y = row_y + (bar_h - 8) / 2 + 0.5
prob = float(probabilities[i])
# name
c.setFillColor(_INK)
c.setFont("Helvetica" if i != top_idx else "Helvetica-Bold", 10)
c.drawString(margin, text_y, label)
# bg track
c.setFillColor(_BG_SOFT)
c.roundRect(bar_x, row_y, bar_w, bar_h, bar_h / 2, stroke=0, fill=1)
# filled portion
fill_w = max(prob * bar_w, 0.5)
if i == top_idx:
_draw_gradient_band(c, bar_x, row_y, fill_w, bar_h, _ACCENT_1, _ACCENT_2,
steps=max(int(fill_w / 2), 4))
else:
c.setFillColor(_INK_FAINT)
c.roundRect(bar_x, row_y, fill_w, bar_h, bar_h / 2, stroke=0, fill=1)
# pct
c.setFillColor(_INK)
c.setFont("Helvetica" if i != top_idx else "Helvetica-Bold", 10)
c.drawRightString(margin + card_w, text_y, f"{prob * 100:.2f}%")
# ---- Footer disclaimer -------------------------------------------------
footer_h = 18 * mm
c.setFillColor(_BG_SOFT)
c.rect(0, 0, width, footer_h, stroke=0, fill=1)
c.setStrokeColor(_BORDER)
c.setLineWidth(0.5)
c.line(0, footer_h, width, footer_h)
c.setFillColor(_WARN)
c.setFont("Helvetica-Bold", 9)
c.drawString(margin, footer_h - 7 * mm, "DISCLAIMER")
c.setFillColor(_INK_SOFT)
c.setFont("Helvetica", 9)
c.drawString(
margin + 22 * mm,
footer_h - 7 * mm,
"For educational and research use only. Not a medical device.",
)
c.setFillColor(_INK_FAINT)
c.setFont("Helvetica", 8)
c.drawString(
margin,
footer_h - 12 * mm,
"Do not use these predictions for diagnosis or treatment decisions. "
"Consult a qualified clinician for any medical concerns.",
)
c.showPage()
c.save()
return buf.getvalue()