import gradio as gr import torch import torch.nn as nn import numpy as np import cv2 from PIL import Image import matplotlib.pyplot as plt import io import torchvision.transforms as transforms import torchvision.transforms.functional as TF import random import os import urllib.request import kagglehub from glob import glob # Global variables - loaded once at startup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = None dataset_images = [] dataset_masks = [] dataset_loaded = False print("="*50) print("BRAIN TUMOR SEGMENTATION APPLICATION") print("="*50) # Your Attention U-Net classes (unchanged) class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x) class AttentionBlock(nn.Module): def __init__(self, F_g, F_l, F_int): super(AttentionBlock, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi, psi # Return both attended features AND attention map class AttentionUNET(nn.Module): def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): super(AttentionUNET, self).__init__() self.out_channels = out_channels self.ups = nn.ModuleList() self.downs = nn.ModuleList() self.attentions = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) for feature in features: self.downs.append(DoubleConv(in_channels, feature)) in_channels = feature self.bottleneck = DoubleConv(features[-1], features[-1]*2) for feature in reversed(features): self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) self.ups.append(DoubleConv(feature*2, feature)) self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) def forward(self, x): skip_connections = [] attention_maps = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx//2] if x.shape != skip_connection.shape: x = TF.resize(x, size=skip_connection.shape[2:]) attended_skip, att_map = self.attentions[idx // 2](x, skip_connection) attention_maps.append(att_map) concat_skip = torch.cat((attended_skip, x), dim=1) x = self.ups[idx+1](concat_skip) return self.final_conv(x), attention_maps def download_and_load_model(): """Download and load model once at startup""" global model print("Loading Attention U-Net model...") model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" model_path = "best_attention_model.pth.tar" # Download model if needed if not os.path.exists(model_path): print("Downloading model weights...") try: urllib.request.urlretrieve(model_url, model_path) except Exception as e: print(f"Failed to download model: {e}") return False # Load model try: model = AttentionUNET(in_channels=1, out_channels=1).to(device) checkpoint = torch.load(model_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint["state_dict"]) model.eval() print("✓ Model loaded successfully!") return True except Exception as e: print(f"Failed to load model: {e}") return False def download_and_load_dataset(): """Download and load entire dataset once at startup""" global dataset_images, dataset_masks, dataset_loaded if dataset_loaded: return True print("Loading brain tumor dataset...") try: # Download dataset using kagglehub - returns directory path dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') print(f"Dataset downloaded to: {dataset_path}") # Find images and masks directories images_dir = os.path.join(dataset_path, 'images') masks_dir = os.path.join(dataset_path, 'masks') # If direct path doesn't exist, search subdirectories if not os.path.exists(images_dir): # Search for images and masks directories for root, dirs, files in os.walk(dataset_path): if 'images' in dirs: images_dir = os.path.join(root, 'images') if 'masks' in dirs: masks_dir = os.path.join(root, 'masks') if not os.path.exists(images_dir) or not os.path.exists(masks_dir): print("Could not find images/masks directories. Searching all files...") # Fallback: find all image files all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \ glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True) dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()] dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()] else: # Load image and mask file paths dataset_images = glob(os.path.join(images_dir, "*.*")) dataset_masks = glob(os.path.join(masks_dir, "*.*")) dataset_images = sorted(dataset_images) dataset_masks = sorted(dataset_masks) print(f"✓ Found {len(dataset_images)} images and {len(dataset_masks)} masks") dataset_loaded = True return True except Exception as e: print(f"Failed to load dataset: {e}") return False def get_random_sample(): """Get a random image and corresponding mask from dataset""" if not dataset_loaded: return None, None, "Dataset not loaded" if not dataset_images: return None, None, "No images found in dataset" # Get random index idx = random.randint(0, len(dataset_images) - 1) img_path = dataset_images[idx] # Find corresponding mask img_name = os.path.basename(img_path) mask_path = None for mask in dataset_masks: if os.path.basename(mask) == img_name: mask_path = mask break try: image = Image.open(img_path).convert("L") mask = Image.open(mask_path).convert("L") if mask_path else None return image, mask, img_name except Exception as e: return None, None, f"Error loading sample: {e}" def preprocess_for_model(image): """Preprocessing for your model - matches the working notebook""" if image.mode != 'L': image = image.convert('L') transform = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) return transform(image).unsqueeze(0) def generate_attention_heatmap(attention_maps): """Generate attention heatmap""" if not attention_maps: return np.zeros((256, 256, 3)) # Resize all attention maps to the same size (256x256) before combining resized_maps = [] target_size = (256, 256) for att_map in attention_maps: # Convert to numpy and squeeze att_np = att_map.squeeze().cpu().numpy() # Resize to target size att_resized = cv2.resize(att_np, target_size) resized_maps.append(att_resized) # Now we can safely average the maps since they're all the same size combined_att = np.mean(resized_maps, axis=0) # Normalize to [0, 1] combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8) # Apply colormap heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET) return heatmap def analyze_image(image, ground_truth, filename): """Main analysis function - FIXED VERSION matching the working notebook""" if model is None: return None, "Model not loaded. Please restart the application." if image is None: return None, "Please select an image first." try: print("="*50) print("DEBUG: Starting analysis...") print(f"Input image mode: {image.mode}") print(f"Input image size: {image.size}") # Preprocess - exactly like the working notebook input_tensor = preprocess_for_model(image).to(device) print(f"Input tensor shape: {input_tensor.shape}") print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}") # Get prediction and attention maps with torch.no_grad(): print("Getting model output...") model_output, attention_maps = model(input_tensor) print(f"Model output shape: {model_output.shape}") print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}") # Apply sigmoid and threshold - EXACTLY like the working notebook pred_mask = torch.sigmoid(model_output) print(f"After sigmoid min/max: {pred_mask.min():.4f}/{pred_mask.max():.4f}") # Apply threshold to get binary mask binary_mask = (pred_mask > 0.5).float() print(f"Binary mask sum (number of 1s): {binary_mask.sum()}") # Convert to numpy - following notebook approach pred_mask_np = binary_mask.cpu().squeeze().numpy() print(f"Numpy binary mask shape: {pred_mask_np.shape}") print(f"Numpy binary mask unique values: {np.unique(pred_mask_np)}") print(f"Numpy binary mask sum: {np.sum(pred_mask_np)}") # Create visualization mask like in the notebook # The notebook uses: inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) # This inverts the mask for better visualization inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) # Generate attention heatmap print("Generating attention heatmap...") att_heatmap = generate_attention_heatmap(attention_maps) print(f"Attention heatmap shape: {att_heatmap.shape}") # Prepare original image array original_np = np.array(image.resize((256, 256))) # Create tumor-only image (like in notebook) tumor_only = np.where(pred_mask_np == 1, original_np, 255) # Create visualization if ground_truth is not None: fig, axes = plt.subplots(2, 4, figsize=(16, 8)) else: fig, axes = plt.subplots(2, 3, figsize=(15, 8)) fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold') # Row 1: Original, Attention, Predicted Mask, Tumor Only axes[0,0].imshow(original_np, cmap='gray') axes[0,0].set_title('Original Image', fontsize=12, weight='bold') axes[0,0].axis('off') # Attention heatmap overlay axes[0,1].imshow(original_np, cmap='gray') axes[0,1].imshow(att_heatmap, alpha=0.4) axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold') axes[0,1].axis('off') # Predicted mask (inverted for visualization) axes[0,2].imshow(inv_pred_mask_np, cmap='gray') axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold') axes[0,2].axis('off') if ground_truth is not None: # Ground truth processing - convert to binary like notebook gt_array = np.array(ground_truth.resize((256, 256))) # Apply same preprocessing as notebook val_test_transform = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy() print(f"Ground truth array shape: {gt_array.shape}") print(f"Ground truth unique values: {np.unique(gt_array)}") # Tumor only image axes[0,3].imshow(tumor_only, cmap='gray') axes[0,3].set_title('Tumor Only', fontsize=12, weight='bold') axes[0,3].axis('off') # Row 2: Ground truth, overlay comparison, metrics axes[1,0].imshow(mask_np, cmap='gray') axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold') axes[1,0].axis('off') # Overlay comparison - following notebook style overlay = np.array(image.convert('RGB').resize((256, 256))) overlay[pred_mask_np == 1] = [0, 255, 0] # Green for prediction overlay[mask_np > 0.5] = [255, 0, 0] # Red for ground truth axes[1,1].imshow(overlay) axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold') axes[1,1].axis('off') # Calculate IoU and Dice exactly like notebook intersection = np.logical_and(pred_mask_np, mask_np).sum() union = np.logical_or(pred_mask_np, mask_np).sum() iou = intersection / (union + 1e-7) # Dice score dice = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7) print(f"Final IoU: {iou:.4f}") print(f"Final Dice: {dice:.4f}") print(f"Intersection: {intersection}") print(f"Union: {union}") print(f"Pred pixels: {np.sum(pred_mask_np)}") print(f"GT pixels: {np.sum(mask_np > 0.5)}") axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold') axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold') axes[1,2].set_xlim(0, 1) axes[1,2].set_ylim(0, 1) axes[1,2].axis('off') axes[1,2].set_title('Metrics', fontsize=12, weight='bold') # Additional tumor statistics axes[1,3].imshow(tumor_only, cmap='gray') axes[1,3].set_title('Segmented Tumor', fontsize=12, weight='bold') axes[1,3].axis('off') else: # No ground truth case axes[1,0].imshow(inv_pred_mask_np, cmap='gray') axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold') axes[1,0].axis('off') # Tumor only axes[1,1].imshow(tumor_only, cmap='gray') axes[1,1].set_title('Tumor Only', fontsize=12, weight='bold') axes[1,1].axis('off') # Overlay overlay = np.array(image.convert('RGB').resize((256, 256))) overlay[pred_mask_np == 1] = [255, 0, 0] axes[1,2].imshow(overlay) axes[1,2].set_title('Prediction Overlay', fontsize=12, weight='bold') axes[1,2].axis('off') plt.tight_layout() # Save plot buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') buf.seek(0) plt.close() result_image = Image.open(buf) # Generate analysis text tumor_pixels = np.sum(pred_mask_np) total_pixels = pred_mask_np.size tumor_percentage = (tumor_pixels / total_pixels) * 100 print(f"Final tumor pixels: {tumor_pixels}") print(f"Final tumor percentage: {tumor_percentage:.2f}%") print("="*50) analysis_text = f""" # Analysis Results **File:** {filename if filename else 'Uploaded Image'} **Tumor Detection:** - Tumor Area: {tumor_percentage:.2f}% - Tumor Pixels: {tumor_pixels:,} **Model Features:** - Attention Visualization: Generated - Post-processing: Applied """ if ground_truth is not None: analysis_text += f""" **Performance Metrics:** - IoU Score: {iou:.4f} - Dice Score: {dice:.4f} """ return result_image, analysis_text except Exception as e: import traceback error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_msg) # For debugging return None, error_msg # Initialize model and dataset at startup print("Initializing application components...") model_loaded = download_and_load_model() dataset_loaded_success = download_and_load_dataset() if not model_loaded: print("WARNING: Model failed to load!") if not dataset_loaded_success: print("WARNING: Dataset failed to load!") print("Application ready!") # Professional CSS css = """ .gradio-container { max-width: 1600px !important; margin: auto !important; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; } .gr-button { border-radius: 6px !important; font-weight: 500 !important; } .gr-button-primary { background: #2563eb !important; border-color: #2563eb !important; } .gr-button-secondary { background: #6b7280 !important; border-color: #6b7280 !important; } h1, h2, h3 { color: #1f2937 !important; } .gr-form { border: 1px solid #e5e7eb !important; border-radius: 8px !important; } """ # Create Gradio interface with gr.Blocks(css=css, title="Brain Tumor Segmentation Analysis") as app: gr.Markdown(""" # Brain Tumor Segmentation Using Attention U-Net **Advanced Medical Image Analysis Tool** Features: Attention Visualization, Dataset Integration, Morphological Post-processing """) # Status display with gr.Row(): with gr.Column(): status_text = f"Model Status: {'✓ Loaded' if model_loaded else '✗ Failed'} | Dataset Status: {'✓ Loaded' if dataset_loaded_success else '✗ Failed'}" if dataset_loaded_success: status_text += f" | Images: {len(dataset_images)} | Masks: {len(dataset_masks)}" gr.Markdown(f"**{status_text}**") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input Selection") # Image display image_display = gr.Image( label="Selected Image", type="pil", height=300 ) # Control buttons with gr.Row(): load_sample_btn = gr.Button("Load Random Sample", variant="primary", scale=1) upload_btn = gr.UploadButton("Upload Image", file_types=["image"], scale=1) analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg") # Dataset info gr.Markdown(f""" **Dataset Information:** - Total Images: {len(dataset_images) if dataset_loaded_success else 'N/A'} - Total Masks: {len(dataset_masks) if dataset_loaded_success else 'N/A'} - Source: nikhilroxtomar/brain-tumor-segmentation """) with gr.Column(scale=2): gr.Markdown("### Analysis Results") result_display = gr.Image( label="Segmentation Analysis", type="pil", height=500 ) analysis_text = gr.Markdown( value="Load an image and click 'Analyze Image' to begin." ) # Hidden states current_ground_truth = gr.State() current_filename = gr.State() # Event handlers def handle_sample_load(): image, mask, filename = get_random_sample() return image, mask, filename def handle_upload(file): if file is not None: image = Image.open(file.name).convert("L") return image, None, os.path.basename(file.name) return None, None, "" load_sample_btn.click( fn=handle_sample_load, outputs=[image_display, current_ground_truth, current_filename] ) upload_btn.upload( fn=handle_upload, inputs=[upload_btn], outputs=[image_display, current_ground_truth, current_filename] ) analyze_btn.click( fn=analyze_image, inputs=[image_display, current_ground_truth, current_filename], outputs=[result_display, analysis_text] ) if __name__ == "__main__": print("\n" + "="*50) print("LAUNCHING BRAIN TUMOR SEGMENTATION APPLICATION") print("="*50) app.launch( server_name="0.0.0.0", server_port=7860, show_error=True, share=False )