ArchCoder's picture
Update app.py
4f4b98a verified
raw
history blame
36.9 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
from torchvision import transforms
import torchvision.transforms.functional as TF
import urllib.request
import os
import kagglehub
import random
from pathlib import Path
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
dataset_path = None
# Define your Attention U-Net architecture (from your training code)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi, psi # Return attention coefficients for visualization
class AttentionUNET(nn.Module):
def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
super(AttentionUNET, self).__init__()
self.out_channels = out_channels
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.attentions = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
# Up part of UNET
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
self.ups.append(DoubleConv(feature*2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x, return_attention=False):
skip_connections = []
attention_maps = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
skip_connection, attention_coeff = self.attentions[idx // 2](skip_connection, x)
if return_attention:
attention_maps.append(attention_coeff)
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)
output = self.final_conv(x)
if return_attention:
return output, attention_maps
return output
def download_dataset():
"""Download Brain Tumor Segmentation dataset from Kaggle"""
global dataset_path
try:
print("πŸ“₯ Downloading Brain Tumor Segmentation dataset...")
dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
print(f"βœ… Dataset downloaded to: {dataset_path}")
return dataset_path
except Exception as e:
print(f"❌ Failed to download dataset: {e}")
return None
def download_model():
"""Download your trained model from HuggingFace"""
model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
model_path = "best_attention_model.pth.tar"
if not os.path.exists(model_path):
print("πŸ“₯ Downloading trained model...")
try:
urllib.request.urlretrieve(model_url, model_path)
print("βœ… Model downloaded successfully!")
except Exception as e:
print(f"❌ Failed to download model: {e}")
return None
else:
print("βœ… Model already exists!")
return model_path
def load_attention_model():
"""Load trained Attention U-Net model"""
global model
if model is None:
try:
print("πŸ”„ Loading Attention U-Net model...")
model_path = download_model()
if model_path is None:
return None
model = AttentionUNET(in_channels=1, out_channels=1).to(device)
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("βœ… Attention U-Net model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
model = None
return model
def get_random_sample_from_dataset():
"""Get a random sample image and ground truth mask from the dataset"""
global dataset_path
if dataset_path is None:
dataset_path = download_dataset()
if dataset_path is None:
return None, None
try:
images_path = Path(dataset_path) / "images"
masks_path = Path(dataset_path) / "masks"
if not images_path.exists() or not masks_path.exists():
print("❌ Dataset structure not found")
return None, None
# Get all image files
image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.png")) + list(images_path.glob("*.tif"))
if not image_files:
print("❌ No image files found in dataset")
return None, None
# Select random image
random_image_file = random.choice(image_files)
image_name = random_image_file.stem
# Find corresponding mask
possible_mask_extensions = ['.jpg', '.png', '.tif', '.gif']
mask_file = None
for ext in possible_mask_extensions:
potential_mask = masks_path / f"{image_name}{ext}"
if potential_mask.exists():
mask_file = potential_mask
break
if mask_file is None:
print(f"❌ No corresponding mask found for {image_name}")
return None, None
# Load image and mask
image = Image.open(random_image_file).convert('L')
mask = Image.open(mask_file).convert('L')
print(f"βœ… Loaded random sample: {image_name}")
return image, mask
except Exception as e:
print(f"❌ Error loading random sample: {e}")
return None, None
def test_time_augmentation(model, image_tensor):
"""Apply Test-Time Augmentation (TTA) for robust predictions"""
augmentations = [
lambda x: x, # Original
lambda x: torch.flip(x, dims=[3]), # Horizontal flip
lambda x: torch.flip(x, dims=[2]), # Vertical flip
lambda x: torch.flip(x, dims=[2, 3]), # Both flips
lambda x: torch.rot90(x, k=1, dims=[2, 3]), # 90Β° rotation
lambda x: torch.rot90(x, k=3, dims=[2, 3]), # 270Β° rotation
]
reverse_augmentations = [
lambda x: x, # Original
lambda x: torch.flip(x, dims=[3]), # Reverse horizontal flip
lambda x: torch.flip(x, dims=[2]), # Reverse vertical flip
lambda x: torch.flip(x, dims=[2, 3]), # Reverse both flips
lambda x: torch.rot90(x, k=3, dims=[2, 3]), # Reverse 90Β° rotation
lambda x: torch.rot90(x, k=1, dims=[2, 3]), # Reverse 270Β° rotation
]
predictions = []
with torch.no_grad():
for aug, rev_aug in zip(augmentations, reverse_augmentations):
# Apply augmentation
aug_input = aug(image_tensor)
# Get prediction
pred = torch.sigmoid(model(aug_input))
# Reverse augmentation on prediction
pred = rev_aug(pred)
predictions.append(pred)
# Average all predictions
tta_prediction = torch.mean(torch.stack(predictions), dim=0)
return tta_prediction
def generate_attention_heatmaps(model, image_tensor):
"""Generate attention heatmaps for interpretability"""
with torch.no_grad():
pred, attention_maps = model(image_tensor, return_attention=True)
# Convert attention maps to numpy for visualization
heatmaps = []
for i, att_map in enumerate(attention_maps):
# Resize attention map to match input size
att_map_resized = TF.resize(att_map, (256, 256))
att_np = att_map_resized.cpu().squeeze().numpy()
heatmaps.append(att_np)
return heatmaps
def preprocess_image(image):
"""Preprocessing exactly like training code"""
if image.mode != 'L':
image = image.convert('L')
val_test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
return val_test_transform(image).unsqueeze(0)
def calculate_metrics(pred_mask, ground_truth_mask):
"""Calculate Dice and IoU metrics"""
pred_binary = (pred_mask > 0.5).float()
gt_binary = (ground_truth_mask > 0.5).float()
# Dice coefficient
intersection = torch.sum(pred_binary * gt_binary)
dice = (2.0 * intersection) / (torch.sum(pred_binary) + torch.sum(gt_binary) + 1e-8)
# IoU
union = torch.sum(pred_binary) + torch.sum(gt_binary) - intersection
iou = intersection / (union + 1e-8)
return dice.item(), iou.item()
def predict_with_enhancements(image, ground_truth=None, use_tta=True, show_attention=True):
"""Enhanced prediction with TTA and attention visualization"""
current_model = load_attention_model()
if current_model is None:
return None, "❌ Failed to load trained model."
if image is None:
return None, "⚠️ Please upload an image first."
try:
print("🧠 Processing with enhanced Attention U-Net...")
input_tensor = preprocess_image(image).to(device)
# Standard prediction
with torch.no_grad():
standard_pred = torch.sigmoid(current_model(input_tensor))
# Test-Time Augmentation
if use_tta:
tta_pred = test_time_augmentation(current_model, input_tensor)
final_pred = tta_pred
else:
final_pred = standard_pred
# Generate attention heatmaps
attention_heatmaps = []
if show_attention:
attention_heatmaps = generate_attention_heatmaps(current_model, input_tensor)
# Convert predictions to binary
pred_mask_binary = (final_pred > 0.5).float()
pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
standard_mask_np = (standard_pred > 0.5).float().cpu().squeeze().numpy()
# Prepare images for visualization
original_np = np.array(image.convert('L').resize((256, 256)))
# Create comprehensive visualization
if ground_truth is not None:
# With ground truth comparison
gt_np = np.array(ground_truth.convert('L').resize((256, 256)))
gt_binary = (gt_np > 127).astype(np.float32) # Threshold ground truth
# Calculate metrics
gt_tensor = torch.tensor(gt_binary).unsqueeze(0).unsqueeze(0).to(device)
dice_score, iou_score = calculate_metrics(final_pred, gt_tensor)
# Create figure with ground truth comparison
n_cols = 6 if show_attention and attention_heatmaps else 5
fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8))
fig.suptitle('🧠 Enhanced Attention U-Net Analysis with Ground Truth Comparison', fontsize=16, weight='bold')
# Top row - Standard analysis
axes[0, 0].imshow(original_np, cmap='gray')
axes[0, 0].set_title('Original Image', fontsize=12, weight='bold')
axes[0, 0].axis('off')
axes[0, 1].imshow(standard_mask_np * 255, cmap='hot')
axes[0, 1].set_title('Standard Prediction', fontsize=12, weight='bold')
axes[0, 1].axis('off')
axes[0, 2].imshow(pred_mask_np * 255, cmap='hot')
axes[0, 2].set_title(f'{"TTA Enhanced" if use_tta else "Final Prediction"}', fontsize=12, weight='bold')
axes[0, 2].axis('off')
axes[0, 3].imshow(gt_binary * 255, cmap='hot')
axes[0, 3].set_title('Ground Truth', fontsize=12, weight='bold')
axes[0, 3].axis('off')
# Overlay comparison
overlay = original_np.copy()
overlay = np.stack([overlay, overlay, overlay], axis=-1)
overlay[pred_mask_np > 0.5] = [255, 0, 0] # Red for prediction
overlay[gt_binary > 0.5] = [0, 255, 0] # Green for ground truth
overlap = (pred_mask_np > 0.5) & (gt_binary > 0.5)
overlay[overlap] = [255, 255, 0] # Yellow for overlap
axes[0, 4].imshow(overlay.astype(np.uint8))
axes[0, 4].set_title('Overlay (Red:Pred, Green:GT, Yellow:Match)', fontsize=10, weight='bold')
axes[0, 4].axis('off')
if show_attention and attention_heatmaps:
# Show combined attention
combined_attention = np.mean(attention_heatmaps, axis=0)
axes[0, 5].imshow(combined_attention, cmap='jet', alpha=0.7)
axes[0, 5].imshow(original_np, cmap='gray', alpha=0.3)
axes[0, 5].set_title('Attention Heatmap', fontsize=12, weight='bold')
axes[0, 5].axis('off')
# Bottom row - Individual attention maps or detailed analysis
if show_attention and attention_heatmaps:
for i, heatmap in enumerate(attention_heatmaps[:n_cols]):
axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7)
axes[1, i].imshow(original_np, cmap='gray', alpha=0.3)
axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold')
axes[1, i].axis('off')
else:
# Show tumor extraction and analysis
tumor_only = np.where(pred_mask_np == 1, original_np, 255)
inv_mask = np.where(pred_mask_np == 1, 0, 255)
axes[1, 0].imshow(tumor_only, cmap='gray')
axes[1, 0].set_title('Tumor Extraction', fontsize=12, weight='bold')
axes[1, 0].axis('off')
axes[1, 1].imshow(inv_mask, cmap='gray')
axes[1, 1].set_title('Inverted Mask', fontsize=12, weight='bold')
axes[1, 1].axis('off')
# Difference map
diff_map = np.abs(pred_mask_np - gt_binary)
axes[1, 2].imshow(diff_map, cmap='Reds')
axes[1, 2].set_title('Difference Map', fontsize=12, weight='bold')
axes[1, 2].axis('off')
# Clear remaining axes
for j in range(3, n_cols):
axes[1, j].axis('off')
else:
# Without ground truth
n_cols = 5 if show_attention and attention_heatmaps else 4
fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8))
fig.suptitle('🧠 Enhanced Attention U-Net Analysis', fontsize=16, weight='bold')
# Top row
images = [original_np, standard_mask_np * 255, pred_mask_np * 255]
titles = ["Original Image", "Standard Prediction", f'{"TTA Enhanced" if use_tta else "Final Prediction"}']
cmaps = ['gray', 'hot', 'hot']
for i in range(3):
axes[0, i].imshow(images[i], cmap=cmaps[i])
axes[0, i].set_title(titles[i], fontsize=12, weight='bold')
axes[0, i].axis('off')
# Tumor extraction
tumor_only = np.where(pred_mask_np == 1, original_np, 255)
axes[0, 3].imshow(tumor_only, cmap='gray')
axes[0, 3].set_title('Tumor Extraction', fontsize=12, weight='bold')
axes[0, 3].axis('off')
if show_attention and attention_heatmaps:
combined_attention = np.mean(attention_heatmaps, axis=0)
axes[0, 4].imshow(combined_attention, cmap='jet', alpha=0.7)
axes[0, 4].imshow(original_np, cmap='gray', alpha=0.3)
axes[0, 4].set_title('Combined Attention', fontsize=12, weight='bold')
axes[0, 4].axis('off')
# Bottom row - Individual attention maps
if show_attention and attention_heatmaps:
for i, heatmap in enumerate(attention_heatmaps[:n_cols]):
axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7)
axes[1, i].imshow(original_np, cmap='gray', alpha=0.3)
axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold')
axes[1, i].axis('off')
else:
# Clear bottom row
for j in range(n_cols):
axes[1, j].axis('off')
plt.tight_layout()
# Save result
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
plt.close()
result_image = Image.open(buf)
# Calculate statistics
tumor_pixels = np.sum(pred_mask_np)
total_pixels = pred_mask_np.size
tumor_percentage = (tumor_pixels / total_pixels) * 100
max_confidence = torch.max(final_pred).item()
mean_confidence = torch.mean(final_pred).item()
# Enhanced analysis text
analysis_text = f"""
## 🧠 Enhanced Attention U-Net Analysis Results
### πŸ“Š Detection Summary
- **Status**: {'πŸ”΄ TUMOR DETECTED' if tumor_pixels > 50 else '🟒 NO SIGNIFICANT TUMOR'}
- **Tumor Coverage**: {tumor_percentage:.2f}% of brain region
- **Tumor Pixels**: {tumor_pixels:,} pixels
- **Max Confidence**: {max_confidence:.4f}
- **Mean Confidence**: {mean_confidence:.4f}
"""
if ground_truth is not None:
analysis_text += f"""
### 🎯 Ground Truth Comparison
- **Dice Score**: {dice_score:.4f} {'βœ… Excellent' if dice_score > 0.8 else '⚠️ Good' if dice_score > 0.6 else '❌ Poor'}
- **IoU Score**: {iou_score:.4f} {'βœ… Excellent' if iou_score > 0.7 else '⚠️ Good' if iou_score > 0.5 else '❌ Poor'}
- **Model Accuracy**: {'High precision match' if dice_score > 0.8 else 'Reasonable match' if dice_score > 0.6 else 'Needs improvement'}
"""
analysis_text += f"""
### πŸš€ Enhancement Features
- **Test-Time Augmentation**: {'βœ… Applied (6 augmentations averaged)' if use_tta else '❌ Disabled'}
- **Attention Visualization**: {'βœ… Generated attention heatmaps' if show_attention else '❌ Disabled'}
- **Boundary Enhancement**: {'βœ… TTA improves edge detection' if use_tta else '⚠️ Standard prediction only'}
- **Interpretability**: {'βœ… Attention gates show focus areas' if show_attention else '❌ Black box mode'}
### πŸ”¬ Model Architecture
- **Base Model**: Attention U-Net with skip connections
- **Training Performance**: Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%
- **Attention Gates**: 4 levels with soft attention mechanism
- **Features Channels**: [32, 64, 128, 256] progression
- **Device**: {device.type.upper()}
### πŸ“ˆ Enhanced Processing Pipeline
- **Preprocessing**: Resize(256Γ—256) + Normalization
- **Augmentations**: Flips (H,V), Rotations (90Β°,270Β°), Combined
- **Attention Fusion**: Multi-scale attention coefficient extraction
- **Post-processing**: Ensemble averaging + Binary thresholding (0.5)
### ⚠️ Medical Disclaimer
This enhanced AI model is for **research and educational purposes only**.
Results include advanced features for better accuracy and interpretability.
Always consult medical professionals for clinical applications.
### πŸ† Research Contributions
βœ… **Attention Gates**: Enhanced boundary detection through selective feature passing
βœ… **Test-Time Augmentation**: Robust predictions via ensemble averaging
βœ… **Interpretability**: Attention heatmaps for clinical trust and validation
βœ… **Efficiency**: No retraining required, minimal computational overhead
"""
print(f"βœ… Enhanced analysis completed! Tumor coverage: {tumor_percentage:.2f}%")
return result_image, analysis_text
except Exception as e:
error_msg = f"❌ Error during enhanced analysis: {str(e)}"
print(error_msg)
return None, error_msg
def load_random_sample():
"""Load a random sample from the dataset"""
image, mask = get_random_sample_from_dataset()
if image is None:
return None, None, "❌ Failed to load random sample from dataset"
return image, mask, "βœ… Random sample loaded from dataset"
def clear_all():
return None, None, None, "Upload a brain MRI image or load a random sample to test the enhanced model"
# Enhanced professional CSS
css = """
.gradio-container {
max-width: 1600px !important;
margin: auto !important;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
#title {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 40px;
border-radius: 20px;
margin-bottom: 30px;
box-shadow: 0 12px 24px rgba(102, 126, 234, 0.4);
}
.feature-box {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
border-radius: 15px;
padding: 25px;
margin: 15px 0;
color: white;
box-shadow: 0 8px 16px rgba(240, 147, 251, 0.3);
}
.metric-card {
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
border-radius: 12px;
padding: 20px;
text-align: center;
margin: 10px;
box-shadow: 0 6px 12px rgba(79, 172, 254, 0.3);
}
.enhancement-badge {
display: inline-block;
background: linear-gradient(45deg, #fa709a 0%, #fee140 100%);
color: white;
padding: 8px 16px;
border-radius: 25px;
margin: 5px;
font-weight: bold;
box-shadow: 0 4px 8px rgba(250, 112, 154, 0.3);
}
"""
# Create enhanced Gradio interface
with gr.Blocks(css=css, title="🧠 Enhanced Brain Tumor Segmentation", theme=gr.themes.Soft()) as app:
gr.HTML("""
<div id="title">
<h1>🧠 Enhanced Attention U-Net Brain Tumor Segmentation</h1>
<p style="font-size: 20px; margin-top: 20px; font-weight: 300;">
πŸš€ Advanced Medical AI with Test-Time Augmentation & Attention Visualization
</p>
<p style="font-size: 16px; margin-top: 15px; opacity: 0.9;">
πŸ“Š Performance: Dice 0.8420 β€’ IoU 0.7297 β€’ Accuracy 98.90% |
πŸ”¬ Research-Grade Interpretability & Robustness
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Input & Controls")
with gr.Tab("πŸ“Έ Upload Image"):
image_input = gr.Image(
label="Brain MRI Scan",
type="pil",
sources=["upload", "webcam"],
height=300
)
with gr.Tab("🎲 Random Sample"):
random_image = gr.Image(
label="Sample Image",
type="pil",
height=300,
interactive=False
)
random_ground_truth = gr.Image(
label="Ground Truth Mask",
type="pil",
height=300,
interactive=False
)
load_sample_btn = gr.Button("🎲 Load Random Sample", variant="secondary", size="lg")
sample_status = gr.Textbox(label="Sample Status", interactive=False)
gr.Markdown("### βš™οΈ Enhancement Options")
use_tta = gr.Checkbox(
label="πŸ”„ Test-Time Augmentation",
value=True,
info="Apply multiple augmentations for robust predictions"
)
show_attention = gr.Checkbox(
label="πŸ”₯ Attention Visualization",
value=True,
info="Generate attention heatmaps for interpretability"
)
with gr.Row():
analyze_btn = gr.Button(
"🧠 Analyze with Enhanced Model",
variant="primary",
scale=3,
size="lg"
)
clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", scale=1)
gr.HTML("""
<div class="feature-box">
<h4 style="margin-bottom: 15px;">🎯 Research Innovations</h4>
<div class="enhancement-badge">Attention Gates</div>
<div class="enhancement-badge">Test-Time Augmentation</div>
<div class="enhancement-badge">Interpretability</div>
<div class="enhancement-badge">Ground Truth Comparison</div>
<p style="margin-top: 15px; font-size: 14px; opacity: 0.9;">
Advanced medical AI combining accuracy, robustness, and clinical interpretability
</p>
</div>
""")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Enhanced Analysis Results")
output_image = gr.Image(
label="Comprehensive Analysis Visualization",
type="pil",
height=600
)
with gr.Accordion("πŸ“ˆ Detailed Analysis Report", open=True):
analysis_output = gr.Markdown(
value="Upload a brain MRI image or load a random sample to test the enhanced Attention U-Net model.",
elem_id="analysis"
)
# Performance metrics section
gr.HTML("""
<div style="margin-top: 40px;">
<h3 style="text-align: center; color: #4a5568; margin-bottom: 25px;">πŸ“Š Model Performance & Research Contributions</h3>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px; margin-bottom: 30px;">
<div class="metric-card">
<h4 style="color: white; margin-bottom: 10px;">🎯 Segmentation Accuracy</h4>
<div style="font-size: 24px; font-weight: bold; margin: 10px 0;">98.90%</div>
<p style="font-size: 14px; opacity: 0.9;">Training accuracy on brain tumor dataset</p>
</div>
<div class="metric-card">
<h4 style="color: white; margin-bottom: 10px;">πŸ“ Dice Score</h4>
<div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.8420</div>
<p style="font-size: 14px; opacity: 0.9;">Overlap similarity coefficient</p>
</div>
<div class="metric-card">
<h4 style="color: white; margin-bottom: 10px;">πŸ”² IoU Score</h4>
<div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.7297</div>
<p style="font-size: 14px; opacity: 0.9;">Intersection over Union metric</p>
</div>
<div class="metric-card">
<h4 style="color: white; margin-bottom: 10px;">⚑ Enhancement Features</h4>
<div style="font-size: 20px; font-weight: bold; margin: 10px 0;">TTA + Attention</div>
<p style="font-size: 14px; opacity: 0.9;">Advanced robustness & interpretability</p>
</div>
</div>
</div>
""")
# Research contributions section
gr.HTML("""
<div style="margin-top: 30px; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white;">
<h3 style="text-align: center; margin-bottom: 25px; color: white;">πŸš€ Novel Research Contributions</h3>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px; margin-bottom: 20px;">
<div>
<h4 style="margin-bottom: 15px; color: #ffd700;">πŸ” 1. Enhanced Boundary Detection</h4>
<ul style="line-height: 1.8; margin-left: 20px;">
<li><strong>Problem:</strong> Traditional U-Net passes noisy features through skip connections</li>
<li><strong>Solution:</strong> Attention gates filter irrelevant encoder features</li>
<li><strong>Impact:</strong> Cleaner boundaries, reduced false positives</li>
</ul>
</div>
<div>
<h4 style="margin-bottom: 15px; color: #ffd700;">πŸ”„ 2. Test-Time Augmentation</h4>
<ul style="line-height: 1.8; margin-left: 20px;">
<li><strong>Problem:</strong> Medical datasets are small, MRI scans vary across centers</li>
<li><strong>Solution:</strong> Multiple augmentations averaged for robust predictions</li>
<li><strong>Impact:</strong> Improved robustness without retraining</li>
</ul>
</div>
<div>
<h4 style="margin-bottom: 15px; color: #ffd700;">πŸ”₯ 3. Attention Visualization</h4>
<ul style="line-height: 1.8; margin-left: 20px;">
<li><strong>Problem:</strong> Deep networks are "black boxes" for clinicians</li>
<li><strong>Solution:</strong> Extract attention coefficients as interpretable heatmaps</li>
<li><strong>Impact:</strong> Build clinical trust through transparency</li>
</ul>
</div>
<div>
<h4 style="margin-bottom: 15px; color: #ffd700;">⚑ 4. Efficient Implementation</h4>
<ul style="line-height: 1.8; margin-left: 20px;">
<li><strong>Problem:</strong> Complex architectures are hard to deploy</li>
<li><strong>Solution:</strong> Low-overhead enhancements within existing backbone</li>
<li><strong>Impact:</strong> Practical for real-world medical workflows</li>
</ul>
</div>
</div>
<div style="text-align: center; padding-top: 20px; border-top: 2px solid rgba(255,255,255,0.3);">
<p style="font-size: 16px; font-weight: 600; margin-bottom: 10px;">
🎯 Research Gap Addressed: Accuracy + Robustness + Interpretability
</p>
<p style="font-size: 14px; opacity: 0.9;">
This combination tackles three major challenges in medical AI with minimal architectural changes
</p>
</div>
</div>
""")
# Dataset and disclaimer section
gr.HTML("""
<div style="margin-top: 30px; padding: 25px; background-color: #f7fafc; border-radius: 15px; border-left: 5px solid #667eea;">
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
<div>
<h4 style="color: #667eea; margin-bottom: 15px;">πŸ“š Dataset Information</h4>
<p><strong>Source:</strong> Brain Tumor Segmentation (Kaggle)</p>
<p><strong>Author:</strong> nikhilroxtomar</p>
<p><strong>Structure:</strong> Images + Ground Truth Masks</p>
<p><strong>Format:</strong> Grayscale MRI scans</p>
<p><strong>Use Case:</strong> Medical image segmentation research</p>
<p><strong>Ground Truth:</strong> Available for metric calculation</p>
</div>
<div>
<h4 style="color: #dc2626; margin-bottom: 15px;">⚠️ Medical Disclaimer</h4>
<p style="color: #dc2626; font-weight: 600; line-height: 1.5;">
This enhanced AI system is designed for <strong>research and educational purposes only</strong>.<br><br>
While the model includes advanced features like attention visualization and test-time augmentation
for improved accuracy and interpretability, all results must be validated by qualified medical professionals.<br><br>
<strong>Not approved for clinical diagnosis or medical decision making.</strong>
</p>
</div>
</div>
<hr style="margin: 25px 0; border: none; border-top: 2px solid #e2e8f0;">
<p style="text-align: center; color: #4a5568; margin: 15px 0; font-weight: 600;">
πŸ”¬ Research-Grade Medical AI β€’ Enhanced Interpretability β€’ Robust Predictions β€’ Ground Truth Validation
</p>
</div>
""")
# Event handlers
def analyze_with_ground_truth(image, gt_mask, use_tta, show_attention):
"""Wrapper function to handle ground truth comparison"""
return predict_with_enhancements(image, gt_mask, use_tta, show_attention)
def analyze_uploaded_image(image, use_tta, show_attention):
"""Wrapper function for uploaded images without ground truth"""
return predict_with_enhancements(image, None, use_tta, show_attention)
# Button event handlers
analyze_btn.click(
fn=lambda img, rand_img, rand_gt, tta, attention: (
analyze_with_ground_truth(rand_img, rand_gt, tta, attention)
if rand_img is not None
else analyze_uploaded_image(img, tta, attention)
),
inputs=[image_input, random_image, random_ground_truth, use_tta, show_attention],
outputs=[output_image, analysis_output],
show_progress=True
)
load_sample_btn.click(
fn=load_random_sample,
inputs=[],
outputs=[random_image, random_ground_truth, sample_status],
show_progress=True
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[image_input, random_image, random_ground_truth, analysis_output]
)
# Auto-load dataset on startup
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
console.log('Enhanced Brain Tumor Segmentation App Loaded');
console.log('Features: TTA + Attention Visualization + Ground Truth Comparison');
});
</script>
""")
if __name__ == "__main__":
print("πŸš€ Starting Enhanced Brain Tumor Segmentation System...")
print("πŸ“Š Model Performance: Dice 0.8420, IoU 0.7297, Accuracy 98.90%")
print("πŸ”¬ Research Features: Attention Gates + TTA + Interpretability")
print("πŸ“₯ Auto-downloading dataset and model...")
# Initialize dataset download
print("πŸ“š Initializing dataset...")
try:
dataset_path = download_dataset()
if dataset_path:
print(f"βœ… Dataset ready at: {dataset_path}")
else:
print("⚠️ Dataset download failed, random samples unavailable")
except Exception as e:
print(f"⚠️ Dataset initialization error: {e}")
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)