hmgill's picture
Update cellemetry/services/sam.py
2feca28 verified
"""
SAM3 segmentation execution - Optimized & Simplified.
Removed morphological filtering for maximum recall and speed.
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.measure import regionprops
import torch
import torchvision
import numpy as np
import time
from PIL import Image
from ..config.schemas import ComponentRequest
from ..config.dependencies import AnalysisDeps
MAX_DIMENSION = 1024 # Speed optimization: Downscale large images
MIN_SOLIDITY = 0.50
MIN_CIRCULARITY = 0.1
# Use /tmp for all outputs
OUTPUT_DIR = "/tmp"
def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
"""
Execute SAM3 segmentation for the given component request.
"""
t_start = time.time()
text_prompt = f"{request.color} {request.entity}"
print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
# 1. Load Image
try:
raw_image = Image.open(deps.image_path).convert("RGB")
except Exception as e:
return f"Error loading image: {e}"
# 2. Resize image if too large (Critical for speed)
w, h = raw_image.size
scale_factor = 1.0
if max(w, h) > MAX_DIMENSION:
scale_factor = MAX_DIMENSION / max(w, h)
new_w = int(w * scale_factor)
new_h = int(h * scale_factor)
raw_image = raw_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
print(f"[Engine] ⚡ Resized image from {w}x{h} to {new_w}x{new_h} (Speedup)")
w, h = new_w, new_h
# 3. Convert normalized coords (0-1000) to pixel coords
sam_input_boxes = []
for box in request.bboxes:
y_min = (box.ymin / 1000) * h
x_min = (box.xmin / 1000) * w
y_max = (box.ymax / 1000) * h
x_max = (box.xmax / 1000) * w
sam_input_boxes.append([x_min, y_min, x_max, y_max])
if not sam_input_boxes:
return "No valid boxes provided."
safe_label = f"{request.entity}".replace(" ", "_").lower()
plot_filename = f"/tmp/out_{safe_label}.png"
data_filename = f"/tmp/data_{safe_label}.npz"
if deps.sam_model is None or deps.sam_processor is None:
return f"[Mock] Would segment '{text_prompt}'."
# 4. Inference
print("[Engine] Running Inference...")
t_inf = time.time()
sam_input_labels = [[1] * len(sam_input_boxes)]
input_boxes_batch = [sam_input_boxes]
inputs = deps.sam_processor(
images=raw_image,
text=text_prompt,
input_boxes=input_boxes_batch,
input_boxes_labels=sam_input_labels,
return_tensors="pt"
).to(deps.device)
with torch.inference_mode():
outputs = deps.sam_model(**inputs)
results = deps.sam_processor.post_process_instance_segmentation(
outputs,
threshold=0.3,
target_sizes=inputs["original_sizes"].tolist()
)[0]
print(f"[Engine] Inference took {time.time() - t_inf:.2f}s")
# 5. REMOVED: Morphological Filtering (Solidity/Circularity)
# Morphology filtering
keep_indices_morph = []
for i, mask_tensor in enumerate(results["masks"]):
mask_np = mask_tensor.cpu().numpy()
mask_np = np.squeeze(mask_np).astype(int)
if mask_np.ndim != 2:
keep_indices_morph.append(False)
continue
props = regionprops(mask_np)
if not props:
keep_indices_morph.append(False)
continue
prop = props[0]
perimeter = prop.perimeter
circularity = (4 * np.pi * prop.area) / (perimeter ** 2) if perimeter > 0 else 0
is_solid = prop.solidity > MIN_SOLIDITY
is_round_enough = circularity > MIN_CIRCULARITY
keep_indices_morph.append(is_solid and is_round_enough)
if any(keep_indices_morph):
keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device)
before_count = len(results["masks"])
results = _filter_results(results, keep_indices_tensor)
print(f"[Filter] Morphology: Dropped {before_count - len(results['masks'])} debris-like objects.")
# 6. NMS (Keep this to remove duplicate detections on the same object)
pred_boxes = results["boxes"]
pred_scores = results["scores"]
if len(pred_scores) > 1:
keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
results = _filter_results(results, keep_indices_nms)
# 7. Save outputs
_save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
mask_count = len(results['masks'])
if mask_count > 0:
masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
masks_array = np.array(masks_list)
np.savez_compressed(data_filename, masks=masks_array)
else:
np.savez_compressed(data_filename, masks=np.array([]))
total_time = time.time() - t_start
print(f"[Engine] ✅ Done in {total_time:.2f}s. Saved {mask_count} masks.")
return f"SUCCESS: Found {mask_count} '{text_prompt}' objects. MASK_FILE={data_filename} PLOT_FILE={plot_filename}"
def _filter_results(results, keep_indices):
"""Helper to slice all dictionary keys at once."""
results["masks"] = results["masks"][keep_indices]
results["scores"] = results["scores"][keep_indices]
results["boxes"] = results["boxes"][keep_indices]
return results
def _save_plot(image, results, boxes, label, filename):
"""Save visualization of segmentation results with bounding boxes."""
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
# 1. Draw Segmentation Masks
if len(results['scores']) > 0:
H, W = results['masks'][0].shape[-2:]
composite = np.zeros((H, W, 4))
for mask, score in zip(results['masks'], results['scores']):
if score > 0.3:
m = mask.cpu().numpy().squeeze()
color = np.random.random(3)
# Add color to mask
for c in range(3):
composite[:, :, c] = np.maximum(composite[:, :, c], m * color[c])
composite[:, :, 3] = np.maximum(composite[:, :, 3], m * 0.5)
ax.imshow(composite)
ax.axis('off')
# Save tightly to remove whitespace
fig.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close(fig)