ArchCoder's picture
Update app.py
fa07b35 verified
raw
history blame
15.4 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import random
import os
import zipfile
import urllib.request
import kagglehub
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
# Your Attention U-Net classes (from your code)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi, psi # Return attention map as well
class AttentionUNET(nn.Module):
def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
super(AttentionUNET, self).__init__()
self.out_channels = out_channels
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.attentions = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
# Up part
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
self.ups.append(DoubleConv(feature*2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
attention_maps = [] # To store attention maps
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
attended_skip, att_map = self.attentions[idx // 2](x, skip_connection) # Get attention map
attention_maps.append(att_map) # Store attention map
concat_skip = torch.cat((attended_skip, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x), attention_maps
def download_model():
"""Download your trained model from HuggingFace"""
model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
model_path = "best_attention_model.pth.tar"
if not os.path.exists(model_path):
print("📥 Downloading your trained model...")
try:
urllib.request.urlretrieve(model_url, model_path)
print("✅ Model downloaded successfully!")
except Exception as e:
print(f"❌ Failed to download model: {e}")
return None
return model_path
def load_your_attention_model():
"""Load YOUR trained Attention U-Net model"""
global model
if model is None:
try:
print("🔄 Loading your trained Attention U-Net model...")
# Download model if needed
model_path = download_model()
if model_path is None:
return None
# Initialize your model architecture
model = AttentionUNET(in_channels=1, out_channels=1).to(device)
# Load your trained weights
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("✅ Your Attention U-Net model loaded successfully!")
except Exception as e:
print(f"❌ Error loading your model: {e}")
model = None
return model
def download_dataset():
"""Download and extract the dataset using kagglehub"""
dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
# Extract if it's a zip
extracted_path = "brain_tumor_dataset"
if not os.path.exists(extracted_path):
with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
zip_ref.extractall(extracted_path)
images_path = os.path.join(extracted_path, 'images')
masks_path = os.path.join(extracted_path, 'masks')
return images_path, masks_path
def load_random_sample():
"""Load a random image and mask from the dataset"""
images_path, masks_path = download_dataset()
image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))]
if not image_files:
return None, None, "No images found in dataset"
random_file = random.choice(image_files)
img_path = os.path.join(images_path, random_file)
mask_path = os.path.join(masks_path, random_file)
image = Image.open(img_path).convert("L")
mask = Image.open(mask_path).convert("L") if os.path.exists(mask_path) else None
return image, mask, random_file
def preprocess_for_your_model(image):
"""Preprocessing exactly like your Colab code"""
if image.mode != 'L':
image = image.convert('L')
val_test_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
return val_test_transform(image).unsqueeze(0)
def apply_tta(model, input_tensor):
"""Test-Time Augmentation: Apply augmentations and average predictions"""
augmentations = [
lambda x: x, # Original
lambda x: TF.rotate(x, 90), # 90 deg rotation
lambda x: TF.rotate(x, -90), # -90 deg rotation
lambda x: TF.hflip(x), # Horizontal flip
lambda x: TF.vflip(x) # Vertical flip
]
predictions = []
for aug in augmentations:
aug_input = aug(input_tensor)
pred = torch.sigmoid(model(aug_input)[0]) # Get prediction
# Reverse the augmentation for averaging
if aug == augmentations[1]: # Reverse 90 deg
pred = TF.rotate(pred, -90)
elif aug == augmentations[2]: # Reverse -90 deg
pred = TF.rotate(pred, 90)
elif aug == augmentations[3]: # Reverse hflip
pred = TF.hflip(pred)
elif aug == augmentations[4]: # Reverse vflip
pred = TF.vflip(pred)
predictions.append(pred)
# Average predictions
avg_pred = torch.mean(torch.stack(predictions), dim=0)
return avg_pred
def generate_attention_heatmap(attention_maps):
"""Generate combined attention heatmap"""
if not attention_maps:
return np.zeros((256, 256))
# Average attention maps from different levels
combined_att = torch.mean(torch.stack(attention_maps), dim=0).squeeze().cpu().numpy()
combined_att = cv2.resize(combined_att, (256, 256))
combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
return heatmap
def predict_tumor(image, ground_truth=None, filename=None):
current_model = load_your_attention_model()
if current_model is None:
return None, "Failed to load your trained model."
if image is None:
return None, "Please upload or load an image first."
try:
# Preprocess
input_tensor = preprocess_for_your_model(image).to(device)
# Apply TTA
avg_pred = apply_tta(current_model, input_tensor)
# Get binary mask
binary_mask = (avg_pred > 0.5).float().squeeze().cpu().numpy()
# Post-processing
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
binary_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
# Extract attention maps
_, attention_maps = current_model(input_tensor)
att_heatmap = generate_attention_heatmap(attention_maps)
# Create visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=20)
# Original
axes[0,0].imshow(image, cmap='gray')
axes[0,0].set_title('Original Image')
axes[0,0].axis('off')
# Attention Heatmap
axes[0,1].imshow(np.array(image), cmap='gray')
axes[0,1].imshow(att_heatmap, alpha=0.5)
axes[0,1].set_title('Attention Heatmap')
axes[0,1].axis('off')
# Predicted Mask
axes[0,2].imshow(binary_mask, cmap='gray')
axes[0,2].set_title('Predicted Mask')
axes[0,2].axis('off')
# Ground Truth (if available)
if ground_truth is not None:
gt_np = np.array(ground_truth.resize((256, 256)))
axes[1,0].imshow(gt_np, cmap='gray')
axes[1,0].set_title('Ground Truth Mask')
axes[1,0].axis('off')
# Comparison Overlay
overlay = np.array(image.convert('RGB'))
overlay[binary_mask > 0] = [0, 255, 0] # Green for prediction
overlay[gt_np > 0] = [255, 0, 0] # Red for ground truth
axes[1,1].imshow(overlay)
axes[1,1].set_title('Prediction (Green) vs GT (Red)')
axes[1,1].axis('off')
# IoU Calculation
intersection = np.sum(binary_mask * (gt_np > 0))
union = np.sum(binary_mask) + np.sum(gt_np > 0) - intersection
iou = intersection / (union + 1e-8)
axes[1,2].text(0.1, 0.5, f'IoU Score: {iou:.4f}', fontsize=20)
axes[1,2].axis('off')
else:
# Overlay for prediction only
overlay = np.array(image.convert('RGB'))
overlay[binary_mask > 0] = [255, 0, 0]
axes[1,0].imshow(overlay)
axes[1,0].set_title('Prediction Overlay')
axes[1,0].axis('off')
axes[1,1].axis('off')
axes[1,2].axis('off')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
plt.close()
result_image = Image.open(buf)
# Statistics
tumor_pixels = np.sum(binary_mask)
total_pixels = binary_mask.size
tumor_percentage = (tumor_pixels / total_pixels) * 100
analysis_text = f"""
## Brain Tumor Segmentation Results
### Detection Summary
- Tumor Percentage: {tumor_percentage:.2f}%
- Tumor Pixels: {tumor_pixels}
- File: {filename if filename else 'Uploaded Image'}
### Model Information
- Your Attention U-Net Model
- Test-Time Augmentation: Applied
- Attention Visualization: Included
"""
if ground_truth is not None:
analysis_text += f"\n- IoU with Ground Truth: {iou:.4f}"
return result_image, analysis_text
except Exception as e:
return None, f"Error: {str(e)}"
def clear_all():
return None, None, None, "Upload or load an image for analysis"
# Professional CSS (white, clean, professional)
css = """
.gradio-container {
max-width: 1400px !important;
margin: auto !important;
background-color: white !important;
font-family: 'Arial', sans-serif !important;
}
h1, h2, h3, h4 {
color: #333333 !important;
}
button {
background-color: #f0f0f0 !important;
color: #333333 !important;
border: 1px solid #dddddd !important;
border-radius: 4px !important;
}
button.primary {
background-color: #007bff !important;
color: white !important;
}
.output-image {
border: 1px solid #dddddd !important;
border-radius: 4px !important;
}
.markdown {
line-height: 1.6 !important;
color: #555555 !important;
}
"""
# Create professional Gradio interface
with gr.Blocks(css=css, title="Brain Tumor Segmentation Application") as app:
gr.Markdown("""
# Brain Tumor Segmentation Using Attention U-Net
A professional tool for medical image analysis
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Selection")
image_input = gr.Image(
label="Upload Brain MRI",
type="pil",
sources=["upload", "webcam"],
height=300
)
load_random_btn = gr.Button("Load Random Sample from Dataset", variant="primary")
with gr.Row():
analyze_btn = gr.Button("Analyze Image", variant="primary", scale=2)
clear_btn = gr.Button("Clear", scale=1)
with gr.Column(scale=2):
gr.Markdown("### Analysis Results")
output_image = gr.Image(
label="Segmentation Results",
type="pil",
height=400
)
analysis_output = gr.Markdown(
value="Select an input method to begin analysis."
)
# Hidden state for ground truth and filename
ground_truth_state = gr.State()
filename_state = gr.State()
# Event handlers
analyze_btn.click(
fn=predict_tumor,
inputs=[image_input, ground_truth_state, filename_state],
outputs=[output_image, analysis_output]
)
load_random_btn.click(
fn=load_random_sample,
inputs=[],
outputs=[image_input, ground_truth_state, filename_state, analysis_output]
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[image_input, output_image, ground_truth_state, analysis_output]
)
if __name__ == "__main__":
print("Starting Brain Tumor Segmentation Application...")
app.launch()