Ravindu9904 commited on
Commit
663f5a4
·
verified ·
1 Parent(s): 2b48069

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -148
app.py CHANGED
@@ -32,187 +32,302 @@ class DicomInterpreter:
32
 
33
  def initialize_model(self):
34
  """Initialize a pretrained model for classification"""
35
- # For simplicity, using a pretrained DenseNet121
36
- # In production, you'd use a model trained on medical data
37
- self.model = DenseNet121(
38
- spatial_dims=2,
39
- in_channels=1,
40
- out_channels=2, # Binary classification for demo
41
- ).to(self.device)
42
-
43
- # Put model in eval mode
44
- self.model.eval()
45
-
46
- print("Model initialized successfully")
 
 
 
47
 
48
  @torch.no_grad()
49
  def analyze_dicom(self, img_array):
50
  """Process a DICOM pixel array and return predictions"""
51
- # Preprocessing
52
- img_tensor = torch.from_numpy(img_array).float()
53
-
54
- # Ensure 3D: [channel, height, width]
55
- if img_tensor.ndim == 2:
56
- img_tensor = img_tensor.unsqueeze(0)
57
 
58
- # Normalize
59
- img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6)
60
-
61
- # Resize
62
- if img_tensor.shape[1:] != (224, 224):
63
- img_tensor = torch.nn.functional.interpolate(
64
- img_tensor.unsqueeze(0),
65
- size=(224, 224),
66
- mode='bilinear',
67
- align_corners=False
68
- ).squeeze(0)
69
-
70
- # Make prediction
71
- img_tensor = img_tensor.to(self.device)
72
- output = self.model(img_tensor.unsqueeze(0))
73
- probabilities = torch.nn.functional.softmax(output, dim=1)
74
-
75
- # Example interpretation
76
- class_names = ["Normal", "Abnormal"] # Example class names
77
- interpretation = {
78
- class_name: float(prob)
79
- for class_name, prob in zip(class_names, probabilities[0].cpu().numpy())
80
- }
81
-
82
- return interpretation
 
 
 
 
 
 
 
83
 
84
  def generate_heatmap(self, img_array):
85
  """Generate a synthetic attention heatmap"""
86
- # In a real application, use techniques like Grad-CAM
87
- # For demo, creating a simple synthetic heatmap
88
-
89
- # Normalize and resize the image
90
- img_tensor = torch.from_numpy(img_array).float()
91
- if img_tensor.ndim == 2:
92
- img_tensor = img_tensor.unsqueeze(0)
93
-
94
- img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6)
95
- if img_tensor.shape[1:] != (224, 224):
96
- img_tensor = torch.nn.functional.interpolate(
97
- img_tensor.unsqueeze(0),
98
- size=(224, 224),
99
- mode='bilinear',
100
- align_corners=False
101
- ).squeeze(0)
102
-
103
- # Create visualization
104
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
105
-
106
- # Original image
107
- ax1.imshow(img_tensor[0].numpy(), cmap='gray')
108
- ax1.set_title('Original Image')
109
- ax1.axis('off')
110
-
111
- # Create a synthetic heatmap (random for demo)
112
- # In production, use actual attention maps from the model
113
- heatmap = np.random.rand(224, 224)
114
-
115
- # Heatmap overlay
116
- ax2.imshow(img_tensor[0].numpy(), cmap='gray')
117
- ax2.imshow(heatmap, cmap='jet', alpha=0.5)
118
- ax2.set_title('AI Attention Map')
119
- ax2.axis('off')
120
-
121
- plt.tight_layout()
122
-
123
- # Convert matplotlib figure to image
124
- buf = io.BytesIO()
125
- fig.savefig(buf, format='png')
126
- buf.seek(0)
127
- plt.close(fig)
128
-
129
- return buf
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # Initialize the DICOM interpreter
132
  interpreter = DicomInterpreter()
133
 
134
  def read_dicom(dicom_file):
135
  """Read a DICOM file and return pixel array and metadata"""
136
- ds = pydicom.dcmread(dicom_file.name)
137
- img = ds.pixel_array
138
-
139
- # Extract metadata
140
- metadata = {
141
- "PatientID": str(getattr(ds, "PatientID", "N/A")),
142
- "Modality": str(getattr(ds, "Modality", "N/A")),
143
- "StudyDescription": str(getattr(ds, "StudyDescription", "N/A")),
144
- "SeriesDescription": str(getattr(ds, "SeriesDescription", "N/A")),
145
- "Dimensions": f"{img.shape[0]} x {img.shape[1]}",
146
- "Manufacturer": str(getattr(ds, "Manufacturer", "N/A"))
147
- }
148
-
149
- return img, metadata
150
-
151
- def process_dicom(dicom_file):
152
- """Process a DICOM file and return visualization and analysis"""
153
- if dicom_file is None:
154
- return None, None, "No file uploaded"
155
-
156
  try:
157
- # Read DICOM
158
- img, metadata = read_dicom(dicom_file)
159
-
160
- # Normalize for display
161
- display_img = (img - img.min()) / (img.max() - img.min() + 1e-6)
162
-
163
- # Run AI analysis
164
- interpretation = interpreter.analyze_dicom(img)
165
-
166
- # Generate heatmap visualization
167
- heatmap_buf = interpreter.generate_heatmap(img)
168
-
169
- # Format metadata as HTML
170
- metadata_html = "<div style='text-align: left; padding: 10px; background-color: #f0f0f0; border-radius: 5px;'>"
171
- metadata_html += "<h3>DICOM Metadata</h3>"
172
- for key, value in metadata.items():
173
- metadata_html += f"<b>{key}</b>: {value}<br>"
174
- metadata_html += "</div>"
175
-
176
- # Format interpretation as HTML
177
- interp_html = "<div style='text-align: left; padding: 10px; background-color: #f0f0f0; border-radius: 5px;'>"
178
- interp_html += "<h3>AI Interpretation</h3>"
179
- for label, prob in interpretation.items():
180
- interp_html += f"<b>{label}</b>: {prob*100:.2f}%<br>"
181
- interp_html += "<p><i>Note: This is a demonstration using a general model. Actual medical applications require properly trained models.</i></p>"
182
- interp_html += "</div>"
183
 
184
- return display_img, heatmap_buf, metadata_html + interp_html
 
 
 
 
 
 
 
 
 
185
 
 
186
  except Exception as e:
187
- return None, None, f"Error processing DICOM file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  # Create Gradio interface
190
- with gr.Blocks(title="DICOM Interpreter with MONAI") as app:
191
- gr.Markdown("# DICOM Interpreter with MONAI")
192
- gr.Markdown("Upload a DICOM file to get AI-assisted interpretation and visualization")
193
 
194
  with gr.Row():
195
  with gr.Column(scale=1):
196
- file_input = gr.File(label="Upload DICOM File")
197
- analyze_btn = gr.Button("Analyze DICOM", variant="primary")
198
 
199
  with gr.Column(scale=2):
200
- with gr.Tab("Results"):
201
- with gr.Row():
202
- image_output = gr.Image(label="DICOM Image", type="numpy")
203
- heatmap_output = gr.Image(label="AI Attention Map", type="numpy")
204
- info_output = gr.HTML(label="Analysis Results")
205
 
206
  analyze_btn.click(
207
- fn=process_dicom,
208
  inputs=[file_input],
209
- outputs=[image_output, heatmap_output, info_output]
210
  )
211
 
212
  gr.Markdown("""
213
  ## About This App
214
 
215
- This application demonstrates how to use MONAI, a PyTorch-based framework for deep learning in healthcare imaging, to analyze DICOM medical images.
 
 
 
 
 
 
216
 
217
  ### Notes:
218
  - This is a demonstration and should not be used for clinical purposes
 
32
 
33
  def initialize_model(self):
34
  """Initialize a pretrained model for classification"""
35
+ try:
36
+ # For simplicity, using a pretrained DenseNet121
37
+ # In production, you'd use a model trained on medical data
38
+ self.model = DenseNet121(
39
+ spatial_dims=2,
40
+ in_channels=1,
41
+ out_channels=2, # Binary classification for demo
42
+ ).to(self.device)
43
+
44
+ # Put model in eval mode
45
+ self.model.eval()
46
+ print("Model initialized successfully")
47
+ except Exception as e:
48
+ print(f"Model initialization error: {str(e)}")
49
+ self.model = None
50
 
51
  @torch.no_grad()
52
  def analyze_dicom(self, img_array):
53
  """Process a DICOM pixel array and return predictions"""
54
+ try:
55
+ # Preprocessing
56
+ img_tensor = torch.from_numpy(img_array).float()
 
 
 
57
 
58
+ # Ensure 3D: [channel, height, width]
59
+ if img_tensor.ndim == 2:
60
+ img_tensor = img_tensor.unsqueeze(0)
61
+
62
+ # Normalize
63
+ img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6)
64
+
65
+ # Resize
66
+ if img_tensor.shape[1:] != (224, 224):
67
+ img_tensor = torch.nn.functional.interpolate(
68
+ img_tensor.unsqueeze(0),
69
+ size=(224, 224),
70
+ mode='bilinear',
71
+ align_corners=False
72
+ ).squeeze(0)
73
+
74
+ # Make prediction
75
+ img_tensor = img_tensor.to(self.device)
76
+ output = self.model(img_tensor.unsqueeze(0))
77
+ probabilities = torch.nn.functional.softmax(output, dim=1)
78
+
79
+ # Example interpretation
80
+ class_names = ["Normal", "Abnormal"] # Example class names
81
+ interpretation = {
82
+ class_name: float(prob)
83
+ for class_name, prob in zip(class_names, probabilities[0].cpu().numpy())
84
+ }
85
+
86
+ return interpretation
87
+ except Exception as e:
88
+ print(f"Analysis error: {str(e)}")
89
+ return {"Error": 1.0}
90
 
91
  def generate_heatmap(self, img_array):
92
  """Generate a synthetic attention heatmap"""
93
+ try:
94
+ # Normalize and resize the image
95
+ img_tensor = torch.from_numpy(img_array).float()
96
+ if img_tensor.ndim == 2:
97
+ img_tensor = img_tensor.unsqueeze(0)
98
+
99
+ img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min() + 1e-6)
100
+ if img_tensor.shape[1:] != (224, 224):
101
+ img_tensor = torch.nn.functional.interpolate(
102
+ img_tensor.unsqueeze(0),
103
+ size=(224, 224),
104
+ mode='bilinear',
105
+ align_corners=False
106
+ ).squeeze(0)
107
+
108
+ # Create visualization
109
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
110
+
111
+ # Original image
112
+ ax1.imshow(img_tensor[0].numpy(), cmap='gray')
113
+ ax1.set_title('Original Image')
114
+ ax1.axis('off')
115
+
116
+ # Create a synthetic heatmap (random for demo)
117
+ # In production, use actual attention maps from the model
118
+ heatmap = np.random.rand(224, 224)
119
+
120
+ # Heatmap overlay
121
+ ax2.imshow(img_tensor[0].numpy(), cmap='gray')
122
+ ax2.imshow(heatmap, cmap='jet', alpha=0.5)
123
+ ax2.set_title('AI Attention Map')
124
+ ax2.axis('off')
125
+
126
+ plt.tight_layout()
127
+
128
+ # Convert matplotlib figure to image
129
+ buf = io.BytesIO()
130
+ fig.savefig(buf, format='png')
131
+ buf.seek(0)
132
+ plt.close(fig)
133
+
134
+ return buf
135
+ except Exception as e:
136
+ print(f"Heatmap generation error: {str(e)}")
137
+ # Create a simple error image
138
+ fig, ax = plt.subplots(figsize=(12, 6))
139
+ ax.text(0.5, 0.5, f"Error generating heatmap: {str(e)}",
140
+ horizontalalignment='center', verticalalignment='center')
141
+ ax.axis('off')
142
+
143
+ buf = io.BytesIO()
144
+ fig.savefig(buf, format='png')
145
+ buf.seek(0)
146
+ plt.close(fig)
147
+
148
+ return buf
149
 
150
  # Initialize the DICOM interpreter
151
  interpreter = DicomInterpreter()
152
 
153
  def read_dicom(dicom_file):
154
  """Read a DICOM file and return pixel array and metadata"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  try:
156
+ ds = pydicom.dcmread(dicom_file.name)
157
+ img = ds.pixel_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Extract metadata
160
+ metadata = {
161
+ "PatientID": str(getattr(ds, "PatientID", "N/A")),
162
+ "Modality": str(getattr(ds, "Modality", "N/A")),
163
+ "StudyDescription": str(getattr(ds, "StudyDescription", "N/A")),
164
+ "SeriesDescription": str(getattr(ds, "SeriesDescription", "N/A")),
165
+ "Dimensions": f"{img.shape[0]} x {img.shape[1]}",
166
+ "Manufacturer": str(getattr(ds, "Manufacturer", "N/A")),
167
+ "Filename": os.path.basename(dicom_file.name)
168
+ }
169
 
170
+ return img, metadata, None
171
  except Exception as e:
172
+ error_msg = f"Error reading DICOM file: {str(e)}"
173
+ print(error_msg)
174
+ return None, None, error_msg
175
+
176
+ def process_dicom_files(dicom_files):
177
+ """Process multiple DICOM files and return results"""
178
+ if not dicom_files:
179
+ return None, "No files uploaded"
180
+
181
+ results = []
182
+ all_results_html = ""
183
+
184
+ for i, dicom_file in enumerate(dicom_files):
185
+ try:
186
+ # Read DICOM
187
+ img, metadata, error = read_dicom(dicom_file)
188
+
189
+ if error:
190
+ results.append({
191
+ "filename": os.path.basename(dicom_file.name),
192
+ "error": error,
193
+ "display_img": None,
194
+ "heatmap_img": None,
195
+ "metadata": None,
196
+ "interpretation": None
197
+ })
198
+ continue
199
+
200
+ # Normalize for display
201
+ display_img = (img - img.min()) / (img.max() - img.min() + 1e-6)
202
+
203
+ # Run AI analysis
204
+ interpretation = interpreter.analyze_dicom(img)
205
+
206
+ # Generate heatmap visualization
207
+ heatmap_buf = interpreter.generate_heatmap(img)
208
+
209
+ # Create a figure with both images for this file
210
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
211
+ ax1.imshow(display_img, cmap='gray')
212
+ ax1.set_title(f"DICOM Image: {os.path.basename(dicom_file.name)}")
213
+ ax1.axis('off')
214
+
215
+ # For the heatmap, read the buffer and display it
216
+ heatmap_buf.seek(0)
217
+ heatmap_img = plt.imread(heatmap_buf)
218
+ ax2.imshow(heatmap_img)
219
+ ax2.set_title("AI Attention Map")
220
+ ax2.axis('off')
221
+
222
+ plt.tight_layout()
223
+
224
+ # Save the combined result
225
+ result_buf = io.BytesIO()
226
+ fig.savefig(result_buf, format='png')
227
+ result_buf.seek(0)
228
+ plt.close(fig)
229
+
230
+ # Build HTML for this result
231
+ file_html = f"""
232
+ <div style='margin: 20px 0; padding: 15px; border: 1px solid #ddd; border-radius: 8px;'>
233
+ <h3>File {i+1}: {os.path.basename(dicom_file.name)}</h3>
234
+ <div style='display: flex; justify-content: center;'>
235
+ <img src='data:image/png;base64,{io.BytesIO(result_buf.read()).getvalue().hex()}' style='max-width: 100%; height: auto;'>
236
+ </div>
237
+
238
+ <div style='display: flex; margin-top: 15px;'>
239
+ <div style='flex: 1; padding: 10px; background-color: #f0f0f0; border-radius: 5px; margin-right: 10px;'>
240
+ <h4>DICOM Metadata</h4>
241
+ <table style='width: 100%;'>
242
+ """
243
+
244
+ # Add metadata to table
245
+ for key, value in metadata.items():
246
+ file_html += f"<tr><td><b>{key}</b></td><td>{value}</td></tr>"
247
+
248
+ file_html += """
249
+ </table>
250
+ </div>
251
+
252
+ <div style='flex: 1; padding: 10px; background-color: #f0f0f0; border-radius: 5px;'>
253
+ <h4>AI Interpretation</h4>
254
+ <table style='width: 100%;'>
255
+ """
256
+
257
+ # Add interpretation to table
258
+ for label, prob in interpretation.items():
259
+ file_html += f"<tr><td><b>{label}</b></td><td>{prob*100:.2f}%</td></tr>"
260
+
261
+ file_html += """
262
+ </table>
263
+ <p><i>Note: This is a demonstration using a general model.</i></p>
264
+ </div>
265
+ </div>
266
+ </div>
267
+ """
268
+
269
+ all_results_html += file_html
270
+
271
+ # Store the result
272
+ results.append({
273
+ "filename": os.path.basename(dicom_file.name),
274
+ "display_img": display_img,
275
+ "heatmap_img": heatmap_img,
276
+ "metadata": metadata,
277
+ "interpretation": interpretation
278
+ })
279
+
280
+ except Exception as e:
281
+ error_msg = f"Error processing file {os.path.basename(dicom_file.name)}: {str(e)}"
282
+ print(error_msg)
283
+ all_results_html += f"""
284
+ <div style='margin: 20px 0; padding: 15px; border: 1px solid #f88; border-radius: 8px; background-color: #fee;'>
285
+ <h3>Error with file {i+1}: {os.path.basename(dicom_file.name)}</h3>
286
+ <p>{error_msg}</p>
287
+ </div>
288
+ """
289
+
290
+ # Create header for the results
291
+ summary_html = f"""
292
+ <div style='padding: 10px; background-color: #e8f4f8; border-radius: 5px; margin-bottom: 20px;'>
293
+ <h2>Analysis Results for {len(dicom_files)} DICOM Files</h2>
294
+ <p>Processed {len(results)} files successfully. Click on individual results below for details.</p>
295
+ </div>
296
+ """
297
+
298
+ final_html = summary_html + all_results_html
299
+
300
+ return final_html
301
 
302
  # Create Gradio interface
303
+ with gr.Blocks(title="Multi-DICOM Interpreter with MONAI") as app:
304
+ gr.Markdown("# Multi-DICOM Interpreter with MONAI")
305
+ gr.Markdown("Upload one or more DICOM files to get AI-assisted interpretation and visualization")
306
 
307
  with gr.Row():
308
  with gr.Column(scale=1):
309
+ file_input = gr.File(label="Upload DICOM Files", file_count="multiple")
310
+ analyze_btn = gr.Button("Analyze DICOM Files", variant="primary")
311
 
312
  with gr.Column(scale=2):
313
+ output = gr.HTML(label="Analysis Results")
 
 
 
 
314
 
315
  analyze_btn.click(
316
+ fn=process_dicom_files,
317
  inputs=[file_input],
318
+ outputs=[output]
319
  )
320
 
321
  gr.Markdown("""
322
  ## About This App
323
 
324
+ This application demonstrates how to use MONAI, a PyTorch-based framework for deep learning in healthcare imaging, to analyze DICOM medical images. You can upload multiple files at once.
325
+
326
+ ### Features:
327
+ - Upload multiple DICOM files at once
328
+ - View images and AI attention maps
329
+ - Get AI interpretation for each image
330
+ - View detailed DICOM metadata
331
 
332
  ### Notes:
333
  - This is a demonstration and should not be used for clinical purposes