import gradio as gr import os import nibabel as nib import torch import torch.nn.functional as F import numpy as np from torch.utils.data import DataLoader, Dataset import matplotlib.pyplot as plt import cv2 import tempfile import gradio as gr import os import nibabel as nib import numpy as np from torch.utils.data import DataLoader, Dataset import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import logging import plotly.graph_objects as go from plotly.subplots import make_subplots import tempfile import json import io from PIL import Image import time from tqdm import tqdm from datetime import datetime, timezone from pathlib import Path # Define the model class (same as during training) class ImprovedUNet3D(nn.Module): def __init__(self, in_channels=4, out_channels=2, base_filters=16): super(ImprovedUNet3D, self).__init__() self.enc1 = self._make_layer(in_channels, base_filters) self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.enc2 = self._make_layer(base_filters, base_filters*2) self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.enc3 = self._make_layer(base_filters*2, base_filters*4) self.upconv2 = nn.ConvTranspose3d(base_filters*4, base_filters*2, kernel_size=2, stride=2) self.dec2 = self._make_layer(base_filters*4, base_filters*2) self.upconv1 = nn.ConvTranspose3d(base_filters*2, base_filters, kernel_size=2, stride=2) self.dec1 = self._make_layer(base_filters*2, base_filters) self.dropout = nn.Dropout3d(0.3) self.final_conv = nn.Conv3d(base_filters, out_channels, kernel_size=1) self._initialize_weights() def _make_layer(self, in_channels, out_channels): return nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), nn.InstanceNorm3d(out_channels), nn.LeakyReLU(inplace=True), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), nn.InstanceNorm3d(out_channels), nn.LeakyReLU(inplace=True) ) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.InstanceNorm3d): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): enc1_out = self.enc1(x) p1 = self.pool1(enc1_out) enc2_out = self.enc2(p1) p2 = self.pool2(enc2_out) enc3_out = self.enc3(p2) up2 = self.upconv2(enc3_out) diffY = enc2_out.size()[2] - up2.size()[2] diffX = enc2_out.size()[3] - up2.size()[3] diffZ = enc2_out.size()[4] - up2.size()[4] up2 = F.pad(up2, [diffZ // 2, diffZ - diffZ // 2, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) concat2 = torch.cat([up2, enc2_out], dim=1) dec2_out = self.dec2(concat2) up1 = self.upconv1(dec2_out) diffY = enc1_out.size()[2] - up1.size()[2] diffX = enc1_out.size()[3] - up1.size()[3] diffZ = enc1_out.size()[4] - up1.size()[4] up1 = F.pad(up1, [diffZ // 2, diffZ - diffZ // 2, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) concat1 = torch.cat([up1, enc1_out], dim=1) dec1_out = self.dec1(concat1) x = self.dropout(dec1_out) out = self.final_conv(x) return out # # --- Load Trained Model --- # model_path = "/content/drive/MyDrive/Brats_2020_trained_model/best_brats_model_dice2020.pt" # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = torch.load(model_path, map_location=device, weights_only=False) # model = model.to(device) # model.eval() # --- Load Trained Model --- # model_path = "/content/drive/MyDrive/BrainTumorBrats/final_trained_models/torch_save/brain_tumor_seg_final_dice_0.8814_epoch_99_complete_20250319_125101.pt" model_path = "brain_tumor_seg_final_dice_0.8814_epoch_99_complete_20250319_125101.pt" # Change path if needed device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: print(f"\nLoading model from: {model_path}") # Load the complete saved model data model_data = torch.load(model_path, map_location=device) # Initialize the model architecture model = ImprovedUNet3D(in_channels=4, out_channels=2, base_filters=16) # Load the state dict model.load_state_dict(model_data['model_state_dict']) model = model.to(device) model.eval() print("\nModel loaded successfully!") print(f"Model Details:") print(f"Architecture: {model_data['model_architecture']}") print(f"Best Validation Dice: {model_data['final_metrics']['final_val_dice']:.4f}") print(f"Final Epoch: {model_data['final_metrics']['final_epoch']}") print(f"Device: {device}") except Exception as e: print(f"Error loading model: {e}") import traceback traceback.print_exc() # --- Define Dataset --- class BraTSDataset(Dataset): def __init__(self, file_paths, slice_range=(60, 100)): """ file_paths: List of 4 file paths corresponding to the 4 MRI modalities slice_range: Tuple of (start_slice, end_slice) to use """ assert len(file_paths) == 4, "Must provide exactly 4 file paths for the 4 MRI modalities" self.file_paths = file_paths self.slice_range = slice_range def __len__(self): return 1 # We're only processing one set of 4 modalities at a time def __getitem__(self, idx): # Load all 4 modalities modalities = [] for file_path in self.file_paths: nii_data = nib.load(file_path).get_fdata() # Normalize each modality independently nii_data = (nii_data - np.min(nii_data)) / (np.max(nii_data) - np.min(nii_data)) # Extract slices nii_data = nii_data[:, :, self.slice_range[0]:self.slice_range[1]] # (H, W, D) # Add to list modalities.append(nii_data) # Stack all modalities along the channel dimension # Convert from (H, W, D) for each modality to (4, H, W, D) combined_data = np.stack(modalities, axis=0) # Convert to tensor combined_tensor = torch.tensor(combined_data, dtype=torch.float32) # Rearrange to (C, D, H, W) format combined_tensor = combined_tensor.permute(0, 3, 1, 2) # (4, D, H, W) return combined_tensor.to(device) import io from PIL import Image def analyze_tumor_presence(pred_volume, threshold=0.5, min_volume_threshold=100): """ Analyze the prediction volume to determine tumor presence and characteristics Returns a tuple of (is_tumor_present, confidence, volume) """ # Calculate tumor probability map tumor_prob = pred_volume > threshold # Calculate tumor volume and characteristics tumor_volume = np.sum(tumor_prob) max_probability = np.max(pred_volume) mean_probability = np.mean(pred_volume[tumor_prob]) # Determine if tumor is present based on volume and probability is_tumor_present = tumor_volume > min_volume_threshold and max_probability > threshold # Calculate confidence level if is_tumor_present: confidence = (mean_probability + max_probability) / 2 * 100 else: confidence = (1 - max_probability) * 100 return is_tumor_present, confidence, tumor_volume def create_2d_visualization(pred_volume, original_scan): """ Create 2D visualization of tumor segmentation overlaid on MRI scan """ # Get middle slice D, H, W = pred_volume.shape mid_slice = D // 2 # Analyze tumor presence is_tumor_present, confidence, tumor_volume = analyze_tumor_presence(pred_volume) # Create figure with larger size plt.figure(figsize=(12, 6)) # Add title with tumor analysis if is_tumor_present: plt.suptitle(f'TUMOR DETECTED\nConfidence: {confidence:.1f}%\nEstimated Volume: {tumor_volume} voxels', color='red', fontsize=14, y=1.05) else: plt.suptitle(f'NO TUMOR DETECTED\nConfidence: {confidence:.1f}%', color='green', fontsize=14, y=1.05) # Create a three-panel view plt.subplot(131) plt.title('Original MRI Scan') plt.imshow(original_scan[mid_slice], cmap='gray') plt.axis('off') plt.subplot(132) plt.title('Tumor Segmentation') tumor_mask = pred_volume[mid_slice] > 0.5 plt.imshow(tumor_mask, cmap='hot', alpha=0.7) plt.axis('off') plt.subplot(133) plt.title('Overlay View') plt.imshow(original_scan[mid_slice], cmap='gray') plt.imshow(tumor_mask, cmap='plasma', alpha=0.5) plt.colorbar(label='Tumor Probability') plt.axis('off') # Adjust layout plt.tight_layout(rect=[0, 0, 1, 0.9]) # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) buf.seek(0) plt.close() # Create text summary summary = f""" 📊 Analysis Results ({datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC): {'🚨 TUMOR DETECTED' if is_tumor_present else '✅ NO TUMOR DETECTED'} Confidence Level: {confidence:.1f}% {f'Estimated Tumor Volume: {tumor_volume} voxels' if is_tumor_present else ''} """ return Image.open(buf), summary def format_size(size): """Format file size in human readable format""" for unit in ['B', 'KB', 'MB', 'GB']: if size < 1024.0: return f"{size:.1f} {unit}" size /= 1024.0 return f"{size:.1f} GB" # Get current UTC time using the correct method def get_current_utc(): return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") def create_upload_progress(files, progress=gr.Progress()): """Create a progress bar for file upload and processing""" progress_text = [] total_size = sum(os.path.getsize(f) for f in files) # Get current time when processing starts current_time = get_current_utc() current_user = os.getenv('USER', 'gajendra82') progress_text.append(f"📁 Total files: {len(files)} (Total size: {format_size(total_size)})") progress_text.append(f"⏰ Start Time: {current_time} UTC") progress_text.append(f"👤 User: {current_user}") progress(0, desc="Starting upload...") for i, file in enumerate(files, 1): file_size = os.path.getsize(file) progress_text.append(f"⏳ Processing file {i}/4: {os.path.basename(file)} ({format_size(file_size)})") progress(i/4, desc=f"Processing file {i}/4...") time.sleep(0.5) # Simulate processing time return "\n".join(progress_text) def format_file_info(files): """Format file information with size and timestamp""" if not files or len(files) == 0: return "No files uploaded" info = [] info.append("📂 Upload Summary:") info.append(f"⏰ Time: 2025-03-16 08:37:30") info.append(f"👤 User: gajendra82") info.append("\nFiles uploaded:") for i, file in enumerate(files, 1): size = os.path.getsize(file) # Format size for unit in ['B', 'KB', 'MB', 'GB']: if size < 1024.0: formatted_size = f"{size:.1f} {unit}" break size /= 1024.0 info.append(f"{i}. {os.path.basename(file)} ({formatted_size})") return "\n".join(info) def create_nifti_preview(file_paths): """Create a preview of the input NIfTI files""" if not file_paths or len(file_paths) == 0: return None try: # Create a figure with subplots for each modality plt.figure(figsize=(16, 4)) modality_names = ['T1', 'T2', 'FLAIR', 'T1CE'] for idx, file_path in enumerate(file_paths): # Load NIfTI file nii_data = nib.load(file_path).get_fdata() # Get middle slice mid_slice = nii_data.shape[2] // 2 slice_data = nii_data[:, :, mid_slice] # Normalize the data slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) # Create subplot plt.subplot(1, 4, idx + 1) plt.title(f'{modality_names[idx]}') plt.imshow(slice_data, cmap='gray') plt.axis('off') plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) buf.seek(0) plt.close() return Image.open(buf) except Exception as e: logging.error(f"Error creating preview: {str(e)}") return None # Update the process_and_predict function def process_and_predict(files): """ Process MRI files and predict tumor segmentation """ if not files or len(files) != 4: return None, None, None, "⚠️ Please upload exactly 4 `.nii` files (T1, T2, FLAIR, T1CE)." try: file_paths = [f if isinstance(f, str) else f.name for f in files] # Create preview of input files preview_img = create_nifti_preview(file_paths) # Create dataset and dataloader dataset = BraTSDataset(file_paths, slice_range=(60, 100)) data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) with torch.no_grad(): for batch in data_loader: output = model(batch) pred = torch.sigmoid(output) pred_volume = pred[0, 0].cpu().numpy() original_scan = batch[0, 0].cpu().numpy() result_img, analysis_text = create_2d_visualization(pred_volume, original_scan) # Add timestamp and user info to analysis text analysis_info = f"""Current Date and Time (UTC): 2025-03-16 09:03:02 User: gajendra82 {analysis_text}""" return ( preview_img, result_img, analysis_info, "✅ Analysis completed successfully!" ) except Exception as e: logging.error(f"Error during processing: {str(e)}") return None, None, None, f"⚠️ Error during processing: {str(e)}" # Update the Gradio interface theme = gr.themes.Default() with gr.Blocks(theme=theme) as demo: gr.Markdown(""" # 🧠 Brain Tumor Segmentation and Analysis Upload your MRI scans for automatic tumor detection and analysis. """) with gr.Row(): with gr.Column(scale=2): file_input = gr.File( file_types=[".nii"], label="Upload 4 MRI Files (T1, T2, FLAIR, T1CE)", type="filepath", file_count="multiple", elem_classes="file-upload" ) submit_btn = gr.Button("Start Analysis", variant="primary") with gr.Column(scale=1): gr.Markdown(""" ### Instructions: 1. Upload all 4 `.nii` files in order: - T1-weighted MRI - T2-weighted MRI - FLAIR sequence - T1CE (with contrast) 2. Click "Start Analysis" 3. Wait for processing to complete """) with gr.Row(): preview_output = gr.Image(label="Input MRI Scans Preview", type="pil") with gr.Row(): result_output = gr.Image(label="Tumor Segmentation Results", type="pil") with gr.Row(): analysis_output = gr.Textbox(label="Analysis Summary", lines=10, interactive=False) # Updated CSS for better file upload experience gr.Markdown(""" """) # Analysis button click event submit_btn.click( fn=process_and_predict, inputs=[file_input], outputs=[ preview_output, result_output, analysis_output, submit_btn ] ) # Launch the interface demo.launch( debug=True, share=True, show_error=True, height=1000, server_port=7860, server_name="0.0.0.0" )