Spaces:
Sleeping
Sleeping
Update cellemetry/services/sam.py
Browse files- cellemetry/services/sam.py +11 -56
cellemetry/services/sam.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
SAM3 segmentation execution - Optimized
|
|
|
|
| 3 |
"""
|
| 4 |
import matplotlib
|
| 5 |
matplotlib.use('Agg')
|
|
@@ -10,14 +11,11 @@ import torchvision
|
|
| 10 |
import numpy as np
|
| 11 |
import time
|
| 12 |
from PIL import Image
|
| 13 |
-
from skimage.measure import regionprops
|
| 14 |
|
| 15 |
from ..config.schemas import ComponentRequest
|
| 16 |
from ..config.dependencies import AnalysisDeps
|
| 17 |
|
| 18 |
-
|
| 19 |
-
MIN_CIRCULARITY = 0.1
|
| 20 |
-
MAX_DIMENSION = 1024 # <-- SPEED OPTIMIZATION: Downscale large images
|
| 21 |
|
| 22 |
# Use /tmp for all outputs
|
| 23 |
OUTPUT_DIR = "/tmp"
|
|
@@ -27,6 +25,8 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
|
|
| 27 |
Execute SAM3 segmentation for the given component request.
|
| 28 |
"""
|
| 29 |
t_start = time.time()
|
|
|
|
|
|
|
| 30 |
text_prompt = f"{request.color} {request.entity}"
|
| 31 |
print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
|
| 32 |
|
|
@@ -36,7 +36,7 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
|
|
| 36 |
except Exception as e:
|
| 37 |
return f"Error loading image: {e}"
|
| 38 |
|
| 39 |
-
# 2.
|
| 40 |
w, h = raw_image.size
|
| 41 |
scale_factor = 1.0
|
| 42 |
if max(w, h) > MAX_DIMENSION:
|
|
@@ -45,13 +45,11 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
|
|
| 45 |
new_h = int(h * scale_factor)
|
| 46 |
raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
| 47 |
print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
|
| 48 |
-
# Update width/height for box calculations below
|
| 49 |
w, h = new_w, new_h
|
| 50 |
|
| 51 |
# 3. Convert normalized coords (0-1000) to pixel coords
|
| 52 |
sam_input_boxes = []
|
| 53 |
for box in request.bboxes:
|
| 54 |
-
# Scale coords to the (possibly resized) image dimensions
|
| 55 |
y_min = (box.ymin / 1000) * h
|
| 56 |
x_min = (box.xmin / 1000) * w
|
| 57 |
y_max = (box.ymax / 1000) * h
|
|
@@ -83,7 +81,6 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
|
|
| 83 |
return_tensors="pt"
|
| 84 |
).to(deps.device)
|
| 85 |
|
| 86 |
-
# Use inference_mode for slight speedup over no_grad
|
| 87 |
with torch.inference_mode():
|
| 88 |
outputs = deps.sam_model(**inputs)
|
| 89 |
|
|
@@ -95,62 +92,23 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
|
|
| 95 |
|
| 96 |
print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")
|
| 97 |
|
| 98 |
-
# 5.
|
| 99 |
-
|
| 100 |
-
keep_indices_morph = []
|
| 101 |
-
|
| 102 |
-
# Pre-fetch masks to cpu/numpy once
|
| 103 |
-
all_masks_np = results["masks"].detach().cpu().numpy().squeeze()
|
| 104 |
-
if all_masks_np.ndim == 2: # Handle single mask case
|
| 105 |
-
all_masks_np = all_masks_np[np.newaxis, ...]
|
| 106 |
|
| 107 |
-
|
| 108 |
-
mask_int = mask_np.astype(int)
|
| 109 |
-
|
| 110 |
-
# Optimization: fast skip if mask is too small (noise)
|
| 111 |
-
if np.sum(mask_int) < 50:
|
| 112 |
-
keep_indices_morph.append(False)
|
| 113 |
-
continue
|
| 114 |
-
|
| 115 |
-
props = regionprops(mask_int)
|
| 116 |
-
if not props:
|
| 117 |
-
keep_indices_morph.append(False)
|
| 118 |
-
continue
|
| 119 |
-
|
| 120 |
-
prop = props[0]
|
| 121 |
-
# Fast calc circularity
|
| 122 |
-
perimeter = prop.perimeter
|
| 123 |
-
if perimeter == 0:
|
| 124 |
-
keep_indices_morph.append(False)
|
| 125 |
-
continue
|
| 126 |
-
|
| 127 |
-
circularity = (4 * np.pi * prop.area) / (perimeter ** 2)
|
| 128 |
-
keep_indices_morph.append(prop.solidity > MIN_SOLIDITY and circularity > MIN_CIRCULARITY)
|
| 129 |
-
|
| 130 |
-
if any(keep_indices_morph):
|
| 131 |
-
keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device)
|
| 132 |
-
results = _filter_results(results, keep_indices_tensor)
|
| 133 |
-
|
| 134 |
-
print(f"[Engine] Filtering took {time.time() - t_filter:.2f}s")
|
| 135 |
-
|
| 136 |
-
# 6. NMS
|
| 137 |
pred_boxes = results["boxes"]
|
| 138 |
pred_scores = results["scores"]
|
| 139 |
if len(pred_scores) > 1:
|
| 140 |
keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
|
| 141 |
results = _filter_results(results, keep_indices_nms)
|
| 142 |
|
| 143 |
-
# 7. Save outputs
|
| 144 |
-
# For demo purposes, we save the resized masks to keep things fast and aligned with the plot)
|
| 145 |
_save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
|
| 146 |
|
| 147 |
mask_count = len(results['masks'])
|
| 148 |
if mask_count > 0:
|
| 149 |
masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
|
| 150 |
masks_array = np.array(masks_list)
|
| 151 |
-
# If we resized, the stats (area) will be in resized pixels.
|
| 152 |
-
# Ideally we'd resize masks back, but for a demo, just warn or accept.
|
| 153 |
-
# Alternatively, save the scale factor to adjust stats later.
|
| 154 |
np.savez_compressed(data_filename, masks=masks_array)
|
| 155 |
else:
|
| 156 |
np.savez_compressed(data_filename, masks=np.array([]))
|
|
@@ -190,15 +148,12 @@ def _save_plot(image, results, boxes, label, filename):
|
|
| 190 |
|
| 191 |
ax.imshow(composite)
|
| 192 |
|
| 193 |
-
# 2. Draw Bounding Boxes
|
| 194 |
-
# 'boxes' comes in as [[xmin, ymin, xmax, ymax], ...]
|
| 195 |
for box in boxes:
|
| 196 |
x_min, y_min, x_max, y_max = box
|
| 197 |
width = x_max - x_min
|
| 198 |
height = y_max - y_min
|
| 199 |
|
| 200 |
-
# Create a Rectangle patch
|
| 201 |
-
# red border, no fill, thick line
|
| 202 |
rect = patches.Rectangle(
|
| 203 |
(x_min, y_min), width, height,
|
| 204 |
linewidth=2, edgecolor='red', facecolor='none', linestyle='--'
|
|
|
|
| 1 |
"""
|
| 2 |
+
SAM3 segmentation execution - Optimized & Simplified.
|
| 3 |
+
Removed morphological filtering for maximum recall and speed.
|
| 4 |
"""
|
| 5 |
import matplotlib
|
| 6 |
matplotlib.use('Agg')
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import time
|
| 13 |
from PIL import Image
|
|
|
|
| 14 |
|
| 15 |
from ..config.schemas import ComponentRequest
|
| 16 |
from ..config.dependencies import AnalysisDeps
|
| 17 |
|
| 18 |
+
MAX_DIMENSION = 1024 # Speed optimization: Downscale large images
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Use /tmp for all outputs
|
| 21 |
OUTPUT_DIR = "/tmp"
|
|
|
|
| 25 |
Execute SAM3 segmentation for the given component request.
|
| 26 |
"""
|
| 27 |
t_start = time.time()
|
| 28 |
+
|
| 29 |
+
# SIMPLIFIED PROMPT: Color + Entity only (e.g., "green cell")
|
| 30 |
text_prompt = f"{request.color} {request.entity}"
|
| 31 |
print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
|
| 32 |
|
|
|
|
| 36 |
except Exception as e:
|
| 37 |
return f"Error loading image: {e}"
|
| 38 |
|
| 39 |
+
# 2. Resize image if too large (Critical for speed)
|
| 40 |
w, h = raw_image.size
|
| 41 |
scale_factor = 1.0
|
| 42 |
if max(w, h) > MAX_DIMENSION:
|
|
|
|
| 45 |
new_h = int(h * scale_factor)
|
| 46 |
raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
| 47 |
print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
|
|
|
|
| 48 |
w, h = new_w, new_h
|
| 49 |
|
| 50 |
# 3. Convert normalized coords (0-1000) to pixel coords
|
| 51 |
sam_input_boxes = []
|
| 52 |
for box in request.bboxes:
|
|
|
|
| 53 |
y_min = (box.ymin / 1000) * h
|
| 54 |
x_min = (box.xmin / 1000) * w
|
| 55 |
y_max = (box.ymax / 1000) * h
|
|
|
|
| 81 |
return_tensors="pt"
|
| 82 |
).to(deps.device)
|
| 83 |
|
|
|
|
| 84 |
with torch.inference_mode():
|
| 85 |
outputs = deps.sam_model(**inputs)
|
| 86 |
|
|
|
|
| 92 |
|
| 93 |
print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")
|
| 94 |
|
| 95 |
+
# 5. REMOVED: Morphological Filtering (Solidity/Circularity)
|
| 96 |
+
# We now trust the SAM model and the Agent's bounding box completely.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
# 6. NMS (Keep this to remove duplicate detections on the same object)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
pred_boxes = results["boxes"]
|
| 100 |
pred_scores = results["scores"]
|
| 101 |
if len(pred_scores) > 1:
|
| 102 |
keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
|
| 103 |
results = _filter_results(results, keep_indices_nms)
|
| 104 |
|
| 105 |
+
# 7. Save outputs
|
|
|
|
| 106 |
_save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
|
| 107 |
|
| 108 |
mask_count = len(results['masks'])
|
| 109 |
if mask_count > 0:
|
| 110 |
masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
|
| 111 |
masks_array = np.array(masks_list)
|
|
|
|
|
|
|
|
|
|
| 112 |
np.savez_compressed(data_filename, masks=masks_array)
|
| 113 |
else:
|
| 114 |
np.savez_compressed(data_filename, masks=np.array([]))
|
|
|
|
| 148 |
|
| 149 |
ax.imshow(composite)
|
| 150 |
|
| 151 |
+
# 2. Draw Bounding Boxes
|
|
|
|
| 152 |
for box in boxes:
|
| 153 |
x_min, y_min, x_max, y_max = box
|
| 154 |
width = x_max - x_min
|
| 155 |
height = y_max - y_min
|
| 156 |
|
|
|
|
|
|
|
| 157 |
rect = patches.Rectangle(
|
| 158 |
(x_min, y_min), width, height,
|
| 159 |
linewidth=2, edgecolor='red', facecolor='none', linestyle='--'
|