Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import pydicom | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Use Agg backend for non-interactive environments | |
| from monai.networks.nets import DenseNet121 | |
| from monai.transforms import ( | |
| Compose, | |
| ScaleIntensity, | |
| Resize, | |
| EnsureChannelFirst | |
| ) | |
| # Ensure PyTorch is available | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| class DicomInterpreter: | |
| def __init__(self): | |
| self.model = None | |
| self.device = device | |
| self.initialize_model() | |
| def initialize_model(self): | |
| """Initialize a pretrained model for classification""" | |
| try: | |
| # For simplicity, using a pretrained DenseNet121 | |
| # In production, you'd use a model trained on medical data | |
| self.model = DenseNet121( | |
| spatial_dims=2, | |
| in_channels=1, | |
| out_channels=2, # Binary classification for demo | |
| ).to(self.device) | |
| # Put model in eval mode | |
| self.model.eval() | |
| print("Model initialized successfully") | |
| except Exception as e: | |
| print(f"Model initialization error: {str(e)}") | |
| self.model = None | |
| def analyze_dicom(self, img_array): | |
| """Process a DICOM pixel array and return predictions""" | |
| try: | |
| # Preprocessing | |
| img_tensor = torch.from_numpy(img_array).float() | |
| # Ensure 3D: [channel, height, width] | |
| if img_tensor.ndim == 2: | |
| img_tensor = img_tensor.unsqueeze(0) | |
| # Normalize | |
| img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6) | |
| # Resize | |
| if img_tensor.shape[1:] != (224, 224): | |
| img_tensor = torch.nn.functional.interpolate( | |
| img_tensor.unsqueeze(0), | |
| size=(224, 224), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0) | |
| # Make prediction | |
| img_tensor = img_tensor.to(self.device) | |
| output = self.model(img_tensor.unsqueeze(0)) | |
| probabilities = torch.nn.functional.softmax(output, dim=1) | |
| # Example interpretation | |
| class_names = ["Normal", "Abnormal"] # Example class names | |
| interpretation = { | |
| class_name: float(prob) | |
| for class_name, prob in zip(class_names, probabilities[0].cpu().numpy()) | |
| } | |
| return interpretation | |
| except Exception as e: | |
| print(f"Analysis error: {str(e)}") | |
| return {"Error": 1.0} | |
| def generate_heatmap(self, img_array): | |
| """Generate a synthetic attention heatmap""" | |
| try: | |
| # Normalize and resize the image | |
| img_tensor = torch.from_numpy(img_array).float() | |
| if img_tensor.ndim == 2: | |
| img_tensor = img_tensor.unsqueeze(0) | |
| img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6) | |
| if img_tensor.shape[1:] != (224, 224): | |
| img_tensor = torch.nn.functional.interpolate( | |
| img_tensor.unsqueeze(0), | |
| size=(224, 224), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0) | |
| # Create visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) | |
| # Original image | |
| ax1.imshow(img_tensor[0].numpy(), cmap='gray') | |
| ax1.set_title('Original Image') | |
| ax1.axis('off') | |
| # Create a synthetic heatmap (random for demo) | |
| # In production, use actual attention maps from the model | |
| heatmap = np.random.rand(224, 224) | |
| # Heatmap overlay | |
| ax2.imshow(img_tensor[0].numpy(), cmap='gray') | |
| ax2.imshow(heatmap, cmap='jet', alpha=0.5) | |
| ax2.set_title('AI Attention Map') | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| # Convert matplotlib figure to image | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close(fig) | |
| return buf | |
| except Exception as e: | |
| print(f"Heatmap generation error: {str(e)}") | |
| # Create a simple error image | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.text(0.5, 0.5, f"Error generating heatmap: {str(e)}", | |
| horizontalalignment='center', verticalalignment='center') | |
| ax.axis('off') | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close(fig) | |
| return buf | |
| # Initialize the DICOM interpreter | |
| interpreter = DicomInterpreter() | |
| def read_dicom(dicom_file): | |
| """Read a DICOM file and return pixel array and metadata""" | |
| try: | |
| ds = pydicom.dcmread(dicom_file.name) | |
| img = ds.pixel_array | |
| # Extract metadata | |
| metadata = { | |
| "PatientID": str(getattr(ds, "PatientID", "N/A")), | |
| "Modality": str(getattr(ds, "Modality", "N/A")), | |
| "StudyDescription": str(getattr(ds, "StudyDescription", "N/A")), | |
| "SeriesDescription": str(getattr(ds, "SeriesDescription", "N/A")), | |
| "Dimensions": f"{img.shape[0]} x {img.shape[1]}", | |
| "Manufacturer": str(getattr(ds, "Manufacturer", "N/A")), | |
| "Filename": os.path.basename(dicom_file.name) | |
| } | |
| return img, metadata, None | |
| except Exception as e: | |
| error_msg = f"Error reading DICOM file: {str(e)}" | |
| print(error_msg) | |
| return None, None, error_msg | |
| def process_dicom_files(dicom_files): | |
| """Process multiple DICOM files and return results""" | |
| if not dicom_files: | |
| return None, "No files uploaded" | |
| results = [] | |
| all_results_html = "" | |
| for i, dicom_file in enumerate(dicom_files): | |
| try: | |
| # Read DICOM | |
| img, metadata, error = read_dicom(dicom_file) | |
| if error: | |
| results.append({ | |
| "filename": os.path.basename(dicom_file.name), | |
| "error": error, | |
| "display_img": None, | |
| "heatmap_img": None, | |
| "metadata": None, | |
| "interpretation": None | |
| }) | |
| continue | |
| # Normalize for display | |
| display_img = (img - img.min()) / (img.max() - img.min() + 1e-6) | |
| # Run AI analysis | |
| interpretation = interpreter.analyze_dicom(img) | |
| # Generate heatmap visualization | |
| heatmap_buf = interpreter.generate_heatmap(img) | |
| # Create a figure with both images for this file | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) | |
| ax1.imshow(display_img, cmap='gray') | |
| ax1.set_title(f"DICOM Image: {os.path.basename(dicom_file.name)}") | |
| ax1.axis('off') | |
| # For the heatmap, read the buffer and display it | |
| heatmap_buf.seek(0) | |
| heatmap_img = plt.imread(heatmap_buf) | |
| ax2.imshow(heatmap_img) | |
| ax2.set_title("AI Attention Map") | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| # Save the combined result | |
| result_buf = io.BytesIO() | |
| fig.savefig(result_buf, format='png') | |
| result_buf.seek(0) | |
| plt.close(fig) | |
| # Build HTML for this result | |
| file_html = f""" | |
| <div style='margin: 20px 0; padding: 15px; border: 1px solid #ddd; border-radius: 8px;'> | |
| <h3>File {i+1}: {os.path.basename(dicom_file.name)}</h3> | |
| <div style='display: flex; justify-content: center;'> | |
| <img src='data:image/png;base64,{io.BytesIO(result_buf.read()).getvalue().hex()}' style='max-width: 100%; height: auto;'> | |
| </div> | |
| <div style='display: flex; margin-top: 15px;'> | |
| <div style='flex: 1; padding: 10px; background-color: #f0f0f0; border-radius: 5px; margin-right: 10px;'> | |
| <h4>DICOM Metadata</h4> | |
| <table style='width: 100%;'> | |
| """ | |
| # Add metadata to table | |
| for key, value in metadata.items(): | |
| file_html += f"<tr><td><b>{key}</b></td><td>{value}</td></tr>" | |
| file_html += """ | |
| </table> | |
| </div> | |
| <div style='flex: 1; padding: 10px; background-color: #f0f0f0; border-radius: 5px;'> | |
| <h4>AI Interpretation</h4> | |
| <table style='width: 100%;'> | |
| """ | |
| # Add interpretation to table | |
| for label, prob in interpretation.items(): | |
| file_html += f"<tr><td><b>{label}</b></td><td>{prob*100:.2f}%</td></tr>" | |
| file_html += """ | |
| </table> | |
| <p><i>Note: This is a demonstration using a general model.</i></p> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| all_results_html += file_html | |
| # Store the result | |
| results.append({ | |
| "filename": os.path.basename(dicom_file.name), | |
| "display_img": display_img, | |
| "heatmap_img": heatmap_img, | |
| "metadata": metadata, | |
| "interpretation": interpretation | |
| }) | |
| except Exception as e: | |
| error_msg = f"Error processing file {os.path.basename(dicom_file.name)}: {str(e)}" | |
| print(error_msg) | |
| all_results_html += f""" | |
| <div style='margin: 20px 0; padding: 15px; border: 1px solid #f88; border-radius: 8px; background-color: #fee;'> | |
| <h3>Error with file {i+1}: {os.path.basename(dicom_file.name)}</h3> | |
| <p>{error_msg}</p> | |
| </div> | |
| """ | |
| # Create header for the results | |
| summary_html = f""" | |
| <div style='padding: 10px; background-color: #e8f4f8; border-radius: 5px; margin-bottom: 20px;'> | |
| <h2>Analysis Results for {len(dicom_files)} DICOM Files</h2> | |
| <p>Processed {len(results)} files successfully. Click on individual results below for details.</p> | |
| </div> | |
| """ | |
| final_html = summary_html + all_results_html | |
| return final_html | |
| # Create Gradio interface | |
| with gr.Blocks(title="Multi-DICOM Interpreter with MONAI") as app: | |
| gr.Markdown("# Multi-DICOM Interpreter with MONAI") | |
| gr.Markdown("Upload one or more DICOM files to get AI-assisted interpretation and visualization") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="Upload DICOM Files", file_count="multiple") | |
| analyze_btn = gr.Button("Analyze DICOM Files", variant="primary") | |
| with gr.Column(scale=2): | |
| output = gr.HTML(label="Analysis Results") | |
| analyze_btn.click( | |
| fn=process_dicom_files, | |
| inputs=[file_input], | |
| outputs=[output] | |
| ) | |
| gr.Markdown(""" | |
| ## About This App | |
| This application demonstrates how to use MONAI, a PyTorch-based framework for deep learning in healthcare imaging, to analyze DICOM medical images. You can upload multiple files at once. | |
| ### Features: | |
| - Upload multiple DICOM files at once | |
| - View images and AI attention maps | |
| - Get AI interpretation for each image | |
| - View detailed DICOM metadata | |
| ### Notes: | |
| - This is a demonstration and should not be used for clinical purposes | |
| - The model is a generic deep learning model and not specifically trained for medical diagnosis | |
| - For actual clinical use, models should be properly trained, validated, and approved | |
| ### Built with: | |
| - MONAI: https://monai.io/ | |
| - PyDICOM: https://pydicom.github.io/ | |
| - Gradio: https://gradio.app/ | |
| - PyTorch: https://pytorch.org/ | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch() |