Brain_Mri / app.py
Gajendra5490's picture
Update app.py
70f1082 verified
import gradio as gr
import os
import nibabel as nib
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import cv2
import tempfile
import gradio as gr
import os
import nibabel as nib
import numpy as np
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import tempfile
import json
import io
from PIL import Image
import time
from tqdm import tqdm
from datetime import datetime, timezone
from pathlib import Path
# Define the model class (same as during training)
class ImprovedUNet3D(nn.Module):
def __init__(self, in_channels=4, out_channels=2, base_filters=16):
super(ImprovedUNet3D, self).__init__()
self.enc1 = self._make_layer(in_channels, base_filters)
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.enc2 = self._make_layer(base_filters, base_filters*2)
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
self.enc3 = self._make_layer(base_filters*2, base_filters*4)
self.upconv2 = nn.ConvTranspose3d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
self.dec2 = self._make_layer(base_filters*4, base_filters*2)
self.upconv1 = nn.ConvTranspose3d(base_filters*2, base_filters, kernel_size=2, stride=2)
self.dec1 = self._make_layer(base_filters*2, base_filters)
self.dropout = nn.Dropout3d(0.3)
self.final_conv = nn.Conv3d(base_filters, out_channels, kernel_size=1)
self._initialize_weights()
def _make_layer(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm3d(out_channels),
nn.LeakyReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.InstanceNorm3d(out_channels),
nn.LeakyReLU(inplace=True)
)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.InstanceNorm3d):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
enc1_out = self.enc1(x)
p1 = self.pool1(enc1_out)
enc2_out = self.enc2(p1)
p2 = self.pool2(enc2_out)
enc3_out = self.enc3(p2)
up2 = self.upconv2(enc3_out)
diffY = enc2_out.size()[2] - up2.size()[2]
diffX = enc2_out.size()[3] - up2.size()[3]
diffZ = enc2_out.size()[4] - up2.size()[4]
up2 = F.pad(up2, [diffZ // 2, diffZ - diffZ // 2, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
concat2 = torch.cat([up2, enc2_out], dim=1)
dec2_out = self.dec2(concat2)
up1 = self.upconv1(dec2_out)
diffY = enc1_out.size()[2] - up1.size()[2]
diffX = enc1_out.size()[3] - up1.size()[3]
diffZ = enc1_out.size()[4] - up1.size()[4]
up1 = F.pad(up1, [diffZ // 2, diffZ - diffZ // 2, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
concat1 = torch.cat([up1, enc1_out], dim=1)
dec1_out = self.dec1(concat1)
x = self.dropout(dec1_out)
out = self.final_conv(x)
return out
# # --- Load Trained Model ---
# model_path = "/content/drive/MyDrive/Brats_2020_trained_model/best_brats_model_dice2020.pt"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = torch.load(model_path, map_location=device, weights_only=False)
# model = model.to(device)
# model.eval()
# --- Load Trained Model ---
# model_path = "/content/drive/MyDrive/BrainTumorBrats/final_trained_models/torch_save/brain_tumor_seg_final_dice_0.8814_epoch_99_complete_20250319_125101.pt"
model_path = "brain_tumor_seg_final_dice_0.8814_epoch_99_complete_20250319_125101.pt" # Change path if needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
print(f"\nLoading model from: {model_path}")
# Load the complete saved model data
model_data = torch.load(model_path, map_location=device)
# Initialize the model architecture
model = ImprovedUNet3D(in_channels=4, out_channels=2, base_filters=16)
# Load the state dict
model.load_state_dict(model_data['model_state_dict'])
model = model.to(device)
model.eval()
print("\nModel loaded successfully!")
print(f"Model Details:")
print(f"Architecture: {model_data['model_architecture']}")
print(f"Best Validation Dice: {model_data['final_metrics']['final_val_dice']:.4f}")
print(f"Final Epoch: {model_data['final_metrics']['final_epoch']}")
print(f"Device: {device}")
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
# --- Define Dataset ---
class BraTSDataset(Dataset):
def __init__(self, file_paths, slice_range=(60, 100)):
"""
file_paths: List of 4 file paths corresponding to the 4 MRI modalities
slice_range: Tuple of (start_slice, end_slice) to use
"""
assert len(file_paths) == 4, "Must provide exactly 4 file paths for the 4 MRI modalities"
self.file_paths = file_paths
self.slice_range = slice_range
def __len__(self):
return 1 # We're only processing one set of 4 modalities at a time
def __getitem__(self, idx):
# Load all 4 modalities
modalities = []
for file_path in self.file_paths:
nii_data = nib.load(file_path).get_fdata()
# Normalize each modality independently
nii_data = (nii_data - np.min(nii_data)) / (np.max(nii_data) - np.min(nii_data))
# Extract slices
nii_data = nii_data[:, :, self.slice_range[0]:self.slice_range[1]] # (H, W, D)
# Add to list
modalities.append(nii_data)
# Stack all modalities along the channel dimension
# Convert from (H, W, D) for each modality to (4, H, W, D)
combined_data = np.stack(modalities, axis=0)
# Convert to tensor
combined_tensor = torch.tensor(combined_data, dtype=torch.float32)
# Rearrange to (C, D, H, W) format
combined_tensor = combined_tensor.permute(0, 3, 1, 2) # (4, D, H, W)
return combined_tensor.to(device)
import io
from PIL import Image
def analyze_tumor_presence(pred_volume, threshold=0.5, min_volume_threshold=100):
"""
Analyze the prediction volume to determine tumor presence and characteristics
Returns a tuple of (is_tumor_present, confidence, volume)
"""
# Calculate tumor probability map
tumor_prob = pred_volume > threshold
# Calculate tumor volume and characteristics
tumor_volume = np.sum(tumor_prob)
max_probability = np.max(pred_volume)
mean_probability = np.mean(pred_volume[tumor_prob])
# Determine if tumor is present based on volume and probability
is_tumor_present = tumor_volume > min_volume_threshold and max_probability > threshold
# Calculate confidence level
if is_tumor_present:
confidence = (mean_probability + max_probability) / 2 * 100
else:
confidence = (1 - max_probability) * 100
return is_tumor_present, confidence, tumor_volume
def create_2d_visualization(pred_volume, original_scan):
"""
Create 2D visualization of tumor segmentation overlaid on MRI scan
"""
# Get middle slice
D, H, W = pred_volume.shape
mid_slice = D // 2
# Analyze tumor presence
is_tumor_present, confidence, tumor_volume = analyze_tumor_presence(pred_volume)
# Create figure with larger size
plt.figure(figsize=(12, 6))
# Add title with tumor analysis
if is_tumor_present:
plt.suptitle(f'TUMOR DETECTED\nConfidence: {confidence:.1f}%\nEstimated Volume: {tumor_volume} voxels',
color='red', fontsize=14, y=1.05)
else:
plt.suptitle(f'NO TUMOR DETECTED\nConfidence: {confidence:.1f}%',
color='green', fontsize=14, y=1.05)
# Create a three-panel view
plt.subplot(131)
plt.title('Original MRI Scan')
plt.imshow(original_scan[mid_slice], cmap='gray')
plt.axis('off')
plt.subplot(132)
plt.title('Tumor Segmentation')
tumor_mask = pred_volume[mid_slice] > 0.5
plt.imshow(tumor_mask, cmap='hot', alpha=0.7)
plt.axis('off')
plt.subplot(133)
plt.title('Overlay View')
plt.imshow(original_scan[mid_slice], cmap='gray')
plt.imshow(tumor_mask, cmap='plasma', alpha=0.5)
plt.colorbar(label='Tumor Probability')
plt.axis('off')
# Adjust layout
plt.tight_layout(rect=[0, 0, 1, 0.9])
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
buf.seek(0)
plt.close()
# Create text summary
summary = f"""
πŸ“Š Analysis Results ({datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC):
{'🚨 TUMOR DETECTED' if is_tumor_present else 'βœ… NO TUMOR DETECTED'}
Confidence Level: {confidence:.1f}%
{f'Estimated Tumor Volume: {tumor_volume} voxels' if is_tumor_present else ''}
"""
return Image.open(buf), summary
def format_size(size):
"""Format file size in human readable format"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size < 1024.0:
return f"{size:.1f} {unit}"
size /= 1024.0
return f"{size:.1f} GB"
# Get current UTC time using the correct method
def get_current_utc():
return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
def create_upload_progress(files, progress=gr.Progress()):
"""Create a progress bar for file upload and processing"""
progress_text = []
total_size = sum(os.path.getsize(f) for f in files)
# Get current time when processing starts
current_time = get_current_utc()
current_user = os.getenv('USER', 'gajendra82')
progress_text.append(f"πŸ“ Total files: {len(files)} (Total size: {format_size(total_size)})")
progress_text.append(f"⏰ Start Time: {current_time} UTC")
progress_text.append(f"πŸ‘€ User: {current_user}")
progress(0, desc="Starting upload...")
for i, file in enumerate(files, 1):
file_size = os.path.getsize(file)
progress_text.append(f"⏳ Processing file {i}/4: {os.path.basename(file)} ({format_size(file_size)})")
progress(i/4, desc=f"Processing file {i}/4...")
time.sleep(0.5) # Simulate processing time
return "\n".join(progress_text)
def format_file_info(files):
"""Format file information with size and timestamp"""
if not files or len(files) == 0:
return "No files uploaded"
info = []
info.append("πŸ“‚ Upload Summary:")
info.append(f"⏰ Time: 2025-03-16 08:37:30")
info.append(f"πŸ‘€ User: gajendra82")
info.append("\nFiles uploaded:")
for i, file in enumerate(files, 1):
size = os.path.getsize(file)
# Format size
for unit in ['B', 'KB', 'MB', 'GB']:
if size < 1024.0:
formatted_size = f"{size:.1f} {unit}"
break
size /= 1024.0
info.append(f"{i}. {os.path.basename(file)} ({formatted_size})")
return "\n".join(info)
def create_nifti_preview(file_paths):
"""Create a preview of the input NIfTI files"""
if not file_paths or len(file_paths) == 0:
return None
try:
# Create a figure with subplots for each modality
plt.figure(figsize=(16, 4))
modality_names = ['T1', 'T2', 'FLAIR', 'T1CE']
for idx, file_path in enumerate(file_paths):
# Load NIfTI file
nii_data = nib.load(file_path).get_fdata()
# Get middle slice
mid_slice = nii_data.shape[2] // 2
slice_data = nii_data[:, :, mid_slice]
# Normalize the data
slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data))
# Create subplot
plt.subplot(1, 4, idx + 1)
plt.title(f'{modality_names[idx]}')
plt.imshow(slice_data, cmap='gray')
plt.axis('off')
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
buf.seek(0)
plt.close()
return Image.open(buf)
except Exception as e:
logging.error(f"Error creating preview: {str(e)}")
return None
# Update the process_and_predict function
def process_and_predict(files):
"""
Process MRI files and predict tumor segmentation
"""
if not files or len(files) != 4:
return None, None, None, "⚠️ Please upload exactly 4 `.nii` files (T1, T2, FLAIR, T1CE)."
try:
file_paths = [f if isinstance(f, str) else f.name for f in files]
# Create preview of input files
preview_img = create_nifti_preview(file_paths)
# Create dataset and dataloader
dataset = BraTSDataset(file_paths, slice_range=(60, 100))
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
with torch.no_grad():
for batch in data_loader:
output = model(batch)
pred = torch.sigmoid(output)
pred_volume = pred[0, 0].cpu().numpy()
original_scan = batch[0, 0].cpu().numpy()
result_img, analysis_text = create_2d_visualization(pred_volume, original_scan)
# Add timestamp and user info to analysis text
analysis_info = f"""Current Date and Time (UTC): 2025-03-16 09:03:02
User: gajendra82
{analysis_text}"""
return (
preview_img,
result_img,
analysis_info,
"βœ… Analysis completed successfully!"
)
except Exception as e:
logging.error(f"Error during processing: {str(e)}")
return None, None, None, f"⚠️ Error during processing: {str(e)}"
# Update the Gradio interface
theme = gr.themes.Default()
with gr.Blocks(theme=theme) as demo:
gr.Markdown("""
# 🧠 Brain Tumor Segmentation and Analysis
Upload your MRI scans for automatic tumor detection and analysis.
""")
with gr.Row():
with gr.Column(scale=2):
file_input = gr.File(
file_types=[".nii"],
label="Upload 4 MRI Files (T1, T2, FLAIR, T1CE)",
type="filepath",
file_count="multiple",
elem_classes="file-upload"
)
submit_btn = gr.Button("Start Analysis", variant="primary")
with gr.Column(scale=1):
gr.Markdown("""
### Instructions:
1. Upload all 4 `.nii` files in order:
- T1-weighted MRI
- T2-weighted MRI
- FLAIR sequence
- T1CE (with contrast)
2. Click "Start Analysis"
3. Wait for processing to complete
""")
with gr.Row():
preview_output = gr.Image(label="Input MRI Scans Preview", type="pil")
with gr.Row():
result_output = gr.Image(label="Tumor Segmentation Results", type="pil")
with gr.Row():
analysis_output = gr.Textbox(label="Analysis Summary", lines=10, interactive=False)
# Updated CSS for better file upload experience
gr.Markdown("""
<style>
.file-upload {
border: 2px dashed #4CAF50;
border-radius: 10px;
padding: 20px;
text-align: center;
transition: all 0.3s ease;
position: relative;
min-height: 100px;
}
.file-upload:hover {
border-color: #2196F3;
background-color: #f0f8ff;
}
.file-upload::before {
content: 'πŸ“ Drop files here or click to upload';
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
color: #666;
pointer-events: none;
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
}
.file-upload.dragover {
border-color: #2196F3;
background-color: #e3f2fd;
}
.file-upload input[type="file"] {
opacity: 0;
width: 100%;
height: 100%;
position: absolute;
top: 0;
left: 0;
cursor: pointer;
}
.upload-progress {
position: absolute;
bottom: 0;
left: 0;
width: 0%;
height: 4px;
background: #2196F3;
transition: width 0.3s ease;
}
/* General font settings */
body, button, input, label, textarea {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important;
}
</style>
<script>
document.addEventListener('DOMContentLoaded', function() {
const uploadArea = document.querySelector('.file-upload');
if (uploadArea) {
// Add progress bar element
const progressBar = document.createElement('div');
progressBar.className = 'upload-progress';
uploadArea.appendChild(progressBar);
// Handle drag and drop visual feedback
['dragenter', 'dragover'].forEach(eventName => {
uploadArea.addEventListener(eventName, (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
});
['dragleave', 'drop'].forEach(eventName => {
uploadArea.addEventListener(eventName, (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
});
});
// Handle file upload progress
uploadArea.addEventListener('change', function(e) {
if (e.target.files.length > 0) {
let progress = 0;
progressBar.style.width = '0%';
const interval = setInterval(() => {
progress += 5;
progressBar.style.width = `${progress}%`;
if (progress >= 100) {
clearInterval(interval);
setTimeout(() => {
progressBar.style.width = '0%';
}, 500);
}
}, 50);
}
});
}
});
</script>
""")
# Analysis button click event
submit_btn.click(
fn=process_and_predict,
inputs=[file_input],
outputs=[
preview_output,
result_output,
analysis_output,
submit_btn
]
)
# Launch the interface
demo.launch(
debug=True,
share=True,
show_error=True,
height=1000,
server_port=7860,
server_name="0.0.0.0"
)