import torch import numpy as np from PIL import Image import warnings import logging from transformers import ( Sam3Model, Sam3Processor, # type: ignore Sam3TrackerModel, Sam3TrackerProcessor, # type: ignore logging as transformers_logging ) from .schemas import ObjectState, SelectorInput from typing import Optional, Any # Suppress specific warnings warnings.filterwarnings("ignore", message=".*The OrderedVocab you are attempting to save contains holes.*") warnings.filterwarnings("ignore", message=".*You are using a model of type sam3_video to instantiate a model of type sam3_tracker.*") transformers_logging.set_verbosity_error() device = "cuda" if torch.cuda.is_available() else "cpu" # Global Models (loaded once) _IMG_MODEL: Optional[Any] = None _IMG_PROCESSOR: Optional[Any] = None _TRK_MODEL: Optional[Any] = None _TRK_PROCESSOR: Optional[Any] = None def load_models(): global _IMG_MODEL, _IMG_PROCESSOR, _TRK_MODEL, _TRK_PROCESSOR if _IMG_MODEL is not None: return print(f"🖥️ Using compute device: {device}") print("⏳ Loading SAM3 Models...") # Use local_files_only=True to skip network checks (faster, more consistent) local_only = True try: # 1. Selector (Sam3Model) _IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3", local_files_only=local_only).to(device) _IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3", local_files_only=local_only) # 2. Refiner (Sam3TrackerModel) _TRK_MODEL = Sam3TrackerModel.from_pretrained("facebook/sam3", local_files_only=local_only).to(device) _TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("facebook/sam3", local_files_only=local_only) except OSError: # Models not cached, need to download first print(f"⚠️ Models not in cache, downloading... (this only happens once)") _IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3").to(device) _IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3") _TRK_MODEL = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device) _TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("facebook/sam3") print(f"✅ All models loaded!") def get_bbox_from_mask(mask_arr): if mask_arr is None: return None if mask_arr.max() == 0: return None y_indices, x_indices = np.where(mask_arr > 0) if len(y_indices) == 0: return None x1, x2 = np.min(x_indices), np.max(x_indices) y1, y2 = np.min(y_indices), np.max(y_indices) # Cast to int for schema compatibility return [int(x1), int(y1), int(x2), int(y2)] def search_objects(selector_input: SelectorInput) -> list[ObjectState]: """ Stage A: The Selector """ if _IMG_MODEL is None: load_models() assert _IMG_MODEL is not None assert _IMG_PROCESSOR is not None image = selector_input.image.convert("RGB") original_w, original_h = image.size # Handle Cropping crop_offset_x, crop_offset_y = 0, 0 if selector_input.crop_box: cx1, cy1, cx2, cy2 = selector_input.crop_box # Ensure valid crop within image bounds cx1 = max(0, cx1) cy1 = max(0, cy1) cx2 = min(original_w, cx2) cy2 = min(original_h, cy2) if cx2 > cx1 and cy2 > cy1: image = image.crop((cx1, cy1, cx2, cy2)) crop_offset_x, crop_offset_y = cx1, cy1 print(f"✂️ Cropped image to: {image.size} (Offset: {crop_offset_x}, {crop_offset_y})") # Prepare inputs input_boxes = None input_labels = None if selector_input.input_boxes: # Adjust boxes to crop coordinates adjusted_boxes = [] for box in selector_input.input_boxes: bx1, by1, bx2, by2 = box # Subtract offset bx1 -= crop_offset_x by1 -= crop_offset_y bx2 -= crop_offset_x by2 -= crop_offset_y # Clip to crop bounds (0 to crop_w/h) crop_w, crop_h = image.size bx1 = max(0, min(crop_w, bx1)) by1 = max(0, min(crop_h, by1)) bx2 = max(0, min(crop_w, bx2)) by2 = max(0, min(crop_h, by2)) adjusted_boxes.append([float(bx1), float(by1), float(bx2), float(by2)]) # SAM3 expects [[ [x1, y1, x2, y2], ... ]] for batch size 1 input_boxes = [adjusted_boxes] if selector_input.input_labels: # Shape: (Batch, N_boxes) -> [[1, 0, ...]] input_labels = [selector_input.input_labels] print(f"🔍 Search Inputs:") print(f" - Text: '{selector_input.text}'") print(f" - Boxes: {input_boxes}") print(f" - Box Labels: {input_labels if 'input_labels' in locals() else 'None'}") print(f" - Image Size: {image.size}") # Note: Sam3Processor might not support input_labels directly in the same way as input_boxes for prompt encoding # If the model supports it, we should pass it. If not, we might need to filter boxes manually or check documentation. # Assuming standard SAM-like behavior where boxes don't usually have labels in this specific API call unless it's point prompts. # However, for "Include/Exclude" areas, if the model treats all boxes as "Include", we have a problem. # Let's check if we can pass it. # SAM3 requires input_ids even if only using boxes. # If no text is provided, we pass a dummy empty string to generate padding tokens. # text_input = [selector_input.text] if (selector_input.text and selector_input.text.strip()) else [""] # User instruction: Pass [None] if text is empty, do not pass "" or None # text_input = [selector_input.text] if (selector_input.text and selector_input.text.strip()) else [None] # Construct arguments dynamically to omit missing inputs processor_kwargs = { "images": image, "return_tensors": "pt" } if selector_input.text and selector_input.text.strip(): processor_kwargs["text"] = [selector_input.text] if input_boxes is not None: processor_kwargs["input_boxes"] = input_boxes if input_labels is not None: processor_kwargs["input_boxes_labels"] = input_labels inputs = _IMG_PROCESSOR(**processor_kwargs).to(device) with torch.no_grad(): outputs = _IMG_MODEL(**inputs) results = _IMG_PROCESSOR.post_process_instance_segmentation( outputs, threshold=0.4, # Configurable? target_sizes=inputs.get("original_sizes").tolist() )[0] candidates = [] raw_masks = results['masks'].cpu().numpy() # [N, H, W] or [N, 1, H, W] raw_scores = results['scores'].cpu().numpy() if raw_masks.ndim == 4: raw_masks = raw_masks.squeeze(1) for idx, mask in enumerate(raw_masks): # mask is boolean/binary for the CROPPED image # Restore to full size if cropped if selector_input.crop_box: full_mask = np.zeros((original_h, original_w), dtype=bool) # Paste cropped mask back # mask shape is (crop_h, crop_w) mh, mw = mask.shape full_mask[crop_offset_y:crop_offset_y+mh, crop_offset_x:crop_offset_x+mw] = mask mask = full_mask anchor_box = get_bbox_from_mask(mask) if anchor_box is None: continue final_name = selector_input.class_name_override or selector_input.text or "Object" candidates.append(ObjectState( score=float(raw_scores[idx]), anchor_box=anchor_box, binary_mask=mask, initial_mask=mask, # Save copy for undo class_name=final_name )) return candidates def refine_object(image: Image.Image, obj_state: ObjectState) -> np.ndarray: """ Stage B: The Refiner """ print(f"🔧 Refine Inputs:") print(f" - Anchor Box: {obj_state.anchor_box}") print(f" - Points: {obj_state.input_points}") print(f" - Point Labels: {obj_state.input_labels}") if _TRK_MODEL is None: load_models() assert _TRK_MODEL is not None assert _TRK_PROCESSOR is not None original_w, original_h = image.size image = image.convert("RGB") # --- Dynamic Cropping Logic --- # 1. Determine bounding box of interest (Anchor Box + All Input Points) # This is the "Refinement Box" that encompasses the object and new points rx1, ry1, rx2, ry2 = obj_state.anchor_box if obj_state.input_points: for pt in obj_state.input_points: px, py = pt rx1 = min(rx1, px) ry1 = min(ry1, py) rx2 = max(rx2, px) ry2 = max(ry2, py) # 2. Add Padding (25%) to create the Crop Box width = rx2 - rx1 height = ry2 - ry1 padding = int(max(width, height) * 0.25) cx1 = max(0, int(rx1 - padding)) cy1 = max(0, int(ry1 - padding)) cx2 = min(original_w, int(rx2 + padding)) cy2 = min(original_h, int(ry2 + padding)) crop_offset_x, crop_offset_y = cx1, cy1 # 3. Crop Image if cx2 > cx1 and cy2 > cy1: image = image.crop((cx1, cy1, cx2, cy2)) else: crop_offset_x, crop_offset_y = 0, 0 # --- Coordinate Adjustment --- # Use the Refinement Box (tight) as the prompt, adjusted to crop coordinates box_float = [ float(rx1 - crop_offset_x), float(ry1 - crop_offset_y), float(rx2 - crop_offset_x), float(ry2 - crop_offset_y) ] # Adjust Points points_float = [] for p in obj_state.input_points: points_float.append([float(p[0] - crop_offset_x), float(p[1] - crop_offset_y)]) # Prepare inputs input_boxes = [[box_float]] # Nesting for Sam3TrackerProcessor: # input_points: 4 levels [Image, Object, Point, Coords] # input_labels: 3 levels [Image, Object, Label] # obj_state.input_points is List[List[float]] (Points for 1 object) -> Level 3 & 4 # So we need to wrap it in [ [ ... ] ] for Image and Object levels input_points = [[points_float]] # obj_state.input_labels is List[int] (Labels for 1 object) -> Level 3 # So we need to wrap it in [ [ ... ] ] for Image and Object levels input_labels = [[obj_state.input_labels]] inputs = _TRK_PROCESSOR( images=image, input_boxes=input_boxes, input_points=input_points, input_labels=input_labels, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = _TRK_MODEL(**inputs, multimask_output=False) masks = _TRK_PROCESSOR.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True )[0] final_mask_crop = masks[0].numpy() if final_mask_crop.ndim == 3: final_mask_crop = final_mask_crop[0] # --- Restore Mask to Full Size --- final_mask = np.zeros((original_h, original_w), dtype=bool) mh, mw = final_mask_crop.shape final_mask[crop_offset_y:crop_offset_y+mh, crop_offset_x:crop_offset_x+mw] = final_mask_crop return final_mask