ArchCoder's picture
Update app.py
16f55d5 verified
raw
history blame
14.9 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
import torchvision.transforms.functional as TF
from torchvision import transforms
import os
import random
import urllib.request
import zipfile
import kagglehub
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
DATASET_PATH = "brain_tumor_dataset"
# Your model classes (from previous code)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi, psi # Return both attended features and attention map
class AttentionUNET(nn.Module):
def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
super(AttentionUNET, self).__init__()
self.out_channels = out_channels
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.attentions = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
# Up part
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
self.ups.append(DoubleConv(feature*2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
attention_maps = [] # To collect attention maps
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
attended, attn_map = self.attentions[idx // 2](x, skip_connection) # Get attention map
attention_maps.append(attn_map)
concat_skip = torch.cat((attended, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x), attention_maps
def download_model():
"""Download your trained model from HuggingFace"""
model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
model_path = "best_attention_model.pth.tar"
if not os.path.exists(model_path):
print("📥 Downloading your trained model...")
try:
urllib.request.urlretrieve(model_url, model_path)
print("✅ Model downloaded successfully!")
except Exception as e:
print(f"❌ Failed to download model: {e}")
return None
return model_path
def load_model():
"""Load your trained Attention U-Net model"""
global model
if model is None:
try:
print("🔄 Loading your trained Attention U-Net model...")
# Download model if needed
model_path = download_model()
if model_path is None:
return None
# Initialize your model architecture
model = AttentionUNET(in_channels=1, out_channels=1).to(device)
# Load your trained weights
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("✅ Your Attention U-Net model loaded successfully!")
except Exception as e:
print(f"❌ Error loading your model: {e}")
model = None
return model
def preprocess_image(image):
"""Preprocessing like your Colab code"""
# Convert to grayscale
if image.mode != 'L':
image = image.convert('L')
# Use your exact transform
val_test_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
return val_test_transform(image).unsqueeze(0) # Add batch dimension
def post_process_mask(pred_mask_np):
"""Post-processing with morphological operations (Novelty 1)"""
# Binarize
binary_mask = (pred_mask_np > 0.5).astype(np.uint8)
# Morphological opening to remove small noise
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
# Morphological closing to fill gaps
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
return binary_mask
def test_time_augmentation(input_tensor, model):
"""Test-Time Augmentation (Novelty 2)"""
predictions = []
# Original
pred, _ = model(input_tensor)
predictions.append(torch.sigmoid(pred))
# Horizontal flip
hflip = TF.hflip(input_tensor)
pred_h, _ = model(hflip)
pred_h = TF.hflip(pred_h)
predictions.append(torch.sigmoid(pred_h))
# Vertical flip
vflip = TF.vflip(input_tensor)
pred_v, _ = model(vflip)
pred_v = TF.vflip(pred_v)
predictions.append(torch.sigmoid(pred_v))
# Average predictions
avg_pred = torch.mean(torch.stack(predictions), dim=0)
return avg_pred.squeeze().cpu().numpy()
def generate_attention_heatmap(attention_maps, size=(256,256)):
"""Generate attention heatmap visualization (Novelty 3)"""
# Average attention maps from all levels
avg_attn = torch.mean(torch.cat([TF.resize(m, size) for m in attention_maps]), dim=0)
attn_np = avg_attn.squeeze().cpu().numpy()
# Normalize and apply colormap
attn_norm = (attn_np - attn_np.min()) / (attn_np.max() - attn_np.min() + 1e-8)
heatmap = plt.cm.hot(attn_norm)[:,:,:3] * 255
return heatmap.astype(np.uint8)
def download_dataset():
"""Download and extract the dataset if not present"""
if not os.path.exists(DATASET_PATH):
print("📥 Downloading brain tumor dataset...")
try:
path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
print(f"Dataset downloaded to: {path}")
# Extract if zipped
for file in os.listdir(path):
if file.endswith('.zip'):
with zipfile.ZipFile(os.path.join(path, file), 'r') as zip_ref:
zip_ref.extractall(DATASET_PATH)
print("✅ Dataset extracted!")
return True
except Exception as e:
print(f"❌ Failed to download dataset: {e}")
return False
print("✅ Dataset already exists!")
return True
def get_random_sample():
"""Get random image and mask from dataset"""
if not os.path.exists(DATASET_PATH):
if not download_dataset():
return None, None
images_path = os.path.join(DATASET_PATH, "images")
masks_path = os.path.join(DATASET_PATH, "masks")
image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))]
if not image_files:
return None, None
random_file = random.choice(image_files)
img_path = os.path.join(images_path, random_file)
mask_path = os.path.join(masks_path, random_file)
if not os.path.exists(mask_path):
return None, None
return Image.open(img_path), Image.open(mask_path)
def predict_tumor(image, use_tta=True, show_attention=True, is_dataset_sample=False, ground_truth=None):
current_model = load_your_attention_model()
if current_model is None:
return None, "Failed to load your trained model."
if image is None:
return None, "Please upload an image first."
try:
print("Processing image...")
input_tensor = preprocess_image(image).to(device)
# Use TTA if enabled
if use_tta:
pred_np = test_time_augmentation(input_tensor, current_model)
else:
pred, attn_maps = current_model(input_tensor)
pred_np = torch.sigmoid(pred).squeeze().cpu().numpy()
attn_maps = attn_maps if show_attention else None
# Post-processing
binary_mask = post_process_mask(pred_np)
# Generate attention heatmap if enabled
attention_heatmap = None
if show_attention and attn_maps:
attention_heatmap = generate_attention_heatmap(attn_maps)
# Create visualization
fig, axes = plt.subplots(1, 3 + int(show_attention) + int(is_dataset_sample and ground_truth is not None), figsize=(20, 5))
fig.suptitle('Brain Tumor Segmentation Results', fontsize=16)
# Original image
original_np = np.array(image.resize((256, 256)))
axes[0].imshow(original_np, cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')
# Predicted mask
axes[1].imshow(binary_mask * 255, cmap='gray')
axes[1].set_title('Predicted Mask')
axes[1].axis('off')
# Overlay
overlay = cv2.cvtColor(original_np, cv2.COLOR_GRAY2RGB) if len(original_np.shape) == 2 else original_np
overlay[binary_mask == 1] = [255, 0, 0] # Red for tumor
overlay = cv2.addWeighted(original_np, 0.7, overlay, 0.3, 0)
axes[2].imshow(overlay)
axes[2].set_title('Overlay')
axes[2].axis('off')
col = 3
if show_attention and attention_heatmap is not None:
axes[col].imshow(attention_heatmap)
axes[col].set_title('Attention Heatmap')
axes[col].axis('off')
col += 1
if is_dataset_sample and ground_truth is not None:
gt_np = np.array(ground_truth.resize((256, 256)))
axes[col].imshow(gt_np, cmap='gray')
axes[col].set_title('Ground Truth')
axes[col].axis('off')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
result_image = Image.open(buf)
# Statistics
tumor_pixels = np.sum(binary_mask)
total_pixels = binary_mask.size
tumor_percentage = (tumor_pixels / total_pixels) * 100
analysis_text = f"""
### Segmentation Statistics
- Tumor Area Percentage: {tumor_percentage:.2f}%
- Tumor Pixels: {tumor_pixels}
- Total Pixels: {total_pixels}
- TTA Used: {'Yes' if use_tta else 'No'}
- Attention Visualization: {'Yes' if show_attention else 'No'}
"""
if is_dataset_sample and ground_truth is not None:
gt_np = np.array(ground_truth.resize((256, 256)))
intersection = np.logical_and(binary_mask, gt_np > 0).sum()
union = np.logical_or(binary_mask, gt_np > 0).sum()
iou = intersection / (union + 1e-8)
dice = (2 * intersection) / (binary_mask.sum() + (gt_np > 0).sum() + 1e-8)
analysis_text += f"""
### Comparison with Ground Truth
- IoU Score: {iou:.4f}
- Dice Score: {dice:.4f}
"""
return result_image, analysis_text
except Exception as e:
return None, f"Error: {str(e)}"
def test_random_sample():
image, mask = get_random_sample()
if image is None:
return None, "Failed to load dataset sample. Please download dataset first."
return predict_tumor(image, use_tta=True, show_attention=True, is_dataset_sample=True, ground_truth=mask)
# Custom CSS for professional, minimalist look
css = """
body, .gradio-container { font-family: 'Arial', sans-serif; color: #333; }
h1, h2, h3, h4 { color: #2c3e50; font-weight: 500; }
.button { background-color: #3498db; color: white; border: none; border-radius: 4px; padding: 10px 20px; font-size: 16px; cursor: pointer; }
.button:hover { background-color: #2980b9; }
.card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 20px; background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
"""
with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
gr.Markdown("# Brain Tumor Segmentation Using Attention U-Net")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
image_input = gr.Image(label="Upload Image", type="pil")
with gr.Row():
predict_btn = gr.Button("Predict")
random_btn = gr.Button("Test Random Sample")
download_btn = gr.Button("Download Dataset")
with gr.Column(scale=2):
gr.Markdown("### Output")
output_image = gr.Image(label="Result")
analysis_output = gr.Textbox(label="Analysis", lines=10)
# Event handlers
predict_btn.click(
fn=predict_tumor,
inputs=[image_input],
outputs=[output_image, analysis_output]
)
random_btn.click(
fn=test_random_sample,
inputs=[],
outputs=[output_image, analysis_output]
)
download_btn.click(
fn=download_dataset,
inputs=[],
outputs=gr.Textbox(value="Dataset download status...")
)
if __name__ == "__main__":
app.launch()