hmgill commited on
Commit
869a58c
·
verified ·
1 Parent(s): 206a302

Update cellemetry/services/sam.py

Browse files
Files changed (1) hide show
  1. cellemetry/services/sam.py +11 -56
cellemetry/services/sam.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- SAM3 segmentation execution - Optimized for Speed.
 
3
  """
4
  import matplotlib
5
  matplotlib.use('Agg')
@@ -10,14 +11,11 @@ import torchvision
10
  import numpy as np
11
  import time
12
  from PIL import Image
13
- from skimage.measure import regionprops
14
 
15
  from ..config.schemas import ComponentRequest
16
  from ..config.dependencies import AnalysisDeps
17
 
18
- MIN_SOLIDITY = 0.50
19
- MIN_CIRCULARITY = 0.1
20
- MAX_DIMENSION = 1024 # <-- SPEED OPTIMIZATION: Downscale large images
21
 
22
  # Use /tmp for all outputs
23
  OUTPUT_DIR = "/tmp"
@@ -27,6 +25,8 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
27
  Execute SAM3 segmentation for the given component request.
28
  """
29
  t_start = time.time()
 
 
30
  text_prompt = f"{request.color} {request.entity}"
31
  print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
32
 
@@ -36,7 +36,7 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
36
  except Exception as e:
37
  return f"Error loading image: {e}"
38
 
39
- # 2. SPEED FIX: Resize image if too large
40
  w, h = raw_image.size
41
  scale_factor = 1.0
42
  if max(w, h) > MAX_DIMENSION:
@@ -45,13 +45,11 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
45
  new_h = int(h * scale_factor)
46
  raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
47
  print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
48
- # Update width/height for box calculations below
49
  w, h = new_w, new_h
50
 
51
  # 3. Convert normalized coords (0-1000) to pixel coords
52
  sam_input_boxes = []
53
  for box in request.bboxes:
54
- # Scale coords to the (possibly resized) image dimensions
55
  y_min = (box.ymin / 1000) * h
56
  x_min = (box.xmin / 1000) * w
57
  y_max = (box.ymax / 1000) * h
@@ -83,7 +81,6 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
83
  return_tensors="pt"
84
  ).to(deps.device)
85
 
86
- # Use inference_mode for slight speedup over no_grad
87
  with torch.inference_mode():
88
  outputs = deps.sam_model(**inputs)
89
 
@@ -95,62 +92,23 @@ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
95
 
96
  print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")
97
 
98
- # 5. Morphology filtering (Optimized)
99
- t_filter = time.time()
100
- keep_indices_morph = []
101
-
102
- # Pre-fetch masks to cpu/numpy once
103
- all_masks_np = results["masks"].detach().cpu().numpy().squeeze()
104
- if all_masks_np.ndim == 2: # Handle single mask case
105
- all_masks_np = all_masks_np[np.newaxis, ...]
106
 
107
- for mask_np in all_masks_np:
108
- mask_int = mask_np.astype(int)
109
-
110
- # Optimization: fast skip if mask is too small (noise)
111
- if np.sum(mask_int) < 50:
112
- keep_indices_morph.append(False)
113
- continue
114
-
115
- props = regionprops(mask_int)
116
- if not props:
117
- keep_indices_morph.append(False)
118
- continue
119
-
120
- prop = props[0]
121
- # Fast calc circularity
122
- perimeter = prop.perimeter
123
- if perimeter == 0:
124
- keep_indices_morph.append(False)
125
- continue
126
-
127
- circularity = (4 * np.pi * prop.area) / (perimeter ** 2)
128
- keep_indices_morph.append(prop.solidity > MIN_SOLIDITY and circularity > MIN_CIRCULARITY)
129
-
130
- if any(keep_indices_morph):
131
- keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device)
132
- results = _filter_results(results, keep_indices_tensor)
133
-
134
- print(f"[Engine] Filtering took {time.time() - t_filter:.2f}s")
135
-
136
- # 6. NMS
137
  pred_boxes = results["boxes"]
138
  pred_scores = results["scores"]
139
  if len(pred_scores) > 1:
140
  keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
141
  results = _filter_results(results, keep_indices_nms)
142
 
143
- # 7. Save outputs (If resized, we must upscale masks back to original?
144
- # For demo purposes, we save the resized masks to keep things fast and aligned with the plot)
145
  _save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
146
 
147
  mask_count = len(results['masks'])
148
  if mask_count > 0:
149
  masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
150
  masks_array = np.array(masks_list)
151
- # If we resized, the stats (area) will be in resized pixels.
152
- # Ideally we'd resize masks back, but for a demo, just warn or accept.
153
- # Alternatively, save the scale factor to adjust stats later.
154
  np.savez_compressed(data_filename, masks=masks_array)
155
  else:
156
  np.savez_compressed(data_filename, masks=np.array([]))
@@ -190,15 +148,12 @@ def _save_plot(image, results, boxes, label, filename):
190
 
191
  ax.imshow(composite)
192
 
193
- # 2. Draw Bounding Boxes (NEW)
194
- # 'boxes' comes in as [[xmin, ymin, xmax, ymax], ...]
195
  for box in boxes:
196
  x_min, y_min, x_max, y_max = box
197
  width = x_max - x_min
198
  height = y_max - y_min
199
 
200
- # Create a Rectangle patch
201
- # red border, no fill, thick line
202
  rect = patches.Rectangle(
203
  (x_min, y_min), width, height,
204
  linewidth=2, edgecolor='red', facecolor='none', linestyle='--'
 
1
  """
2
+ SAM3 segmentation execution - Optimized & Simplified.
3
+ Removed morphological filtering for maximum recall and speed.
4
  """
5
  import matplotlib
6
  matplotlib.use('Agg')
 
11
  import numpy as np
12
  import time
13
  from PIL import Image
 
14
 
15
  from ..config.schemas import ComponentRequest
16
  from ..config.dependencies import AnalysisDeps
17
 
18
+ MAX_DIMENSION = 1024 # Speed optimization: Downscale large images
 
 
19
 
20
  # Use /tmp for all outputs
21
  OUTPUT_DIR = "/tmp"
 
25
  Execute SAM3 segmentation for the given component request.
26
  """
27
  t_start = time.time()
28
+
29
+ # SIMPLIFIED PROMPT: Color + Entity only (e.g., "green cell")
30
  text_prompt = f"{request.color} {request.entity}"
31
  print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
32
 
 
36
  except Exception as e:
37
  return f"Error loading image: {e}"
38
 
39
+ # 2. Resize image if too large (Critical for speed)
40
  w, h = raw_image.size
41
  scale_factor = 1.0
42
  if max(w, h) > MAX_DIMENSION:
 
45
  new_h = int(h * scale_factor)
46
  raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
47
  print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
 
48
  w, h = new_w, new_h
49
 
50
  # 3. Convert normalized coords (0-1000) to pixel coords
51
  sam_input_boxes = []
52
  for box in request.bboxes:
 
53
  y_min = (box.ymin / 1000) * h
54
  x_min = (box.xmin / 1000) * w
55
  y_max = (box.ymax / 1000) * h
 
81
  return_tensors="pt"
82
  ).to(deps.device)
83
 
 
84
  with torch.inference_mode():
85
  outputs = deps.sam_model(**inputs)
86
 
 
92
 
93
  print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")
94
 
95
+ # 5. REMOVED: Morphological Filtering (Solidity/Circularity)
96
+ # We now trust the SAM model and the Agent's bounding box completely.
 
 
 
 
 
 
97
 
98
+ # 6. NMS (Keep this to remove duplicate detections on the same object)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  pred_boxes = results["boxes"]
100
  pred_scores = results["scores"]
101
  if len(pred_scores) > 1:
102
  keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
103
  results = _filter_results(results, keep_indices_nms)
104
 
105
+ # 7. Save outputs
 
106
  _save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
107
 
108
  mask_count = len(results['masks'])
109
  if mask_count > 0:
110
  masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
111
  masks_array = np.array(masks_list)
 
 
 
112
  np.savez_compressed(data_filename, masks=masks_array)
113
  else:
114
  np.savez_compressed(data_filename, masks=np.array([]))
 
148
 
149
  ax.imshow(composite)
150
 
151
+ # 2. Draw Bounding Boxes
 
152
  for box in boxes:
153
  x_min, y_min, x_max, y_max = box
154
  width = x_max - x_min
155
  height = y_max - y_min
156
 
 
 
157
  rect = patches.Rectangle(
158
  (x_min, y_min), width, height,
159
  linewidth=2, edgecolor='red', facecolor='none', linestyle='--'