Spaces:
Running on Zero
Running on Zero
File size: 5,742 Bytes
74593d4 bc27759 0cc203f 74593d4 c7dac2c 74593d4 c7dac2c bc27759 c7dac2c bc27759 c7dac2c bc27759 a3f9b81 c7dac2c 1fd4203 c7dac2c 28be05f c7dac2c 28be05f c7dac2c 28be05f c7dac2c 74593d4 57943d6 c7dac2c bc27759 c7dac2c a3f9b81 c7dac2c bc27759 c7dac2c bc27759 c7dac2c 57943d6 c7dac2c bc27759 c7dac2c a3f9b81 c7dac2c a3f9b81 c7dac2c 74593d4 c7dac2c 74593d4 c7dac2c 28be05f c7dac2c 5a34082 74593d4 c7dac2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import os
import cv2
import tempfile
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import Sam3VideoModel, Sam3VideoProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Loading SAM3 Video Model...")
VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
print("Model loaded!")
OUTPUT_FPS = 24
def apply_green_mask(base_image, mask_data, opacity=0.5):
"""Draw green mask overlay on a frame."""
if isinstance(base_image, np.ndarray):
base_image = Image.fromarray(base_image)
base_image = base_image.convert("RGBA")
if mask_data is None or len(mask_data) == 0:
return base_image.convert("RGB")
if isinstance(mask_data, torch.Tensor):
mask_data = mask_data.cpu().numpy()
mask_data = mask_data.astype(np.uint8)
if mask_data.ndim == 4:
mask_data = mask_data[0]
if mask_data.ndim == 3 and mask_data.shape[0] == 1:
mask_data = mask_data[0]
if mask_data.ndim == 3:
# Multiple masks — merge into one
mask_data = np.any(mask_data > 0, axis=0).astype(np.uint8)
green = (0, 255, 0)
mask_bitmap = Image.fromarray((mask_data * 255).astype(np.uint8))
if mask_bitmap.size != base_image.size:
mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)
color_fill = Image.new("RGBA", base_image.size, green + (0,))
mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
color_fill.putalpha(mask_alpha)
return Image.alpha_composite(base_image, color_fill).convert("RGB")
def get_video_info(video_path):
"""Return frame count and fps of the input video."""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 24
cap.release()
duration = total_frames / fps
return total_frames, fps, duration
def calc_timeout(source_vid, text_query):
if not source_vid:
return 60
_, _, duration = get_video_info(source_vid)
# ~2s processing per second of video, with a floor/ceiling
return max(60, min(int(duration * 3) + 30, 300))
@spaces.GPU(duration=calc_timeout)
def run_video_segmentation(source_vid, text_query):
if VID_MODEL is None or VID_PROCESSOR is None:
raise gr.Error("Video model failed to load.")
if not source_vid or not text_query:
raise gr.Error("Please provide both a video and a text prompt.")
try:
cap = cv2.VideoCapture(source_vid)
src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
src_fps = cap.get(cv2.CAP_PROP_FPS) or 24
video_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
total_frames = len(video_frames)
duration = total_frames / src_fps
status = f"Loaded {total_frames} frames ({duration:.1f}s @ {src_fps:.0f}fps). Processing..."
print(status)
session = VID_PROCESSOR.init_video_session(
video=video_frames, inference_device=device, dtype=torch.bfloat16
)
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
temp_out = tempfile.mktemp(suffix=".mp4")
writer = cv2.VideoWriter(temp_out, cv2.VideoWriter_fourcc(*"mp4v"), OUTPUT_FPS, (src_w, src_h))
for model_out in VID_MODEL.propagate_in_video_iterator(
inference_session=session, max_frame_num_to_track=total_frames
):
post = VID_PROCESSOR.postprocess_outputs(session, model_out)
f_idx = model_out.frame_idx
original = Image.fromarray(video_frames[f_idx])
if "masks" in post:
masks = post["masks"]
if masks.ndim == 4:
masks = masks.squeeze(1)
frame_out = apply_green_mask(original, masks)
else:
frame_out = original
writer.write(cv2.cvtColor(np.array(frame_out), cv2.COLOR_RGB2BGR))
writer.release()
out_info = f"Done — {total_frames} frames, {duration:.1f}s input → output at {OUTPUT_FPS}fps"
return temp_out, out_info
except Exception as e:
return None, f"Error: {str(e)}"
css = """
#col-container { margin: 0 auto; max-width: 1000px; }
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# SAM3 Video Segmentation — Green Mask")
gr.Markdown(
"Upload a video and describe what to segment. "
"Output is rendered at **24fps** with a **green mask** overlay."
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Input Video", format="mp4")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., person, red car, dog",
)
run_btn = gr.Button("Segment Video", variant="primary", size="lg")
with gr.Column():
video_output = gr.Video(label="Segmented Video", autoplay=True)
status_box = gr.Textbox(label="Status", interactive=False)
run_btn.click(
fn=run_video_segmentation,
inputs=[video_input, text_prompt],
outputs=[video_output, status_box],
)
if __name__ == "__main__":
demo.launch(show_error=True) |