Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| import torchvision.transforms.functional as TF | |
| from torchvision import transforms | |
| import os | |
| import random | |
| import urllib.request | |
| import zipfile | |
| import kagglehub | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = None | |
| DATASET_PATH = "brain_tumor_dataset" | |
| # Your model classes (from previous code) | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionBlock, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi, psi # Return both attended features and attention map | |
| class AttentionUNET(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): | |
| super(AttentionUNET, self).__init__() | |
| self.out_channels = out_channels | |
| self.ups = nn.ModuleList() | |
| self.downs = nn.ModuleList() | |
| self.attentions = nn.ModuleList() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Down part | |
| for feature in features: | |
| self.downs.append(DoubleConv(in_channels, feature)) | |
| in_channels = feature | |
| # Bottleneck | |
| self.bottleneck = DoubleConv(features[-1], features[-1]*2) | |
| # Up part | |
| for feature in reversed(features): | |
| self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) | |
| self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) | |
| self.ups.append(DoubleConv(feature*2, feature)) | |
| self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
| def forward(self, x): | |
| skip_connections = [] | |
| attention_maps = [] # To collect attention maps | |
| for down in self.downs: | |
| x = down(x) | |
| skip_connections.append(x) | |
| x = self.pool(x) | |
| x = self.bottleneck(x) | |
| skip_connections = skip_connections[::-1] | |
| for idx in range(0, len(self.ups), 2): | |
| x = self.ups[idx](x) | |
| skip_connection = skip_connections[idx//2] | |
| if x.shape != skip_connection.shape: | |
| x = TF.resize(x, size=skip_connection.shape[2:]) | |
| attended, attn_map = self.attentions[idx // 2](x, skip_connection) # Get attention map | |
| attention_maps.append(attn_map) | |
| concat_skip = torch.cat((attended, x), dim=1) | |
| x = self.ups[idx+1](concat_skip) | |
| return self.final_conv(x), attention_maps | |
| def download_model(): | |
| """Download your trained model from HuggingFace""" | |
| model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" | |
| model_path = "best_attention_model.pth.tar" | |
| if not os.path.exists(model_path): | |
| print("📥 Downloading your trained model...") | |
| try: | |
| urllib.request.urlretrieve(model_url, model_path) | |
| print("✅ Model downloaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Failed to download model: {e}") | |
| return None | |
| return model_path | |
| def load_model(): | |
| """Load your trained Attention U-Net model""" | |
| global model | |
| if model is None: | |
| try: | |
| print("🔄 Loading your trained Attention U-Net model...") | |
| # Download model if needed | |
| model_path = download_model() | |
| if model_path is None: | |
| return None | |
| # Initialize your model architecture | |
| model = AttentionUNET(in_channels=1, out_channels=1).to(device) | |
| # Load your trained weights | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| print("✅ Your Attention U-Net model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading your model: {e}") | |
| model = None | |
| return model | |
| def preprocess_image(image): | |
| """Preprocessing like your Colab code""" | |
| # Convert to grayscale | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| # Use your exact transform | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((256,256)), | |
| transforms.ToTensor() | |
| ]) | |
| return val_test_transform(image).unsqueeze(0) # Add batch dimension | |
| def post_process_mask(pred_mask_np): | |
| """Post-processing with morphological operations (Novelty 1)""" | |
| # Binarize | |
| binary_mask = (pred_mask_np > 0.5).astype(np.uint8) | |
| # Morphological opening to remove small noise | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5)) | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) | |
| # Morphological closing to fill gaps | |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) | |
| return binary_mask | |
| def test_time_augmentation(input_tensor, model): | |
| """Test-Time Augmentation (Novelty 2)""" | |
| predictions = [] | |
| # Original | |
| pred, _ = model(input_tensor) | |
| predictions.append(torch.sigmoid(pred)) | |
| # Horizontal flip | |
| hflip = TF.hflip(input_tensor) | |
| pred_h, _ = model(hflip) | |
| pred_h = TF.hflip(pred_h) | |
| predictions.append(torch.sigmoid(pred_h)) | |
| # Vertical flip | |
| vflip = TF.vflip(input_tensor) | |
| pred_v, _ = model(vflip) | |
| pred_v = TF.vflip(pred_v) | |
| predictions.append(torch.sigmoid(pred_v)) | |
| # Average predictions | |
| avg_pred = torch.mean(torch.stack(predictions), dim=0) | |
| return avg_pred.squeeze().cpu().numpy() | |
| def generate_attention_heatmap(attention_maps, size=(256,256)): | |
| """Generate attention heatmap visualization (Novelty 3)""" | |
| # Average attention maps from all levels | |
| avg_attn = torch.mean(torch.cat([TF.resize(m, size) for m in attention_maps]), dim=0) | |
| attn_np = avg_attn.squeeze().cpu().numpy() | |
| # Normalize and apply colormap | |
| attn_norm = (attn_np - attn_np.min()) / (attn_np.max() - attn_np.min() + 1e-8) | |
| heatmap = plt.cm.hot(attn_norm)[:,:,:3] * 255 | |
| return heatmap.astype(np.uint8) | |
| def download_dataset(): | |
| """Download and extract the dataset if not present""" | |
| if not os.path.exists(DATASET_PATH): | |
| print("📥 Downloading brain tumor dataset...") | |
| try: | |
| path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') | |
| print(f"Dataset downloaded to: {path}") | |
| # Extract if zipped | |
| for file in os.listdir(path): | |
| if file.endswith('.zip'): | |
| with zipfile.ZipFile(os.path.join(path, file), 'r') as zip_ref: | |
| zip_ref.extractall(DATASET_PATH) | |
| print("✅ Dataset extracted!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Failed to download dataset: {e}") | |
| return False | |
| print("✅ Dataset already exists!") | |
| return True | |
| def get_random_sample(): | |
| """Get random image and mask from dataset""" | |
| if not os.path.exists(DATASET_PATH): | |
| if not download_dataset(): | |
| return None, None | |
| images_path = os.path.join(DATASET_PATH, "images") | |
| masks_path = os.path.join(DATASET_PATH, "masks") | |
| image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))] | |
| if not image_files: | |
| return None, None | |
| random_file = random.choice(image_files) | |
| img_path = os.path.join(images_path, random_file) | |
| mask_path = os.path.join(masks_path, random_file) | |
| if not os.path.exists(mask_path): | |
| return None, None | |
| return Image.open(img_path), Image.open(mask_path) | |
| def predict_tumor(image, use_tta=True, show_attention=True, is_dataset_sample=False, ground_truth=None): | |
| current_model = load_your_attention_model() | |
| if current_model is None: | |
| return None, "Failed to load your trained model." | |
| if image is None: | |
| return None, "Please upload an image first." | |
| try: | |
| print("Processing image...") | |
| input_tensor = preprocess_image(image).to(device) | |
| # Use TTA if enabled | |
| if use_tta: | |
| pred_np = test_time_augmentation(input_tensor, current_model) | |
| else: | |
| pred, attn_maps = current_model(input_tensor) | |
| pred_np = torch.sigmoid(pred).squeeze().cpu().numpy() | |
| attn_maps = attn_maps if show_attention else None | |
| # Post-processing | |
| binary_mask = post_process_mask(pred_np) | |
| # Generate attention heatmap if enabled | |
| attention_heatmap = None | |
| if show_attention and attn_maps: | |
| attention_heatmap = generate_attention_heatmap(attn_maps) | |
| # Create visualization | |
| fig, axes = plt.subplots(1, 3 + int(show_attention) + int(is_dataset_sample and ground_truth is not None), figsize=(20, 5)) | |
| fig.suptitle('Brain Tumor Segmentation Results', fontsize=16) | |
| # Original image | |
| original_np = np.array(image.resize((256, 256))) | |
| axes[0].imshow(original_np, cmap='gray') | |
| axes[0].set_title('Original Image') | |
| axes[0].axis('off') | |
| # Predicted mask | |
| axes[1].imshow(binary_mask * 255, cmap='gray') | |
| axes[1].set_title('Predicted Mask') | |
| axes[1].axis('off') | |
| # Overlay | |
| overlay = cv2.cvtColor(original_np, cv2.COLOR_GRAY2RGB) if len(original_np.shape) == 2 else original_np | |
| overlay[binary_mask == 1] = [255, 0, 0] # Red for tumor | |
| overlay = cv2.addWeighted(original_np, 0.7, overlay, 0.3, 0) | |
| axes[2].imshow(overlay) | |
| axes[2].set_title('Overlay') | |
| axes[2].axis('off') | |
| col = 3 | |
| if show_attention and attention_heatmap is not None: | |
| axes[col].imshow(attention_heatmap) | |
| axes[col].set_title('Attention Heatmap') | |
| axes[col].axis('off') | |
| col += 1 | |
| if is_dataset_sample and ground_truth is not None: | |
| gt_np = np.array(ground_truth.resize((256, 256))) | |
| axes[col].imshow(gt_np, cmap='gray') | |
| axes[col].set_title('Ground Truth') | |
| axes[col].axis('off') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| result_image = Image.open(buf) | |
| # Statistics | |
| tumor_pixels = np.sum(binary_mask) | |
| total_pixels = binary_mask.size | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| analysis_text = f""" | |
| ### Segmentation Statistics | |
| - Tumor Area Percentage: {tumor_percentage:.2f}% | |
| - Tumor Pixels: {tumor_pixels} | |
| - Total Pixels: {total_pixels} | |
| - TTA Used: {'Yes' if use_tta else 'No'} | |
| - Attention Visualization: {'Yes' if show_attention else 'No'} | |
| """ | |
| if is_dataset_sample and ground_truth is not None: | |
| gt_np = np.array(ground_truth.resize((256, 256))) | |
| intersection = np.logical_and(binary_mask, gt_np > 0).sum() | |
| union = np.logical_or(binary_mask, gt_np > 0).sum() | |
| iou = intersection / (union + 1e-8) | |
| dice = (2 * intersection) / (binary_mask.sum() + (gt_np > 0).sum() + 1e-8) | |
| analysis_text += f""" | |
| ### Comparison with Ground Truth | |
| - IoU Score: {iou:.4f} | |
| - Dice Score: {dice:.4f} | |
| """ | |
| return result_image, analysis_text | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def test_random_sample(): | |
| image, mask = get_random_sample() | |
| if image is None: | |
| return None, "Failed to load dataset sample. Please download dataset first." | |
| return predict_tumor(image, use_tta=True, show_attention=True, is_dataset_sample=True, ground_truth=mask) | |
| # Custom CSS for professional, minimalist look | |
| css = """ | |
| body, .gradio-container { font-family: 'Arial', sans-serif; color: #333; } | |
| h1, h2, h3, h4 { color: #2c3e50; font-weight: 500; } | |
| .button { background-color: #3498db; color: white; border: none; border-radius: 4px; padding: 10px 20px; font-size: 16px; cursor: pointer; } | |
| .button:hover { background-color: #2980b9; } | |
| .card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 20px; background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } | |
| """ | |
| with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app: | |
| gr.Markdown("# Brain Tumor Segmentation Using Attention U-Net") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| image_input = gr.Image(label="Upload Image", type="pil") | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict") | |
| random_btn = gr.Button("Test Random Sample") | |
| download_btn = gr.Button("Download Dataset") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Output") | |
| output_image = gr.Image(label="Result") | |
| analysis_output = gr.Textbox(label="Analysis", lines=10) | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict_tumor, | |
| inputs=[image_input], | |
| outputs=[output_image, analysis_output] | |
| ) | |
| random_btn.click( | |
| fn=test_random_sample, | |
| inputs=[], | |
| outputs=[output_image, analysis_output] | |
| ) | |
| download_btn.click( | |
| fn=download_dataset, | |
| inputs=[], | |
| outputs=gr.Textbox(value="Dataset download status...") | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |