ArchCoder's picture
Update app.py
c5d3869 verified
raw
history blame
15.1 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
from torchvision import transforms
import torchvision.transforms.functional as TF
import urllib.request
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
# Define your Attention U-Net architecture (from your training code)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNET(nn.Module):
def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
super(AttentionUNET, self).__init__()
self.out_channels = out_channels
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.attentions = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
# Up part of UNET
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
self.ups.append(DoubleConv(feature*2, feature))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1] #reverse list
for idx in range(0, len(self.ups), 2): #do up and double_conv
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
skip_connection = self.attentions[idx // 2](skip_connection, x)
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x)
def download_model():
"""Download your trained model from HuggingFace"""
model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
model_path = "best_attention_model.pth.tar"
if not os.path.exists(model_path):
print("πŸ“₯ Downloading your trained model...")
try:
urllib.request.urlretrieve(model_url, model_path)
print("βœ… Model downloaded successfully!")
except Exception as e:
print(f"❌ Failed to download model: {e}")
return None
else:
print("βœ… Model already exists!")
return model_path
def load_your_attention_model():
"""Load YOUR trained Attention U-Net model"""
global model
if model is None:
try:
print("πŸ”„ Loading your trained Attention U-Net model...")
# Download model if needed
model_path = download_model()
if model_path is None:
return None
# Initialize your model architecture
model = AttentionUNET(in_channels=1, out_channels=1).to(device)
# Load your trained weights
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("βœ… Your Attention U-Net model loaded successfully!")
except Exception as e:
print(f"❌ Error loading your model: {e}")
model = None
return model
def preprocess_for_your_model(image):
"""Preprocessing exactly like your Colab code"""
# Convert to grayscale (like your Colab code)
if image.mode != 'L':
image = image.convert('L')
# Use the exact same transform as your Colab code
val_test_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
return val_test_transform(image).unsqueeze(0) # Add batch dimension
def predict_tumor(image):
current_model = load_your_attention_model()
if current_model is None:
return None, "❌ Failed to load your trained model."
if image is None:
return None, "⚠️ Please upload an image first."
try:
print("🧠 Processing with YOUR trained Attention U-Net...")
# Use the exact preprocessing from your Colab code
input_tensor = preprocess_for_your_model(image).to(device)
# Predict using your model (exactly like your Colab code)
with torch.no_grad():
pred_mask = torch.sigmoid(current_model(input_tensor))
pred_mask_binary = (pred_mask > 0.5).float()
# Convert to numpy (like your Colab code)
pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
original_np = np.array(image.convert('L').resize((256, 256)))
# Create inverted mask for visualization (like your Colab code)
inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
# Create tumor-only image (like your Colab code)
tumor_only = np.where(pred_mask_np == 1, original_np, 255)
# Create visualization (matching your Colab 4-panel layout)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
fig.suptitle('🧠 Your Attention U-Net Results', fontsize=16, fontweight='bold')
titles = ["Original Image", "Tumor Segmentation", "Inverted Mask", "Tumor Only"]
images = [original_np, pred_mask_np * 255, inv_pred_mask_np, tumor_only]
cmaps = ['gray', 'hot', 'gray', 'gray']
for i, ax in enumerate(axes):
ax.imshow(images[i], cmap=cmaps[i])
ax.set_title(titles[i], fontsize=12, fontweight='bold')
ax.axis('off')
plt.tight_layout()
# Save result
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
buf.seek(0)
plt.close()
result_image = Image.open(buf)
# Calculate statistics (like your Colab code)
tumor_pixels = np.sum(pred_mask_np)
total_pixels = pred_mask_np.size
tumor_percentage = (tumor_pixels / total_pixels) * 100
# Calculate confidence metrics
max_confidence = torch.max(pred_mask).item()
mean_confidence = torch.mean(pred_mask).item()
analysis_text = f"""
## 🧠 Your Attention U-Net Analysis Results
### πŸ“Š Detection Summary:
- **Status**: {'πŸ”΄ TUMOR DETECTED' if tumor_pixels > 50 else '🟒 NO SIGNIFICANT TUMOR'}
- **Tumor Area**: {tumor_percentage:.2f}% of brain region
- **Tumor Pixels**: {tumor_pixels:,} pixels
- **Max Confidence**: {max_confidence:.4f}
- **Mean Confidence**: {mean_confidence:.4f}
### πŸ”¬ Your Model Information:
- **Architecture**: YOUR trained Attention U-Net
- **Training Performance**: Dice: 0.8420, IoU: 0.7297
- **Input**: Grayscale (single channel)
- **Output**: Binary segmentation mask
- **Device**: {device.type.upper()}
### 🎯 Model Performance:
- **Training Accuracy**: 98.90%
- **Best Dice Score**: 0.8420
- **Best IoU Score**: 0.7297
- **Training Dataset**: Brain tumor segmentation dataset
### πŸ“ˆ Processing Details:
- **Preprocessing**: Resize(256Γ—256) + ToTensor (your exact method)
- **Threshold**: 0.5 (sigmoid > 0.5)
- **Architecture**: Attention gates + Skip connections
- **Features**: [32, 64, 128, 256] channels
### ⚠️ Medical Disclaimer:
This is YOUR trained AI model for **research and educational purposes only**.
Results should be validated by medical professionals. Not for clinical diagnosis.
### πŸ† Model Quality:
βœ… This is your own trained model with proven {tumor_percentage:.2f}% detection capability!
"""
print(f"βœ… Your model analysis completed! Tumor area: {tumor_percentage:.2f}%")
return result_image, analysis_text
except Exception as e:
error_msg = f"❌ Error with your model: {str(e)}"
print(error_msg)
return None, error_msg
def clear_all():
return None, None, "Upload a brain MRI image to test YOUR trained Attention U-Net model"
# Enhanced CSS for your model
css = """
.gradio-container {
max-width: 1400px !important;
margin: auto !important;
}
#title {
text-align: center;
background: linear-gradient(135deg, #8B5CF6 0%, #7C3AED 100%);
color: white;
padding: 30px;
border-radius: 15px;
margin-bottom: 25px;
box-shadow: 0 8px 16px rgba(139, 92, 246, 0.3);
}
"""
# Create Gradio interface for your model
with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model", theme=gr.themes.Soft()) as app:
gr.HTML("""
<div id="title">
<h1>🧠 YOUR Attention U-Net Model</h1>
<p style="font-size: 18px; margin-top: 15px;">
Using Your Own Trained Model β€’ Dice: 0.8420 β€’ IoU: 0.7297
</p>
<p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
Loaded from: ArchCoder/the-op-segmenter HuggingFace Space
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Upload Brain MRI")
image_input = gr.Image(
label="Brain MRI Scan",
type="pil",
sources=["upload", "webcam"],
height=350
)
with gr.Row():
analyze_btn = gr.Button("πŸ” Analyze with YOUR Model", variant="primary", scale=2, size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
gr.HTML("""
<div style="margin-top: 20px; padding: 20px; background: linear-gradient(135deg, #F3E8FF 0%, #EDE9FE 100%); border-radius: 10px; border-left: 4px solid #8B5CF6;">
<h4 style="color: #8B5CF6; margin-bottom: 15px;">πŸ† Your Model Features:</h4>
<ul style="margin: 10px 0; padding-left: 20px; line-height: 1.6;">
<li><strong>Personal Model:</strong> Your own trained Attention U-Net</li>
<li><strong>Proven Performance:</strong> 84.2% Dice Score, 72.97% IoU</li>
<li><strong>Attention Gates:</strong> Advanced feature selection</li>
<li><strong>Clean Output:</strong> Binary segmentation masks</li>
<li><strong>4-Panel View:</strong> Complete analysis like your Colab</li>
</ul>
</div>
""")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Your Model Results")
output_image = gr.Image(
label="Your Attention U-Net Analysis",
type="pil",
height=500
)
analysis_output = gr.Markdown(
value="Upload a brain MRI image to test YOUR trained Attention U-Net model.",
elem_id="analysis"
)
# Footer highlighting your model
gr.HTML("""
<div style="margin-top: 30px; padding: 25px; background-color: #F8FAFC; border-radius: 15px; border: 2px solid #8B5CF6;">
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
<div>
<h4 style="color: #8B5CF6; margin-bottom: 15px;">πŸ† Your Personal AI Model</h4>
<p><strong>Architecture:</strong> Attention U-Net with skip connections</p>
<p><strong>Performance:</strong> Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%</p>
<p><strong>Training:</strong> Your own dataset-specific training</p>
<p><strong>Features:</strong> [32, 64, 128, 256] channel progression</p>
</div>
<div>
<h4 style="color: #DC2626; margin-bottom: 15px;">⚠️ Your Model Disclaimer</h4>
<p style="color: #DC2626; font-weight: 600; line-height: 1.4;">
This is YOUR personally trained AI model for <strong>research purposes only</strong>.<br>
Results reflect your model's training performance.<br>
Always validate with medical professionals for any clinical application.
</p>
</div>
</div>
<hr style="margin: 20px 0; border: none; border-top: 2px solid #E5E7EB;">
<p style="text-align: center; color: #6B7280; margin: 10px 0; font-weight: 600;">
πŸš€ Your Personal Attention U-Net β€’ Downloaded from HuggingFace β€’ Research-Grade Performance
</p>
</div>
""")
# Event handlers
analyze_btn.click(
fn=predict_tumor,
inputs=[image_input],
outputs=[output_image, analysis_output],
show_progress=True
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[image_input, output_image, analysis_output]
)
if __name__ == "__main__":
print("πŸš€ Starting YOUR Attention U-Net Model System...")
print("πŸ† Using your personally trained model")
print("πŸ“₯ Auto-downloading from HuggingFace...")
print("🎯 Expected performance: Dice 0.8420, IoU 0.7297")
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)