import gradio as gr import torch import numpy as np import nibabel as nib from skimage.transform import resize from monai.networks.nets import resnet18 import os import base64 # --- CONFIGURATION --- TARGET_SIZE = (64, 64, 64) # Match VesselMNIST3D training data NUM_CLASSES = 2 CLASS_LABELS = ["Normal", "Aneurysm"] # --- LOAD MODEL (Using MONAI) --- model = resnet18( spatial_dims=3, n_input_channels=1, num_classes=NUM_CLASSES ) # Load the weights try: model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False)) model.eval() print("Model loaded successfully from model.pth.") except FileNotFoundError: print("ERROR: model.pth not found. Please ensure the file is present.") # --- GRAD-CAM IMPLEMENTATION --- class GradCAM3D: """ Grad-CAM for 3D CNNs to generate attention heatmaps """ def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None # Register hooks target_layer.register_forward_hook(self.save_activation) target_layer.register_full_backward_hook(self.save_gradient) def save_activation(self, module, input, output): self.activations = output.detach() def save_gradient(self, module, grad_input, grad_output): self.gradients = grad_output[0].detach() def generate_cam(self, input_tensor, target_class=None): """ Generate Grad-CAM heatmap Args: input_tensor: Input tensor [1, 1, D, H, W] target_class: Target class index (if None, uses predicted class) Returns: cam: 3D heatmap normalized to [0, 1] """ # Forward pass self.model.zero_grad() output = self.model(input_tensor) # Get target class if target_class is None: target_class = output.argmax(dim=1).item() # Backward pass output[0, target_class].backward() # Get gradients and activations gradients = self.gradients # [1, C, D, H, W] activations = self.activations # [1, C, D, H, W] # Global average pooling of gradients (importance weights) weights = torch.mean(gradients, dim=(2, 3, 4), keepdim=True) # [1, C, 1, 1, 1] # Weighted combination of activation maps cam = torch.sum(weights * activations, dim=1, keepdim=True) # [1, 1, D, H, W] # Apply ReLU (only positive influence) cam = torch.relu(cam) # Normalize to [0, 1] cam = cam.squeeze().cpu().numpy() if cam.max() > 0: cam = (cam - cam.min()) / (cam.max() - cam.min()) return cam # Initialize Grad-CAM (target the last conv layer before pooling) # For MONAI ResNet18, the last conv block is layer4 grad_cam = GradCAM3D(model, target_layer=model.layer4[-1].conv2) # --- PREDICT FUNCTION --- def predict(file_obj): """ Processes a NIfTI file, runs inference, and returns: 1. Classification results 2. Original scan data (base64 encoded NIfTI) 3. Heatmap data (base64 encoded NIfTI) """ print(f"Received file: {file_obj}") print(f"File type: {type(file_obj)}") if hasattr(file_obj, '__dict__'): print(f"File attributes: {file_obj.__dict__}") if file_obj is None: print("ERROR: No file received") return { "predictions": [[label, 0.0] for label in CLASS_LABELS], "scan_data": None, "heatmap_data": None } # Handle different types of file_obj (could be string path or file object) if isinstance(file_obj, str): file_path = file_obj elif hasattr(file_obj, 'name'): file_path = file_obj.name else: print(f"ERROR: Unexpected file_obj type: {type(file_obj)}") return { "predictions": [[label, 0.0] for label in CLASS_LABELS], "scan_data": None, "heatmap_data": None } print(f"Extracted file path: {file_path}") print(f"File exists: {os.path.exists(file_path)}") # List directory contents to see what's actually there if not os.path.exists(file_path): dir_path = os.path.dirname(file_path) if os.path.exists(dir_path): print(f"Directory exists, contents: {os.listdir(dir_path)}") else: print(f"Directory doesn't exist: {dir_path}") # Validate file and add/fix extension try: # Check if file is actually gzipped by reading magic bytes with open(file_path, 'rb') as f: magic = f.read(2) is_gzipped = (magic == b'\x1f\x8b') print(f"File is gzipped: {is_gzipped}") lower_path = file_path.lower() has_nii_ext = lower_path.endswith('.nii') has_gz_ext = lower_path.endswith('.nii.gz') print(f"File extension check - ends with .nii: {has_nii_ext}, ends with .nii.gz: {has_gz_ext}") # Determine correct extension based on actual file content correct_ext = ".nii.gz" if is_gzipped else ".nii" # Case 1: File has .nii.gz but is not gzipped - fix it if has_gz_ext and not is_gzipped: new_path = file_path[:-3] # Remove .gz os.rename(file_path, new_path) file_path = new_path print(f"Fixed incorrect .gz extension, renamed to: {file_path}") # Case 2: File has .nii but is gzipped - fix it elif has_nii_ext and is_gzipped and not has_gz_ext: new_path = file_path + ".gz" os.rename(file_path, new_path) file_path = new_path print(f"Added missing .gz extension, renamed to: {file_path}") # Case 3: File has no NIfTI extension - add correct one elif not has_nii_ext and not has_gz_ext: new_path = file_path + correct_ext os.rename(file_path, new_path) file_path = new_path print(f"Added extension {correct_ext}, renamed to: {file_path}") else: print(f"File extension is correct: {file_path}") except Exception as e: print(f"ERROR: File validation failed: {e}") import traceback traceback.print_exc() return { "predictions": [[label, 0.0] for label in CLASS_LABELS], "scan_data": None, "heatmap_data": None } # Preprocessing and Inference try: # Check file header to see if it's actually a NIfTI file with open(file_path, 'rb') as f: header_bytes = f.read(4) print(f"File header (first 4 bytes): {header_bytes}") print(f"File header as hex: {header_bytes.hex()}") print(f"File size: {os.path.getsize(file_path)} bytes") # Load NIfTI data nifti_img = nib.load(file_path) img_data = nifti_img.get_fdata() original_shape = img_data.shape # Resize to target size (28x28x28) and Normalize img_data_resized = resize(img_data, TARGET_SIZE, mode='constant', anti_aliasing=True).astype(np.float32) if img_data_resized.max() > img_data_resized.min(): img_data_normalized = (img_data_resized - img_data_resized.min()) / (img_data_resized.max() - img_data_resized.min()) else: img_data_normalized = np.zeros(TARGET_SIZE, dtype=np.float32) # Convert to Tensor [Batch, Channel, Depth, Height, Width] inp = torch.from_numpy(img_data_normalized).float().unsqueeze(0).unsqueeze(0) print(f"Input tensor shape: {inp.shape}, min: {inp.min()}, max: {inp.max()}") # Run Inference with torch.no_grad(): outputs = model(inp) probs = torch.nn.functional.softmax(outputs, dim=1)[0] # Format predictions results = [] predicted_class = probs.argmax().item() for i, label in enumerate(CLASS_LABELS): results.append([label, float(probs[i])]) print(f"Prediction result: {results}") print(f"Predicted class: {predicted_class} ({CLASS_LABELS[predicted_class]})") # Generate Grad-CAM heatmap for the predicted class # Need to create a new tensor that requires gradients inp_grad = torch.from_numpy(img_data_normalized).float().unsqueeze(0).unsqueeze(0) inp_grad.requires_grad = True heatmap = grad_cam.generate_cam(inp_grad, target_class=predicted_class) # Resize heatmap back to original scan size heatmap_resized = resize(heatmap, original_shape, mode='constant', anti_aliasing=True).astype(np.float32) # Create NIfTI images for scan and heatmap scan_nifti = nib.Nifti1Image(img_data.astype(np.float32), affine=nifti_img.affine) heatmap_nifti = nib.Nifti1Image(heatmap_resized, affine=nifti_img.affine) # Encode to base64 for transmission import tempfile def nifti_to_base64(nifti_image): with tempfile.NamedTemporaryFile(suffix='.nii', delete=False) as tmp: nib.save(nifti_image, tmp.name) tmp.flush() with open(tmp.name, 'rb') as f: data = f.read() os.unlink(tmp.name) return base64.b64encode(data).decode('utf-8') scan_b64 = nifti_to_base64(scan_nifti) heatmap_b64 = nifti_to_base64(heatmap_nifti) return { "predictions": results, "scan_data": scan_b64, "heatmap_data": heatmap_b64, "predicted_class": CLASS_LABELS[predicted_class], "confidence": float(probs[predicted_class]) } except Exception as e: print(f"MODEL/NIBABEL ERROR: {str(e)}") import traceback traceback.print_exc() return { "predictions": [[label, 0.0] for label in CLASS_LABELS], "scan_data": None, "heatmap_data": None } # --- LAUNCH INTERFACE --- iface = gr.Interface( fn=predict, inputs=gr.File( label="Upload .nii or .nii.gz 3D Volume", type="filepath" ), outputs=gr.JSON(label="Analysis Results"), title="NeuroScan 3D Aneurysm Classifier with Visualization", description="Upload a 3D NIfTI (.nii or .nii.gz) volume for classification (Normal vs. Aneurysm) with attention visualization.", api_name="predict" ) iface.launch()