import os import cv2 import tempfile import spaces import gradio as gr import numpy as np import torch from PIL import Image from transformers import Sam3VideoModel, Sam3VideoProcessor device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print("Loading SAM3 Video Model...") VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16) VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3") print("Model loaded!") OUTPUT_FPS = 24 def apply_green_mask(base_image, mask_data, opacity=0.5): """Draw green mask overlay on a frame.""" if isinstance(base_image, np.ndarray): base_image = Image.fromarray(base_image) base_image = base_image.convert("RGBA") if mask_data is None or len(mask_data) == 0: return base_image.convert("RGB") if isinstance(mask_data, torch.Tensor): mask_data = mask_data.cpu().numpy() mask_data = mask_data.astype(np.uint8) if mask_data.ndim == 4: mask_data = mask_data[0] if mask_data.ndim == 3 and mask_data.shape[0] == 1: mask_data = mask_data[0] if mask_data.ndim == 3: # Multiple masks — merge into one mask_data = np.any(mask_data > 0, axis=0).astype(np.uint8) green = (0, 255, 0) mask_bitmap = Image.fromarray((mask_data * 255).astype(np.uint8)) if mask_bitmap.size != base_image.size: mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST) color_fill = Image.new("RGBA", base_image.size, green + (0,)) mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0) color_fill.putalpha(mask_alpha) return Image.alpha_composite(base_image, color_fill).convert("RGB") def get_video_info(video_path): """Return frame count and fps of the input video.""" cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or 24 cap.release() duration = total_frames / fps return total_frames, fps, duration def calc_timeout(source_vid, text_query): if not source_vid: return 60 _, _, duration = get_video_info(source_vid) # ~2s processing per second of video, with a floor/ceiling return max(60, min(int(duration * 3) + 30, 300)) @spaces.GPU(duration=calc_timeout) def run_video_segmentation(source_vid, text_query): if VID_MODEL is None or VID_PROCESSOR is None: raise gr.Error("Video model failed to load.") if not source_vid or not text_query: raise gr.Error("Please provide both a video and a text prompt.") try: cap = cv2.VideoCapture(source_vid) src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) src_fps = cap.get(cv2.CAP_PROP_FPS) or 24 video_frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cap.release() total_frames = len(video_frames) duration = total_frames / src_fps status = f"Loaded {total_frames} frames ({duration:.1f}s @ {src_fps:.0f}fps). Processing..." print(status) session = VID_PROCESSOR.init_video_session( video=video_frames, inference_device=device, dtype=torch.bfloat16 ) session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query) temp_out = tempfile.mktemp(suffix=".mp4") writer = cv2.VideoWriter(temp_out, cv2.VideoWriter_fourcc(*"mp4v"), OUTPUT_FPS, (src_w, src_h)) for model_out in VID_MODEL.propagate_in_video_iterator( inference_session=session, max_frame_num_to_track=total_frames ): post = VID_PROCESSOR.postprocess_outputs(session, model_out) f_idx = model_out.frame_idx original = Image.fromarray(video_frames[f_idx]) if "masks" in post: masks = post["masks"] if masks.ndim == 4: masks = masks.squeeze(1) frame_out = apply_green_mask(original, masks) else: frame_out = original writer.write(cv2.cvtColor(np.array(frame_out), cv2.COLOR_RGB2BGR)) writer.release() out_info = f"Done — {total_frames} frames, {duration:.1f}s input → output at {OUTPUT_FPS}fps" return temp_out, out_info except Exception as e: return None, f"Error: {str(e)}" css = """ #col-container { margin: 0 auto; max-width: 1000px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# SAM3 Video Segmentation — Green Mask") gr.Markdown( "Upload a video and describe what to segment. " "Output is rendered at **24fps** with a **green mask** overlay." ) with gr.Row(): with gr.Column(): video_input = gr.Video(label="Input Video", format="mp4") text_prompt = gr.Textbox( label="Text Prompt", placeholder="e.g., person, red car, dog", ) run_btn = gr.Button("Segment Video", variant="primary", size="lg") with gr.Column(): video_output = gr.Video(label="Segmented Video", autoplay=True) status_box = gr.Textbox(label="Status", interactive=False) run_btn.click( fn=run_video_segmentation, inputs=[video_input, text_prompt], outputs=[video_output, status_box], ) if __name__ == "__main__": demo.launch(show_error=True)