Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| import random | |
| import os | |
| import urllib.request | |
| import kagglehub | |
| from glob import glob | |
| # Global variables - loaded once at startup | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = None | |
| dataset_images = [] | |
| dataset_masks = [] | |
| dataset_loaded = False | |
| print("="*50) | |
| print("BRAIN TUMOR SEGMENTATION APPLICATION") | |
| print("="*50) | |
| # Your Attention U-Net classes (unchanged) | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionBlock, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi, psi # Return both attended features AND attention map | |
| class AttentionUNET(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): | |
| super(AttentionUNET, self).__init__() | |
| self.out_channels = out_channels | |
| self.ups = nn.ModuleList() | |
| self.downs = nn.ModuleList() | |
| self.attentions = nn.ModuleList() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| for feature in features: | |
| self.downs.append(DoubleConv(in_channels, feature)) | |
| in_channels = feature | |
| self.bottleneck = DoubleConv(features[-1], features[-1]*2) | |
| for feature in reversed(features): | |
| self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) | |
| self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) | |
| self.ups.append(DoubleConv(feature*2, feature)) | |
| self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
| def forward(self, x): | |
| skip_connections = [] | |
| attention_maps = [] | |
| for down in self.downs: | |
| x = down(x) | |
| skip_connections.append(x) | |
| x = self.pool(x) | |
| x = self.bottleneck(x) | |
| skip_connections = skip_connections[::-1] | |
| for idx in range(0, len(self.ups), 2): | |
| x = self.ups[idx](x) | |
| skip_connection = skip_connections[idx//2] | |
| if x.shape != skip_connection.shape: | |
| x = TF.resize(x, size=skip_connection.shape[2:]) | |
| attended_skip, att_map = self.attentions[idx // 2](x, skip_connection) | |
| attention_maps.append(att_map) | |
| concat_skip = torch.cat((attended_skip, x), dim=1) | |
| x = self.ups[idx+1](concat_skip) | |
| return self.final_conv(x), attention_maps | |
| def download_and_load_model(): | |
| """Download and load model once at startup""" | |
| global model | |
| print("Loading Attention U-Net model...") | |
| model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" | |
| model_path = "best_attention_model.pth.tar" | |
| # Download model if needed | |
| if not os.path.exists(model_path): | |
| print("Downloading model weights...") | |
| try: | |
| urllib.request.urlretrieve(model_url, model_path) | |
| except Exception as e: | |
| print(f"Failed to download model: {e}") | |
| return False | |
| # Load model | |
| try: | |
| model = AttentionUNET(in_channels=1, out_channels=1).to(device) | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| print("✓ Model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| return False | |
| def download_and_load_dataset(): | |
| """Download and load entire dataset once at startup""" | |
| global dataset_images, dataset_masks, dataset_loaded | |
| if dataset_loaded: | |
| return True | |
| print("Loading brain tumor dataset...") | |
| try: | |
| # Download dataset using kagglehub - returns directory path | |
| dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') | |
| print(f"Dataset downloaded to: {dataset_path}") | |
| # Find images and masks directories | |
| images_dir = os.path.join(dataset_path, 'images') | |
| masks_dir = os.path.join(dataset_path, 'masks') | |
| # If direct path doesn't exist, search subdirectories | |
| if not os.path.exists(images_dir): | |
| # Search for images and masks directories | |
| for root, dirs, files in os.walk(dataset_path): | |
| if 'images' in dirs: | |
| images_dir = os.path.join(root, 'images') | |
| if 'masks' in dirs: | |
| masks_dir = os.path.join(root, 'masks') | |
| if not os.path.exists(images_dir) or not os.path.exists(masks_dir): | |
| print("Could not find images/masks directories. Searching all files...") | |
| # Fallback: find all image files | |
| all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \ | |
| glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True) | |
| dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()] | |
| dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()] | |
| else: | |
| # Load image and mask file paths | |
| dataset_images = glob(os.path.join(images_dir, "*.*")) | |
| dataset_masks = glob(os.path.join(masks_dir, "*.*")) | |
| dataset_images = sorted(dataset_images) | |
| dataset_masks = sorted(dataset_masks) | |
| print(f"✓ Found {len(dataset_images)} images and {len(dataset_masks)} masks") | |
| dataset_loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"Failed to load dataset: {e}") | |
| return False | |
| def get_random_sample(): | |
| """Get a random image and corresponding mask from dataset""" | |
| if not dataset_loaded: | |
| return None, None, "Dataset not loaded" | |
| if not dataset_images: | |
| return None, None, "No images found in dataset" | |
| # Get random index | |
| idx = random.randint(0, len(dataset_images) - 1) | |
| img_path = dataset_images[idx] | |
| # Find corresponding mask | |
| img_name = os.path.basename(img_path) | |
| mask_path = None | |
| for mask in dataset_masks: | |
| if os.path.basename(mask) == img_name: | |
| mask_path = mask | |
| break | |
| try: | |
| image = Image.open(img_path).convert("L") | |
| mask = Image.open(mask_path).convert("L") if mask_path else None | |
| return image, mask, img_name | |
| except Exception as e: | |
| return None, None, f"Error loading sample: {e}" | |
| def preprocess_for_model(image): | |
| """Preprocessing for your model - matches the working notebook""" | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| transform = transforms.Compose([ | |
| transforms.Resize((256,256)), | |
| transforms.ToTensor() | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def generate_attention_heatmap(attention_maps): | |
| """Generate attention heatmap""" | |
| if not attention_maps: | |
| return np.zeros((256, 256, 3)) | |
| # Resize all attention maps to the same size (256x256) before combining | |
| resized_maps = [] | |
| target_size = (256, 256) | |
| for att_map in attention_maps: | |
| # Convert to numpy and squeeze | |
| att_np = att_map.squeeze().cpu().numpy() | |
| # Resize to target size | |
| att_resized = cv2.resize(att_np, target_size) | |
| resized_maps.append(att_resized) | |
| # Now we can safely average the maps since they're all the same size | |
| combined_att = np.mean(resized_maps, axis=0) | |
| # Normalize to [0, 1] | |
| combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8) | |
| # Apply colormap | |
| heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| return heatmap | |
| def analyze_image(image, ground_truth, filename): | |
| """Main analysis function - FIXED VERSION matching the working notebook""" | |
| if model is None: | |
| return None, "Model not loaded. Please restart the application." | |
| if image is None: | |
| return None, "Please select an image first." | |
| try: | |
| print("="*50) | |
| print("DEBUG: Starting analysis...") | |
| print(f"Input image mode: {image.mode}") | |
| print(f"Input image size: {image.size}") | |
| # Preprocess - exactly like the working notebook | |
| input_tensor = preprocess_for_model(image).to(device) | |
| print(f"Input tensor shape: {input_tensor.shape}") | |
| print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}") | |
| # Get prediction and attention maps | |
| with torch.no_grad(): | |
| print("Getting model output...") | |
| model_output, attention_maps = model(input_tensor) | |
| print(f"Model output shape: {model_output.shape}") | |
| print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}") | |
| # Apply sigmoid and threshold - EXACTLY like the working notebook | |
| pred_mask = torch.sigmoid(model_output) | |
| print(f"After sigmoid min/max: {pred_mask.min():.4f}/{pred_mask.max():.4f}") | |
| # Apply threshold to get binary mask | |
| binary_mask = (pred_mask > 0.5).float() | |
| print(f"Binary mask sum (number of 1s): {binary_mask.sum()}") | |
| # Convert to numpy - following notebook approach | |
| pred_mask_np = binary_mask.cpu().squeeze().numpy() | |
| print(f"Numpy binary mask shape: {pred_mask_np.shape}") | |
| print(f"Numpy binary mask unique values: {np.unique(pred_mask_np)}") | |
| print(f"Numpy binary mask sum: {np.sum(pred_mask_np)}") | |
| # Create visualization mask like in the notebook | |
| # The notebook uses: inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) | |
| # This inverts the mask for better visualization | |
| inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) | |
| # Generate attention heatmap | |
| print("Generating attention heatmap...") | |
| att_heatmap = generate_attention_heatmap(attention_maps) | |
| print(f"Attention heatmap shape: {att_heatmap.shape}") | |
| # Prepare original image array | |
| original_np = np.array(image.resize((256, 256))) | |
| # Create tumor-only image (like in notebook) | |
| tumor_only = np.where(pred_mask_np == 1, original_np, 255) | |
| # Create visualization | |
| if ground_truth is not None: | |
| fig, axes = plt.subplots(2, 4, figsize=(16, 8)) | |
| else: | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8)) | |
| fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold') | |
| # Row 1: Original, Attention, Predicted Mask, Tumor Only | |
| axes[0,0].imshow(original_np, cmap='gray') | |
| axes[0,0].set_title('Original Image', fontsize=12, weight='bold') | |
| axes[0,0].axis('off') | |
| # Attention heatmap overlay | |
| axes[0,1].imshow(original_np, cmap='gray') | |
| axes[0,1].imshow(att_heatmap, alpha=0.4) | |
| axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold') | |
| axes[0,1].axis('off') | |
| # Predicted mask (inverted for visualization) | |
| axes[0,2].imshow(inv_pred_mask_np, cmap='gray') | |
| axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold') | |
| axes[0,2].axis('off') | |
| if ground_truth is not None: | |
| # Ground truth processing - convert to binary like notebook | |
| gt_array = np.array(ground_truth.resize((256, 256))) | |
| # Apply same preprocessing as notebook | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((256,256)), | |
| transforms.ToTensor() | |
| ]) | |
| mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy() | |
| print(f"Ground truth array shape: {gt_array.shape}") | |
| print(f"Ground truth unique values: {np.unique(gt_array)}") | |
| # Tumor only image | |
| axes[0,3].imshow(tumor_only, cmap='gray') | |
| axes[0,3].set_title('Tumor Only', fontsize=12, weight='bold') | |
| axes[0,3].axis('off') | |
| # Row 2: Ground truth, overlay comparison, metrics | |
| axes[1,0].imshow(mask_np, cmap='gray') | |
| axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold') | |
| axes[1,0].axis('off') | |
| # Overlay comparison - following notebook style | |
| overlay = np.array(image.convert('RGB').resize((256, 256))) | |
| overlay[pred_mask_np == 1] = [0, 255, 0] # Green for prediction | |
| overlay[mask_np > 0.5] = [255, 0, 0] # Red for ground truth | |
| axes[1,1].imshow(overlay) | |
| axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold') | |
| axes[1,1].axis('off') | |
| # Calculate IoU and Dice exactly like notebook | |
| intersection = np.logical_and(pred_mask_np, mask_np).sum() | |
| union = np.logical_or(pred_mask_np, mask_np).sum() | |
| iou = intersection / (union + 1e-7) | |
| # Dice score | |
| dice = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7) | |
| print(f"Final IoU: {iou:.4f}") | |
| print(f"Final Dice: {dice:.4f}") | |
| print(f"Intersection: {intersection}") | |
| print(f"Union: {union}") | |
| print(f"Pred pixels: {np.sum(pred_mask_np)}") | |
| print(f"GT pixels: {np.sum(mask_np > 0.5)}") | |
| axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold') | |
| axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold') | |
| axes[1,2].set_xlim(0, 1) | |
| axes[1,2].set_ylim(0, 1) | |
| axes[1,2].axis('off') | |
| axes[1,2].set_title('Metrics', fontsize=12, weight='bold') | |
| # Additional tumor statistics | |
| axes[1,3].imshow(tumor_only, cmap='gray') | |
| axes[1,3].set_title('Segmented Tumor', fontsize=12, weight='bold') | |
| axes[1,3].axis('off') | |
| else: | |
| # No ground truth case | |
| axes[1,0].imshow(inv_pred_mask_np, cmap='gray') | |
| axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold') | |
| axes[1,0].axis('off') | |
| # Tumor only | |
| axes[1,1].imshow(tumor_only, cmap='gray') | |
| axes[1,1].set_title('Tumor Only', fontsize=12, weight='bold') | |
| axes[1,1].axis('off') | |
| # Overlay | |
| overlay = np.array(image.convert('RGB').resize((256, 256))) | |
| overlay[pred_mask_np == 1] = [255, 0, 0] | |
| axes[1,2].imshow(overlay) | |
| axes[1,2].set_title('Prediction Overlay', fontsize=12, weight='bold') | |
| axes[1,2].axis('off') | |
| plt.tight_layout() | |
| # Save plot | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| plt.close() | |
| result_image = Image.open(buf) | |
| # Generate analysis text | |
| tumor_pixels = np.sum(pred_mask_np) | |
| total_pixels = pred_mask_np.size | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| print(f"Final tumor pixels: {tumor_pixels}") | |
| print(f"Final tumor percentage: {tumor_percentage:.2f}%") | |
| print("="*50) | |
| analysis_text = f""" | |
| # Analysis Results | |
| **File:** {filename if filename else 'Uploaded Image'} | |
| **Tumor Detection:** | |
| - Tumor Area: {tumor_percentage:.2f}% | |
| - Tumor Pixels: {tumor_pixels:,} | |
| **Model Features:** | |
| - Attention Visualization: Generated | |
| - Post-processing: Applied | |
| """ | |
| if ground_truth is not None: | |
| analysis_text += f""" | |
| **Performance Metrics:** | |
| - IoU Score: {iou:.4f} | |
| - Dice Score: {dice:.4f} | |
| """ | |
| return result_image, analysis_text | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| print(error_msg) # For debugging | |
| return None, error_msg | |
| # Initialize model and dataset at startup | |
| print("Initializing application components...") | |
| model_loaded = download_and_load_model() | |
| dataset_loaded_success = download_and_load_dataset() | |
| if not model_loaded: | |
| print("WARNING: Model failed to load!") | |
| if not dataset_loaded_success: | |
| print("WARNING: Dataset failed to load!") | |
| print("Application ready!") | |
| # Professional CSS | |
| css = """ | |
| .gradio-container { | |
| max-width: 1600px !important; | |
| margin: auto !important; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
| } | |
| .gr-button { | |
| border-radius: 6px !important; | |
| font-weight: 500 !important; | |
| } | |
| .gr-button-primary { | |
| background: #2563eb !important; | |
| border-color: #2563eb !important; | |
| } | |
| .gr-button-secondary { | |
| background: #6b7280 !important; | |
| border-color: #6b7280 !important; | |
| } | |
| h1, h2, h3 { | |
| color: #1f2937 !important; | |
| } | |
| .gr-form { | |
| border: 1px solid #e5e7eb !important; | |
| border-radius: 8px !important; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="Brain Tumor Segmentation Analysis") as app: | |
| gr.Markdown(""" | |
| # Brain Tumor Segmentation Using Attention U-Net | |
| **Advanced Medical Image Analysis Tool** | |
| Features: Attention Visualization, Dataset Integration, Morphological Post-processing | |
| """) | |
| # Status display | |
| with gr.Row(): | |
| with gr.Column(): | |
| status_text = f"Model Status: {'✓ Loaded' if model_loaded else '✗ Failed'} | Dataset Status: {'✓ Loaded' if dataset_loaded_success else '✗ Failed'}" | |
| if dataset_loaded_success: | |
| status_text += f" | Images: {len(dataset_images)} | Masks: {len(dataset_masks)}" | |
| gr.Markdown(f"**{status_text}**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Selection") | |
| # Image display | |
| image_display = gr.Image( | |
| label="Selected Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| load_sample_btn = gr.Button("Load Random Sample", variant="primary", scale=1) | |
| upload_btn = gr.UploadButton("Upload Image", file_types=["image"], scale=1) | |
| analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg") | |
| # Dataset info | |
| gr.Markdown(f""" | |
| **Dataset Information:** | |
| - Total Images: {len(dataset_images) if dataset_loaded_success else 'N/A'} | |
| - Total Masks: {len(dataset_masks) if dataset_loaded_success else 'N/A'} | |
| - Source: nikhilroxtomar/brain-tumor-segmentation | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Analysis Results") | |
| result_display = gr.Image( | |
| label="Segmentation Analysis", | |
| type="pil", | |
| height=500 | |
| ) | |
| analysis_text = gr.Markdown( | |
| value="Load an image and click 'Analyze Image' to begin." | |
| ) | |
| # Hidden states | |
| current_ground_truth = gr.State() | |
| current_filename = gr.State() | |
| # Event handlers | |
| def handle_sample_load(): | |
| image, mask, filename = get_random_sample() | |
| return image, mask, filename | |
| def handle_upload(file): | |
| if file is not None: | |
| image = Image.open(file.name).convert("L") | |
| return image, None, os.path.basename(file.name) | |
| return None, None, "" | |
| load_sample_btn.click( | |
| fn=handle_sample_load, | |
| outputs=[image_display, current_ground_truth, current_filename] | |
| ) | |
| upload_btn.upload( | |
| fn=handle_upload, | |
| inputs=[upload_btn], | |
| outputs=[image_display, current_ground_truth, current_filename] | |
| ) | |
| analyze_btn.click( | |
| fn=analyze_image, | |
| inputs=[image_display, current_ground_truth, current_filename], | |
| outputs=[result_display, analysis_text] | |
| ) | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print("LAUNCHING BRAIN TUMOR SEGMENTATION APPLICATION") | |
| print("="*50) | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) |