Spaces:
Sleeping
Sleeping
| """ | |
| Brain Tumor Classification Web App | |
| Optimized for Hugging Face Spaces deployment | |
| - Detects Glioma, Meningioma, No Tumor, Pituitary | |
| - Generates GradCAM visualizations | |
| - Creates PDF reports | |
| """ | |
| import os | |
| import io | |
| import uuid | |
| import cv2 | |
| import torch | |
| import base64 | |
| import numpy as np | |
| from PIL import Image | |
| from datetime import datetime | |
| from flask import Flask, render_template, request, jsonify, send_file | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.lib import colors | |
| from reportlab.platypus import ( | |
| SimpleDocTemplate, Table, TableStyle, Paragraph, Spacer, | |
| Image as RLImage, HRFlowable | |
| ) | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.lib.units import inch, mm | |
| from reportlab.lib.enums import TA_CENTER, TA_LEFT, TA_RIGHT | |
| import gc | |
| # ============================================================================ | |
| # Model Architecture (same as training) | |
| # ============================================================================ | |
| class FFEBlock(nn.Module): | |
| """Swin-T backbone with GradCAM support.""" | |
| def __init__(self): | |
| super().__init__() | |
| from torchvision import models | |
| backbone = models.swin_t(weights=None) | |
| self.features = backbone.features | |
| self.norm = backbone.norm | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| self.out_dim = 768 | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.norm(x) | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.pool(x) | |
| return x.flatten(1) | |
| def forward_features(self, x): | |
| """Return spatial features before pooling (for GradCAM)""" | |
| x = self.features(x) | |
| x = self.norm(x) | |
| x = x.permute(0, 3, 1, 2) # [B, C, H, W] | |
| return x | |
| class ProjectionMLP(nn.Module): | |
| """NRPL projector: 768 -> 384 -> 192""" | |
| def __init__(self, in_dim=768, hidden_dim=384, out_dim=192): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, out_dim) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class LinearClassifier(nn.Module): | |
| """Classification model for deployment.""" | |
| def __init__(self, num_classes=4): | |
| super().__init__() | |
| self.encoder = FFEBlock() | |
| self.projector = ProjectionMLP(in_dim=768, hidden_dim=384, out_dim=192) | |
| self.head = nn.Linear(192, num_classes) | |
| def forward(self, x): | |
| h = self.encoder(x) | |
| z = self.projector(h) | |
| return self.head(z) | |
| def forward_features(self, x): | |
| """Return spatial features before pooling (for GradCAM)""" | |
| return self.encoder.forward_features(x) | |
| def predict_proba(self, x): | |
| return torch.softmax(self.forward(x), dim=1) | |
| # ============================================================================ | |
| # GradCAM Implementation (Memory Efficient) — FIXED | |
| # ============================================================================ | |
| class GradCAM: | |
| """Gradient-based Class Activation Map for visualization.""" | |
| def __init__(self, model, target_layer): | |
| """ | |
| Args: | |
| model: The full LinearClassifier model | |
| target_layer: The layer to hook — should be encoder.norm for Swin-T | |
| """ | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| self.handles = [] | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| """Register forward and backward hooks.""" | |
| def forward_hook(module, input, output): | |
| # encoder.norm output shape: (B, H, W, C) | |
| self.activations = output.detach().clone() | |
| def backward_hook(module, grad_input, grad_output): | |
| # grad_output[0] shape: (B, H, W, C) | |
| self.gradients = grad_output[0].detach().clone() | |
| h1 = self.target_layer.register_forward_hook(forward_hook) | |
| # Use register_full_backward_hook (non-deprecated) | |
| h2 = self.target_layer.register_full_backward_hook(backward_hook) | |
| self.handles = [h1, h2] | |
| def generate(self, input_tensor, class_idx): | |
| """ | |
| Generate GradCAM for a specific class. | |
| Args: | |
| input_tensor: Input image [1, 3, 224, 224] | |
| class_idx: Class index to visualize | |
| Returns: | |
| cam: GradCAM heatmap [224, 224] normalized to 0-1 | |
| """ | |
| self.model.eval() | |
| self.gradients = None | |
| self.activations = None | |
| # Forward pass — need gradients, so NOT under torch.no_grad() | |
| input_tensor.requires_grad_(True) | |
| output = self.model(input_tensor) | |
| # Backward pass for target class | |
| self.model.zero_grad() | |
| target = output[0, class_idx] | |
| target.backward() | |
| # Check hooks fired | |
| if self.gradients is None or self.activations is None: | |
| print("[WARN] GradCAM hooks did not fire") | |
| return np.ones((224, 224), dtype=np.float32) * 0.5 | |
| # encoder.norm output is (B, H, W, C) — take first batch | |
| gradients = self.gradients[0] # (H, W, C) | |
| activations = self.activations[0] # (H, W, C) | |
| # Permute to (C, H, W) for standard GradCAM computation | |
| gradients = gradients.permute(2, 0, 1) # (C, H, W) | |
| activations = activations.permute(2, 0, 1) # (C, H, W) | |
| # Global average pooling of gradients over spatial dims -> channel weights | |
| weights = gradients.mean(dim=(1, 2)) # (C,) | |
| # Weighted combination of activation maps | |
| cam = torch.zeros(activations.shape[1:], dtype=activations.dtype) # (H, W) | |
| for i, w in enumerate(weights): | |
| cam += w * activations[i] | |
| # ReLU — only positive contributions | |
| cam = torch.clamp(cam, min=0) | |
| cam_np = cam.cpu().detach().numpy() | |
| # Normalize to 0-1 | |
| cam_min, cam_max = cam_np.min(), cam_np.max() | |
| if cam_max - cam_min > 1e-6: | |
| cam_np = (cam_np - cam_min) / (cam_max - cam_min) | |
| else: | |
| cam_np = np.zeros_like(cam_np) | |
| # Resize to 224x224 | |
| cam_resized = cv2.resize(cam_np, (224, 224), interpolation=cv2.INTER_LINEAR) | |
| return cam_resized | |
| def remove_hooks(self): | |
| """Clean up hooks.""" | |
| for handle in self.handles: | |
| handle.remove() | |
| self.handles = [] | |
| def __del__(self): | |
| self.remove_hooks() | |
| # ============================================================================ | |
| # MRI Validation — Simplified + Confidence Gating | |
| # ============================================================================ | |
| def validate_mri_basic(image_bytes): | |
| """ | |
| Basic sanity check on uploaded image. | |
| Returns: (is_valid, message) | |
| """ | |
| try: | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| # Check format | |
| if img.format and img.format not in ['JPEG', 'PNG', 'BMP', 'TIFF', 'GIF', 'MPO']: | |
| return False, "Invalid image format. Please upload JPEG, PNG, BMP, GIF, or TIFF." | |
| w, h = img.size | |
| # Minimum size | |
| if w < 50 or h < 50: | |
| return False, "Image too small. Please upload a larger image." | |
| # Check aspect ratio — brain MRI is roughly square | |
| aspect = max(w, h) / max(min(w, h), 1) | |
| if aspect > 3.0: | |
| return False, "Image aspect ratio is unusual for a Brain MRI scan." | |
| return True, "OK" | |
| except Exception as e: | |
| return False, f"Could not read image: {str(e)}" | |
| def check_confidence_gate(probs, threshold=0.55): | |
| """ | |
| Reject images where the model is not confident — likely out-of-distribution. | |
| Uses max probability + entropy check. | |
| Args: | |
| probs: softmax probabilities tensor (4,) | |
| threshold: minimum confidence to accept | |
| Returns: | |
| (is_confident, max_prob, pred_class_idx) | |
| """ | |
| max_prob = probs.max().item() | |
| pred_idx = probs.argmax().item() | |
| # Entropy check: uniform distribution has max entropy | |
| entropy = -(probs * torch.log(probs + 1e-9)).sum().item() | |
| max_entropy = np.log(len(probs)) # ln(4) ≈ 1.386 | |
| normalized_entropy = entropy / max_entropy # 0 = certain, 1 = uniform | |
| # Reject if confidence too low OR entropy too high | |
| if max_prob < threshold or normalized_entropy > 0.85: | |
| return False, max_prob, pred_idx | |
| return True, max_prob, pred_idx | |
| # ============================================================================ | |
| # Flask App Setup | |
| # ============================================================================ | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload | |
| # Configuration | |
| CLASS_NAMES = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary'] | |
| DEVICE = torch.device("cpu") | |
| MODEL_PATH = "Model.pt" | |
| # Image preprocessing (same as training) | |
| EVAL_TRANSFORM = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Global model | |
| model = None | |
| gradcam = None | |
| # In-memory storage keyed by UUID (replaces broken session storage) | |
| # Each entry: { 'original_bytes': bytes, 'gradcam_map': np.array, 'prediction': dict, 'timestamp': float } | |
| analysis_store = {} | |
| MAX_STORE_SIZE = 20 # Max concurrent sessions to keep in memory | |
| def cleanup_old_entries(): | |
| """Remove old entries to prevent memory leaks.""" | |
| import time | |
| if len(analysis_store) > MAX_STORE_SIZE: | |
| # Sort by timestamp, remove oldest | |
| sorted_keys = sorted(analysis_store.keys(), | |
| key=lambda k: analysis_store[k].get('timestamp', 0)) | |
| for key in sorted_keys[:len(sorted_keys) - MAX_STORE_SIZE // 2]: | |
| del analysis_store[key] | |
| gc.collect() | |
| def load_model(): | |
| """Load model on startup.""" | |
| global model, gradcam | |
| if model is None: | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file {MODEL_PATH} not found!") | |
| print(f"Loading model from {MODEL_PATH}...") | |
| model = LinearClassifier(num_classes=4).to(DEVICE) | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # GradCAM: hook on encoder.norm (last spatial layer before pooling) | |
| # encoder.norm is a LayerNorm; its output is (B, H, W, C) = spatial features | |
| gradcam = GradCAM(model, model.encoder.norm) | |
| print("[SUCCESS] Model and GradCAM loaded successfully!") | |
| # ============================================================================ | |
| # Routes | |
| # ============================================================================ | |
| def index(): | |
| """Render home page.""" | |
| return render_template('index.html') | |
| def upload(): | |
| """Handle image upload, validate, classify, generate GradCAM.""" | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file provided'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| try: | |
| # Read file bytes into memory (no temp files) | |
| image_bytes = file.read() | |
| # Basic validation | |
| is_valid, message = validate_mri_basic(image_bytes) | |
| if not is_valid: | |
| return jsonify({ | |
| 'error': message, | |
| 'warning': 'Please upload a valid Brain MRI image.' | |
| }), 400 | |
| # Load and preprocess | |
| img = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| img_tensor = EVAL_TRANSFORM(img).unsqueeze(0).to(DEVICE) | |
| # --- Classification --- | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probs = torch.softmax(output, dim=1)[0] | |
| # Confidence gating — reject out-of-distribution images | |
| is_confident, max_prob, pred_idx = check_confidence_gate(probs) | |
| if not is_confident: | |
| return jsonify({ | |
| 'error': 'This image does not appear to be a valid Brain MRI scan. ' | |
| 'The model could not confidently classify it.', | |
| 'warning': 'Please upload a valid Brain MRI image.' | |
| }), 400 | |
| confidence = probs[pred_idx].item() | |
| # --- GradCAM Generation --- | |
| gradcam_map = None | |
| try: | |
| # Need a fresh tensor with grad tracking for GradCAM | |
| img_tensor_gc = EVAL_TRANSFORM(img).unsqueeze(0).to(DEVICE) | |
| gradcam_map = gradcam.generate(img_tensor_gc, pred_idx) | |
| except Exception as e: | |
| print(f"[WARN] GradCAM generation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| gradcam_map = np.zeros((224, 224), dtype=np.float32) | |
| # --- Store in memory with UUID --- | |
| import time | |
| session_id = str(uuid.uuid4()) | |
| cleanup_old_entries() | |
| analysis_store[session_id] = { | |
| 'original_bytes': image_bytes, | |
| 'gradcam_map': gradcam_map, | |
| 'prediction': { | |
| 'class': CLASS_NAMES[pred_idx], | |
| 'confidence': float(confidence), | |
| 'class_idx': pred_idx, | |
| 'all_probs': {CLASS_NAMES[i]: float(probs[i].item()) for i in range(4)} | |
| }, | |
| 'timestamp': time.time() | |
| } | |
| return jsonify({ | |
| 'success': True, | |
| 'session_id': session_id, | |
| 'prediction': CLASS_NAMES[pred_idx], | |
| 'confidence': f"{confidence*100:.1f}%", | |
| 'confidence_num': float(f"{confidence*100:.1f}"), | |
| 'all_probs': {CLASS_NAMES[i]: float(probs[i].item()) for i in range(4)} | |
| }) | |
| except Exception as e: | |
| print(f"Upload error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': f"Classification failed: {str(e)}"}), 500 | |
| def get_gradcam(): | |
| """Return GradCAM visualization as base64 — overlaid on original image.""" | |
| data = request.json or {} | |
| session_id = data.get('session_id', '') | |
| if not session_id or session_id not in analysis_store: | |
| return jsonify({'error': 'No analysis data found. Please analyze an image first.'}), 400 | |
| try: | |
| entry = analysis_store[session_id] | |
| gradcam_map = entry['gradcam_map'] | |
| original_bytes = entry['original_bytes'] | |
| # Load original image | |
| original_img = Image.open(io.BytesIO(original_bytes)).convert('RGB') | |
| original_img = original_img.resize((224, 224)) | |
| original_np = np.array(original_img) # (224, 224, 3) RGB | |
| # Create GradCAM heatmap | |
| heatmap = (gradcam_map * 255).astype(np.uint8) | |
| heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # BGR | |
| heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) # RGB | |
| # Overlay: blend original + heatmap | |
| overlay = cv2.addWeighted(original_np, 0.55, heatmap_colored, 0.45, 0) | |
| # Encode original to base64 | |
| orig_pil = Image.fromarray(original_np) | |
| orig_buf = io.BytesIO() | |
| orig_pil.save(orig_buf, format='JPEG', quality=90) | |
| orig_b64 = base64.b64encode(orig_buf.getvalue()).decode() | |
| # Encode overlay to base64 | |
| overlay_pil = Image.fromarray(overlay) | |
| overlay_buf = io.BytesIO() | |
| overlay_pil.save(overlay_buf, format='JPEG', quality=90) | |
| overlay_b64 = base64.b64encode(overlay_buf.getvalue()).decode() | |
| return jsonify({ | |
| 'original': f"data:image/jpeg;base64,{orig_b64}", | |
| 'gradcam': f"data:image/jpeg;base64,{overlay_b64}" | |
| }) | |
| except Exception as e: | |
| print(f"GradCAM error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': f'Failed to generate GradCAM visualization: {str(e)}'}), 500 | |
| def generate_report(): | |
| """Generate a professional PDF report filling one A4 page.""" | |
| data = request.json or {} | |
| user_name = data.get('name', '').strip() | |
| session_id = data.get('session_id', '') | |
| if not user_name: | |
| return jsonify({'error': 'Name is required to generate the report.'}), 400 | |
| if not session_id or session_id not in analysis_store: | |
| return jsonify({'error': 'No analysis data found. Please analyze an image first.'}), 400 | |
| try: | |
| entry = analysis_store[session_id] | |
| pred_data = entry['prediction'] | |
| gradcam_map = entry['gradcam_map'] | |
| original_bytes = entry['original_bytes'] | |
| # --- Prepare images --- | |
| # Original | |
| original_img = Image.open(io.BytesIO(original_bytes)).convert('RGB') | |
| original_img = original_img.resize((224, 224)) | |
| original_np = np.array(original_img) | |
| # GradCAM overlay | |
| heatmap = (gradcam_map * 255).astype(np.uint8) | |
| heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
| overlay = cv2.addWeighted(original_np, 0.55, heatmap_colored, 0.45, 0) | |
| # Save images to BytesIO for ReportLab | |
| orig_buf = io.BytesIO() | |
| Image.fromarray(original_np).save(orig_buf, format='PNG') | |
| orig_buf.seek(0) | |
| overlay_buf = io.BytesIO() | |
| Image.fromarray(overlay).save(overlay_buf, format='PNG') | |
| overlay_buf.seek(0) | |
| # --- Build PDF --- | |
| pdf_buffer = io.BytesIO() | |
| page_w, page_h = A4 # 595.27, 841.89 points | |
| margin = 0.55 * inch | |
| doc = SimpleDocTemplate( | |
| pdf_buffer, | |
| pagesize=A4, | |
| rightMargin=margin, | |
| leftMargin=margin, | |
| topMargin=0.4 * inch, | |
| bottomMargin=0.4 * inch | |
| ) | |
| usable_w = page_w - 2 * margin | |
| story = [] | |
| styles = getSampleStyleSheet() | |
| # --- Custom Styles --- | |
| title_style = ParagraphStyle( | |
| 'ReportTitle', parent=styles['Heading1'], | |
| fontSize=22, textColor=colors.HexColor('#0D2B4E'), | |
| spaceAfter=2, alignment=TA_CENTER, fontName='Helvetica-Bold' | |
| ) | |
| subtitle_style = ParagraphStyle( | |
| 'ReportSubtitle', parent=styles['Normal'], | |
| fontSize=10, textColor=colors.HexColor('#5A7DA0'), | |
| spaceAfter=8, alignment=TA_CENTER, fontName='Helvetica' | |
| ) | |
| section_style = ParagraphStyle( | |
| 'SectionHeader', parent=styles['Heading2'], | |
| fontSize=13, textColor=colors.white, | |
| spaceBefore=10, spaceAfter=4, fontName='Helvetica-Bold', | |
| backColor=colors.HexColor('#1A3A5C'), leftIndent=6, | |
| borderPadding=(4, 6, 4, 6) | |
| ) | |
| label_style = ParagraphStyle( | |
| 'Label', parent=styles['Normal'], | |
| fontSize=10, fontName='Helvetica-Bold', | |
| textColor=colors.HexColor('#333333') | |
| ) | |
| value_style = ParagraphStyle( | |
| 'Value', parent=styles['Normal'], | |
| fontSize=10, fontName='Helvetica', | |
| textColor=colors.HexColor('#1A3A5C') | |
| ) | |
| disclaimer_style = ParagraphStyle( | |
| 'Disclaimer', parent=styles['Normal'], | |
| fontSize=7.5, textColor=colors.HexColor('#888888'), | |
| alignment=TA_CENTER, fontName='Helvetica-Oblique', | |
| spaceBefore=6 | |
| ) | |
| # ========== HEADER ========== | |
| story.append(Paragraph("🧠 Brain Tumor Classification Report", title_style)) | |
| story.append(Paragraph("AI-Powered MRI Analysis with Explainability · MoCo + Swin-Transformer", subtitle_style)) | |
| story.append(HRFlowable(width="100%", thickness=2, color=colors.HexColor('#1A3A5C'))) | |
| story.append(Spacer(1, 8)) | |
| # ========== PATIENT INFO ========== | |
| story.append(Paragraph("Patient Information", section_style)) | |
| story.append(Spacer(1, 4)) | |
| now = datetime.now() | |
| info_data = [ | |
| [Paragraph("<b>Patient Name</b>", label_style), | |
| Paragraph(user_name, value_style), | |
| Paragraph("<b>Report Date</b>", label_style), | |
| Paragraph(now.strftime("%B %d, %Y"), value_style)], | |
| [Paragraph("<b>Report Time</b>", label_style), | |
| Paragraph(now.strftime("%I:%M %p"), value_style), | |
| Paragraph("<b>Report ID</b>", label_style), | |
| Paragraph(session_id[:8].upper(), value_style)], | |
| ] | |
| info_table = Table(info_data, colWidths=[1.3*inch, 2.0*inch, 1.3*inch, 2.0*inch]) | |
| info_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (0, -1), colors.HexColor('#EDF2F7')), | |
| ('BACKGROUND', (2, 0), (2, -1), colors.HexColor('#EDF2F7')), | |
| ('TEXTCOLOR', (0, 0), (-1, -1), colors.black), | |
| ('ALIGN', (0, 0), (-1, -1), 'LEFT'), | |
| ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), | |
| ('FONTSIZE', (0, 0), (-1, -1), 10), | |
| ('TOPPADDING', (0, 0), (-1, -1), 6), | |
| ('BOTTOMPADDING', (0, 0), (-1, -1), 6), | |
| ('LEFTPADDING', (0, 0), (-1, -1), 8), | |
| ('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#CBD5E0')), | |
| ('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#1A3A5C')), | |
| ])) | |
| story.append(info_table) | |
| story.append(Spacer(1, 10)) | |
| # ========== CLASSIFICATION RESULT ========== | |
| story.append(Paragraph("Classification Result", section_style)) | |
| story.append(Spacer(1, 4)) | |
| # Determine color based on tumor type | |
| tumor_class = pred_data['class'] | |
| conf_pct = pred_data['confidence'] * 100 | |
| if tumor_class == 'No Tumor': | |
| result_color = colors.HexColor('#2F855A') # green | |
| result_bg = colors.HexColor('#F0FFF4') | |
| else: | |
| result_color = colors.HexColor('#C53030') # red | |
| result_bg = colors.HexColor('#FFF5F5') | |
| result_style = ParagraphStyle( | |
| 'ResultValue', parent=styles['Normal'], | |
| fontSize=16, fontName='Helvetica-Bold', | |
| textColor=result_color, alignment=TA_CENTER | |
| ) | |
| conf_style = ParagraphStyle( | |
| 'ConfValue', parent=styles['Normal'], | |
| fontSize=12, fontName='Helvetica-Bold', | |
| textColor=colors.HexColor('#1A3A5C'), alignment=TA_CENTER | |
| ) | |
| result_data = [ | |
| [Paragraph("<b>Diagnosis</b>", ParagraphStyle('x', parent=label_style, alignment=TA_CENTER)), | |
| Paragraph("<b>Confidence</b>", ParagraphStyle('x', parent=label_style, alignment=TA_CENTER))], | |
| [Paragraph(tumor_class, result_style), | |
| Paragraph(f"{conf_pct:.2f}%", conf_style)], | |
| ] | |
| result_table = Table(result_data, colWidths=[usable_w * 0.55, usable_w * 0.45]) | |
| result_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#EDF2F7')), | |
| ('BACKGROUND', (0, 1), (0, 1), result_bg), | |
| ('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
| ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), | |
| ('TOPPADDING', (0, 0), (-1, -1), 8), | |
| ('BOTTOMPADDING', (0, 0), (-1, -1), 8), | |
| ('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#CBD5E0')), | |
| ('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#1A3A5C')), | |
| ])) | |
| story.append(result_table) | |
| story.append(Spacer(1, 10)) | |
| # ========== ALL PROBABILITIES ========== | |
| story.append(Paragraph("Class Probabilities", section_style)) | |
| story.append(Spacer(1, 4)) | |
| prob_header = [ | |
| Paragraph("<b>Tumor Class</b>", label_style), | |
| Paragraph("<b>Probability</b>", label_style), | |
| Paragraph("<b>Score</b>", label_style), | |
| ] | |
| prob_rows = [prob_header] | |
| for cls_name, prob_val in pred_data['all_probs'].items(): | |
| pct = prob_val * 100 | |
| bar_width = max(int(pct / 100 * 30), 1) | |
| bar_str = "█" * bar_width + "░" * (30 - bar_width) | |
| is_pred = cls_name == tumor_class | |
| fn = 'Helvetica-Bold' if is_pred else 'Helvetica' | |
| tc = colors.HexColor('#1A3A5C') if is_pred else colors.HexColor('#555555') | |
| row_style = ParagraphStyle('ps', parent=styles['Normal'], fontSize=9, fontName=fn, textColor=tc) | |
| prob_rows.append([ | |
| Paragraph(cls_name, row_style), | |
| Paragraph(f"{pct:.2f}%", row_style), | |
| Paragraph(f'<font face="Courier" size="7">{bar_str}</font>', row_style), | |
| ]) | |
| prob_table = Table(prob_rows, colWidths=[1.5*inch, 1.0*inch, usable_w - 2.5*inch]) | |
| prob_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#1A3A5C')), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.white), | |
| ('ALIGN', (0, 0), (-1, -1), 'LEFT'), | |
| ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ('TOPPADDING', (0, 0), (-1, -1), 5), | |
| ('BOTTOMPADDING', (0, 0), (-1, -1), 5), | |
| ('LEFTPADDING', (0, 0), (-1, -1), 8), | |
| ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#F7FAFC')]), | |
| ('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#CBD5E0')), | |
| ('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#1A3A5C')), | |
| ])) | |
| story.append(prob_table) | |
| story.append(Spacer(1, 10)) | |
| # ========== VISUALIZATIONS ========== | |
| story.append(Paragraph("MRI Visualization & GradCAM Analysis", section_style)) | |
| story.append(Spacer(1, 4)) | |
| img_width = 2.6 * inch | |
| img_height = 2.6 * inch | |
| orig_rl = RLImage(orig_buf, width=img_width, height=img_height) | |
| overlay_rl = RLImage(overlay_buf, width=img_width, height=img_height) | |
| img_label_style = ParagraphStyle( | |
| 'ImgLabel', parent=styles['Normal'], | |
| fontSize=9, fontName='Helvetica-Bold', | |
| textColor=colors.HexColor('#1A3A5C'), alignment=TA_CENTER, | |
| spaceBefore=4 | |
| ) | |
| vis_data = [ | |
| [orig_rl, Spacer(0.2*inch, 0), overlay_rl], | |
| [Paragraph("Original MRI Scan", img_label_style), '', | |
| Paragraph("GradCAM Attention Map", img_label_style)], | |
| ] | |
| vis_table = Table(vis_data, colWidths=[img_width + 0.1*inch, 0.3*inch, img_width + 0.1*inch]) | |
| vis_table.setStyle(TableStyle([ | |
| ('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
| ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), | |
| ('TOPPADDING', (0, 0), (-1, -1), 4), | |
| ('BOTTOMPADDING', (0, 0), (-1, -1), 4), | |
| ])) | |
| story.append(vis_table) | |
| story.append(Spacer(1, 6)) | |
| gradcam_note = ParagraphStyle( | |
| 'GCNote', parent=styles['Normal'], | |
| fontSize=8, fontName='Helvetica-Oblique', | |
| textColor=colors.HexColor('#666666'), alignment=TA_CENTER | |
| ) | |
| story.append(Paragraph( | |
| "The GradCAM visualization highlights regions of the MRI that most influenced the model's prediction. " | |
| "Warmer colors (red/yellow) indicate higher importance.", gradcam_note)) | |
| story.append(Spacer(1, 10)) | |
| # ========== FOOTER ========== | |
| story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#CBD5E0'))) | |
| story.append(Paragraph( | |
| "⚠ DISCLAIMER: This report is generated by an AI model for research and educational purposes only. " | |
| "It is NOT a substitute for professional medical diagnosis. Please consult a qualified healthcare " | |
| "provider for clinical interpretation of MRI scans.", | |
| disclaimer_style | |
| )) | |
| story.append(Paragraph( | |
| f"Generated on {now.strftime('%B %d, %Y at %I:%M %p')} · " | |
| "Brain Tumor Classification System · MoCo SSL + Swin-Transformer", | |
| ParagraphStyle('Footer2', parent=disclaimer_style, spaceBefore=2) | |
| )) | |
| # Build PDF | |
| doc.build(story) | |
| pdf_buffer.seek(0) | |
| return send_file( | |
| pdf_buffer, | |
| mimetype='application/pdf', | |
| as_attachment=True, | |
| download_name=f"Brain_Tumor_Report_{user_name.replace(' ', '_')}.pdf" | |
| ) | |
| except Exception as e: | |
| print(f"Report generation error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': f'Report generation failed: {str(e)}'}), 500 | |
| # ============================================================================ | |
| # Error Handlers | |
| # ============================================================================ | |
| def too_large(e): | |
| return jsonify({'error': 'File too large. Maximum size is 16MB.'}), 413 | |
| if __name__ == '__main__': | |
| try: | |
| load_model() | |
| print("Starting Flask application...") | |
| # Support both HF Spaces (7860) and local development (5000) | |
| port = int(os.environ.get('PORT', 7860)) | |
| print(f"Running on http://0.0.0.0:{port}") | |
| app.run(host='0.0.0.0', port=port, debug=False, threaded=True) | |
| except Exception as e: | |
| print(f"Fatal error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| exit(1) | |