Spaces:
Sleeping
Sleeping
Update cellemetry/services/sam.py
Browse files- cellemetry/services/sam.py +22 -5
cellemetry/services/sam.py
CHANGED
|
@@ -4,6 +4,7 @@ SAM3 segmentation execution - Optimized for Speed.
|
|
| 4 |
import matplotlib
|
| 5 |
matplotlib.use('Agg')
|
| 6 |
import matplotlib.pyplot as plt
|
|
|
|
| 7 |
import torch
|
| 8 |
import torchvision
|
| 9 |
import numpy as np
|
|
@@ -169,13 +170,12 @@ def _filter_results(results, keep_indices):
|
|
| 169 |
|
| 170 |
|
| 171 |
def _save_plot(image, results, boxes, label, filename):
|
| 172 |
-
"""Save visualization of segmentation results."""
|
| 173 |
fig, ax = plt.subplots(figsize=(10, 10))
|
| 174 |
ax.imshow(image)
|
| 175 |
|
| 176 |
-
#
|
| 177 |
if len(results['scores']) > 0:
|
| 178 |
-
# Create a single composite mask image for faster plotting than individual ax.imshow calls
|
| 179 |
H, W = results['masks'][0].shape[-2:]
|
| 180 |
composite = np.zeros((H, W, 4))
|
| 181 |
|
|
@@ -190,7 +190,24 @@ def _save_plot(image, results, boxes, label, filename):
|
|
| 190 |
|
| 191 |
ax.imshow(composite)
|
| 192 |
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
ax.axis('off')
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
plt.close(fig)
|
|
|
|
| 4 |
import matplotlib
|
| 5 |
matplotlib.use('Agg')
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
+
import matplotlib.patches as patches
|
| 8 |
import torch
|
| 9 |
import torchvision
|
| 10 |
import numpy as np
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
def _save_plot(image, results, boxes, label, filename):
|
| 173 |
+
"""Save visualization of segmentation results with bounding boxes."""
|
| 174 |
fig, ax = plt.subplots(figsize=(10, 10))
|
| 175 |
ax.imshow(image)
|
| 176 |
|
| 177 |
+
# 1. Draw Segmentation Masks
|
| 178 |
if len(results['scores']) > 0:
|
|
|
|
| 179 |
H, W = results['masks'][0].shape[-2:]
|
| 180 |
composite = np.zeros((H, W, 4))
|
| 181 |
|
|
|
|
| 190 |
|
| 191 |
ax.imshow(composite)
|
| 192 |
|
| 193 |
+
# 2. Draw Bounding Boxes (NEW)
|
| 194 |
+
# 'boxes' comes in as [[xmin, ymin, xmax, ymax], ...]
|
| 195 |
+
for box in boxes:
|
| 196 |
+
x_min, y_min, x_max, y_max = box
|
| 197 |
+
width = x_max - x_min
|
| 198 |
+
height = y_max - y_min
|
| 199 |
+
|
| 200 |
+
# Create a Rectangle patch
|
| 201 |
+
# red border, no fill, thick line
|
| 202 |
+
rect = patches.Rectangle(
|
| 203 |
+
(x_min, y_min), width, height,
|
| 204 |
+
linewidth=2, edgecolor='red', facecolor='none', linestyle='--'
|
| 205 |
+
)
|
| 206 |
+
ax.add_patch(rect)
|
| 207 |
+
|
| 208 |
+
ax.set_title(f"{label} (with Agent BBoxes)")
|
| 209 |
ax.axis('off')
|
| 210 |
+
|
| 211 |
+
# Save tightly to remove whitespace
|
| 212 |
+
fig.savefig(filename, bbox_inches='tight', pad_inches=0)
|
| 213 |
plt.close(fig)
|