update
Browse files
app.py
CHANGED
|
@@ -125,35 +125,30 @@ def get_prompt(click_state, click_input):
|
|
| 125 |
|
| 126 |
def load_video(video_input, video_state):
|
| 127 |
"""
|
| 128 |
-
Load video
|
| 129 |
"""
|
| 130 |
if video_input is None:
|
| 131 |
return video_state, None, \
|
| 132 |
gr.update(visible=False), gr.update(visible=False), \
|
| 133 |
gr.update(visible=False), gr.update(visible=False)
|
| 134 |
|
| 135 |
-
# Extract
|
| 136 |
-
|
| 137 |
-
cap = cv2.VideoCapture(video_input)
|
| 138 |
-
ret, first_frame = cap.read()
|
| 139 |
-
cap.release()
|
| 140 |
|
| 141 |
-
if
|
| 142 |
return video_state, None, \
|
| 143 |
gr.update(visible=False), gr.update(visible=False), \
|
| 144 |
gr.update(visible=False), gr.update(visible=False)
|
| 145 |
-
|
| 146 |
-
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 147 |
|
| 148 |
-
# Initialize video state
|
| 149 |
video_state = {
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
"first_frame_mask": None,
|
| 153 |
"masks": None,
|
| 154 |
}
|
| 155 |
|
| 156 |
-
first_frame_pil = Image.fromarray(
|
| 157 |
|
| 158 |
return video_state, first_frame_pil, \
|
| 159 |
gr.update(visible=True), gr.update(visible=True), \
|
|
@@ -161,6 +156,17 @@ def load_video(video_input, video_state):
|
|
| 161 |
|
| 162 |
|
| 163 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
| 165 |
"""
|
| 166 |
Add click and update mask on first frame
|
|
@@ -171,10 +177,7 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
|
| 171 |
click_state: [[points], [labels]]
|
| 172 |
evt: Gradio SelectData event with click coordinates
|
| 173 |
"""
|
| 174 |
-
|
| 175 |
-
initialize_models()
|
| 176 |
-
|
| 177 |
-
if video_state is None or "first_frame" not in video_state: # Check for first_frame
|
| 178 |
return None, video_state, click_state
|
| 179 |
|
| 180 |
# Add new click
|
|
@@ -186,13 +189,9 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
|
| 186 |
|
| 187 |
print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
|
| 188 |
|
| 189 |
-
# Generate mask with SAM2
|
| 190 |
-
first_frame = video_state["
|
| 191 |
-
mask =
|
| 192 |
-
frame=first_frame,
|
| 193 |
-
points=click_state[0],
|
| 194 |
-
labels=click_state[1]
|
| 195 |
-
)
|
| 196 |
|
| 197 |
# Store mask in video state
|
| 198 |
video_state["first_frame_mask"] = mask
|
|
@@ -280,37 +279,46 @@ def propagate_masks(video_state, click_state):
|
|
| 280 |
|
| 281 |
|
| 282 |
@spaces.GPU(duration=120)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
def run_videomama_with_sam2(video_state, click_state):
|
| 284 |
"""
|
| 285 |
Run SAM2 propagation and VideoMaMa inference together
|
| 286 |
"""
|
| 287 |
-
|
| 288 |
-
initialize_models()
|
| 289 |
-
|
| 290 |
-
if video_state is None or "video_path" not in video_state:
|
| 291 |
return video_state, None, None, None, "⚠️ No video loaded"
|
| 292 |
|
| 293 |
if len(click_state[0]) == 0:
|
| 294 |
return video_state, None, None, None, "⚠️ Please add at least one point first"
|
| 295 |
|
| 296 |
-
|
| 297 |
-
print(f"Loading frames from {video_state['video_path']}...")
|
| 298 |
-
frames, fps = extract_frames_from_video(video_state["video_path"], max_frames=50)
|
| 299 |
|
| 300 |
-
#
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
labels=click_state[1]
|
| 306 |
)
|
| 307 |
|
| 308 |
video_state["masks"] = masks
|
| 309 |
-
print(f"✓ Generated {len(masks)} masks")
|
| 310 |
-
|
| 311 |
-
# Step 2: Run VideoMaMa
|
| 312 |
-
print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
|
| 313 |
-
output_frames = videomama(videomama_pipeline, frames, masks)
|
| 314 |
|
| 315 |
# Save output videos
|
| 316 |
output_dir = Path("outputs")
|
|
@@ -513,4 +521,4 @@ if __name__ == "__main__":
|
|
| 513 |
# server_port=7860,
|
| 514 |
# share=True
|
| 515 |
# )
|
| 516 |
-
demo.launch()
|
|
|
|
| 125 |
|
| 126 |
def load_video(video_input, video_state):
|
| 127 |
"""
|
| 128 |
+
Load video and extract first frame for mask generation
|
| 129 |
"""
|
| 130 |
if video_input is None:
|
| 131 |
return video_state, None, \
|
| 132 |
gr.update(visible=False), gr.update(visible=False), \
|
| 133 |
gr.update(visible=False), gr.update(visible=False)
|
| 134 |
|
| 135 |
+
# Extract frames
|
| 136 |
+
frames, fps = extract_frames_from_video(video_input, max_frames=50)
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
if len(frames) == 0:
|
| 139 |
return video_state, None, \
|
| 140 |
gr.update(visible=False), gr.update(visible=False), \
|
| 141 |
gr.update(visible=False), gr.update(visible=False)
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
# Initialize video state
|
| 144 |
video_state = {
|
| 145 |
+
"frames": frames,
|
| 146 |
+
"fps": fps,
|
| 147 |
"first_frame_mask": None,
|
| 148 |
"masks": None,
|
| 149 |
}
|
| 150 |
|
| 151 |
+
first_frame_pil = Image.fromarray(frames[0])
|
| 152 |
|
| 153 |
return video_state, first_frame_pil, \
|
| 154 |
gr.update(visible=True), gr.update(visible=True), \
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
@spaces.GPU
|
| 159 |
+
def generate_sam2_mask(first_frame, points, labels):
|
| 160 |
+
"""GPU-intensive SAM2 mask generation"""
|
| 161 |
+
initialize_models()
|
| 162 |
+
mask = sam2_tracker.get_first_frame_mask(
|
| 163 |
+
frame=first_frame,
|
| 164 |
+
points=points,
|
| 165 |
+
labels=labels
|
| 166 |
+
)
|
| 167 |
+
return mask
|
| 168 |
+
|
| 169 |
+
|
| 170 |
def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
| 171 |
"""
|
| 172 |
Add click and update mask on first frame
|
|
|
|
| 177 |
click_state: [[points], [labels]]
|
| 178 |
evt: Gradio SelectData event with click coordinates
|
| 179 |
"""
|
| 180 |
+
if video_state is None or "frames" not in video_state:
|
|
|
|
|
|
|
|
|
|
| 181 |
return None, video_state, click_state
|
| 182 |
|
| 183 |
# Add new click
|
|
|
|
| 189 |
|
| 190 |
print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
|
| 191 |
|
| 192 |
+
# Generate mask with SAM2 (GPU operation)
|
| 193 |
+
first_frame = video_state["frames"][0]
|
| 194 |
+
mask = generate_sam2_mask(first_frame, click_state[0], click_state[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
# Store mask in video state
|
| 197 |
video_state["first_frame_mask"] = mask
|
|
|
|
| 279 |
|
| 280 |
|
| 281 |
@spaces.GPU(duration=120)
|
| 282 |
+
def process_video_with_models(frames, points, labels):
|
| 283 |
+
"""GPU-intensive video processing with SAM2 and VideoMaMa"""
|
| 284 |
+
initialize_models()
|
| 285 |
+
|
| 286 |
+
# Step 1: Track through video with SAM2
|
| 287 |
+
print(f"🎯 Tracking object through {len(frames)} frames with SAM2...")
|
| 288 |
+
masks = sam2_tracker.track_video(
|
| 289 |
+
frames=frames,
|
| 290 |
+
points=points,
|
| 291 |
+
labels=labels
|
| 292 |
+
)
|
| 293 |
+
print(f"✓ Generated {len(masks)} masks")
|
| 294 |
+
|
| 295 |
+
# Step 2: Run VideoMaMa
|
| 296 |
+
print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
|
| 297 |
+
output_frames = videomama(videomama_pipeline, frames, masks)
|
| 298 |
+
|
| 299 |
+
return masks, output_frames
|
| 300 |
+
|
| 301 |
+
|
| 302 |
def run_videomama_with_sam2(video_state, click_state):
|
| 303 |
"""
|
| 304 |
Run SAM2 propagation and VideoMaMa inference together
|
| 305 |
"""
|
| 306 |
+
if video_state is None or "frames" not in video_state:
|
|
|
|
|
|
|
|
|
|
| 307 |
return video_state, None, None, None, "⚠️ No video loaded"
|
| 308 |
|
| 309 |
if len(click_state[0]) == 0:
|
| 310 |
return video_state, None, None, None, "⚠️ Please add at least one point first"
|
| 311 |
|
| 312 |
+
frames = video_state["frames"]
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
# Run GPU-intensive processing
|
| 315 |
+
masks, output_frames = process_video_with_models(
|
| 316 |
+
frames,
|
| 317 |
+
click_state[0],
|
| 318 |
+
click_state[1]
|
|
|
|
| 319 |
)
|
| 320 |
|
| 321 |
video_state["masks"] = masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
# Save output videos
|
| 324 |
output_dir = Path("outputs")
|
|
|
|
| 521 |
# server_port=7860,
|
| 522 |
# share=True
|
| 523 |
# )
|
| 524 |
+
demo.launch()
|