SSL-Former / app.py
saifulislamoni's picture
Deploy: Brain Tumor Classifier with GradCAM and PDF reports - MoCo SSL + Swin-Transformer
c703689
"""
Brain Tumor Classification Web App
Optimized for Hugging Face Spaces deployment
- Detects Glioma, Meningioma, No Tumor, Pituitary
- Generates GradCAM visualizations
- Creates PDF reports
"""
import os
import io
import uuid
import cv2
import torch
import base64
import numpy as np
from PIL import Image
from datetime import datetime
from flask import Flask, render_template, request, jsonify, send_file
import torch.nn as nn
import torchvision.transforms as transforms
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import (
SimpleDocTemplate, Table, TableStyle, Paragraph, Spacer,
Image as RLImage, HRFlowable
)
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch, mm
from reportlab.lib.enums import TA_CENTER, TA_LEFT, TA_RIGHT
import gc
# ============================================================================
# Model Architecture (same as training)
# ============================================================================
class FFEBlock(nn.Module):
"""Swin-T backbone with GradCAM support."""
def __init__(self):
super().__init__()
from torchvision import models
backbone = models.swin_t(weights=None)
self.features = backbone.features
self.norm = backbone.norm
self.pool = nn.AdaptiveAvgPool2d(1)
self.out_dim = 768
def forward(self, x):
x = self.features(x)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
x = self.pool(x)
return x.flatten(1)
def forward_features(self, x):
"""Return spatial features before pooling (for GradCAM)"""
x = self.features(x)
x = self.norm(x)
x = x.permute(0, 3, 1, 2) # [B, C, H, W]
return x
class ProjectionMLP(nn.Module):
"""NRPL projector: 768 -> 384 -> 192"""
def __init__(self, in_dim=768, hidden_dim=384, out_dim=192):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, x):
return self.net(x)
class LinearClassifier(nn.Module):
"""Classification model for deployment."""
def __init__(self, num_classes=4):
super().__init__()
self.encoder = FFEBlock()
self.projector = ProjectionMLP(in_dim=768, hidden_dim=384, out_dim=192)
self.head = nn.Linear(192, num_classes)
def forward(self, x):
h = self.encoder(x)
z = self.projector(h)
return self.head(z)
def forward_features(self, x):
"""Return spatial features before pooling (for GradCAM)"""
return self.encoder.forward_features(x)
def predict_proba(self, x):
return torch.softmax(self.forward(x), dim=1)
# ============================================================================
# GradCAM Implementation (Memory Efficient) — FIXED
# ============================================================================
class GradCAM:
"""Gradient-based Class Activation Map for visualization."""
def __init__(self, model, target_layer):
"""
Args:
model: The full LinearClassifier model
target_layer: The layer to hook — should be encoder.norm for Swin-T
"""
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.handles = []
self._register_hooks()
def _register_hooks(self):
"""Register forward and backward hooks."""
def forward_hook(module, input, output):
# encoder.norm output shape: (B, H, W, C)
self.activations = output.detach().clone()
def backward_hook(module, grad_input, grad_output):
# grad_output[0] shape: (B, H, W, C)
self.gradients = grad_output[0].detach().clone()
h1 = self.target_layer.register_forward_hook(forward_hook)
# Use register_full_backward_hook (non-deprecated)
h2 = self.target_layer.register_full_backward_hook(backward_hook)
self.handles = [h1, h2]
def generate(self, input_tensor, class_idx):
"""
Generate GradCAM for a specific class.
Args:
input_tensor: Input image [1, 3, 224, 224]
class_idx: Class index to visualize
Returns:
cam: GradCAM heatmap [224, 224] normalized to 0-1
"""
self.model.eval()
self.gradients = None
self.activations = None
# Forward pass — need gradients, so NOT under torch.no_grad()
input_tensor.requires_grad_(True)
output = self.model(input_tensor)
# Backward pass for target class
self.model.zero_grad()
target = output[0, class_idx]
target.backward()
# Check hooks fired
if self.gradients is None or self.activations is None:
print("[WARN] GradCAM hooks did not fire")
return np.ones((224, 224), dtype=np.float32) * 0.5
# encoder.norm output is (B, H, W, C) — take first batch
gradients = self.gradients[0] # (H, W, C)
activations = self.activations[0] # (H, W, C)
# Permute to (C, H, W) for standard GradCAM computation
gradients = gradients.permute(2, 0, 1) # (C, H, W)
activations = activations.permute(2, 0, 1) # (C, H, W)
# Global average pooling of gradients over spatial dims -> channel weights
weights = gradients.mean(dim=(1, 2)) # (C,)
# Weighted combination of activation maps
cam = torch.zeros(activations.shape[1:], dtype=activations.dtype) # (H, W)
for i, w in enumerate(weights):
cam += w * activations[i]
# ReLU — only positive contributions
cam = torch.clamp(cam, min=0)
cam_np = cam.cpu().detach().numpy()
# Normalize to 0-1
cam_min, cam_max = cam_np.min(), cam_np.max()
if cam_max - cam_min > 1e-6:
cam_np = (cam_np - cam_min) / (cam_max - cam_min)
else:
cam_np = np.zeros_like(cam_np)
# Resize to 224x224
cam_resized = cv2.resize(cam_np, (224, 224), interpolation=cv2.INTER_LINEAR)
return cam_resized
def remove_hooks(self):
"""Clean up hooks."""
for handle in self.handles:
handle.remove()
self.handles = []
def __del__(self):
self.remove_hooks()
# ============================================================================
# MRI Validation — Simplified + Confidence Gating
# ============================================================================
def validate_mri_basic(image_bytes):
"""
Basic sanity check on uploaded image.
Returns: (is_valid, message)
"""
try:
img = Image.open(io.BytesIO(image_bytes))
# Check format
if img.format and img.format not in ['JPEG', 'PNG', 'BMP', 'TIFF', 'GIF', 'MPO']:
return False, "Invalid image format. Please upload JPEG, PNG, BMP, GIF, or TIFF."
w, h = img.size
# Minimum size
if w < 50 or h < 50:
return False, "Image too small. Please upload a larger image."
# Check aspect ratio — brain MRI is roughly square
aspect = max(w, h) / max(min(w, h), 1)
if aspect > 3.0:
return False, "Image aspect ratio is unusual for a Brain MRI scan."
return True, "OK"
except Exception as e:
return False, f"Could not read image: {str(e)}"
def check_confidence_gate(probs, threshold=0.55):
"""
Reject images where the model is not confident — likely out-of-distribution.
Uses max probability + entropy check.
Args:
probs: softmax probabilities tensor (4,)
threshold: minimum confidence to accept
Returns:
(is_confident, max_prob, pred_class_idx)
"""
max_prob = probs.max().item()
pred_idx = probs.argmax().item()
# Entropy check: uniform distribution has max entropy
entropy = -(probs * torch.log(probs + 1e-9)).sum().item()
max_entropy = np.log(len(probs)) # ln(4) ≈ 1.386
normalized_entropy = entropy / max_entropy # 0 = certain, 1 = uniform
# Reject if confidence too low OR entropy too high
if max_prob < threshold or normalized_entropy > 0.85:
return False, max_prob, pred_idx
return True, max_prob, pred_idx
# ============================================================================
# Flask App Setup
# ============================================================================
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload
# Configuration
CLASS_NAMES = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']
DEVICE = torch.device("cpu")
MODEL_PATH = "Model.pt"
# Image preprocessing (same as training)
EVAL_TRANSFORM = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Global model
model = None
gradcam = None
# In-memory storage keyed by UUID (replaces broken session storage)
# Each entry: { 'original_bytes': bytes, 'gradcam_map': np.array, 'prediction': dict, 'timestamp': float }
analysis_store = {}
MAX_STORE_SIZE = 20 # Max concurrent sessions to keep in memory
def cleanup_old_entries():
"""Remove old entries to prevent memory leaks."""
import time
if len(analysis_store) > MAX_STORE_SIZE:
# Sort by timestamp, remove oldest
sorted_keys = sorted(analysis_store.keys(),
key=lambda k: analysis_store[k].get('timestamp', 0))
for key in sorted_keys[:len(sorted_keys) - MAX_STORE_SIZE // 2]:
del analysis_store[key]
gc.collect()
def load_model():
"""Load model on startup."""
global model, gradcam
if model is None:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file {MODEL_PATH} not found!")
print(f"Loading model from {MODEL_PATH}...")
model = LinearClassifier(num_classes=4).to(DEVICE)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
# GradCAM: hook on encoder.norm (last spatial layer before pooling)
# encoder.norm is a LayerNorm; its output is (B, H, W, C) = spatial features
gradcam = GradCAM(model, model.encoder.norm)
print("[SUCCESS] Model and GradCAM loaded successfully!")
# ============================================================================
# Routes
# ============================================================================
@app.route('/')
def index():
"""Render home page."""
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload():
"""Handle image upload, validate, classify, generate GradCAM."""
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
try:
# Read file bytes into memory (no temp files)
image_bytes = file.read()
# Basic validation
is_valid, message = validate_mri_basic(image_bytes)
if not is_valid:
return jsonify({
'error': message,
'warning': 'Please upload a valid Brain MRI image.'
}), 400
# Load and preprocess
img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
img_tensor = EVAL_TRANSFORM(img).unsqueeze(0).to(DEVICE)
# --- Classification ---
with torch.no_grad():
output = model(img_tensor)
probs = torch.softmax(output, dim=1)[0]
# Confidence gating — reject out-of-distribution images
is_confident, max_prob, pred_idx = check_confidence_gate(probs)
if not is_confident:
return jsonify({
'error': 'This image does not appear to be a valid Brain MRI scan. '
'The model could not confidently classify it.',
'warning': 'Please upload a valid Brain MRI image.'
}), 400
confidence = probs[pred_idx].item()
# --- GradCAM Generation ---
gradcam_map = None
try:
# Need a fresh tensor with grad tracking for GradCAM
img_tensor_gc = EVAL_TRANSFORM(img).unsqueeze(0).to(DEVICE)
gradcam_map = gradcam.generate(img_tensor_gc, pred_idx)
except Exception as e:
print(f"[WARN] GradCAM generation failed: {e}")
import traceback
traceback.print_exc()
gradcam_map = np.zeros((224, 224), dtype=np.float32)
# --- Store in memory with UUID ---
import time
session_id = str(uuid.uuid4())
cleanup_old_entries()
analysis_store[session_id] = {
'original_bytes': image_bytes,
'gradcam_map': gradcam_map,
'prediction': {
'class': CLASS_NAMES[pred_idx],
'confidence': float(confidence),
'class_idx': pred_idx,
'all_probs': {CLASS_NAMES[i]: float(probs[i].item()) for i in range(4)}
},
'timestamp': time.time()
}
return jsonify({
'success': True,
'session_id': session_id,
'prediction': CLASS_NAMES[pred_idx],
'confidence': f"{confidence*100:.1f}%",
'confidence_num': float(f"{confidence*100:.1f}"),
'all_probs': {CLASS_NAMES[i]: float(probs[i].item()) for i in range(4)}
})
except Exception as e:
print(f"Upload error: {e}")
import traceback
traceback.print_exc()
return jsonify({'error': f"Classification failed: {str(e)}"}), 500
@app.route('/get-gradcam', methods=['POST'])
def get_gradcam():
"""Return GradCAM visualization as base64 — overlaid on original image."""
data = request.json or {}
session_id = data.get('session_id', '')
if not session_id or session_id not in analysis_store:
return jsonify({'error': 'No analysis data found. Please analyze an image first.'}), 400
try:
entry = analysis_store[session_id]
gradcam_map = entry['gradcam_map']
original_bytes = entry['original_bytes']
# Load original image
original_img = Image.open(io.BytesIO(original_bytes)).convert('RGB')
original_img = original_img.resize((224, 224))
original_np = np.array(original_img) # (224, 224, 3) RGB
# Create GradCAM heatmap
heatmap = (gradcam_map * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # BGR
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) # RGB
# Overlay: blend original + heatmap
overlay = cv2.addWeighted(original_np, 0.55, heatmap_colored, 0.45, 0)
# Encode original to base64
orig_pil = Image.fromarray(original_np)
orig_buf = io.BytesIO()
orig_pil.save(orig_buf, format='JPEG', quality=90)
orig_b64 = base64.b64encode(orig_buf.getvalue()).decode()
# Encode overlay to base64
overlay_pil = Image.fromarray(overlay)
overlay_buf = io.BytesIO()
overlay_pil.save(overlay_buf, format='JPEG', quality=90)
overlay_b64 = base64.b64encode(overlay_buf.getvalue()).decode()
return jsonify({
'original': f"data:image/jpeg;base64,{orig_b64}",
'gradcam': f"data:image/jpeg;base64,{overlay_b64}"
})
except Exception as e:
print(f"GradCAM error: {e}")
import traceback
traceback.print_exc()
return jsonify({'error': f'Failed to generate GradCAM visualization: {str(e)}'}), 500
@app.route('/generate-report', methods=['POST'])
def generate_report():
"""Generate a professional PDF report filling one A4 page."""
data = request.json or {}
user_name = data.get('name', '').strip()
session_id = data.get('session_id', '')
if not user_name:
return jsonify({'error': 'Name is required to generate the report.'}), 400
if not session_id or session_id not in analysis_store:
return jsonify({'error': 'No analysis data found. Please analyze an image first.'}), 400
try:
entry = analysis_store[session_id]
pred_data = entry['prediction']
gradcam_map = entry['gradcam_map']
original_bytes = entry['original_bytes']
# --- Prepare images ---
# Original
original_img = Image.open(io.BytesIO(original_bytes)).convert('RGB')
original_img = original_img.resize((224, 224))
original_np = np.array(original_img)
# GradCAM overlay
heatmap = (gradcam_map * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
overlay = cv2.addWeighted(original_np, 0.55, heatmap_colored, 0.45, 0)
# Save images to BytesIO for ReportLab
orig_buf = io.BytesIO()
Image.fromarray(original_np).save(orig_buf, format='PNG')
orig_buf.seek(0)
overlay_buf = io.BytesIO()
Image.fromarray(overlay).save(overlay_buf, format='PNG')
overlay_buf.seek(0)
# --- Build PDF ---
pdf_buffer = io.BytesIO()
page_w, page_h = A4 # 595.27, 841.89 points
margin = 0.55 * inch
doc = SimpleDocTemplate(
pdf_buffer,
pagesize=A4,
rightMargin=margin,
leftMargin=margin,
topMargin=0.4 * inch,
bottomMargin=0.4 * inch
)
usable_w = page_w - 2 * margin
story = []
styles = getSampleStyleSheet()
# --- Custom Styles ---
title_style = ParagraphStyle(
'ReportTitle', parent=styles['Heading1'],
fontSize=22, textColor=colors.HexColor('#0D2B4E'),
spaceAfter=2, alignment=TA_CENTER, fontName='Helvetica-Bold'
)
subtitle_style = ParagraphStyle(
'ReportSubtitle', parent=styles['Normal'],
fontSize=10, textColor=colors.HexColor('#5A7DA0'),
spaceAfter=8, alignment=TA_CENTER, fontName='Helvetica'
)
section_style = ParagraphStyle(
'SectionHeader', parent=styles['Heading2'],
fontSize=13, textColor=colors.white,
spaceBefore=10, spaceAfter=4, fontName='Helvetica-Bold',
backColor=colors.HexColor('#1A3A5C'), leftIndent=6,
borderPadding=(4, 6, 4, 6)
)
label_style = ParagraphStyle(
'Label', parent=styles['Normal'],
fontSize=10, fontName='Helvetica-Bold',
textColor=colors.HexColor('#333333')
)
value_style = ParagraphStyle(
'Value', parent=styles['Normal'],
fontSize=10, fontName='Helvetica',
textColor=colors.HexColor('#1A3A5C')
)
disclaimer_style = ParagraphStyle(
'Disclaimer', parent=styles['Normal'],
fontSize=7.5, textColor=colors.HexColor('#888888'),
alignment=TA_CENTER, fontName='Helvetica-Oblique',
spaceBefore=6
)
# ========== HEADER ==========
story.append(Paragraph("🧠 Brain Tumor Classification Report", title_style))
story.append(Paragraph("AI-Powered MRI Analysis with Explainability · MoCo + Swin-Transformer", subtitle_style))
story.append(HRFlowable(width="100%", thickness=2, color=colors.HexColor('#1A3A5C')))
story.append(Spacer(1, 8))
# ========== PATIENT INFO ==========
story.append(Paragraph("Patient Information", section_style))
story.append(Spacer(1, 4))
now = datetime.now()
info_data = [
[Paragraph("<b>Patient Name</b>", label_style),
Paragraph(user_name, value_style),
Paragraph("<b>Report Date</b>", label_style),
Paragraph(now.strftime("%B %d, %Y"), value_style)],
[Paragraph("<b>Report Time</b>", label_style),
Paragraph(now.strftime("%I:%M %p"), value_style),
Paragraph("<b>Report ID</b>", label_style),
Paragraph(session_id[:8].upper(), value_style)],
]
info_table = Table(info_data, colWidths=[1.3*inch, 2.0*inch, 1.3*inch, 2.0*inch])
info_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (0, -1), colors.HexColor('#EDF2F7')),
('BACKGROUND', (2, 0), (2, -1), colors.HexColor('#EDF2F7')),
('TEXTCOLOR', (0, 0), (-1, -1), colors.black),
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('TOPPADDING', (0, 0), (-1, -1), 6),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('LEFTPADDING', (0, 0), (-1, -1), 8),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#CBD5E0')),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#1A3A5C')),
]))
story.append(info_table)
story.append(Spacer(1, 10))
# ========== CLASSIFICATION RESULT ==========
story.append(Paragraph("Classification Result", section_style))
story.append(Spacer(1, 4))
# Determine color based on tumor type
tumor_class = pred_data['class']
conf_pct = pred_data['confidence'] * 100
if tumor_class == 'No Tumor':
result_color = colors.HexColor('#2F855A') # green
result_bg = colors.HexColor('#F0FFF4')
else:
result_color = colors.HexColor('#C53030') # red
result_bg = colors.HexColor('#FFF5F5')
result_style = ParagraphStyle(
'ResultValue', parent=styles['Normal'],
fontSize=16, fontName='Helvetica-Bold',
textColor=result_color, alignment=TA_CENTER
)
conf_style = ParagraphStyle(
'ConfValue', parent=styles['Normal'],
fontSize=12, fontName='Helvetica-Bold',
textColor=colors.HexColor('#1A3A5C'), alignment=TA_CENTER
)
result_data = [
[Paragraph("<b>Diagnosis</b>", ParagraphStyle('x', parent=label_style, alignment=TA_CENTER)),
Paragraph("<b>Confidence</b>", ParagraphStyle('x', parent=label_style, alignment=TA_CENTER))],
[Paragraph(tumor_class, result_style),
Paragraph(f"{conf_pct:.2f}%", conf_style)],
]
result_table = Table(result_data, colWidths=[usable_w * 0.55, usable_w * 0.45])
result_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#EDF2F7')),
('BACKGROUND', (0, 1), (0, 1), result_bg),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
('TOPPADDING', (0, 0), (-1, -1), 8),
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#CBD5E0')),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#1A3A5C')),
]))
story.append(result_table)
story.append(Spacer(1, 10))
# ========== ALL PROBABILITIES ==========
story.append(Paragraph("Class Probabilities", section_style))
story.append(Spacer(1, 4))
prob_header = [
Paragraph("<b>Tumor Class</b>", label_style),
Paragraph("<b>Probability</b>", label_style),
Paragraph("<b>Score</b>", label_style),
]
prob_rows = [prob_header]
for cls_name, prob_val in pred_data['all_probs'].items():
pct = prob_val * 100
bar_width = max(int(pct / 100 * 30), 1)
bar_str = "█" * bar_width + "░" * (30 - bar_width)
is_pred = cls_name == tumor_class
fn = 'Helvetica-Bold' if is_pred else 'Helvetica'
tc = colors.HexColor('#1A3A5C') if is_pred else colors.HexColor('#555555')
row_style = ParagraphStyle('ps', parent=styles['Normal'], fontSize=9, fontName=fn, textColor=tc)
prob_rows.append([
Paragraph(cls_name, row_style),
Paragraph(f"{pct:.2f}%", row_style),
Paragraph(f'<font face="Courier" size="7">{bar_str}</font>', row_style),
])
prob_table = Table(prob_rows, colWidths=[1.5*inch, 1.0*inch, usable_w - 2.5*inch])
prob_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#1A3A5C')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('TOPPADDING', (0, 0), (-1, -1), 5),
('BOTTOMPADDING', (0, 0), (-1, -1), 5),
('LEFTPADDING', (0, 0), (-1, -1), 8),
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#F7FAFC')]),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#CBD5E0')),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#1A3A5C')),
]))
story.append(prob_table)
story.append(Spacer(1, 10))
# ========== VISUALIZATIONS ==========
story.append(Paragraph("MRI Visualization & GradCAM Analysis", section_style))
story.append(Spacer(1, 4))
img_width = 2.6 * inch
img_height = 2.6 * inch
orig_rl = RLImage(orig_buf, width=img_width, height=img_height)
overlay_rl = RLImage(overlay_buf, width=img_width, height=img_height)
img_label_style = ParagraphStyle(
'ImgLabel', parent=styles['Normal'],
fontSize=9, fontName='Helvetica-Bold',
textColor=colors.HexColor('#1A3A5C'), alignment=TA_CENTER,
spaceBefore=4
)
vis_data = [
[orig_rl, Spacer(0.2*inch, 0), overlay_rl],
[Paragraph("Original MRI Scan", img_label_style), '',
Paragraph("GradCAM Attention Map", img_label_style)],
]
vis_table = Table(vis_data, colWidths=[img_width + 0.1*inch, 0.3*inch, img_width + 0.1*inch])
vis_table.setStyle(TableStyle([
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
('TOPPADDING', (0, 0), (-1, -1), 4),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
]))
story.append(vis_table)
story.append(Spacer(1, 6))
gradcam_note = ParagraphStyle(
'GCNote', parent=styles['Normal'],
fontSize=8, fontName='Helvetica-Oblique',
textColor=colors.HexColor('#666666'), alignment=TA_CENTER
)
story.append(Paragraph(
"The GradCAM visualization highlights regions of the MRI that most influenced the model's prediction. "
"Warmer colors (red/yellow) indicate higher importance.", gradcam_note))
story.append(Spacer(1, 10))
# ========== FOOTER ==========
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#CBD5E0')))
story.append(Paragraph(
"⚠ DISCLAIMER: This report is generated by an AI model for research and educational purposes only. "
"It is NOT a substitute for professional medical diagnosis. Please consult a qualified healthcare "
"provider for clinical interpretation of MRI scans.",
disclaimer_style
))
story.append(Paragraph(
f"Generated on {now.strftime('%B %d, %Y at %I:%M %p')} · "
"Brain Tumor Classification System · MoCo SSL + Swin-Transformer",
ParagraphStyle('Footer2', parent=disclaimer_style, spaceBefore=2)
))
# Build PDF
doc.build(story)
pdf_buffer.seek(0)
return send_file(
pdf_buffer,
mimetype='application/pdf',
as_attachment=True,
download_name=f"Brain_Tumor_Report_{user_name.replace(' ', '_')}.pdf"
)
except Exception as e:
print(f"Report generation error: {e}")
import traceback
traceback.print_exc()
return jsonify({'error': f'Report generation failed: {str(e)}'}), 500
# ============================================================================
# Error Handlers
# ============================================================================
@app.errorhandler(413)
def too_large(e):
return jsonify({'error': 'File too large. Maximum size is 16MB.'}), 413
if __name__ == '__main__':
try:
load_model()
print("Starting Flask application...")
# Support both HF Spaces (7860) and local development (5000)
port = int(os.environ.get('PORT', 7860))
print(f"Running on http://0.0.0.0:{port}")
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)
except Exception as e:
print(f"Fatal error: {e}")
import traceback
traceback.print_exc()
exit(1)