Spaces:
Running on Zero
Running on Zero
| import os | |
| import cv2 | |
| import tempfile | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import Sam3VideoModel, Sam3VideoProcessor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| print("Loading SAM3 Video Model...") | |
| VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16) | |
| VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3") | |
| print("Model loaded!") | |
| OUTPUT_FPS = 24 | |
| def apply_green_mask(base_image, mask_data, opacity=0.5): | |
| """Draw green mask overlay on a frame.""" | |
| if isinstance(base_image, np.ndarray): | |
| base_image = Image.fromarray(base_image) | |
| base_image = base_image.convert("RGBA") | |
| if mask_data is None or len(mask_data) == 0: | |
| return base_image.convert("RGB") | |
| if isinstance(mask_data, torch.Tensor): | |
| mask_data = mask_data.cpu().numpy() | |
| mask_data = mask_data.astype(np.uint8) | |
| if mask_data.ndim == 4: | |
| mask_data = mask_data[0] | |
| if mask_data.ndim == 3 and mask_data.shape[0] == 1: | |
| mask_data = mask_data[0] | |
| if mask_data.ndim == 3: | |
| # Multiple masks — merge into one | |
| mask_data = np.any(mask_data > 0, axis=0).astype(np.uint8) | |
| green = (0, 255, 0) | |
| mask_bitmap = Image.fromarray((mask_data * 255).astype(np.uint8)) | |
| if mask_bitmap.size != base_image.size: | |
| mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST) | |
| color_fill = Image.new("RGBA", base_image.size, green + (0,)) | |
| mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0) | |
| color_fill.putalpha(mask_alpha) | |
| return Image.alpha_composite(base_image, color_fill).convert("RGB") | |
| def get_video_info(video_path): | |
| """Return frame count and fps of the input video.""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 24 | |
| cap.release() | |
| duration = total_frames / fps | |
| return total_frames, fps, duration | |
| def calc_timeout(source_vid, text_query): | |
| if not source_vid: | |
| return 60 | |
| _, _, duration = get_video_info(source_vid) | |
| # ~2s processing per second of video, with a floor/ceiling | |
| return max(60, min(int(duration * 3) + 30, 300)) | |
| def run_video_segmentation(source_vid, text_query): | |
| if VID_MODEL is None or VID_PROCESSOR is None: | |
| raise gr.Error("Video model failed to load.") | |
| if not source_vid or not text_query: | |
| raise gr.Error("Please provide both a video and a text prompt.") | |
| try: | |
| cap = cv2.VideoCapture(source_vid) | |
| src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| src_fps = cap.get(cv2.CAP_PROP_FPS) or 24 | |
| video_frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| total_frames = len(video_frames) | |
| duration = total_frames / src_fps | |
| status = f"Loaded {total_frames} frames ({duration:.1f}s @ {src_fps:.0f}fps). Processing..." | |
| print(status) | |
| session = VID_PROCESSOR.init_video_session( | |
| video=video_frames, inference_device=device, dtype=torch.bfloat16 | |
| ) | |
| session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query) | |
| temp_out = tempfile.mktemp(suffix=".mp4") | |
| writer = cv2.VideoWriter(temp_out, cv2.VideoWriter_fourcc(*"mp4v"), OUTPUT_FPS, (src_w, src_h)) | |
| for model_out in VID_MODEL.propagate_in_video_iterator( | |
| inference_session=session, max_frame_num_to_track=total_frames | |
| ): | |
| post = VID_PROCESSOR.postprocess_outputs(session, model_out) | |
| f_idx = model_out.frame_idx | |
| original = Image.fromarray(video_frames[f_idx]) | |
| if "masks" in post: | |
| masks = post["masks"] | |
| if masks.ndim == 4: | |
| masks = masks.squeeze(1) | |
| frame_out = apply_green_mask(original, masks) | |
| else: | |
| frame_out = original | |
| writer.write(cv2.cvtColor(np.array(frame_out), cv2.COLOR_RGB2BGR)) | |
| writer.release() | |
| out_info = f"Done — {total_frames} frames, {duration:.1f}s input → output at {OUTPUT_FPS}fps" | |
| return temp_out, out_info | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| css = """ | |
| #col-container { margin: 0 auto; max-width: 1000px; } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# SAM3 Video Segmentation — Green Mask") | |
| gr.Markdown( | |
| "Upload a video and describe what to segment. " | |
| "Output is rendered at **24fps** with a **green mask** overlay." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Input Video", format="mp4") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., person, red car, dog", | |
| ) | |
| run_btn = gr.Button("Segment Video", variant="primary", size="lg") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Segmented Video", autoplay=True) | |
| status_box = gr.Textbox(label="Status", interactive=False) | |
| run_btn.click( | |
| fn=run_video_segmentation, | |
| inputs=[video_input, text_prompt], | |
| outputs=[video_output, status_box], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) |