|
|
|
|
|
import inspect |
|
|
import sys |
|
|
import tempfile |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import matplotlib |
|
|
import numpy as np |
|
|
import spaces |
|
|
import torch |
|
|
from loguru import logger |
|
|
from PIL import Image |
|
|
from transformers import ( |
|
|
Sam3VideoModel, |
|
|
Sam3VideoProcessor, |
|
|
) |
|
|
|
|
|
|
|
|
from ffmpeg_extractor import extract_frames, get_video_metadata |
|
|
|
|
|
|
|
|
from toolbox.mask_encoding import b64_mask_encode |
|
|
from visualizer import mask_to_xyxy |
|
|
|
|
|
logger.remove() |
|
|
logger.add( |
|
|
sys.stderr, |
|
|
format="<d>{time:YYYY-MM-DD ddd HH:mm:ss}</d> | <lvl>{level}</lvl> | <lvl>{message}</lvl>", |
|
|
) |
|
|
|
|
|
|
|
|
DTYPE = ( |
|
|
torch.bfloat16 |
|
|
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
|
|
else torch.float16 |
|
|
) |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Device: {DEVICE}, dtype: {DTYPE}") |
|
|
logger.info("Loading Models and Processors...") |
|
|
try: |
|
|
VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(DEVICE, dtype=DTYPE) |
|
|
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3") |
|
|
logger.success("Models and Processors Loaded!") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ CRITICAL ERROR LOADING VIDEO MODELS: {e}") |
|
|
VID_MODEL = None |
|
|
VID_PROCESSOR = None |
|
|
|
|
|
|
|
|
def apply_mask_overlay(base_image, mask_data, object_ids=None, opacity=0.5): |
|
|
"""Draws segmentation masks on top of an image, using object IDs for coloring.""" |
|
|
if isinstance(base_image, np.ndarray): |
|
|
base_image = Image.fromarray(base_image) |
|
|
base_image = base_image.convert("RGBA") |
|
|
|
|
|
if mask_data is None or len(mask_data) == 0: |
|
|
return base_image.convert("RGB") |
|
|
|
|
|
if isinstance(mask_data, torch.Tensor): |
|
|
mask_data = mask_data.cpu().numpy() |
|
|
mask_data = mask_data.astype(np.uint8) |
|
|
|
|
|
|
|
|
if mask_data.ndim == 4: |
|
|
mask_data = mask_data[0] |
|
|
if mask_data.ndim == 3 and mask_data.shape[0] == 1: |
|
|
mask_data = mask_data[0] |
|
|
|
|
|
num_masks = mask_data.shape[0] if mask_data.ndim == 3 else 1 |
|
|
if mask_data.ndim == 2: |
|
|
mask_data = [mask_data] |
|
|
num_masks = 1 |
|
|
|
|
|
|
|
|
if object_ids is not None and len(object_ids) == num_masks: |
|
|
|
|
|
try: |
|
|
color_map = matplotlib.colormaps["rainbow"] |
|
|
except AttributeError: |
|
|
import matplotlib.cm as cm |
|
|
|
|
|
color_map = cm.get_cmap("rainbow") |
|
|
|
|
|
unique_ids = sorted(set(object_ids)) |
|
|
id_to_color_idx = {oid: i for i, oid in enumerate(unique_ids)} |
|
|
rgb_colors = [ |
|
|
tuple( |
|
|
int(c * 255) |
|
|
for c in color_map(id_to_color_idx[oid] / max(len(unique_ids), 1))[:3] |
|
|
) |
|
|
for oid in object_ids |
|
|
] |
|
|
else: |
|
|
try: |
|
|
color_map = matplotlib.colormaps["rainbow"].resampled(max(num_masks, 1)) |
|
|
except AttributeError: |
|
|
import matplotlib.cm as cm |
|
|
|
|
|
color_map = cm.get_cmap("rainbow").resampled(max(num_masks, 1)) |
|
|
rgb_colors = [ |
|
|
tuple(int(c * 255) for c in color_map(i)[:3]) for i in range(num_masks) |
|
|
] |
|
|
|
|
|
composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0)) |
|
|
|
|
|
for i, single_mask in enumerate(mask_data): |
|
|
mask_bitmap = Image.fromarray((single_mask * 255).astype(np.uint8)) |
|
|
if mask_bitmap.size != base_image.size: |
|
|
mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST) |
|
|
|
|
|
fill_color = rgb_colors[i] |
|
|
color_fill = Image.new("RGBA", base_image.size, fill_color + (0,)) |
|
|
mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0) |
|
|
color_fill.putalpha(mask_alpha) |
|
|
composite_layer = Image.alpha_composite(composite_layer, color_fill) |
|
|
|
|
|
return Image.alpha_composite(base_image, composite_layer).convert("RGB") |
|
|
|
|
|
|
|
|
def frames_to_vid(pil_frames, output_path: str, vid_fps: int, vid_w: int, vid_h: int): |
|
|
assert len(pil_frames) > 0, f"Number of frames must be greater than 0" |
|
|
assert isinstance(pil_frames, list), f"pil_frames must be a list" |
|
|
video_writer = cv2.VideoWriter( |
|
|
output_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h) |
|
|
) |
|
|
for f in pil_frames: |
|
|
video_writer.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR)) |
|
|
video_writer.release() |
|
|
return output_path |
|
|
|
|
|
|
|
|
def calc_timeout_duration(vid_file, *args, **kwargs): |
|
|
sig = inspect.signature(video_inference) |
|
|
bound = sig.bind(vid_file, *args, **kwargs) |
|
|
bound.apply_defaults() |
|
|
return bound.arguments.get("timeout_duration", 60) |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=calc_timeout_duration) |
|
|
def video_inference( |
|
|
input_video, |
|
|
prompt: str, |
|
|
timeout_duration: int = 60, |
|
|
annotation_mode: bool = False, |
|
|
): |
|
|
""" |
|
|
Segments objects in a video using a text prompt. |
|
|
Returns a list of detection dicts (one per object per frame) and output video path/status. |
|
|
""" |
|
|
assert type(VID_MODEL) != type(None) and type(VID_PROCESSOR) != type( |
|
|
None |
|
|
), "Video Models failed to load on startup." |
|
|
assert input_video and prompt, "Missing video or prompt." |
|
|
|
|
|
|
|
|
video_path = ( |
|
|
input_video if isinstance(input_video, str) else input_video.get("name", None) |
|
|
) |
|
|
assert video_path, "Invalid video input." |
|
|
|
|
|
|
|
|
vmeta = get_video_metadata(video_path, bverbose=False) |
|
|
assert vmeta, "Failed to extract video metadata." |
|
|
vid_fps = vmeta["fps"] |
|
|
vid_w = vmeta["width"] |
|
|
vid_h = vmeta["height"] |
|
|
|
|
|
|
|
|
pil_frames = extract_frames( |
|
|
video_path, |
|
|
fps=int(vid_fps), |
|
|
max_short_edge=min(vid_w, vid_h), |
|
|
write_timestamp=False, |
|
|
write_frame_num=False, |
|
|
output_dir=None, |
|
|
) |
|
|
assert len(pil_frames) > 0, "No frames found in video." |
|
|
|
|
|
|
|
|
video_frames = [np.array(frame.convert("RGB")) for frame in pil_frames] |
|
|
|
|
|
session = VID_PROCESSOR.init_video_session( |
|
|
video=video_frames, inference_device=DEVICE, dtype=DTYPE |
|
|
) |
|
|
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt) |
|
|
temp_out_path = tempfile.mktemp(suffix=".mp4") |
|
|
|
|
|
detections = [] |
|
|
annotated_frames = [] |
|
|
for model_out in VID_MODEL.propagate_in_video_iterator( |
|
|
inference_session=session, max_frame_num_to_track=len(video_frames) |
|
|
): |
|
|
post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out) |
|
|
f_idx = model_out.frame_idx |
|
|
original_pil = Image.fromarray(video_frames[f_idx]) |
|
|
if "masks" in post_processed: |
|
|
detected_masks = post_processed["masks"] |
|
|
object_ids = post_processed["object_ids"] |
|
|
object_ids = [int(oid) for oid in object_ids] |
|
|
if detected_masks.ndim == 4: |
|
|
detected_masks = detected_masks.squeeze(1) |
|
|
|
|
|
for i, mask in enumerate(detected_masks): |
|
|
mask = mask.cpu().numpy() |
|
|
mask_bin = (mask > 0.0).astype(np.uint8) |
|
|
xyxy = mask_to_xyxy(mask_bin) |
|
|
if not xyxy: |
|
|
continue |
|
|
x0, y0, x1, y1 = xyxy |
|
|
det = { |
|
|
"frame": f_idx, |
|
|
"track_id": int(object_ids[i]), |
|
|
"x": x0 / vid_w, |
|
|
"y": y0 / vid_h, |
|
|
"w": (x1 - x0) / vid_w, |
|
|
"h": (y1 - y0) / vid_h, |
|
|
"conf": 1, |
|
|
"mask_b64": b64_mask_encode(mask_bin).decode("ascii"), |
|
|
} |
|
|
detections.append(det) |
|
|
|
|
|
if annotation_mode: |
|
|
final_frame = ( |
|
|
apply_mask_overlay(original_pil, detected_masks, object_ids=object_ids) |
|
|
if "masks" in post_processed |
|
|
else original_pil |
|
|
) |
|
|
annotated_frames.append(final_frame) |
|
|
|
|
|
return ( |
|
|
frames_to_vid( |
|
|
annotated_frames, |
|
|
output_path=temp_out_path, |
|
|
vid_fps=vid_fps, |
|
|
vid_h=vid_h, |
|
|
vid_w=vid_w, |
|
|
) |
|
|
if annotation_mode |
|
|
else detections |
|
|
) |
|
|
|
|
|
|
|
|
def video_annotation(input_video, prompt: str, timeout_duration: int = 60): |
|
|
return video_inference( |
|
|
input_video, prompt, timeout_duration=timeout_duration, annotation_mode=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
|
with gr.Tab("Video-Object Tracking"): |
|
|
gr.Interface( |
|
|
fn=video_inference, |
|
|
inputs=[ |
|
|
gr.Video(label="Input Video"), |
|
|
gr.Textbox( |
|
|
label="Prompt", |
|
|
lines=3, |
|
|
info="Describe the Object(s) you would like to track/ segmentate", |
|
|
value="", |
|
|
), |
|
|
gr.Radio([60, 120, 180, 240], value=60, label="Timeout (seconds)"), |
|
|
], |
|
|
outputs=gr.JSON(label="Output JSON"), |
|
|
title="SAM3 Video Segmentation", |
|
|
description="Segment Objects in Video using Text Prompts", |
|
|
api_name="video_inference", |
|
|
) |
|
|
with gr.Tab("Video Annotation"): |
|
|
gr.Interface( |
|
|
fn=video_annotation, |
|
|
inputs=[ |
|
|
gr.Video(label="Input Video"), |
|
|
gr.Textbox( |
|
|
label="Prompt", |
|
|
lines=3, |
|
|
info="Describe the Object(s) you would like to track/ segmentate", |
|
|
value="", |
|
|
), |
|
|
gr.Radio([60, 120, 180, 240], value=60, label="Timeout (seconds)"), |
|
|
], |
|
|
outputs=gr.Video(label="Processed Video"), |
|
|
title="SAM3 Video Segmentation", |
|
|
description="Segment Objects in Video using Text Prompts", |
|
|
api_name="video_annotation", |
|
|
) |
|
|
app.launch( |
|
|
mcp_server=True, app_kwargs={"docs_url": "/docs"} |
|
|
) |
|
|
|