LabelPlayground / autolabel /segment.py
Erick
Upload folder using huggingface_hub
47cb9bd verified
"""
segment.py β€” SAM2 segmentation using bounding-box prompts.
Workflow (Grounded SAM2 pattern):
OWLv2 text prompts β†’ bounding boxes
SAM2 box prompts β†’ pixel masks
Model: facebook/sam2-hiera-tiny (~160 MB, fast enough for development)
Each detection returned by segment_with_boxes() gains two extra fields:
"mask": bool numpy array (H, W) β€” pixel mask in image space
"segmentation": COCO polygon list [[x, y, x, y, ...], ...]
"""
from __future__ import annotations
import logging
from typing import Optional
import numpy as np
import torch
from PIL import Image
logger = logging.getLogger(__name__)
SAM2_DEFAULT_MODEL = "facebook/sam2-hiera-tiny"
def load_sam2(device: str, model_id: str = SAM2_DEFAULT_MODEL):
"""Load SAM2 processor and model onto *device*. Returns (processor, model)."""
from transformers import Sam2Processor, Sam2Model
logger.info("Loading SAM2 %s on %s …", model_id, device)
processor = Sam2Processor.from_pretrained(model_id)
# SAM2 runs in float32 β€” bfloat16/float16 not reliably supported on all backends
model = Sam2Model.from_pretrained(model_id, torch_dtype=torch.float32).to(device)
model.eval()
logger.info("SAM2 ready.")
return processor, model
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a boolean 2-D mask to a COCO polygon list.
Returns a list of polygons; each polygon is a flat [x1,y1,x2,y2,…] list.
Returns [] if cv2 is unavailable or no contour is found.
"""
try:
import cv2
except ImportError:
logger.warning("opencv-python not installed β€” segmentation polygons skipped.")
return []
mask_u8 = mask.astype(np.uint8) * 255
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polygons: list[list[float]] = []
for contour in contours:
if contour.size >= 6: # need at least 3 points
polygons.append(contour.flatten().tolist())
return polygons
def segment_with_boxes(
pil_image: Image.Image,
detections: list[dict],
processor,
model,
device: str,
) -> list[dict]:
"""Run SAM2 on *pil_image* using the bounding box from each detection.
Each detection in the returned list gains:
"mask" β€” bool numpy array (H, W)
"segmentation" β€” COCO polygon list
Detections without a valid box are passed through unchanged (no mask field).
"""
if not detections:
return detections
augmented: list[dict] = []
h, w = pil_image.height, pil_image.width
for det in detections:
box = det.get("box_xyxy")
if box is None:
augmented.append(det)
continue
x1, y1, x2, y2 = box
try:
# input_boxes: [batch=1, n_boxes=1, 4]
encoding = processor(
images=pil_image,
input_boxes=[[[x1, y1, x2, y2]]],
return_tensors="pt",
)
# transformers 5.x Sam2Processor returns: pixel_values, original_sizes,
# input_boxes β€” no reshaped_input_sizes. Move all tensors to device.
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in encoding.items()}
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# pred_masks shape: [batch, n_boxes, n_masks, H_low, W_low]
# post_process_masks(masks, original_sizes) β€” transformers 5.x API:
# iterates over batch; each masks[i] goes through F.interpolate to
# original_size, then optional binarise. Expects 4-D per-image tensor.
# We pass pred_masks directly; masks[0] = [n_boxes, n_masks, H_low, W_low]
# which F.interpolate handles as [N, C, H, W].
original_sizes = encoding.get("original_sizes", torch.tensor([[h, w]]))
masks = processor.post_process_masks(
outputs.pred_masks,
original_sizes,
)
# masks[0]: [n_boxes=1, n_masks=1, H_orig, W_orig]
mask_np: np.ndarray = masks[0][0, 0].cpu().numpy().astype(bool)
except Exception:
logger.exception(
"SAM2 failed for '%s' β€” using empty mask", det.get("label", "?")
)
mask_np = np.zeros((h, w), dtype=bool)
polygons = _mask_to_polygon(mask_np)
augmented.append({**det, "mask": mask_np, "segmentation": polygons})
return augmented