""" segment.py — SAM2 segmentation using bounding-box prompts. Workflow (Grounded SAM2 pattern): OWLv2 text prompts → bounding boxes SAM2 box prompts → pixel masks Model: facebook/sam2-hiera-tiny (~160 MB, fast enough for development) Each detection returned by segment_with_boxes() gains two extra fields: "mask": bool numpy array (H, W) — pixel mask in image space "segmentation": COCO polygon list [[x, y, x, y, ...], ...] """ from __future__ import annotations import logging from typing import Optional import numpy as np import torch from PIL import Image logger = logging.getLogger(__name__) SAM2_DEFAULT_MODEL = "facebook/sam2-hiera-tiny" def load_sam2(device: str, model_id: str = SAM2_DEFAULT_MODEL): """Load SAM2 processor and model onto *device*. Returns (processor, model).""" from transformers import Sam2Processor, Sam2Model logger.info("Loading SAM2 %s on %s …", model_id, device) processor = Sam2Processor.from_pretrained(model_id) # SAM2 runs in float32 — bfloat16/float16 not reliably supported on all backends model = Sam2Model.from_pretrained(model_id, torch_dtype=torch.float32).to(device) model.eval() logger.info("SAM2 ready.") return processor, model def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: """Convert a boolean 2-D mask to a COCO polygon list. Returns a list of polygons; each polygon is a flat [x1,y1,x2,y2,…] list. Returns [] if cv2 is unavailable or no contour is found. """ try: import cv2 except ImportError: logger.warning("opencv-python not installed — segmentation polygons skipped.") return [] mask_u8 = mask.astype(np.uint8) * 255 contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) polygons: list[list[float]] = [] for contour in contours: if contour.size >= 6: # need at least 3 points polygons.append(contour.flatten().tolist()) return polygons def segment_with_boxes( pil_image: Image.Image, detections: list[dict], processor, model, device: str, ) -> list[dict]: """Run SAM2 on *pil_image* using the bounding box from each detection. Each detection in the returned list gains: "mask" — bool numpy array (H, W) "segmentation" — COCO polygon list Detections without a valid box are passed through unchanged (no mask field). """ if not detections: return detections augmented: list[dict] = [] h, w = pil_image.height, pil_image.width for det in detections: box = det.get("box_xyxy") if box is None: augmented.append(det) continue x1, y1, x2, y2 = box try: # input_boxes: [batch=1, n_boxes=1, 4] encoding = processor( images=pil_image, input_boxes=[[[x1, y1, x2, y2]]], return_tensors="pt", ) # transformers 5.x Sam2Processor returns: pixel_values, original_sizes, # input_boxes — no reshaped_input_sizes. Move all tensors to device. inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in encoding.items()} with torch.no_grad(): outputs = model(**inputs, multimask_output=False) # pred_masks shape: [batch, n_boxes, n_masks, H_low, W_low] # post_process_masks(masks, original_sizes) — transformers 5.x API: # iterates over batch; each masks[i] goes through F.interpolate to # original_size, then optional binarise. Expects 4-D per-image tensor. # We pass pred_masks directly; masks[0] = [n_boxes, n_masks, H_low, W_low] # which F.interpolate handles as [N, C, H, W]. original_sizes = encoding.get("original_sizes", torch.tensor([[h, w]])) masks = processor.post_process_masks( outputs.pred_masks, original_sizes, ) # masks[0]: [n_boxes=1, n_masks=1, H_orig, W_orig] mask_np: np.ndarray = masks[0][0, 0].cpu().numpy().astype(bool) except Exception: logger.exception( "SAM2 failed for '%s' — using empty mask", det.get("label", "?") ) mask_np = np.zeros((h, w), dtype=bool) polygons = _mask_to_polygon(mask_np) augmented.append({**det, "mask": mask_np, "segmentation": polygons}) return augmented