hmgill commited on
Commit
ba2f58e
·
verified ·
1 Parent(s): 7f36c33

Update cellemetry/services/sam.py

Browse files
Files changed (1) hide show
  1. cellemetry/services/sam.py +22 -5
cellemetry/services/sam.py CHANGED
@@ -4,6 +4,7 @@ SAM3 segmentation execution - Optimized for Speed.
4
  import matplotlib
5
  matplotlib.use('Agg')
6
  import matplotlib.pyplot as plt
 
7
  import torch
8
  import torchvision
9
  import numpy as np
@@ -169,13 +170,12 @@ def _filter_results(results, keep_indices):
169
 
170
 
171
  def _save_plot(image, results, boxes, label, filename):
172
- """Save visualization of segmentation results."""
173
  fig, ax = plt.subplots(figsize=(10, 10))
174
  ax.imshow(image)
175
 
176
- # Batch visualization for speed
177
  if len(results['scores']) > 0:
178
- # Create a single composite mask image for faster plotting than individual ax.imshow calls
179
  H, W = results['masks'][0].shape[-2:]
180
  composite = np.zeros((H, W, 4))
181
 
@@ -190,7 +190,24 @@ def _save_plot(image, results, boxes, label, filename):
190
 
191
  ax.imshow(composite)
192
 
193
- ax.set_title(f"{label}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  ax.axis('off')
195
- fig.savefig(filename)
 
 
196
  plt.close(fig)
 
4
  import matplotlib
5
  matplotlib.use('Agg')
6
  import matplotlib.pyplot as plt
7
+ import matplotlib.patches as patches
8
  import torch
9
  import torchvision
10
  import numpy as np
 
170
 
171
 
172
  def _save_plot(image, results, boxes, label, filename):
173
+ """Save visualization of segmentation results with bounding boxes."""
174
  fig, ax = plt.subplots(figsize=(10, 10))
175
  ax.imshow(image)
176
 
177
+ # 1. Draw Segmentation Masks
178
  if len(results['scores']) > 0:
 
179
  H, W = results['masks'][0].shape[-2:]
180
  composite = np.zeros((H, W, 4))
181
 
 
190
 
191
  ax.imshow(composite)
192
 
193
+ # 2. Draw Bounding Boxes (NEW)
194
+ # 'boxes' comes in as [[xmin, ymin, xmax, ymax], ...]
195
+ for box in boxes:
196
+ x_min, y_min, x_max, y_max = box
197
+ width = x_max - x_min
198
+ height = y_max - y_min
199
+
200
+ # Create a Rectangle patch
201
+ # red border, no fill, thick line
202
+ rect = patches.Rectangle(
203
+ (x_min, y_min), width, height,
204
+ linewidth=2, edgecolor='red', facecolor='none', linestyle='--'
205
+ )
206
+ ax.add_patch(rect)
207
+
208
+ ax.set_title(f"{label} (with Agent BBoxes)")
209
  ax.axis('off')
210
+
211
+ # Save tightly to remove whitespace
212
+ fig.savefig(filename, bbox_inches='tight', pad_inches=0)
213
  plt.close(fig)