samwell Claude commited on
Commit
3a18164
·
1 Parent(s): 9dc78d7

fix: Improve segmentation overlay and add DICOM support

Browse files

Segmentation improvements:
- Improved matplotlib overlay rendering with better color opacity
- Added debug logging to track mask detection and alignment
- Changed to subplot-based rendering for better overlay composition
- Increased overlay opacity from 30% to 40% for better visibility
- Added mask count tracking to verify segmentation success

DICOM file support:
- Added DICOM file extensions to Gradio file upload (.dcm, .dicom)
- Updated placeholder text to indicate DICOM support
- Added DICOM file detection in chat function
- DICOM files are passed to agent for processing with DICOM tool
- Added error traceback printing for better debugging

The segmentation overlay should now properly show colored masks
on top of the X-ray image, and DICOM files can be uploaded
without errors.

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +41 -29
  2. medrax/tools/segmentation/segmentation.py +24 -11
app.py CHANGED
@@ -177,41 +177,53 @@ def chat(message, history, mode):
177
 
178
  if files and len(files) > 0:
179
  image_path = files[0]
 
 
 
 
180
  # Store image path for tools to use
181
  # LangChain Google GenAI expects images as base64 or PIL
182
  try:
183
- # Open and encode image for Gemini
184
- with Image.open(image_path) as img:
185
- # Convert to RGB if needed
186
- if img.mode != "RGB":
187
- img = img.convert("RGB")
188
-
189
- # Resize if too large (max 4096x4096 for Gemini)
190
- max_size = 4096
191
- if img.width > max_size or img.height > max_size:
192
- img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
193
-
194
- # Store as bytes for LangChain
195
- buffered = BytesIO()
196
- img.save(buffered, format="PNG")
197
- img_bytes = buffered.getvalue()
198
- img_b64 = base64.b64encode(img_bytes).decode()
199
-
200
- # Create multimodal content for Gemini
201
- # Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]
202
- image_content = {
203
- "type": "image_url",
204
- "image_url": {
205
- "url": f"data:image/png;base64,{img_b64}"
 
 
 
 
 
 
 
206
  }
207
- }
208
 
209
- # Include image path in text for tools to use
210
- text = f"[Image: {image_path}]\n\n{text}"
211
 
212
  except Exception as e:
213
  print(f"Error processing image: {e}")
214
- text = f"[Failed to load image: {image_path}]\n\n{text}"
 
 
215
 
216
  message = text
217
 
@@ -269,8 +281,8 @@ with gr.Blocks() as demo:
269
 
270
  msg = gr.MultimodalTextbox(
271
  label="Message",
272
- placeholder="Upload an X-ray image and ask a question...",
273
- file_types=["image"]
274
  )
275
 
276
  def respond(message, chat_history, mode_selection):
 
177
 
178
  if files and len(files) > 0:
179
  image_path = files[0]
180
+
181
+ # Check if it's a DICOM file
182
+ is_dicom = image_path.lower().endswith(('.dcm', '.dicom'))
183
+
184
  # Store image path for tools to use
185
  # LangChain Google GenAI expects images as base64 or PIL
186
  try:
187
+ if is_dicom:
188
+ # DICOM files need to be converted first
189
+ # We'll just pass the path and let the agent handle it
190
+ text = f"[DICOM file uploaded: {image_path}]\n\n{text}"
191
+ print(f"DICOM file detected: {image_path}")
192
+ else:
193
+ # Open and encode image for Gemini
194
+ with Image.open(image_path) as img:
195
+ # Convert to RGB if needed
196
+ if img.mode != "RGB":
197
+ img = img.convert("RGB")
198
+
199
+ # Resize if too large (max 4096x4096 for Gemini)
200
+ max_size = 4096
201
+ if img.width > max_size or img.height > max_size:
202
+ img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
203
+
204
+ # Store as bytes for LangChain
205
+ buffered = BytesIO()
206
+ img.save(buffered, format="PNG")
207
+ img_bytes = buffered.getvalue()
208
+ img_b64 = base64.b64encode(img_bytes).decode()
209
+
210
+ # Create multimodal content for Gemini
211
+ # Format: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]
212
+ image_content = {
213
+ "type": "image_url",
214
+ "image_url": {
215
+ "url": f"data:image/png;base64,{img_b64}"
216
+ }
217
  }
 
218
 
219
+ # Include image path in text for tools to use
220
+ text = f"[Image: {image_path}]\n\n{text}"
221
 
222
  except Exception as e:
223
  print(f"Error processing image: {e}")
224
+ import traceback
225
+ traceback.print_exc()
226
+ text = f"[Failed to load image: {image_path}. Error: {str(e)}]\n\n{text}"
227
 
228
  message = text
229
 
 
281
 
282
  msg = gr.MultimodalTextbox(
283
  label="Message",
284
+ placeholder="Upload an X-ray image (JPG, PNG, DICOM) and ask a question...",
285
+ file_types=["image", ".dcm", ".dicom", ".DCM", ".DICOM"]
286
  )
287
 
288
  def respond(message, chat_history, mode_selection):
medrax/tools/segmentation/segmentation.py CHANGED
@@ -173,36 +173,49 @@ class ChestXRaySegmentationTool(BaseTool):
173
 
174
  def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
175
  """Save visualization of original image with segmentation masks overlaid."""
176
- plt.figure(figsize=(10, 10))
177
- plt.imshow(original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0])
 
 
178
 
179
  # Generate color palette for organs
180
  colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
181
 
182
  # Process and overlay each organ mask
 
183
  for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)):
184
  mask = pred_masks[0, organ_idx].cpu().numpy()
 
 
 
 
185
  if mask.sum() > 0:
 
186
  # Align the mask to the original image coordinates
187
  if mask.shape != original_img.shape:
188
  mask = self._align_mask_to_original(mask, original_img.shape)
 
189
 
190
  # Create a colored overlay with transparency
191
- colored_mask = np.zeros((*original_img.shape, 4))
192
- colored_mask[mask > 0] = (*color[:3], 0.3)
193
- plt.imshow(colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0])
 
194
 
195
  # Add legend entry for the organ
196
  organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
197
- plt.plot([], [], color=color, label=organ_name, linewidth=3)
 
 
198
 
199
- plt.title("Segmentation Overlay")
200
- plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
201
- plt.axis("off")
 
202
 
203
  save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
204
- plt.savefig(save_path, bbox_inches="tight", dpi=300)
205
- plt.close()
206
 
207
  return str(save_path)
208
 
 
173
 
174
  def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
175
  """Save visualization of original image with segmentation masks overlaid."""
176
+ fig, ax = plt.subplots(figsize=(12, 12))
177
+
178
+ # Display original image
179
+ ax.imshow(original_img, cmap="gray")
180
 
181
  # Generate color palette for organs
182
  colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
183
 
184
  # Process and overlay each organ mask
185
+ masks_found = 0
186
  for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)):
187
  mask = pred_masks[0, organ_idx].cpu().numpy()
188
+
189
+ # Debug: print mask info
190
+ print(f"Organ index {organ_idx}: mask sum = {mask.sum()}, mask shape = {mask.shape}")
191
+
192
  if mask.sum() > 0:
193
+ masks_found += 1
194
  # Align the mask to the original image coordinates
195
  if mask.shape != original_img.shape:
196
  mask = self._align_mask_to_original(mask, original_img.shape)
197
+ print(f"Aligned mask shape: {mask.shape}, sum: {mask.sum()}")
198
 
199
  # Create a colored overlay with transparency
200
+ # Convert binary mask to RGBA overlay
201
+ overlay = np.zeros((*original_img.shape, 4))
202
+ overlay[mask > 0] = [color[0], color[1], color[2], 0.4] # 40% opacity
203
+ ax.imshow(overlay)
204
 
205
  # Add legend entry for the organ
206
  organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
207
+ ax.plot([], [], color=color, label=organ_name, linewidth=3)
208
+
209
+ print(f"Total masks found and rendered: {masks_found}")
210
 
211
+ ax.set_title("Segmentation Overlay", fontsize=16, pad=20)
212
+ if masks_found > 0:
213
+ ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=10)
214
+ ax.axis("off")
215
 
216
  save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
217
+ plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black')
218
+ plt.close(fig)
219
 
220
  return str(save_path)
221