Spaces:
Running on Zero
Running on Zero
| """ | |
| segment.py β SAM2 segmentation using bounding-box prompts. | |
| Workflow (Grounded SAM2 pattern): | |
| OWLv2 text prompts β bounding boxes | |
| SAM2 box prompts β pixel masks | |
| Model: facebook/sam2-hiera-tiny (~160 MB, fast enough for development) | |
| Each detection returned by segment_with_boxes() gains two extra fields: | |
| "mask": bool numpy array (H, W) β pixel mask in image space | |
| "segmentation": COCO polygon list [[x, y, x, y, ...], ...] | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| SAM2_DEFAULT_MODEL = "facebook/sam2-hiera-tiny" | |
| def load_sam2(device: str, model_id: str = SAM2_DEFAULT_MODEL): | |
| """Load SAM2 processor and model onto *device*. Returns (processor, model).""" | |
| from transformers import Sam2Processor, Sam2Model | |
| logger.info("Loading SAM2 %s on %s β¦", model_id, device) | |
| processor = Sam2Processor.from_pretrained(model_id) | |
| # SAM2 runs in float32 β bfloat16/float16 not reliably supported on all backends | |
| model = Sam2Model.from_pretrained(model_id, torch_dtype=torch.float32).to(device) | |
| model.eval() | |
| logger.info("SAM2 ready.") | |
| return processor, model | |
| def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: | |
| """Convert a boolean 2-D mask to a COCO polygon list. | |
| Returns a list of polygons; each polygon is a flat [x1,y1,x2,y2,β¦] list. | |
| Returns [] if cv2 is unavailable or no contour is found. | |
| """ | |
| try: | |
| import cv2 | |
| except ImportError: | |
| logger.warning("opencv-python not installed β segmentation polygons skipped.") | |
| return [] | |
| mask_u8 = mask.astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| polygons: list[list[float]] = [] | |
| for contour in contours: | |
| if contour.size >= 6: # need at least 3 points | |
| polygons.append(contour.flatten().tolist()) | |
| return polygons | |
| def segment_with_boxes( | |
| pil_image: Image.Image, | |
| detections: list[dict], | |
| processor, | |
| model, | |
| device: str, | |
| ) -> list[dict]: | |
| """Run SAM2 on *pil_image* using the bounding box from each detection. | |
| Each detection in the returned list gains: | |
| "mask" β bool numpy array (H, W) | |
| "segmentation" β COCO polygon list | |
| Detections without a valid box are passed through unchanged (no mask field). | |
| """ | |
| if not detections: | |
| return detections | |
| augmented: list[dict] = [] | |
| h, w = pil_image.height, pil_image.width | |
| for det in detections: | |
| box = det.get("box_xyxy") | |
| if box is None: | |
| augmented.append(det) | |
| continue | |
| x1, y1, x2, y2 = box | |
| try: | |
| # input_boxes: [batch=1, n_boxes=1, 4] | |
| encoding = processor( | |
| images=pil_image, | |
| input_boxes=[[[x1, y1, x2, y2]]], | |
| return_tensors="pt", | |
| ) | |
| # transformers 5.x Sam2Processor returns: pixel_values, original_sizes, | |
| # input_boxes β no reshaped_input_sizes. Move all tensors to device. | |
| inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in encoding.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs, multimask_output=False) | |
| # pred_masks shape: [batch, n_boxes, n_masks, H_low, W_low] | |
| # post_process_masks(masks, original_sizes) β transformers 5.x API: | |
| # iterates over batch; each masks[i] goes through F.interpolate to | |
| # original_size, then optional binarise. Expects 4-D per-image tensor. | |
| # We pass pred_masks directly; masks[0] = [n_boxes, n_masks, H_low, W_low] | |
| # which F.interpolate handles as [N, C, H, W]. | |
| original_sizes = encoding.get("original_sizes", torch.tensor([[h, w]])) | |
| masks = processor.post_process_masks( | |
| outputs.pred_masks, | |
| original_sizes, | |
| ) | |
| # masks[0]: [n_boxes=1, n_masks=1, H_orig, W_orig] | |
| mask_np: np.ndarray = masks[0][0, 0].cpu().numpy().astype(bool) | |
| except Exception: | |
| logger.exception( | |
| "SAM2 failed for '%s' β using empty mask", det.get("label", "?") | |
| ) | |
| mask_np = np.zeros((h, w), dtype=bool) | |
| polygons = _mask_to_polygon(mask_np) | |
| augmented.append({**det, "mask": mask_np, "segmentation": polygons}) | |
| return augmented | |