G-Paris
Initialize Space with LFS for images
19f31ed
import torch
import numpy as np
from PIL import Image
import warnings
import logging
from transformers import (
Sam3Model, Sam3Processor, # type: ignore
Sam3TrackerModel, Sam3TrackerProcessor, # type: ignore
logging as transformers_logging
)
from .schemas import ObjectState, SelectorInput
from typing import Optional, Any
# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*The OrderedVocab you are attempting to save contains holes.*")
warnings.filterwarnings("ignore", message=".*You are using a model of type sam3_video to instantiate a model of type sam3_tracker.*")
transformers_logging.set_verbosity_error()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Global Models (loaded once)
_IMG_MODEL: Optional[Any] = None
_IMG_PROCESSOR: Optional[Any] = None
_TRK_MODEL: Optional[Any] = None
_TRK_PROCESSOR: Optional[Any] = None
def load_models():
global _IMG_MODEL, _IMG_PROCESSOR, _TRK_MODEL, _TRK_PROCESSOR
if _IMG_MODEL is not None: return
print(f"🖥️ Using compute device: {device}")
print("⏳ Loading SAM3 Models...")
# Use local_files_only=True to skip network checks (faster, more consistent)
local_only = True
try:
# 1. Selector (Sam3Model)
_IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3", local_files_only=local_only).to(device)
_IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3", local_files_only=local_only)
# 2. Refiner (Sam3TrackerModel)
_TRK_MODEL = Sam3TrackerModel.from_pretrained("facebook/sam3", local_files_only=local_only).to(device)
_TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("facebook/sam3", local_files_only=local_only)
except OSError:
# Models not cached, need to download first
print(f"⚠️ Models not in cache, downloading... (this only happens once)")
_IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3").to(device)
_IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3")
_TRK_MODEL = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device)
_TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("facebook/sam3")
print(f"✅ All models loaded!")
def get_bbox_from_mask(mask_arr):
if mask_arr is None: return None
if mask_arr.max() == 0: return None
y_indices, x_indices = np.where(mask_arr > 0)
if len(y_indices) == 0: return None
x1, x2 = np.min(x_indices), np.max(x_indices)
y1, y2 = np.min(y_indices), np.max(y_indices)
# Cast to int for schema compatibility
return [int(x1), int(y1), int(x2), int(y2)]
def search_objects(selector_input: SelectorInput) -> list[ObjectState]:
"""
Stage A: The Selector
"""
if _IMG_MODEL is None: load_models()
assert _IMG_MODEL is not None
assert _IMG_PROCESSOR is not None
image = selector_input.image.convert("RGB")
original_w, original_h = image.size
# Handle Cropping
crop_offset_x, crop_offset_y = 0, 0
if selector_input.crop_box:
cx1, cy1, cx2, cy2 = selector_input.crop_box
# Ensure valid crop within image bounds
cx1 = max(0, cx1)
cy1 = max(0, cy1)
cx2 = min(original_w, cx2)
cy2 = min(original_h, cy2)
if cx2 > cx1 and cy2 > cy1:
image = image.crop((cx1, cy1, cx2, cy2))
crop_offset_x, crop_offset_y = cx1, cy1
print(f"✂️ Cropped image to: {image.size} (Offset: {crop_offset_x}, {crop_offset_y})")
# Prepare inputs
input_boxes = None
input_labels = None
if selector_input.input_boxes:
# Adjust boxes to crop coordinates
adjusted_boxes = []
for box in selector_input.input_boxes:
bx1, by1, bx2, by2 = box
# Subtract offset
bx1 -= crop_offset_x
by1 -= crop_offset_y
bx2 -= crop_offset_x
by2 -= crop_offset_y
# Clip to crop bounds (0 to crop_w/h)
crop_w, crop_h = image.size
bx1 = max(0, min(crop_w, bx1))
by1 = max(0, min(crop_h, by1))
bx2 = max(0, min(crop_w, bx2))
by2 = max(0, min(crop_h, by2))
adjusted_boxes.append([float(bx1), float(by1), float(bx2), float(by2)])
# SAM3 expects [[ [x1, y1, x2, y2], ... ]] for batch size 1
input_boxes = [adjusted_boxes]
if selector_input.input_labels:
# Shape: (Batch, N_boxes) -> [[1, 0, ...]]
input_labels = [selector_input.input_labels]
print(f"🔍 Search Inputs:")
print(f" - Text: '{selector_input.text}'")
print(f" - Boxes: {input_boxes}")
print(f" - Box Labels: {input_labels if 'input_labels' in locals() else 'None'}")
print(f" - Image Size: {image.size}")
# Note: Sam3Processor might not support input_labels directly in the same way as input_boxes for prompt encoding
# If the model supports it, we should pass it. If not, we might need to filter boxes manually or check documentation.
# Assuming standard SAM-like behavior where boxes don't usually have labels in this specific API call unless it's point prompts.
# However, for "Include/Exclude" areas, if the model treats all boxes as "Include", we have a problem.
# Let's check if we can pass it.
# SAM3 requires input_ids even if only using boxes.
# If no text is provided, we pass a dummy empty string to generate padding tokens.
# text_input = [selector_input.text] if (selector_input.text and selector_input.text.strip()) else [""]
# User instruction: Pass [None] if text is empty, do not pass "" or None
# text_input = [selector_input.text] if (selector_input.text and selector_input.text.strip()) else [None]
# Construct arguments dynamically to omit missing inputs
processor_kwargs = {
"images": image,
"return_tensors": "pt"
}
if selector_input.text and selector_input.text.strip():
processor_kwargs["text"] = [selector_input.text]
if input_boxes is not None:
processor_kwargs["input_boxes"] = input_boxes
if input_labels is not None:
processor_kwargs["input_boxes_labels"] = input_labels
inputs = _IMG_PROCESSOR(**processor_kwargs).to(device)
with torch.no_grad():
outputs = _IMG_MODEL(**inputs)
results = _IMG_PROCESSOR.post_process_instance_segmentation(
outputs,
threshold=0.4, # Configurable?
target_sizes=inputs.get("original_sizes").tolist()
)[0]
candidates = []
raw_masks = results['masks'].cpu().numpy() # [N, H, W] or [N, 1, H, W]
raw_scores = results['scores'].cpu().numpy()
if raw_masks.ndim == 4: raw_masks = raw_masks.squeeze(1)
for idx, mask in enumerate(raw_masks):
# mask is boolean/binary for the CROPPED image
# Restore to full size if cropped
if selector_input.crop_box:
full_mask = np.zeros((original_h, original_w), dtype=bool)
# Paste cropped mask back
# mask shape is (crop_h, crop_w)
mh, mw = mask.shape
full_mask[crop_offset_y:crop_offset_y+mh, crop_offset_x:crop_offset_x+mw] = mask
mask = full_mask
anchor_box = get_bbox_from_mask(mask)
if anchor_box is None: continue
final_name = selector_input.class_name_override or selector_input.text or "Object"
candidates.append(ObjectState(
score=float(raw_scores[idx]),
anchor_box=anchor_box,
binary_mask=mask,
initial_mask=mask, # Save copy for undo
class_name=final_name
))
return candidates
def refine_object(image: Image.Image, obj_state: ObjectState) -> np.ndarray:
"""
Stage B: The Refiner
"""
print(f"🔧 Refine Inputs:")
print(f" - Anchor Box: {obj_state.anchor_box}")
print(f" - Points: {obj_state.input_points}")
print(f" - Point Labels: {obj_state.input_labels}")
if _TRK_MODEL is None: load_models()
assert _TRK_MODEL is not None
assert _TRK_PROCESSOR is not None
original_w, original_h = image.size
image = image.convert("RGB")
# --- Dynamic Cropping Logic ---
# 1. Determine bounding box of interest (Anchor Box + All Input Points)
# This is the "Refinement Box" that encompasses the object and new points
rx1, ry1, rx2, ry2 = obj_state.anchor_box
if obj_state.input_points:
for pt in obj_state.input_points:
px, py = pt
rx1 = min(rx1, px)
ry1 = min(ry1, py)
rx2 = max(rx2, px)
ry2 = max(ry2, py)
# 2. Add Padding (25%) to create the Crop Box
width = rx2 - rx1
height = ry2 - ry1
padding = int(max(width, height) * 0.25)
cx1 = max(0, int(rx1 - padding))
cy1 = max(0, int(ry1 - padding))
cx2 = min(original_w, int(rx2 + padding))
cy2 = min(original_h, int(ry2 + padding))
crop_offset_x, crop_offset_y = cx1, cy1
# 3. Crop Image
if cx2 > cx1 and cy2 > cy1:
image = image.crop((cx1, cy1, cx2, cy2))
else:
crop_offset_x, crop_offset_y = 0, 0
# --- Coordinate Adjustment ---
# Use the Refinement Box (tight) as the prompt, adjusted to crop coordinates
box_float = [
float(rx1 - crop_offset_x),
float(ry1 - crop_offset_y),
float(rx2 - crop_offset_x),
float(ry2 - crop_offset_y)
]
# Adjust Points
points_float = []
for p in obj_state.input_points:
points_float.append([float(p[0] - crop_offset_x), float(p[1] - crop_offset_y)])
# Prepare inputs
input_boxes = [[box_float]]
# Nesting for Sam3TrackerProcessor:
# input_points: 4 levels [Image, Object, Point, Coords]
# input_labels: 3 levels [Image, Object, Label]
# obj_state.input_points is List[List[float]] (Points for 1 object) -> Level 3 & 4
# So we need to wrap it in [ [ ... ] ] for Image and Object levels
input_points = [[points_float]]
# obj_state.input_labels is List[int] (Labels for 1 object) -> Level 3
# So we need to wrap it in [ [ ... ] ] for Image and Object levels
input_labels = [[obj_state.input_labels]]
inputs = _TRK_PROCESSOR(
images=image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = _TRK_MODEL(**inputs, multimask_output=False)
masks = _TRK_PROCESSOR.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"],
binarize=True
)[0]
final_mask_crop = masks[0].numpy()
if final_mask_crop.ndim == 3: final_mask_crop = final_mask_crop[0]
# --- Restore Mask to Full Size ---
final_mask = np.zeros((original_h, original_w), dtype=bool)
mh, mw = final_mask_crop.shape
final_mask[crop_offset_y:crop_offset_y+mh, crop_offset_x:crop_offset_x+mw] = final_mask_crop
return final_mask