Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| import urllib.request | |
| import os | |
| import kagglehub | |
| import random | |
| from pathlib import Path | |
| import seaborn as sns | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = None | |
| dataset_path = None | |
| # Define your Attention U-Net architecture (from your training code) | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionBlock, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi, psi # Return attention coefficients for visualization | |
| class AttentionUNET(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): | |
| super(AttentionUNET, self).__init__() | |
| self.out_channels = out_channels | |
| self.ups = nn.ModuleList() | |
| self.downs = nn.ModuleList() | |
| self.attentions = nn.ModuleList() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Down part of UNET | |
| for feature in features: | |
| self.downs.append(DoubleConv(in_channels, feature)) | |
| in_channels = feature | |
| # Bottleneck | |
| self.bottleneck = DoubleConv(features[-1], features[-1]*2) | |
| # Up part of UNET | |
| for feature in reversed(features): | |
| self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) | |
| self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) | |
| self.ups.append(DoubleConv(feature*2, feature)) | |
| self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
| def forward(self, x, return_attention=False): | |
| skip_connections = [] | |
| attention_maps = [] | |
| for down in self.downs: | |
| x = down(x) | |
| skip_connections.append(x) | |
| x = self.pool(x) | |
| x = self.bottleneck(x) | |
| skip_connections = skip_connections[::-1] | |
| for idx in range(0, len(self.ups), 2): | |
| x = self.ups[idx](x) | |
| skip_connection = skip_connections[idx//2] | |
| if x.shape != skip_connection.shape: | |
| x = TF.resize(x, size=skip_connection.shape[2:]) | |
| skip_connection, attention_coeff = self.attentions[idx // 2](skip_connection, x) | |
| if return_attention: | |
| attention_maps.append(attention_coeff) | |
| concat_skip = torch.cat((skip_connection, x), dim=1) | |
| x = self.ups[idx+1](concat_skip) | |
| output = self.final_conv(x) | |
| if return_attention: | |
| return output, attention_maps | |
| return output | |
| def download_dataset(): | |
| """Download Brain Tumor Segmentation dataset from Kaggle""" | |
| global dataset_path | |
| try: | |
| print("π₯ Downloading Brain Tumor Segmentation dataset...") | |
| dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') | |
| print(f"β Dataset downloaded to: {dataset_path}") | |
| return dataset_path | |
| except Exception as e: | |
| print(f"β Failed to download dataset: {e}") | |
| return None | |
| def download_model(): | |
| """Download your trained model from HuggingFace""" | |
| model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" | |
| model_path = "best_attention_model.pth.tar" | |
| if not os.path.exists(model_path): | |
| print("π₯ Downloading trained model...") | |
| try: | |
| urllib.request.urlretrieve(model_url, model_path) | |
| print("β Model downloaded successfully!") | |
| except Exception as e: | |
| print(f"β Failed to download model: {e}") | |
| return None | |
| else: | |
| print("β Model already exists!") | |
| return model_path | |
| def load_attention_model(): | |
| """Load trained Attention U-Net model""" | |
| global model | |
| if model is None: | |
| try: | |
| print("π Loading Attention U-Net model...") | |
| model_path = download_model() | |
| if model_path is None: | |
| return None | |
| model = AttentionUNET(in_channels=1, out_channels=1).to(device) | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| print("β Attention U-Net model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| model = None | |
| return model | |
| def get_random_sample_from_dataset(): | |
| """Get a random sample image and ground truth mask from the dataset""" | |
| global dataset_path | |
| if dataset_path is None: | |
| dataset_path = download_dataset() | |
| if dataset_path is None: | |
| return None, None | |
| try: | |
| images_path = Path(dataset_path) / "images" | |
| masks_path = Path(dataset_path) / "masks" | |
| if not images_path.exists() or not masks_path.exists(): | |
| print("β Dataset structure not found") | |
| return None, None | |
| # Get all image files | |
| image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.png")) + list(images_path.glob("*.tif")) | |
| if not image_files: | |
| print("β No image files found in dataset") | |
| return None, None | |
| # Select random image | |
| random_image_file = random.choice(image_files) | |
| image_name = random_image_file.stem | |
| # Find corresponding mask | |
| possible_mask_extensions = ['.jpg', '.png', '.tif', '.gif'] | |
| mask_file = None | |
| for ext in possible_mask_extensions: | |
| potential_mask = masks_path / f"{image_name}{ext}" | |
| if potential_mask.exists(): | |
| mask_file = potential_mask | |
| break | |
| if mask_file is None: | |
| print(f"β No corresponding mask found for {image_name}") | |
| return None, None | |
| # Load image and mask | |
| image = Image.open(random_image_file).convert('L') | |
| mask = Image.open(mask_file).convert('L') | |
| print(f"β Loaded random sample: {image_name}") | |
| return image, mask | |
| except Exception as e: | |
| print(f"β Error loading random sample: {e}") | |
| return None, None | |
| def test_time_augmentation(model, image_tensor): | |
| """Apply Test-Time Augmentation (TTA) for robust predictions""" | |
| augmentations = [ | |
| lambda x: x, # Original | |
| lambda x: torch.flip(x, dims=[3]), # Horizontal flip | |
| lambda x: torch.flip(x, dims=[2]), # Vertical flip | |
| lambda x: torch.flip(x, dims=[2, 3]), # Both flips | |
| lambda x: torch.rot90(x, k=1, dims=[2, 3]), # 90Β° rotation | |
| lambda x: torch.rot90(x, k=3, dims=[2, 3]), # 270Β° rotation | |
| ] | |
| reverse_augmentations = [ | |
| lambda x: x, # Original | |
| lambda x: torch.flip(x, dims=[3]), # Reverse horizontal flip | |
| lambda x: torch.flip(x, dims=[2]), # Reverse vertical flip | |
| lambda x: torch.flip(x, dims=[2, 3]), # Reverse both flips | |
| lambda x: torch.rot90(x, k=3, dims=[2, 3]), # Reverse 90Β° rotation | |
| lambda x: torch.rot90(x, k=1, dims=[2, 3]), # Reverse 270Β° rotation | |
| ] | |
| predictions = [] | |
| with torch.no_grad(): | |
| for aug, rev_aug in zip(augmentations, reverse_augmentations): | |
| # Apply augmentation | |
| aug_input = aug(image_tensor) | |
| # Get prediction | |
| pred = torch.sigmoid(model(aug_input)) | |
| # Reverse augmentation on prediction | |
| pred = rev_aug(pred) | |
| predictions.append(pred) | |
| # Average all predictions | |
| tta_prediction = torch.mean(torch.stack(predictions), dim=0) | |
| return tta_prediction | |
| def generate_attention_heatmaps(model, image_tensor): | |
| """Generate attention heatmaps for interpretability""" | |
| with torch.no_grad(): | |
| pred, attention_maps = model(image_tensor, return_attention=True) | |
| # Convert attention maps to numpy for visualization | |
| heatmaps = [] | |
| for i, att_map in enumerate(attention_maps): | |
| # Resize attention map to match input size | |
| att_map_resized = TF.resize(att_map, (256, 256)) | |
| att_np = att_map_resized.cpu().squeeze().numpy() | |
| heatmaps.append(att_np) | |
| return heatmaps | |
| def preprocess_image(image): | |
| """Preprocessing exactly like training code""" | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| return val_test_transform(image).unsqueeze(0) | |
| def calculate_metrics(pred_mask, ground_truth_mask): | |
| """Calculate Dice and IoU metrics""" | |
| pred_binary = (pred_mask > 0.5).float() | |
| gt_binary = (ground_truth_mask > 0.5).float() | |
| # Dice coefficient | |
| intersection = torch.sum(pred_binary * gt_binary) | |
| dice = (2.0 * intersection) / (torch.sum(pred_binary) + torch.sum(gt_binary) + 1e-8) | |
| # IoU | |
| union = torch.sum(pred_binary) + torch.sum(gt_binary) - intersection | |
| iou = intersection / (union + 1e-8) | |
| return dice.item(), iou.item() | |
| def predict_with_enhancements(image, ground_truth=None, use_tta=True, show_attention=True): | |
| """Enhanced prediction with TTA and attention visualization""" | |
| current_model = load_attention_model() | |
| if current_model is None: | |
| return None, "β Failed to load trained model." | |
| if image is None: | |
| return None, "β οΈ Please upload an image first." | |
| try: | |
| print("π§ Processing with enhanced Attention U-Net...") | |
| input_tensor = preprocess_image(image).to(device) | |
| # Standard prediction | |
| with torch.no_grad(): | |
| standard_pred = torch.sigmoid(current_model(input_tensor)) | |
| # Test-Time Augmentation | |
| if use_tta: | |
| tta_pred = test_time_augmentation(current_model, input_tensor) | |
| final_pred = tta_pred | |
| else: | |
| final_pred = standard_pred | |
| # Generate attention heatmaps | |
| attention_heatmaps = [] | |
| if show_attention: | |
| attention_heatmaps = generate_attention_heatmaps(current_model, input_tensor) | |
| # Convert predictions to binary | |
| pred_mask_binary = (final_pred > 0.5).float() | |
| pred_mask_np = pred_mask_binary.cpu().squeeze().numpy() | |
| standard_mask_np = (standard_pred > 0.5).float().cpu().squeeze().numpy() | |
| # Prepare images for visualization | |
| original_np = np.array(image.convert('L').resize((256, 256))) | |
| # Create comprehensive visualization | |
| if ground_truth is not None: | |
| # With ground truth comparison | |
| gt_np = np.array(ground_truth.convert('L').resize((256, 256))) | |
| gt_binary = (gt_np > 127).astype(np.float32) # Threshold ground truth | |
| # Calculate metrics | |
| gt_tensor = torch.tensor(gt_binary).unsqueeze(0).unsqueeze(0).to(device) | |
| dice_score, iou_score = calculate_metrics(final_pred, gt_tensor) | |
| # Create figure with ground truth comparison | |
| n_cols = 6 if show_attention and attention_heatmaps else 5 | |
| fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8)) | |
| fig.suptitle('π§ Enhanced Attention U-Net Analysis with Ground Truth Comparison', fontsize=16, weight='bold') | |
| # Top row - Standard analysis | |
| axes[0, 0].imshow(original_np, cmap='gray') | |
| axes[0, 0].set_title('Original Image', fontsize=12, weight='bold') | |
| axes[0, 0].axis('off') | |
| axes[0, 1].imshow(standard_mask_np * 255, cmap='hot') | |
| axes[0, 1].set_title('Standard Prediction', fontsize=12, weight='bold') | |
| axes[0, 1].axis('off') | |
| axes[0, 2].imshow(pred_mask_np * 255, cmap='hot') | |
| axes[0, 2].set_title(f'{"TTA Enhanced" if use_tta else "Final Prediction"}', fontsize=12, weight='bold') | |
| axes[0, 2].axis('off') | |
| axes[0, 3].imshow(gt_binary * 255, cmap='hot') | |
| axes[0, 3].set_title('Ground Truth', fontsize=12, weight='bold') | |
| axes[0, 3].axis('off') | |
| # Overlay comparison | |
| overlay = original_np.copy() | |
| overlay = np.stack([overlay, overlay, overlay], axis=-1) | |
| overlay[pred_mask_np > 0.5] = [255, 0, 0] # Red for prediction | |
| overlay[gt_binary > 0.5] = [0, 255, 0] # Green for ground truth | |
| overlap = (pred_mask_np > 0.5) & (gt_binary > 0.5) | |
| overlay[overlap] = [255, 255, 0] # Yellow for overlap | |
| axes[0, 4].imshow(overlay.astype(np.uint8)) | |
| axes[0, 4].set_title('Overlay (Red:Pred, Green:GT, Yellow:Match)', fontsize=10, weight='bold') | |
| axes[0, 4].axis('off') | |
| if show_attention and attention_heatmaps: | |
| # Show combined attention | |
| combined_attention = np.mean(attention_heatmaps, axis=0) | |
| axes[0, 5].imshow(combined_attention, cmap='jet', alpha=0.7) | |
| axes[0, 5].imshow(original_np, cmap='gray', alpha=0.3) | |
| axes[0, 5].set_title('Attention Heatmap', fontsize=12, weight='bold') | |
| axes[0, 5].axis('off') | |
| # Bottom row - Individual attention maps or detailed analysis | |
| if show_attention and attention_heatmaps: | |
| for i, heatmap in enumerate(attention_heatmaps[:n_cols]): | |
| axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7) | |
| axes[1, i].imshow(original_np, cmap='gray', alpha=0.3) | |
| axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold') | |
| axes[1, i].axis('off') | |
| else: | |
| # Show tumor extraction and analysis | |
| tumor_only = np.where(pred_mask_np == 1, original_np, 255) | |
| inv_mask = np.where(pred_mask_np == 1, 0, 255) | |
| axes[1, 0].imshow(tumor_only, cmap='gray') | |
| axes[1, 0].set_title('Tumor Extraction', fontsize=12, weight='bold') | |
| axes[1, 0].axis('off') | |
| axes[1, 1].imshow(inv_mask, cmap='gray') | |
| axes[1, 1].set_title('Inverted Mask', fontsize=12, weight='bold') | |
| axes[1, 1].axis('off') | |
| # Difference map | |
| diff_map = np.abs(pred_mask_np - gt_binary) | |
| axes[1, 2].imshow(diff_map, cmap='Reds') | |
| axes[1, 2].set_title('Difference Map', fontsize=12, weight='bold') | |
| axes[1, 2].axis('off') | |
| # Clear remaining axes | |
| for j in range(3, n_cols): | |
| axes[1, j].axis('off') | |
| else: | |
| # Without ground truth | |
| n_cols = 5 if show_attention and attention_heatmaps else 4 | |
| fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8)) | |
| fig.suptitle('π§ Enhanced Attention U-Net Analysis', fontsize=16, weight='bold') | |
| # Top row | |
| images = [original_np, standard_mask_np * 255, pred_mask_np * 255] | |
| titles = ["Original Image", "Standard Prediction", f'{"TTA Enhanced" if use_tta else "Final Prediction"}'] | |
| cmaps = ['gray', 'hot', 'hot'] | |
| for i in range(3): | |
| axes[0, i].imshow(images[i], cmap=cmaps[i]) | |
| axes[0, i].set_title(titles[i], fontsize=12, weight='bold') | |
| axes[0, i].axis('off') | |
| # Tumor extraction | |
| tumor_only = np.where(pred_mask_np == 1, original_np, 255) | |
| axes[0, 3].imshow(tumor_only, cmap='gray') | |
| axes[0, 3].set_title('Tumor Extraction', fontsize=12, weight='bold') | |
| axes[0, 3].axis('off') | |
| if show_attention and attention_heatmaps: | |
| combined_attention = np.mean(attention_heatmaps, axis=0) | |
| axes[0, 4].imshow(combined_attention, cmap='jet', alpha=0.7) | |
| axes[0, 4].imshow(original_np, cmap='gray', alpha=0.3) | |
| axes[0, 4].set_title('Combined Attention', fontsize=12, weight='bold') | |
| axes[0, 4].axis('off') | |
| # Bottom row - Individual attention maps | |
| if show_attention and attention_heatmaps: | |
| for i, heatmap in enumerate(attention_heatmaps[:n_cols]): | |
| axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7) | |
| axes[1, i].imshow(original_np, cmap='gray', alpha=0.3) | |
| axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold') | |
| axes[1, i].axis('off') | |
| else: | |
| # Clear bottom row | |
| for j in range(n_cols): | |
| axes[1, j].axis('off') | |
| plt.tight_layout() | |
| # Save result | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| plt.close() | |
| result_image = Image.open(buf) | |
| # Calculate statistics | |
| tumor_pixels = np.sum(pred_mask_np) | |
| total_pixels = pred_mask_np.size | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| max_confidence = torch.max(final_pred).item() | |
| mean_confidence = torch.mean(final_pred).item() | |
| # Enhanced analysis text | |
| analysis_text = f""" | |
| ## π§ Enhanced Attention U-Net Analysis Results | |
| ### π Detection Summary | |
| - **Status**: {'π΄ TUMOR DETECTED' if tumor_pixels > 50 else 'π’ NO SIGNIFICANT TUMOR'} | |
| - **Tumor Coverage**: {tumor_percentage:.2f}% of brain region | |
| - **Tumor Pixels**: {tumor_pixels:,} pixels | |
| - **Max Confidence**: {max_confidence:.4f} | |
| - **Mean Confidence**: {mean_confidence:.4f} | |
| """ | |
| if ground_truth is not None: | |
| analysis_text += f""" | |
| ### π― Ground Truth Comparison | |
| - **Dice Score**: {dice_score:.4f} {'β Excellent' if dice_score > 0.8 else 'β οΈ Good' if dice_score > 0.6 else 'β Poor'} | |
| - **IoU Score**: {iou_score:.4f} {'β Excellent' if iou_score > 0.7 else 'β οΈ Good' if iou_score > 0.5 else 'β Poor'} | |
| - **Model Accuracy**: {'High precision match' if dice_score > 0.8 else 'Reasonable match' if dice_score > 0.6 else 'Needs improvement'} | |
| """ | |
| analysis_text += f""" | |
| ### π Enhancement Features | |
| - **Test-Time Augmentation**: {'β Applied (6 augmentations averaged)' if use_tta else 'β Disabled'} | |
| - **Attention Visualization**: {'β Generated attention heatmaps' if show_attention else 'β Disabled'} | |
| - **Boundary Enhancement**: {'β TTA improves edge detection' if use_tta else 'β οΈ Standard prediction only'} | |
| - **Interpretability**: {'β Attention gates show focus areas' if show_attention else 'β Black box mode'} | |
| ### π¬ Model Architecture | |
| - **Base Model**: Attention U-Net with skip connections | |
| - **Training Performance**: Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90% | |
| - **Attention Gates**: 4 levels with soft attention mechanism | |
| - **Features Channels**: [32, 64, 128, 256] progression | |
| - **Device**: {device.type.upper()} | |
| ### π Enhanced Processing Pipeline | |
| - **Preprocessing**: Resize(256Γ256) + Normalization | |
| - **Augmentations**: Flips (H,V), Rotations (90Β°,270Β°), Combined | |
| - **Attention Fusion**: Multi-scale attention coefficient extraction | |
| - **Post-processing**: Ensemble averaging + Binary thresholding (0.5) | |
| ### β οΈ Medical Disclaimer | |
| This enhanced AI model is for **research and educational purposes only**. | |
| Results include advanced features for better accuracy and interpretability. | |
| Always consult medical professionals for clinical applications. | |
| ### π Research Contributions | |
| β **Attention Gates**: Enhanced boundary detection through selective feature passing | |
| β **Test-Time Augmentation**: Robust predictions via ensemble averaging | |
| β **Interpretability**: Attention heatmaps for clinical trust and validation | |
| β **Efficiency**: No retraining required, minimal computational overhead | |
| """ | |
| print(f"β Enhanced analysis completed! Tumor coverage: {tumor_percentage:.2f}%") | |
| return result_image, analysis_text | |
| except Exception as e: | |
| error_msg = f"β Error during enhanced analysis: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def load_random_sample(): | |
| """Load a random sample from the dataset""" | |
| image, mask = get_random_sample_from_dataset() | |
| if image is None: | |
| return None, None, "β Failed to load random sample from dataset" | |
| return image, mask, "β Random sample loaded from dataset" | |
| def clear_all(): | |
| return None, None, None, "Upload a brain MRI image or load a random sample to test the enhanced model" | |
| # Enhanced professional CSS | |
| css = """ | |
| .gradio-container { | |
| max-width: 1600px !important; | |
| margin: auto !important; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| #title { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 40px; | |
| border-radius: 20px; | |
| margin-bottom: 30px; | |
| box-shadow: 0 12px 24px rgba(102, 126, 234, 0.4); | |
| } | |
| .feature-box { | |
| background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
| border-radius: 15px; | |
| padding: 25px; | |
| margin: 15px 0; | |
| color: white; | |
| box-shadow: 0 8px 16px rgba(240, 147, 251, 0.3); | |
| } | |
| .metric-card { | |
| background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); | |
| border-radius: 12px; | |
| padding: 20px; | |
| text-align: center; | |
| margin: 10px; | |
| box-shadow: 0 6px 12px rgba(79, 172, 254, 0.3); | |
| } | |
| .enhancement-badge { | |
| display: inline-block; | |
| background: linear-gradient(45deg, #fa709a 0%, #fee140 100%); | |
| color: white; | |
| padding: 8px 16px; | |
| border-radius: 25px; | |
| margin: 5px; | |
| font-weight: bold; | |
| box-shadow: 0 4px 8px rgba(250, 112, 154, 0.3); | |
| } | |
| """ | |
| # Create enhanced Gradio interface | |
| with gr.Blocks(css=css, title="π§ Enhanced Brain Tumor Segmentation", theme=gr.themes.Soft()) as app: | |
| gr.HTML(""" | |
| <div id="title"> | |
| <h1>π§ Enhanced Attention U-Net Brain Tumor Segmentation</h1> | |
| <p style="font-size: 20px; margin-top: 20px; font-weight: 300;"> | |
| π Advanced Medical AI with Test-Time Augmentation & Attention Visualization | |
| </p> | |
| <p style="font-size: 16px; margin-top: 15px; opacity: 0.9;"> | |
| π Performance: Dice 0.8420 β’ IoU 0.7297 β’ Accuracy 98.90% | | |
| π¬ Research-Grade Interpretability & Robustness | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input & Controls") | |
| with gr.Tab("πΈ Upload Image"): | |
| image_input = gr.Image( | |
| label="Brain MRI Scan", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| height=300 | |
| ) | |
| with gr.Tab("π² Random Sample"): | |
| random_image = gr.Image( | |
| label="Sample Image", | |
| type="pil", | |
| height=300, | |
| interactive=False | |
| ) | |
| random_ground_truth = gr.Image( | |
| label="Ground Truth Mask", | |
| type="pil", | |
| height=300, | |
| interactive=False | |
| ) | |
| load_sample_btn = gr.Button("π² Load Random Sample", variant="secondary", size="lg") | |
| sample_status = gr.Textbox(label="Sample Status", interactive=False) | |
| gr.Markdown("### βοΈ Enhancement Options") | |
| use_tta = gr.Checkbox( | |
| label="π Test-Time Augmentation", | |
| value=True, | |
| info="Apply multiple augmentations for robust predictions" | |
| ) | |
| show_attention = gr.Checkbox( | |
| label="π₯ Attention Visualization", | |
| value=True, | |
| info="Generate attention heatmaps for interpretability" | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button( | |
| "π§ Analyze with Enhanced Model", | |
| variant="primary", | |
| scale=3, | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button("ποΈ Clear All", variant="secondary", scale=1) | |
| gr.HTML(""" | |
| <div class="feature-box"> | |
| <h4 style="margin-bottom: 15px;">π― Research Innovations</h4> | |
| <div class="enhancement-badge">Attention Gates</div> | |
| <div class="enhancement-badge">Test-Time Augmentation</div> | |
| <div class="enhancement-badge">Interpretability</div> | |
| <div class="enhancement-badge">Ground Truth Comparison</div> | |
| <p style="margin-top: 15px; font-size: 14px; opacity: 0.9;"> | |
| Advanced medical AI combining accuracy, robustness, and clinical interpretability | |
| </p> | |
| </div> | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Enhanced Analysis Results") | |
| output_image = gr.Image( | |
| label="Comprehensive Analysis Visualization", | |
| type="pil", | |
| height=600 | |
| ) | |
| with gr.Accordion("π Detailed Analysis Report", open=True): | |
| analysis_output = gr.Markdown( | |
| value="Upload a brain MRI image or load a random sample to test the enhanced Attention U-Net model.", | |
| elem_id="analysis" | |
| ) | |
| # Performance metrics section | |
| gr.HTML(""" | |
| <div style="margin-top: 40px;"> | |
| <h3 style="text-align: center; color: #4a5568; margin-bottom: 25px;">π Model Performance & Research Contributions</h3> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px; margin-bottom: 30px;"> | |
| <div class="metric-card"> | |
| <h4 style="color: white; margin-bottom: 10px;">π― Segmentation Accuracy</h4> | |
| <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">98.90%</div> | |
| <p style="font-size: 14px; opacity: 0.9;">Training accuracy on brain tumor dataset</p> | |
| </div> | |
| <div class="metric-card"> | |
| <h4 style="color: white; margin-bottom: 10px;">π Dice Score</h4> | |
| <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.8420</div> | |
| <p style="font-size: 14px; opacity: 0.9;">Overlap similarity coefficient</p> | |
| </div> | |
| <div class="metric-card"> | |
| <h4 style="color: white; margin-bottom: 10px;">π² IoU Score</h4> | |
| <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.7297</div> | |
| <p style="font-size: 14px; opacity: 0.9;">Intersection over Union metric</p> | |
| </div> | |
| <div class="metric-card"> | |
| <h4 style="color: white; margin-bottom: 10px;">β‘ Enhancement Features</h4> | |
| <div style="font-size: 20px; font-weight: bold; margin: 10px 0;">TTA + Attention</div> | |
| <p style="font-size: 14px; opacity: 0.9;">Advanced robustness & interpretability</p> | |
| </div> | |
| </div> | |
| </div> | |
| """) | |
| # Research contributions section | |
| gr.HTML(""" | |
| <div style="margin-top: 30px; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white;"> | |
| <h3 style="text-align: center; margin-bottom: 25px; color: white;">π Novel Research Contributions</h3> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px; margin-bottom: 20px;"> | |
| <div> | |
| <h4 style="margin-bottom: 15px; color: #ffd700;">π 1. Enhanced Boundary Detection</h4> | |
| <ul style="line-height: 1.8; margin-left: 20px;"> | |
| <li><strong>Problem:</strong> Traditional U-Net passes noisy features through skip connections</li> | |
| <li><strong>Solution:</strong> Attention gates filter irrelevant encoder features</li> | |
| <li><strong>Impact:</strong> Cleaner boundaries, reduced false positives</li> | |
| </ul> | |
| </div> | |
| <div> | |
| <h4 style="margin-bottom: 15px; color: #ffd700;">π 2. Test-Time Augmentation</h4> | |
| <ul style="line-height: 1.8; margin-left: 20px;"> | |
| <li><strong>Problem:</strong> Medical datasets are small, MRI scans vary across centers</li> | |
| <li><strong>Solution:</strong> Multiple augmentations averaged for robust predictions</li> | |
| <li><strong>Impact:</strong> Improved robustness without retraining</li> | |
| </ul> | |
| </div> | |
| <div> | |
| <h4 style="margin-bottom: 15px; color: #ffd700;">π₯ 3. Attention Visualization</h4> | |
| <ul style="line-height: 1.8; margin-left: 20px;"> | |
| <li><strong>Problem:</strong> Deep networks are "black boxes" for clinicians</li> | |
| <li><strong>Solution:</strong> Extract attention coefficients as interpretable heatmaps</li> | |
| <li><strong>Impact:</strong> Build clinical trust through transparency</li> | |
| </ul> | |
| </div> | |
| <div> | |
| <h4 style="margin-bottom: 15px; color: #ffd700;">β‘ 4. Efficient Implementation</h4> | |
| <ul style="line-height: 1.8; margin-left: 20px;"> | |
| <li><strong>Problem:</strong> Complex architectures are hard to deploy</li> | |
| <li><strong>Solution:</strong> Low-overhead enhancements within existing backbone</li> | |
| <li><strong>Impact:</strong> Practical for real-world medical workflows</li> | |
| </ul> | |
| </div> | |
| </div> | |
| <div style="text-align: center; padding-top: 20px; border-top: 2px solid rgba(255,255,255,0.3);"> | |
| <p style="font-size: 16px; font-weight: 600; margin-bottom: 10px;"> | |
| π― Research Gap Addressed: Accuracy + Robustness + Interpretability | |
| </p> | |
| <p style="font-size: 14px; opacity: 0.9;"> | |
| This combination tackles three major challenges in medical AI with minimal architectural changes | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| # Dataset and disclaimer section | |
| gr.HTML(""" | |
| <div style="margin-top: 30px; padding: 25px; background-color: #f7fafc; border-radius: 15px; border-left: 5px solid #667eea;"> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;"> | |
| <div> | |
| <h4 style="color: #667eea; margin-bottom: 15px;">π Dataset Information</h4> | |
| <p><strong>Source:</strong> Brain Tumor Segmentation (Kaggle)</p> | |
| <p><strong>Author:</strong> nikhilroxtomar</p> | |
| <p><strong>Structure:</strong> Images + Ground Truth Masks</p> | |
| <p><strong>Format:</strong> Grayscale MRI scans</p> | |
| <p><strong>Use Case:</strong> Medical image segmentation research</p> | |
| <p><strong>Ground Truth:</strong> Available for metric calculation</p> | |
| </div> | |
| <div> | |
| <h4 style="color: #dc2626; margin-bottom: 15px;">β οΈ Medical Disclaimer</h4> | |
| <p style="color: #dc2626; font-weight: 600; line-height: 1.5;"> | |
| This enhanced AI system is designed for <strong>research and educational purposes only</strong>.<br><br> | |
| While the model includes advanced features like attention visualization and test-time augmentation | |
| for improved accuracy and interpretability, all results must be validated by qualified medical professionals.<br><br> | |
| <strong>Not approved for clinical diagnosis or medical decision making.</strong> | |
| </p> | |
| </div> | |
| </div> | |
| <hr style="margin: 25px 0; border: none; border-top: 2px solid #e2e8f0;"> | |
| <p style="text-align: center; color: #4a5568; margin: 15px 0; font-weight: 600;"> | |
| π¬ Research-Grade Medical AI β’ Enhanced Interpretability β’ Robust Predictions β’ Ground Truth Validation | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| def analyze_with_ground_truth(image, gt_mask, use_tta, show_attention): | |
| """Wrapper function to handle ground truth comparison""" | |
| return predict_with_enhancements(image, gt_mask, use_tta, show_attention) | |
| def analyze_uploaded_image(image, use_tta, show_attention): | |
| """Wrapper function for uploaded images without ground truth""" | |
| return predict_with_enhancements(image, None, use_tta, show_attention) | |
| # Button event handlers | |
| analyze_btn.click( | |
| fn=lambda img, rand_img, rand_gt, tta, attention: ( | |
| analyze_with_ground_truth(rand_img, rand_gt, tta, attention) | |
| if rand_img is not None | |
| else analyze_uploaded_image(img, tta, attention) | |
| ), | |
| inputs=[image_input, random_image, random_ground_truth, use_tta, show_attention], | |
| outputs=[output_image, analysis_output], | |
| show_progress=True | |
| ) | |
| load_sample_btn.click( | |
| fn=load_random_sample, | |
| inputs=[], | |
| outputs=[random_image, random_ground_truth, sample_status], | |
| show_progress=True | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[image_input, random_image, random_ground_truth, analysis_output] | |
| ) | |
| # Auto-load dataset on startup | |
| gr.HTML(""" | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function() { | |
| console.log('Enhanced Brain Tumor Segmentation App Loaded'); | |
| console.log('Features: TTA + Attention Visualization + Ground Truth Comparison'); | |
| }); | |
| </script> | |
| """) | |
| if __name__ == "__main__": | |
| print("π Starting Enhanced Brain Tumor Segmentation System...") | |
| print("π Model Performance: Dice 0.8420, IoU 0.7297, Accuracy 98.90%") | |
| print("π¬ Research Features: Attention Gates + TTA + Interpretability") | |
| print("π₯ Auto-downloading dataset and model...") | |
| # Initialize dataset download | |
| print("π Initializing dataset...") | |
| try: | |
| dataset_path = download_dataset() | |
| if dataset_path: | |
| print(f"β Dataset ready at: {dataset_path}") | |
| else: | |
| print("β οΈ Dataset download failed, random samples unavailable") | |
| except Exception as e: | |
| print(f"β οΈ Dataset initialization error: {e}") | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) |