import os import io import tempfile import numpy as np import torch import gradio as gr import pydicom import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # Use Agg backend for non-interactive environments from monai.networks.nets import DenseNet121 from monai.transforms import ( Compose, ScaleIntensity, Resize, EnsureChannelFirst ) # Ensure PyTorch is available print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") class DicomInterpreter: def __init__(self): self.model = None self.device = device self.initialize_model() def initialize_model(self): """Initialize a pretrained model for classification""" try: # For simplicity, using a pretrained DenseNet121 # In production, you'd use a model trained on medical data self.model = DenseNet121( spatial_dims=2, in_channels=1, out_channels=2, # Binary classification for demo ).to(self.device) # Put model in eval mode self.model.eval() print("Model initialized successfully") except Exception as e: print(f"Model initialization error: {str(e)}") self.model = None @torch.no_grad() def analyze_dicom(self, img_array): """Process a DICOM pixel array and return predictions""" try: # Preprocessing img_tensor = torch.from_numpy(img_array).float() # Ensure 3D: [channel, height, width] if img_tensor.ndim == 2: img_tensor = img_tensor.unsqueeze(0) # Normalize img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6) # Resize if img_tensor.shape[1:] != (224, 224): img_tensor = torch.nn.functional.interpolate( img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False ).squeeze(0) # Make prediction img_tensor = img_tensor.to(self.device) output = self.model(img_tensor.unsqueeze(0)) probabilities = torch.nn.functional.softmax(output, dim=1) # Example interpretation class_names = ["Normal", "Abnormal"] # Example class names interpretation = { class_name: float(prob) for class_name, prob in zip(class_names, probabilities[0].cpu().numpy()) } return interpretation except Exception as e: print(f"Analysis error: {str(e)}") return {"Error": 1.0} def generate_heatmap(self, img_array): """Generate a synthetic attention heatmap""" try: # Normalize and resize the image img_tensor = torch.from_numpy(img_array).float() if img_tensor.ndim == 2: img_tensor = img_tensor.unsqueeze(0) img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6) if img_tensor.shape[1:] != (224, 224): img_tensor = torch.nn.functional.interpolate( img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False ).squeeze(0) # Create visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) # Original image ax1.imshow(img_tensor[0].numpy(), cmap='gray') ax1.set_title('Original Image') ax1.axis('off') # Create a synthetic heatmap (random for demo) # In production, use actual attention maps from the model heatmap = np.random.rand(224, 224) # Heatmap overlay ax2.imshow(img_tensor[0].numpy(), cmap='gray') ax2.imshow(heatmap, cmap='jet', alpha=0.5) ax2.set_title('AI Attention Map') ax2.axis('off') plt.tight_layout() # Convert matplotlib figure to image buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) plt.close(fig) return buf except Exception as e: print(f"Heatmap generation error: {str(e)}") # Create a simple error image fig, ax = plt.subplots(figsize=(12, 6)) ax.text(0.5, 0.5, f"Error generating heatmap: {str(e)}", horizontalalignment='center', verticalalignment='center') ax.axis('off') buf = io.BytesIO() fig.savefig(buf, format='png') buf.seek(0) plt.close(fig) return buf # Initialize the DICOM interpreter interpreter = DicomInterpreter() def read_dicom(dicom_file): """Read a DICOM file and return pixel array and metadata""" try: ds = pydicom.dcmread(dicom_file.name) img = ds.pixel_array # Extract metadata metadata = { "PatientID": str(getattr(ds, "PatientID", "N/A")), "Modality": str(getattr(ds, "Modality", "N/A")), "StudyDescription": str(getattr(ds, "StudyDescription", "N/A")), "SeriesDescription": str(getattr(ds, "SeriesDescription", "N/A")), "Dimensions": f"{img.shape[0]} x {img.shape[1]}", "Manufacturer": str(getattr(ds, "Manufacturer", "N/A")), "Filename": os.path.basename(dicom_file.name) } return img, metadata, None except Exception as e: error_msg = f"Error reading DICOM file: {str(e)}" print(error_msg) return None, None, error_msg def process_dicom_files(dicom_files): """Process multiple DICOM files and return results""" if not dicom_files: return None, "No files uploaded" results = [] all_results_html = "" for i, dicom_file in enumerate(dicom_files): try: # Read DICOM img, metadata, error = read_dicom(dicom_file) if error: results.append({ "filename": os.path.basename(dicom_file.name), "error": error, "display_img": None, "heatmap_img": None, "metadata": None, "interpretation": None }) continue # Normalize for display display_img = (img - img.min()) / (img.max() - img.min() + 1e-6) # Run AI analysis interpretation = interpreter.analyze_dicom(img) # Generate heatmap visualization heatmap_buf = interpreter.generate_heatmap(img) # Create a figure with both images for this file fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) ax1.imshow(display_img, cmap='gray') ax1.set_title(f"DICOM Image: {os.path.basename(dicom_file.name)}") ax1.axis('off') # For the heatmap, read the buffer and display it heatmap_buf.seek(0) heatmap_img = plt.imread(heatmap_buf) ax2.imshow(heatmap_img) ax2.set_title("AI Attention Map") ax2.axis('off') plt.tight_layout() # Save the combined result result_buf = io.BytesIO() fig.savefig(result_buf, format='png') result_buf.seek(0) plt.close(fig) # Build HTML for this result file_html = f"""
| {key} | {value} |
| {label} | {prob*100:.2f}% |
Note: This is a demonstration using a general model.
{error_msg}
Processed {len(results)} files successfully. Click on individual results below for details.