Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import nibabel as nib | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.utils.data import DataLoader, Dataset | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import tempfile | |
| import gradio as gr | |
| import os | |
| import nibabel as nib | |
| import numpy as np | |
| from torch.utils.data import DataLoader, Dataset | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import logging | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import tempfile | |
| import json | |
| import io | |
| from PIL import Image | |
| import time | |
| from tqdm import tqdm | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| # Define the model class (same as during training) | |
| class ImprovedUNet3D(nn.Module): | |
| def __init__(self, in_channels=4, out_channels=2, base_filters=16): | |
| super(ImprovedUNet3D, self).__init__() | |
| self.enc1 = self._make_layer(in_channels, base_filters) | |
| self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) | |
| self.enc2 = self._make_layer(base_filters, base_filters*2) | |
| self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) | |
| self.enc3 = self._make_layer(base_filters*2, base_filters*4) | |
| self.upconv2 = nn.ConvTranspose3d(base_filters*4, base_filters*2, kernel_size=2, stride=2) | |
| self.dec2 = self._make_layer(base_filters*4, base_filters*2) | |
| self.upconv1 = nn.ConvTranspose3d(base_filters*2, base_filters, kernel_size=2, stride=2) | |
| self.dec1 = self._make_layer(base_filters*2, base_filters) | |
| self.dropout = nn.Dropout3d(0.3) | |
| self.final_conv = nn.Conv3d(base_filters, out_channels, kernel_size=1) | |
| self._initialize_weights() | |
| def _make_layer(self, in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.InstanceNorm3d(out_channels), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.InstanceNorm3d(out_channels), | |
| nn.LeakyReLU(inplace=True) | |
| ) | |
| def _initialize_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv3d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.InstanceNorm3d): | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| enc1_out = self.enc1(x) | |
| p1 = self.pool1(enc1_out) | |
| enc2_out = self.enc2(p1) | |
| p2 = self.pool2(enc2_out) | |
| enc3_out = self.enc3(p2) | |
| up2 = self.upconv2(enc3_out) | |
| diffY = enc2_out.size()[2] - up2.size()[2] | |
| diffX = enc2_out.size()[3] - up2.size()[3] | |
| diffZ = enc2_out.size()[4] - up2.size()[4] | |
| up2 = F.pad(up2, [diffZ // 2, diffZ - diffZ // 2, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) | |
| concat2 = torch.cat([up2, enc2_out], dim=1) | |
| dec2_out = self.dec2(concat2) | |
| up1 = self.upconv1(dec2_out) | |
| diffY = enc1_out.size()[2] - up1.size()[2] | |
| diffX = enc1_out.size()[3] - up1.size()[3] | |
| diffZ = enc1_out.size()[4] - up1.size()[4] | |
| up1 = F.pad(up1, [diffZ // 2, diffZ - diffZ // 2, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) | |
| concat1 = torch.cat([up1, enc1_out], dim=1) | |
| dec1_out = self.dec1(concat1) | |
| x = self.dropout(dec1_out) | |
| out = self.final_conv(x) | |
| return out | |
| # # --- Load Trained Model --- | |
| # model_path = "/content/drive/MyDrive/Brats_2020_trained_model/best_brats_model_dice2020.pt" | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # model = torch.load(model_path, map_location=device, weights_only=False) | |
| # model = model.to(device) | |
| # model.eval() | |
| # --- Load Trained Model --- | |
| # model_path = "/content/drive/MyDrive/BrainTumorBrats/final_trained_models/torch_save/brain_tumor_seg_final_dice_0.8814_epoch_99_complete_20250319_125101.pt" | |
| model_path = "brain_tumor_seg_final_dice_0.8814_epoch_99_complete_20250319_125101.pt" # Change path if needed | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| try: | |
| print(f"\nLoading model from: {model_path}") | |
| # Load the complete saved model data | |
| model_data = torch.load(model_path, map_location=device) | |
| # Initialize the model architecture | |
| model = ImprovedUNet3D(in_channels=4, out_channels=2, base_filters=16) | |
| # Load the state dict | |
| model.load_state_dict(model_data['model_state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| print("\nModel loaded successfully!") | |
| print(f"Model Details:") | |
| print(f"Architecture: {model_data['model_architecture']}") | |
| print(f"Best Validation Dice: {model_data['final_metrics']['final_val_dice']:.4f}") | |
| print(f"Final Epoch: {model_data['final_metrics']['final_epoch']}") | |
| print(f"Device: {device}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # --- Define Dataset --- | |
| class BraTSDataset(Dataset): | |
| def __init__(self, file_paths, slice_range=(60, 100)): | |
| """ | |
| file_paths: List of 4 file paths corresponding to the 4 MRI modalities | |
| slice_range: Tuple of (start_slice, end_slice) to use | |
| """ | |
| assert len(file_paths) == 4, "Must provide exactly 4 file paths for the 4 MRI modalities" | |
| self.file_paths = file_paths | |
| self.slice_range = slice_range | |
| def __len__(self): | |
| return 1 # We're only processing one set of 4 modalities at a time | |
| def __getitem__(self, idx): | |
| # Load all 4 modalities | |
| modalities = [] | |
| for file_path in self.file_paths: | |
| nii_data = nib.load(file_path).get_fdata() | |
| # Normalize each modality independently | |
| nii_data = (nii_data - np.min(nii_data)) / (np.max(nii_data) - np.min(nii_data)) | |
| # Extract slices | |
| nii_data = nii_data[:, :, self.slice_range[0]:self.slice_range[1]] # (H, W, D) | |
| # Add to list | |
| modalities.append(nii_data) | |
| # Stack all modalities along the channel dimension | |
| # Convert from (H, W, D) for each modality to (4, H, W, D) | |
| combined_data = np.stack(modalities, axis=0) | |
| # Convert to tensor | |
| combined_tensor = torch.tensor(combined_data, dtype=torch.float32) | |
| # Rearrange to (C, D, H, W) format | |
| combined_tensor = combined_tensor.permute(0, 3, 1, 2) # (4, D, H, W) | |
| return combined_tensor.to(device) | |
| import io | |
| from PIL import Image | |
| def analyze_tumor_presence(pred_volume, threshold=0.5, min_volume_threshold=100): | |
| """ | |
| Analyze the prediction volume to determine tumor presence and characteristics | |
| Returns a tuple of (is_tumor_present, confidence, volume) | |
| """ | |
| # Calculate tumor probability map | |
| tumor_prob = pred_volume > threshold | |
| # Calculate tumor volume and characteristics | |
| tumor_volume = np.sum(tumor_prob) | |
| max_probability = np.max(pred_volume) | |
| mean_probability = np.mean(pred_volume[tumor_prob]) | |
| # Determine if tumor is present based on volume and probability | |
| is_tumor_present = tumor_volume > min_volume_threshold and max_probability > threshold | |
| # Calculate confidence level | |
| if is_tumor_present: | |
| confidence = (mean_probability + max_probability) / 2 * 100 | |
| else: | |
| confidence = (1 - max_probability) * 100 | |
| return is_tumor_present, confidence, tumor_volume | |
| def create_2d_visualization(pred_volume, original_scan): | |
| """ | |
| Create 2D visualization of tumor segmentation overlaid on MRI scan | |
| """ | |
| # Get middle slice | |
| D, H, W = pred_volume.shape | |
| mid_slice = D // 2 | |
| # Analyze tumor presence | |
| is_tumor_present, confidence, tumor_volume = analyze_tumor_presence(pred_volume) | |
| # Create figure with larger size | |
| plt.figure(figsize=(12, 6)) | |
| # Add title with tumor analysis | |
| if is_tumor_present: | |
| plt.suptitle(f'TUMOR DETECTED\nConfidence: {confidence:.1f}%\nEstimated Volume: {tumor_volume} voxels', | |
| color='red', fontsize=14, y=1.05) | |
| else: | |
| plt.suptitle(f'NO TUMOR DETECTED\nConfidence: {confidence:.1f}%', | |
| color='green', fontsize=14, y=1.05) | |
| # Create a three-panel view | |
| plt.subplot(131) | |
| plt.title('Original MRI Scan') | |
| plt.imshow(original_scan[mid_slice], cmap='gray') | |
| plt.axis('off') | |
| plt.subplot(132) | |
| plt.title('Tumor Segmentation') | |
| tumor_mask = pred_volume[mid_slice] > 0.5 | |
| plt.imshow(tumor_mask, cmap='hot', alpha=0.7) | |
| plt.axis('off') | |
| plt.subplot(133) | |
| plt.title('Overlay View') | |
| plt.imshow(original_scan[mid_slice], cmap='gray') | |
| plt.imshow(tumor_mask, cmap='plasma', alpha=0.5) | |
| plt.colorbar(label='Tumor Probability') | |
| plt.axis('off') | |
| # Adjust layout | |
| plt.tight_layout(rect=[0, 0, 1, 0.9]) | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) | |
| buf.seek(0) | |
| plt.close() | |
| # Create text summary | |
| summary = f""" | |
| π Analysis Results ({datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC): | |
| {'π¨ TUMOR DETECTED' if is_tumor_present else 'β NO TUMOR DETECTED'} | |
| Confidence Level: {confidence:.1f}% | |
| {f'Estimated Tumor Volume: {tumor_volume} voxels' if is_tumor_present else ''} | |
| """ | |
| return Image.open(buf), summary | |
| def format_size(size): | |
| """Format file size in human readable format""" | |
| for unit in ['B', 'KB', 'MB', 'GB']: | |
| if size < 1024.0: | |
| return f"{size:.1f} {unit}" | |
| size /= 1024.0 | |
| return f"{size:.1f} GB" | |
| # Get current UTC time using the correct method | |
| def get_current_utc(): | |
| return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") | |
| def create_upload_progress(files, progress=gr.Progress()): | |
| """Create a progress bar for file upload and processing""" | |
| progress_text = [] | |
| total_size = sum(os.path.getsize(f) for f in files) | |
| # Get current time when processing starts | |
| current_time = get_current_utc() | |
| current_user = os.getenv('USER', 'gajendra82') | |
| progress_text.append(f"π Total files: {len(files)} (Total size: {format_size(total_size)})") | |
| progress_text.append(f"β° Start Time: {current_time} UTC") | |
| progress_text.append(f"π€ User: {current_user}") | |
| progress(0, desc="Starting upload...") | |
| for i, file in enumerate(files, 1): | |
| file_size = os.path.getsize(file) | |
| progress_text.append(f"β³ Processing file {i}/4: {os.path.basename(file)} ({format_size(file_size)})") | |
| progress(i/4, desc=f"Processing file {i}/4...") | |
| time.sleep(0.5) # Simulate processing time | |
| return "\n".join(progress_text) | |
| def format_file_info(files): | |
| """Format file information with size and timestamp""" | |
| if not files or len(files) == 0: | |
| return "No files uploaded" | |
| info = [] | |
| info.append("π Upload Summary:") | |
| info.append(f"β° Time: 2025-03-16 08:37:30") | |
| info.append(f"π€ User: gajendra82") | |
| info.append("\nFiles uploaded:") | |
| for i, file in enumerate(files, 1): | |
| size = os.path.getsize(file) | |
| # Format size | |
| for unit in ['B', 'KB', 'MB', 'GB']: | |
| if size < 1024.0: | |
| formatted_size = f"{size:.1f} {unit}" | |
| break | |
| size /= 1024.0 | |
| info.append(f"{i}. {os.path.basename(file)} ({formatted_size})") | |
| return "\n".join(info) | |
| def create_nifti_preview(file_paths): | |
| """Create a preview of the input NIfTI files""" | |
| if not file_paths or len(file_paths) == 0: | |
| return None | |
| try: | |
| # Create a figure with subplots for each modality | |
| plt.figure(figsize=(16, 4)) | |
| modality_names = ['T1', 'T2', 'FLAIR', 'T1CE'] | |
| for idx, file_path in enumerate(file_paths): | |
| # Load NIfTI file | |
| nii_data = nib.load(file_path).get_fdata() | |
| # Get middle slice | |
| mid_slice = nii_data.shape[2] // 2 | |
| slice_data = nii_data[:, :, mid_slice] | |
| # Normalize the data | |
| slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) | |
| # Create subplot | |
| plt.subplot(1, 4, idx + 1) | |
| plt.title(f'{modality_names[idx]}') | |
| plt.imshow(slice_data, cmap='gray') | |
| plt.axis('off') | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| except Exception as e: | |
| logging.error(f"Error creating preview: {str(e)}") | |
| return None | |
| # Update the process_and_predict function | |
| def process_and_predict(files): | |
| """ | |
| Process MRI files and predict tumor segmentation | |
| """ | |
| if not files or len(files) != 4: | |
| return None, None, None, "β οΈ Please upload exactly 4 `.nii` files (T1, T2, FLAIR, T1CE)." | |
| try: | |
| file_paths = [f if isinstance(f, str) else f.name for f in files] | |
| # Create preview of input files | |
| preview_img = create_nifti_preview(file_paths) | |
| # Create dataset and dataloader | |
| dataset = BraTSDataset(file_paths, slice_range=(60, 100)) | |
| data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
| with torch.no_grad(): | |
| for batch in data_loader: | |
| output = model(batch) | |
| pred = torch.sigmoid(output) | |
| pred_volume = pred[0, 0].cpu().numpy() | |
| original_scan = batch[0, 0].cpu().numpy() | |
| result_img, analysis_text = create_2d_visualization(pred_volume, original_scan) | |
| # Add timestamp and user info to analysis text | |
| analysis_info = f"""Current Date and Time (UTC): 2025-03-16 09:03:02 | |
| User: gajendra82 | |
| {analysis_text}""" | |
| return ( | |
| preview_img, | |
| result_img, | |
| analysis_info, | |
| "β Analysis completed successfully!" | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error during processing: {str(e)}") | |
| return None, None, None, f"β οΈ Error during processing: {str(e)}" | |
| # Update the Gradio interface | |
| theme = gr.themes.Default() | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown(""" | |
| # π§ Brain Tumor Segmentation and Analysis | |
| Upload your MRI scans for automatic tumor detection and analysis. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_input = gr.File( | |
| file_types=[".nii"], | |
| label="Upload 4 MRI Files (T1, T2, FLAIR, T1CE)", | |
| type="filepath", | |
| file_count="multiple", | |
| elem_classes="file-upload" | |
| ) | |
| submit_btn = gr.Button("Start Analysis", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### Instructions: | |
| 1. Upload all 4 `.nii` files in order: | |
| - T1-weighted MRI | |
| - T2-weighted MRI | |
| - FLAIR sequence | |
| - T1CE (with contrast) | |
| 2. Click "Start Analysis" | |
| 3. Wait for processing to complete | |
| """) | |
| with gr.Row(): | |
| preview_output = gr.Image(label="Input MRI Scans Preview", type="pil") | |
| with gr.Row(): | |
| result_output = gr.Image(label="Tumor Segmentation Results", type="pil") | |
| with gr.Row(): | |
| analysis_output = gr.Textbox(label="Analysis Summary", lines=10, interactive=False) | |
| # Updated CSS for better file upload experience | |
| gr.Markdown(""" | |
| <style> | |
| .file-upload { | |
| border: 2px dashed #4CAF50; | |
| border-radius: 10px; | |
| padding: 20px; | |
| text-align: center; | |
| transition: all 0.3s ease; | |
| position: relative; | |
| min-height: 100px; | |
| } | |
| .file-upload:hover { | |
| border-color: #2196F3; | |
| background-color: #f0f8ff; | |
| } | |
| .file-upload::before { | |
| content: 'π Drop files here or click to upload'; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| color: #666; | |
| pointer-events: none; | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| } | |
| .file-upload.dragover { | |
| border-color: #2196F3; | |
| background-color: #e3f2fd; | |
| } | |
| .file-upload input[type="file"] { | |
| opacity: 0; | |
| width: 100%; | |
| height: 100%; | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| cursor: pointer; | |
| } | |
| .upload-progress { | |
| position: absolute; | |
| bottom: 0; | |
| left: 0; | |
| width: 0%; | |
| height: 4px; | |
| background: #2196F3; | |
| transition: width 0.3s ease; | |
| } | |
| /* General font settings */ | |
| body, button, input, label, textarea { | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important; | |
| } | |
| </style> | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function() { | |
| const uploadArea = document.querySelector('.file-upload'); | |
| if (uploadArea) { | |
| // Add progress bar element | |
| const progressBar = document.createElement('div'); | |
| progressBar.className = 'upload-progress'; | |
| uploadArea.appendChild(progressBar); | |
| // Handle drag and drop visual feedback | |
| ['dragenter', 'dragover'].forEach(eventName => { | |
| uploadArea.addEventListener(eventName, (e) => { | |
| e.preventDefault(); | |
| uploadArea.classList.add('dragover'); | |
| }); | |
| }); | |
| ['dragleave', 'drop'].forEach(eventName => { | |
| uploadArea.addEventListener(eventName, (e) => { | |
| e.preventDefault(); | |
| uploadArea.classList.remove('dragover'); | |
| }); | |
| }); | |
| // Handle file upload progress | |
| uploadArea.addEventListener('change', function(e) { | |
| if (e.target.files.length > 0) { | |
| let progress = 0; | |
| progressBar.style.width = '0%'; | |
| const interval = setInterval(() => { | |
| progress += 5; | |
| progressBar.style.width = `${progress}%`; | |
| if (progress >= 100) { | |
| clearInterval(interval); | |
| setTimeout(() => { | |
| progressBar.style.width = '0%'; | |
| }, 500); | |
| } | |
| }, 50); | |
| } | |
| }); | |
| } | |
| }); | |
| </script> | |
| """) | |
| # Analysis button click event | |
| submit_btn.click( | |
| fn=process_and_predict, | |
| inputs=[file_input], | |
| outputs=[ | |
| preview_output, | |
| result_output, | |
| analysis_output, | |
| submit_btn | |
| ] | |
| ) | |
| # Launch the interface | |
| demo.launch( | |
| debug=True, | |
| share=True, | |
| show_error=True, | |
| height=1000, | |
| server_port=7860, | |
| server_name="0.0.0.0" | |
| ) |