Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| import random | |
| import os | |
| import zipfile | |
| import urllib.request | |
| import kagglehub | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = None | |
| # Your Attention U-Net classes (from your code) | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionBlock, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi, psi # Return attention map as well | |
| class AttentionUNET(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): | |
| super(AttentionUNET, self).__init__() | |
| self.out_channels = out_channels | |
| self.ups = nn.ModuleList() | |
| self.downs = nn.ModuleList() | |
| self.attentions = nn.ModuleList() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Down part | |
| for feature in features: | |
| self.downs.append(DoubleConv(in_channels, feature)) | |
| in_channels = feature | |
| # Bottleneck | |
| self.bottleneck = DoubleConv(features[-1], features[-1]*2) | |
| # Up part | |
| for feature in reversed(features): | |
| self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) | |
| self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) | |
| self.ups.append(DoubleConv(feature*2, feature)) | |
| self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
| def forward(self, x): | |
| skip_connections = [] | |
| attention_maps = [] # To store attention maps | |
| for down in self.downs: | |
| x = down(x) | |
| skip_connections.append(x) | |
| x = self.pool(x) | |
| x = self.bottleneck(x) | |
| skip_connections = skip_connections[::-1] | |
| for idx in range(0, len(self.ups), 2): | |
| x = self.ups[idx](x) | |
| skip_connection = skip_connections[idx//2] | |
| if x.shape != skip_connection.shape: | |
| x = TF.resize(x, size=skip_connection.shape[2:]) | |
| attended_skip, att_map = self.attentions[idx // 2](x, skip_connection) # Get attention map | |
| attention_maps.append(att_map) # Store attention map | |
| concat_skip = torch.cat((attended_skip, x), dim=1) | |
| x = self.ups[idx+1](concat_skip) | |
| return self.final_conv(x), attention_maps | |
| def download_model(): | |
| """Download your trained model from HuggingFace""" | |
| model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" | |
| model_path = "best_attention_model.pth.tar" | |
| if not os.path.exists(model_path): | |
| print("📥 Downloading your trained model...") | |
| try: | |
| urllib.request.urlretrieve(model_url, model_path) | |
| print("✅ Model downloaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Failed to download model: {e}") | |
| return None | |
| return model_path | |
| def load_your_attention_model(): | |
| """Load YOUR trained Attention U-Net model""" | |
| global model | |
| if model is None: | |
| try: | |
| print("🔄 Loading your trained Attention U-Net model...") | |
| # Download model if needed | |
| model_path = download_model() | |
| if model_path is None: | |
| return None | |
| # Initialize your model architecture | |
| model = AttentionUNET(in_channels=1, out_channels=1).to(device) | |
| # Load your trained weights | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| print("✅ Your Attention U-Net model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading your model: {e}") | |
| model = None | |
| return model | |
| def download_dataset(): | |
| """Download and extract the dataset using kagglehub""" | |
| dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') | |
| # Extract if it's a zip | |
| extracted_path = "brain_tumor_dataset" | |
| if not os.path.exists(extracted_path): | |
| with zipfile.ZipFile(dataset_path, 'r') as zip_ref: | |
| zip_ref.extractall(extracted_path) | |
| images_path = os.path.join(extracted_path, 'images') | |
| masks_path = os.path.join(extracted_path, 'masks') | |
| return images_path, masks_path | |
| def load_random_sample(): | |
| """Load a random image and mask from the dataset""" | |
| images_path, masks_path = download_dataset() | |
| image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))] | |
| if not image_files: | |
| return None, None, "No images found in dataset" | |
| random_file = random.choice(image_files) | |
| img_path = os.path.join(images_path, random_file) | |
| mask_path = os.path.join(masks_path, random_file) | |
| image = Image.open(img_path).convert("L") | |
| mask = Image.open(mask_path).convert("L") if os.path.exists(mask_path) else None | |
| return image, mask, random_file | |
| def preprocess_for_your_model(image): | |
| """Preprocessing exactly like your Colab code""" | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((256,256)), | |
| transforms.ToTensor() | |
| ]) | |
| return val_test_transform(image).unsqueeze(0) | |
| def apply_tta(model, input_tensor): | |
| """Test-Time Augmentation: Apply augmentations and average predictions""" | |
| augmentations = [ | |
| lambda x: x, # Original | |
| lambda x: TF.rotate(x, 90), # 90 deg rotation | |
| lambda x: TF.rotate(x, -90), # -90 deg rotation | |
| lambda x: TF.hflip(x), # Horizontal flip | |
| lambda x: TF.vflip(x) # Vertical flip | |
| ] | |
| predictions = [] | |
| for aug in augmentations: | |
| aug_input = aug(input_tensor) | |
| pred = torch.sigmoid(model(aug_input)[0]) # Get prediction | |
| # Reverse the augmentation for averaging | |
| if aug == augmentations[1]: # Reverse 90 deg | |
| pred = TF.rotate(pred, -90) | |
| elif aug == augmentations[2]: # Reverse -90 deg | |
| pred = TF.rotate(pred, 90) | |
| elif aug == augmentations[3]: # Reverse hflip | |
| pred = TF.hflip(pred) | |
| elif aug == augmentations[4]: # Reverse vflip | |
| pred = TF.vflip(pred) | |
| predictions.append(pred) | |
| # Average predictions | |
| avg_pred = torch.mean(torch.stack(predictions), dim=0) | |
| return avg_pred | |
| def generate_attention_heatmap(attention_maps): | |
| """Generate combined attention heatmap""" | |
| if not attention_maps: | |
| return np.zeros((256, 256)) | |
| # Average attention maps from different levels | |
| combined_att = torch.mean(torch.stack(attention_maps), dim=0).squeeze().cpu().numpy() | |
| combined_att = cv2.resize(combined_att, (256, 256)) | |
| combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8) | |
| heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| return heatmap | |
| def predict_tumor(image, ground_truth=None, filename=None): | |
| current_model = load_your_attention_model() | |
| if current_model is None: | |
| return None, "Failed to load your trained model." | |
| if image is None: | |
| return None, "Please upload or load an image first." | |
| try: | |
| # Preprocess | |
| input_tensor = preprocess_for_your_model(image).to(device) | |
| # Apply TTA | |
| avg_pred = apply_tta(current_model, input_tensor) | |
| # Get binary mask | |
| binary_mask = (avg_pred > 0.5).float().squeeze().cpu().numpy() | |
| # Post-processing | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5)) | |
| binary_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel) | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) | |
| # Extract attention maps | |
| _, attention_maps = current_model(input_tensor) | |
| att_heatmap = generate_attention_heatmap(attention_maps) | |
| # Create visualization | |
| fig, axes = plt.subplots(2, 3, figsize=(18, 12)) | |
| fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=20) | |
| # Original | |
| axes[0,0].imshow(image, cmap='gray') | |
| axes[0,0].set_title('Original Image') | |
| axes[0,0].axis('off') | |
| # Attention Heatmap | |
| axes[0,1].imshow(np.array(image), cmap='gray') | |
| axes[0,1].imshow(att_heatmap, alpha=0.5) | |
| axes[0,1].set_title('Attention Heatmap') | |
| axes[0,1].axis('off') | |
| # Predicted Mask | |
| axes[0,2].imshow(binary_mask, cmap='gray') | |
| axes[0,2].set_title('Predicted Mask') | |
| axes[0,2].axis('off') | |
| # Ground Truth (if available) | |
| if ground_truth is not None: | |
| gt_np = np.array(ground_truth.resize((256, 256))) | |
| axes[1,0].imshow(gt_np, cmap='gray') | |
| axes[1,0].set_title('Ground Truth Mask') | |
| axes[1,0].axis('off') | |
| # Comparison Overlay | |
| overlay = np.array(image.convert('RGB')) | |
| overlay[binary_mask > 0] = [0, 255, 0] # Green for prediction | |
| overlay[gt_np > 0] = [255, 0, 0] # Red for ground truth | |
| axes[1,1].imshow(overlay) | |
| axes[1,1].set_title('Prediction (Green) vs GT (Red)') | |
| axes[1,1].axis('off') | |
| # IoU Calculation | |
| intersection = np.sum(binary_mask * (gt_np > 0)) | |
| union = np.sum(binary_mask) + np.sum(gt_np > 0) - intersection | |
| iou = intersection / (union + 1e-8) | |
| axes[1,2].text(0.1, 0.5, f'IoU Score: {iou:.4f}', fontsize=20) | |
| axes[1,2].axis('off') | |
| else: | |
| # Overlay for prediction only | |
| overlay = np.array(image.convert('RGB')) | |
| overlay[binary_mask > 0] = [255, 0, 0] | |
| axes[1,0].imshow(overlay) | |
| axes[1,0].set_title('Prediction Overlay') | |
| axes[1,0].axis('off') | |
| axes[1,1].axis('off') | |
| axes[1,2].axis('off') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| plt.close() | |
| result_image = Image.open(buf) | |
| # Statistics | |
| tumor_pixels = np.sum(binary_mask) | |
| total_pixels = binary_mask.size | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| analysis_text = f""" | |
| ## Brain Tumor Segmentation Results | |
| ### Detection Summary | |
| - Tumor Percentage: {tumor_percentage:.2f}% | |
| - Tumor Pixels: {tumor_pixels} | |
| - File: {filename if filename else 'Uploaded Image'} | |
| ### Model Information | |
| - Your Attention U-Net Model | |
| - Test-Time Augmentation: Applied | |
| - Attention Visualization: Included | |
| """ | |
| if ground_truth is not None: | |
| analysis_text += f"\n- IoU with Ground Truth: {iou:.4f}" | |
| return result_image, analysis_text | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def clear_all(): | |
| return None, None, None, "Upload or load an image for analysis" | |
| # Professional CSS (white, clean, professional) | |
| css = """ | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: auto !important; | |
| background-color: white !important; | |
| font-family: 'Arial', sans-serif !important; | |
| } | |
| h1, h2, h3, h4 { | |
| color: #333333 !important; | |
| } | |
| button { | |
| background-color: #f0f0f0 !important; | |
| color: #333333 !important; | |
| border: 1px solid #dddddd !important; | |
| border-radius: 4px !important; | |
| } | |
| button.primary { | |
| background-color: #007bff !important; | |
| color: white !important; | |
| } | |
| .output-image { | |
| border: 1px solid #dddddd !important; | |
| border-radius: 4px !important; | |
| } | |
| .markdown { | |
| line-height: 1.6 !important; | |
| color: #555555 !important; | |
| } | |
| """ | |
| # Create professional Gradio interface | |
| with gr.Blocks(css=css, title="Brain Tumor Segmentation Application") as app: | |
| gr.Markdown(""" | |
| # Brain Tumor Segmentation Using Attention U-Net | |
| A professional tool for medical image analysis | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Selection") | |
| image_input = gr.Image( | |
| label="Upload Brain MRI", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| height=300 | |
| ) | |
| load_random_btn = gr.Button("Load Random Sample from Dataset", variant="primary") | |
| with gr.Row(): | |
| analyze_btn = gr.Button("Analyze Image", variant="primary", scale=2) | |
| clear_btn = gr.Button("Clear", scale=1) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Analysis Results") | |
| output_image = gr.Image( | |
| label="Segmentation Results", | |
| type="pil", | |
| height=400 | |
| ) | |
| analysis_output = gr.Markdown( | |
| value="Select an input method to begin analysis." | |
| ) | |
| # Hidden state for ground truth and filename | |
| ground_truth_state = gr.State() | |
| filename_state = gr.State() | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=predict_tumor, | |
| inputs=[image_input, ground_truth_state, filename_state], | |
| outputs=[output_image, analysis_output] | |
| ) | |
| load_random_btn.click( | |
| fn=load_random_sample, | |
| inputs=[], | |
| outputs=[image_input, ground_truth_state, filename_state, analysis_output] | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[image_input, output_image, ground_truth_state, analysis_output] | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting Brain Tumor Segmentation Application...") | |
| app.launch() | |