File size: 5,742 Bytes
74593d4
bc27759
 
0cc203f
74593d4
 
 
c7dac2c
 
74593d4
 
c7dac2c
 
 
 
 
 
 
 
 
 
 
 
bc27759
 
 
c7dac2c
bc27759
 
c7dac2c
bc27759
 
 
a3f9b81
c7dac2c
 
 
 
 
 
 
1fd4203
c7dac2c
 
 
 
28be05f
c7dac2c
 
 
28be05f
c7dac2c
28be05f
 
c7dac2c
 
 
 
 
 
 
 
74593d4
57943d6
c7dac2c
 
 
 
 
 
 
 
 
 
 
 
bc27759
c7dac2c
 
a3f9b81
c7dac2c
 
 
 
 
bc27759
c7dac2c
 
 
 
bc27759
c7dac2c
 
 
 
 
 
 
 
 
 
57943d6
c7dac2c
 
 
 
 
 
 
 
bc27759
c7dac2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3f9b81
c7dac2c
 
a3f9b81
c7dac2c
 
74593d4
 
c7dac2c
74593d4
c7dac2c
 
 
 
 
 
 
 
 
 
 
 
28be05f
c7dac2c
 
 
 
 
 
 
 
 
 
 
5a34082
74593d4
c7dac2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import cv2
import tempfile
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import Sam3VideoModel, Sam3VideoProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print("Loading SAM3 Video Model...")
VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
print("Model loaded!")

OUTPUT_FPS = 24


def apply_green_mask(base_image, mask_data, opacity=0.5):
    """Draw green mask overlay on a frame."""
    if isinstance(base_image, np.ndarray):
        base_image = Image.fromarray(base_image)
    base_image = base_image.convert("RGBA")

    if mask_data is None or len(mask_data) == 0:
        return base_image.convert("RGB")

    if isinstance(mask_data, torch.Tensor):
        mask_data = mask_data.cpu().numpy()
    mask_data = mask_data.astype(np.uint8)

    if mask_data.ndim == 4:
        mask_data = mask_data[0]
    if mask_data.ndim == 3 and mask_data.shape[0] == 1:
        mask_data = mask_data[0]
    if mask_data.ndim == 3:
        # Multiple masks — merge into one
        mask_data = np.any(mask_data > 0, axis=0).astype(np.uint8)

    green = (0, 255, 0)
    mask_bitmap = Image.fromarray((mask_data * 255).astype(np.uint8))
    if mask_bitmap.size != base_image.size:
        mask_bitmap = mask_bitmap.resize(base_image.size, resample=Image.NEAREST)

    color_fill = Image.new("RGBA", base_image.size, green + (0,))
    mask_alpha = mask_bitmap.point(lambda v: int(v * opacity) if v > 0 else 0)
    color_fill.putalpha(mask_alpha)

    return Image.alpha_composite(base_image, color_fill).convert("RGB")


def get_video_info(video_path):
    """Return frame count and fps of the input video."""
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS) or 24
    cap.release()
    duration = total_frames / fps
    return total_frames, fps, duration


def calc_timeout(source_vid, text_query):
    if not source_vid:
        return 60
    _, _, duration = get_video_info(source_vid)
    # ~2s processing per second of video, with a floor/ceiling
    return max(60, min(int(duration * 3) + 30, 300))


@spaces.GPU(duration=calc_timeout)
def run_video_segmentation(source_vid, text_query):
    if VID_MODEL is None or VID_PROCESSOR is None:
        raise gr.Error("Video model failed to load.")
    if not source_vid or not text_query:
        raise gr.Error("Please provide both a video and a text prompt.")

    try:
        cap = cv2.VideoCapture(source_vid)
        src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        src_fps = cap.get(cv2.CAP_PROP_FPS) or 24

        video_frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        cap.release()

        total_frames = len(video_frames)
        duration = total_frames / src_fps
        status = f"Loaded {total_frames} frames ({duration:.1f}s @ {src_fps:.0f}fps). Processing..."
        print(status)

        session = VID_PROCESSOR.init_video_session(
            video=video_frames, inference_device=device, dtype=torch.bfloat16
        )
        session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)

        temp_out = tempfile.mktemp(suffix=".mp4")
        writer = cv2.VideoWriter(temp_out, cv2.VideoWriter_fourcc(*"mp4v"), OUTPUT_FPS, (src_w, src_h))

        for model_out in VID_MODEL.propagate_in_video_iterator(
            inference_session=session, max_frame_num_to_track=total_frames
        ):
            post = VID_PROCESSOR.postprocess_outputs(session, model_out)
            f_idx = model_out.frame_idx
            original = Image.fromarray(video_frames[f_idx])

            if "masks" in post:
                masks = post["masks"]
                if masks.ndim == 4:
                    masks = masks.squeeze(1)
                frame_out = apply_green_mask(original, masks)
            else:
                frame_out = original

            writer.write(cv2.cvtColor(np.array(frame_out), cv2.COLOR_RGB2BGR))

        writer.release()

        out_info = f"Done — {total_frames} frames, {duration:.1f}s input → output at {OUTPUT_FPS}fps"
        return temp_out, out_info

    except Exception as e:
        return None, f"Error: {str(e)}"


css = """
#col-container { margin: 0 auto; max-width: 1000px; }
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# SAM3 Video Segmentation — Green Mask")
        gr.Markdown(
            "Upload a video and describe what to segment. "
            "Output is rendered at **24fps** with a **green mask** overlay."
        )

        with gr.Row():
            with gr.Column():
                video_input = gr.Video(label="Input Video", format="mp4")
                text_prompt = gr.Textbox(
                    label="Text Prompt",
                    placeholder="e.g., person, red car, dog",
                )
                run_btn = gr.Button("Segment Video", variant="primary", size="lg")

            with gr.Column():
                video_output = gr.Video(label="Segmented Video", autoplay=True)
                status_box = gr.Textbox(label="Status", interactive=False)

        run_btn.click(
            fn=run_video_segmentation,
            inputs=[video_input, text_prompt],
            outputs=[video_output, status_box],
        )

if __name__ == "__main__":
    demo.launch(show_error=True)