DICOM_AI / app.py
Ravindu9904's picture
Update app.py
663f5a4 verified
import os
import io
import tempfile
import numpy as np
import torch
import gradio as gr
import pydicom
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use Agg backend for non-interactive environments
from monai.networks.nets import DenseNet121
from monai.transforms import (
Compose,
ScaleIntensity,
Resize,
EnsureChannelFirst
)
# Ensure PyTorch is available
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
class DicomInterpreter:
def __init__(self):
self.model = None
self.device = device
self.initialize_model()
def initialize_model(self):
"""Initialize a pretrained model for classification"""
try:
# For simplicity, using a pretrained DenseNet121
# In production, you'd use a model trained on medical data
self.model = DenseNet121(
spatial_dims=2,
in_channels=1,
out_channels=2, # Binary classification for demo
).to(self.device)
# Put model in eval mode
self.model.eval()
print("Model initialized successfully")
except Exception as e:
print(f"Model initialization error: {str(e)}")
self.model = None
@torch.no_grad()
def analyze_dicom(self, img_array):
"""Process a DICOM pixel array and return predictions"""
try:
# Preprocessing
img_tensor = torch.from_numpy(img_array).float()
# Ensure 3D: [channel, height, width]
if img_tensor.ndim == 2:
img_tensor = img_tensor.unsqueeze(0)
# Normalize
img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6)
# Resize
if img_tensor.shape[1:] != (224, 224):
img_tensor = torch.nn.functional.interpolate(
img_tensor.unsqueeze(0),
size=(224, 224),
mode='bilinear',
align_corners=False
).squeeze(0)
# Make prediction
img_tensor = img_tensor.to(self.device)
output = self.model(img_tensor.unsqueeze(0))
probabilities = torch.nn.functional.softmax(output, dim=1)
# Example interpretation
class_names = ["Normal", "Abnormal"] # Example class names
interpretation = {
class_name: float(prob)
for class_name, prob in zip(class_names, probabilities[0].cpu().numpy())
}
return interpretation
except Exception as e:
print(f"Analysis error: {str(e)}")
return {"Error": 1.0}
def generate_heatmap(self, img_array):
"""Generate a synthetic attention heatmap"""
try:
# Normalize and resize the image
img_tensor = torch.from_numpy(img_array).float()
if img_tensor.ndim == 2:
img_tensor = img_tensor.unsqueeze(0)
img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6)
if img_tensor.shape[1:] != (224, 224):
img_tensor = torch.nn.functional.interpolate(
img_tensor.unsqueeze(0),
size=(224, 224),
mode='bilinear',
align_corners=False
).squeeze(0)
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# Original image
ax1.imshow(img_tensor[0].numpy(), cmap='gray')
ax1.set_title('Original Image')
ax1.axis('off')
# Create a synthetic heatmap (random for demo)
# In production, use actual attention maps from the model
heatmap = np.random.rand(224, 224)
# Heatmap overlay
ax2.imshow(img_tensor[0].numpy(), cmap='gray')
ax2.imshow(heatmap, cmap='jet', alpha=0.5)
ax2.set_title('AI Attention Map')
ax2.axis('off')
plt.tight_layout()
# Convert matplotlib figure to image
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
plt.close(fig)
return buf
except Exception as e:
print(f"Heatmap generation error: {str(e)}")
# Create a simple error image
fig, ax = plt.subplots(figsize=(12, 6))
ax.text(0.5, 0.5, f"Error generating heatmap: {str(e)}",
horizontalalignment='center', verticalalignment='center')
ax.axis('off')
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
plt.close(fig)
return buf
# Initialize the DICOM interpreter
interpreter = DicomInterpreter()
def read_dicom(dicom_file):
"""Read a DICOM file and return pixel array and metadata"""
try:
ds = pydicom.dcmread(dicom_file.name)
img = ds.pixel_array
# Extract metadata
metadata = {
"PatientID": str(getattr(ds, "PatientID", "N/A")),
"Modality": str(getattr(ds, "Modality", "N/A")),
"StudyDescription": str(getattr(ds, "StudyDescription", "N/A")),
"SeriesDescription": str(getattr(ds, "SeriesDescription", "N/A")),
"Dimensions": f"{img.shape[0]} x {img.shape[1]}",
"Manufacturer": str(getattr(ds, "Manufacturer", "N/A")),
"Filename": os.path.basename(dicom_file.name)
}
return img, metadata, None
except Exception as e:
error_msg = f"Error reading DICOM file: {str(e)}"
print(error_msg)
return None, None, error_msg
def process_dicom_files(dicom_files):
"""Process multiple DICOM files and return results"""
if not dicom_files:
return None, "No files uploaded"
results = []
all_results_html = ""
for i, dicom_file in enumerate(dicom_files):
try:
# Read DICOM
img, metadata, error = read_dicom(dicom_file)
if error:
results.append({
"filename": os.path.basename(dicom_file.name),
"error": error,
"display_img": None,
"heatmap_img": None,
"metadata": None,
"interpretation": None
})
continue
# Normalize for display
display_img = (img - img.min()) / (img.max() - img.min() + 1e-6)
# Run AI analysis
interpretation = interpreter.analyze_dicom(img)
# Generate heatmap visualization
heatmap_buf = interpreter.generate_heatmap(img)
# Create a figure with both images for this file
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(display_img, cmap='gray')
ax1.set_title(f"DICOM Image: {os.path.basename(dicom_file.name)}")
ax1.axis('off')
# For the heatmap, read the buffer and display it
heatmap_buf.seek(0)
heatmap_img = plt.imread(heatmap_buf)
ax2.imshow(heatmap_img)
ax2.set_title("AI Attention Map")
ax2.axis('off')
plt.tight_layout()
# Save the combined result
result_buf = io.BytesIO()
fig.savefig(result_buf, format='png')
result_buf.seek(0)
plt.close(fig)
# Build HTML for this result
file_html = f"""
<div style='margin: 20px 0; padding: 15px; border: 1px solid #ddd; border-radius: 8px;'>
<h3>File {i+1}: {os.path.basename(dicom_file.name)}</h3>
<div style='display: flex; justify-content: center;'>
<img src='data:image/png;base64,{io.BytesIO(result_buf.read()).getvalue().hex()}' style='max-width: 100%; height: auto;'>
</div>
<div style='display: flex; margin-top: 15px;'>
<div style='flex: 1; padding: 10px; background-color: #f0f0f0; border-radius: 5px; margin-right: 10px;'>
<h4>DICOM Metadata</h4>
<table style='width: 100%;'>
"""
# Add metadata to table
for key, value in metadata.items():
file_html += f"<tr><td><b>{key}</b></td><td>{value}</td></tr>"
file_html += """
</table>
</div>
<div style='flex: 1; padding: 10px; background-color: #f0f0f0; border-radius: 5px;'>
<h4>AI Interpretation</h4>
<table style='width: 100%;'>
"""
# Add interpretation to table
for label, prob in interpretation.items():
file_html += f"<tr><td><b>{label}</b></td><td>{prob*100:.2f}%</td></tr>"
file_html += """
</table>
<p><i>Note: This is a demonstration using a general model.</i></p>
</div>
</div>
</div>
"""
all_results_html += file_html
# Store the result
results.append({
"filename": os.path.basename(dicom_file.name),
"display_img": display_img,
"heatmap_img": heatmap_img,
"metadata": metadata,
"interpretation": interpretation
})
except Exception as e:
error_msg = f"Error processing file {os.path.basename(dicom_file.name)}: {str(e)}"
print(error_msg)
all_results_html += f"""
<div style='margin: 20px 0; padding: 15px; border: 1px solid #f88; border-radius: 8px; background-color: #fee;'>
<h3>Error with file {i+1}: {os.path.basename(dicom_file.name)}</h3>
<p>{error_msg}</p>
</div>
"""
# Create header for the results
summary_html = f"""
<div style='padding: 10px; background-color: #e8f4f8; border-radius: 5px; margin-bottom: 20px;'>
<h2>Analysis Results for {len(dicom_files)} DICOM Files</h2>
<p>Processed {len(results)} files successfully. Click on individual results below for details.</p>
</div>
"""
final_html = summary_html + all_results_html
return final_html
# Create Gradio interface
with gr.Blocks(title="Multi-DICOM Interpreter with MONAI") as app:
gr.Markdown("# Multi-DICOM Interpreter with MONAI")
gr.Markdown("Upload one or more DICOM files to get AI-assisted interpretation and visualization")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload DICOM Files", file_count="multiple")
analyze_btn = gr.Button("Analyze DICOM Files", variant="primary")
with gr.Column(scale=2):
output = gr.HTML(label="Analysis Results")
analyze_btn.click(
fn=process_dicom_files,
inputs=[file_input],
outputs=[output]
)
gr.Markdown("""
## About This App
This application demonstrates how to use MONAI, a PyTorch-based framework for deep learning in healthcare imaging, to analyze DICOM medical images. You can upload multiple files at once.
### Features:
- Upload multiple DICOM files at once
- View images and AI attention maps
- Get AI interpretation for each image
- View detailed DICOM metadata
### Notes:
- This is a demonstration and should not be used for clinical purposes
- The model is a generic deep learning model and not specifically trained for medical diagnosis
- For actual clinical use, models should be properly trained, validated, and approved
### Built with:
- MONAI: https://monai.io/
- PyDICOM: https://pydicom.github.io/
- Gradio: https://gradio.app/
- PyTorch: https://pytorch.org/
""")
# Launch the app
if __name__ == "__main__":
app.launch()