Spaces:
Running on Zero
Running on Zero
| import functools | |
| import math | |
| import os | |
| import tempfile | |
| from collections import defaultdict | |
| import cv2 | |
| import numpy as np | |
| import PIL | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFile | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| import gradio as gr | |
| import spaces | |
| from molmo_utils import process_vision_info | |
| from typing import Iterable | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| Image.MAX_IMAGE_PIXELS = None | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| # ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "allenai/MolmoPoint-8B" | |
| MAX_IMAGE_SIZE = 512 | |
| MAX_VIDEO_HEIGHT = 512 | |
| POINT_SIZE = 0.01 | |
| KEYFRAME_HOLD_FRAMES = 3 | |
| SHOW_TRAILS = True | |
| MAX_NEW_TOKENS = 2048 | |
| MAX_FPS = 10 | |
| COLORS = [ | |
| "rgb(255, 100, 180)", | |
| "rgb(100, 180, 255)", | |
| "rgb(180, 255, 100)", | |
| "rgb(255, 180, 100)", | |
| "rgb(100, 255, 180)", | |
| "rgb(180, 100, 255)", | |
| "rgb(255, 255, 100)", | |
| "rgb(100, 255, 255)", | |
| "rgb(255, 120, 120)", | |
| "rgb(120, 255, 255)", | |
| "rgb(255, 255, 120)", | |
| "rgb(255, 120, 255)", | |
| ] | |
| # ββ Model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Loading {MODEL_ID}...") | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| padding_side="left", | |
| ) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| dtype="bfloat16", | |
| device_map="auto", | |
| ) | |
| print("Model loaded successfully.") | |
| # ββ Helper functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_rgb(color_str): | |
| """Parse 'rgb(r, g, b)' to (r, g, b) tuple.""" | |
| nums = color_str.replace("rgb(", "").replace(")", "").split(",") | |
| return tuple(int(n.strip()) for n in nums) | |
| COLORS_BGR = [(_parse_rgb(c)[2], _parse_rgb(c)[1], _parse_rgb(c)[0]) for c in COLORS] | |
| def is_tracking_output(generated_text: str) -> bool: | |
| """Detect tracking from model output by checking for <tracks tag.""" | |
| return generated_text.strip().startswith("<tracks") | |
| def cast_float_bf16(t: torch.Tensor): | |
| if torch.is_floating_point(t): | |
| t = t.to(torch.bfloat16) | |
| return t | |
| def draw_points(image, points): | |
| if isinstance(image, np.ndarray): | |
| annotation = PIL.Image.fromarray(image) | |
| else: | |
| annotation = image.copy() | |
| draw = ImageDraw.Draw(annotation) | |
| w, h = annotation.size | |
| size = max(5, int(max(w, h) * POINT_SIZE)) | |
| for i, (x, y) in enumerate(points): | |
| color = COLORS[0] | |
| draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None) | |
| return annotation | |
| def draw_points_colored(image, points_with_ids): | |
| """Draw points with per-instance-ID colors for tracking visualization.""" | |
| if isinstance(image, np.ndarray): | |
| annotation = PIL.Image.fromarray(image) | |
| else: | |
| annotation = image.copy() | |
| draw = ImageDraw.Draw(annotation) | |
| w, h = annotation.size | |
| size = max(5, int(max(w, h) * POINT_SIZE)) | |
| for object_id, x, y in points_with_ids: | |
| color = COLORS[(object_id - 1) % len(COLORS)] | |
| draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None) | |
| return annotation | |
| def format_points_list(points, is_video=False): | |
| """Format extracted points as a flat Python list string.""" | |
| if not points: | |
| return "[]" | |
| rows = [] | |
| if is_video: | |
| for object_id, ts, x, y in points: | |
| rows.append(f"[{int(object_id)}, {float(ts):.2f}, {float(x):.1f}, {float(y):.1f}]") | |
| else: | |
| for object_id, ix, x, y in points: | |
| rows.append(f"[{int(object_id)}, {int(ix)}, {float(x):.1f}, {float(y):.1f}]") | |
| return "[" + ", ".join(rows) + "]" | |
| def _interpolate_keyframes(keyframes, total_frames, max_gap=None): | |
| """Linearly interpolate positions between keyframes. | |
| keyframes: sorted list of (frame_idx, x, y) | |
| max_gap: if set, skip interpolation (leave gap invisible) when two keyframes | |
| are more than this many frames apart. | |
| Returns dict {frame_idx: (x, y)} for every frame from first to last keyframe. | |
| """ | |
| if not keyframes: | |
| return {} | |
| positions = {} | |
| for i in range(len(keyframes)): | |
| f_idx, x, y = keyframes[i] | |
| positions[f_idx] = (x, y) | |
| if i + 1 < len(keyframes): | |
| nf, nx, ny = keyframes[i + 1] | |
| span = nf - f_idx | |
| if span > 1 and (max_gap is None or span <= max_gap): | |
| for t in range(1, span): | |
| alpha = t / span | |
| positions[f_idx + t] = (x + alpha * (nx - x), y + alpha * (ny - y)) | |
| return positions | |
| def create_annotated_video(video_path, points, metadata, tracking): | |
| """Draw points on the original video with interpolation and fading trails. | |
| Points format: [(object_id, timestamp, x, y), ...] | |
| Coordinates are in the processed frame space (metadata["video_size"]). | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| proc_w, proc_h = metadata["video_size"] | |
| scale_x = vid_w / proc_w | |
| scale_y = vid_h / proc_h | |
| # Build per-object keyframes: {obj_id: [(frame_idx, x, y), ...]} | |
| obj_keyframes = defaultdict(list) | |
| for object_id, ts, x, y in points: | |
| f_idx = int(round(float(ts) * fps)) | |
| sx, sy = float(x) * scale_x, float(y) * scale_y | |
| obj_keyframes[int(object_id)].append((f_idx, sx, sy)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| obj_positions = {} | |
| obj_keyframe_set = {} | |
| max_gap_frames = int(fps) # gaps > 1 second: no interpolation, object disappears | |
| for obj_id, kfs in obj_keyframes.items(): | |
| kfs.sort(key=lambda k: k[0]) | |
| obj_positions[obj_id] = _interpolate_keyframes(kfs, total_frames, max_gap=max_gap_frames) | |
| raw_kf = set(f_idx for f_idx, _, _ in kfs) | |
| obj_keyframe_set[obj_id] = set( | |
| f for kf in raw_kf for f in range(kf - KEYFRAME_HOLD_FRAMES, kf + KEYFRAME_HOLD_FRAMES + 1) | |
| ) | |
| out_path = tempfile.mktemp(suffix=".mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(out_path, fourcc, fps, (vid_w, vid_h)) | |
| radius = max(5, int(max(vid_w, vid_h) * POINT_SIZE)) | |
| trail_length = int(fps * 2) | |
| obj_history = defaultdict(list) | |
| current_frame = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| for obj_id, positions in obj_positions.items(): | |
| if current_frame in positions: | |
| px, py = positions[current_frame] | |
| # Clear trail if the object was absent in the previous frame (gap reappearance) | |
| if current_frame > 0 and (current_frame - 1) not in positions: | |
| obj_history[obj_id] = [] | |
| obj_history[obj_id].append((px, py)) | |
| if len(obj_history[obj_id]) > trail_length: | |
| obj_history[obj_id] = obj_history[obj_id][-trail_length:] | |
| if tracking: | |
| color = COLORS_BGR[(obj_id - 1) % len(COLORS_BGR)] | |
| else: | |
| color = COLORS_BGR[0] | |
| # Draw fading trail | |
| trail = obj_history[obj_id] | |
| n_trail = len(trail) | |
| if SHOW_TRAILS and n_trail >= 2: | |
| for i in range(n_trail - 1): | |
| alpha = (i + 1) / n_trail | |
| trail_color = tuple(int(c * alpha) for c in color) | |
| thickness = max(1, int(radius * 0.6 * alpha)) | |
| pt1 = (int(trail[i][0]), int(trail[i][1])) | |
| pt2 = (int(trail[i + 1][0]), int(trail[i + 1][1])) | |
| cv2.line(frame, pt1, pt2, trail_color, thickness) | |
| # Solid on keyframes, outline-only on interpolated frames | |
| if current_frame in obj_keyframe_set[obj_id]: | |
| cv2.circle(frame, (int(px), int(py)), radius, color, -1) | |
| cv2.circle(frame, (int(px), int(py)), radius + 2, (255, 255, 255), 2) | |
| else: | |
| cv2.circle(frame, (int(px), int(py)), radius, color, 2) | |
| out.write(frame) | |
| current_frame += 1 | |
| cap.release() | |
| out.release() | |
| return out_path | |
| # ββ Inference functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_images(user_text, input_images, max_tokens): | |
| if not input_images: | |
| return "Please upload at least one image.", [], "[]" | |
| pil_images = [] | |
| for img_path in input_images: | |
| if isinstance(img_path, tuple): | |
| img_path = img_path[0] | |
| pil_images.append(Image.open(img_path).convert("RGB")) | |
| # Build messages | |
| content = [dict(type="text", text=user_text)] | |
| for img in pil_images: | |
| content.append(dict(type="image", image=img)) | |
| messages = [{"role": "user", "content": content}] | |
| # Process inputs | |
| images, _, _ = process_vision_info(messages) | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| print(f"Prompt: {text}") | |
| inputs = processor( | |
| images=images, | |
| text=text, | |
| padding=True, | |
| return_tensors="pt", | |
| return_pointing_metadata=True, | |
| ) | |
| metadata = inputs.pop("metadata") | |
| inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()} | |
| # Generate | |
| with torch.inference_mode(): | |
| with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| output = model.generate( | |
| **inputs, | |
| logits_processor=model.build_logit_processor_from_inputs(inputs), | |
| max_new_tokens=int(max_tokens), | |
| temperature=0.0 | |
| ) | |
| generated_tokens = output[0, inputs["input_ids"].size(1):] | |
| generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| # Extract points | |
| points = model.extract_image_points( | |
| generated_text, | |
| metadata["token_pooling"], | |
| metadata["subpatch_mapping"], | |
| metadata["image_sizes"], | |
| ) | |
| points_table = format_points_list(points, is_video=False) | |
| print(f"Output text: {generated_text}") | |
| print("Extracted points:", points_table) | |
| if points: | |
| group_by_index = defaultdict(list) | |
| for object_id, ix, x, y in points: | |
| group_by_index[ix].append((x, y)) | |
| annotated = [] | |
| for ix, pts in group_by_index.items(): | |
| annotated.append(draw_points(images[ix], pts)) | |
| return generated_text, annotated, points_table | |
| return generated_text, pil_images, points_table | |
| def process_video(user_text, video_path, frame_sample_mode, max_frames, max_fps, max_tokens): | |
| if not video_path: | |
| return "Please upload a video.", None, [], "[]" | |
| # Build messages | |
| video_kwargs_msg = { | |
| "num_frames": int(max_frames), | |
| "frame_sample_mode": frame_sample_mode, | |
| } | |
| if max_fps is not None and max_fps > 0: | |
| video_kwargs_msg["max_fps"] = int(max_fps) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| dict(type="text", text=user_text), | |
| dict(type="video", video=video_path, **video_kwargs_msg), | |
| ], | |
| } | |
| ] | |
| # Process vision info | |
| _, videos, video_kwargs = process_vision_info(messages) | |
| videos, video_metadatas = zip(*videos) | |
| videos, video_metadatas = list(videos), list(video_metadatas) | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| print(f"Prompt: {text}") | |
| inputs = processor( | |
| videos=videos, | |
| video_metadata=video_metadatas, | |
| text=text, | |
| padding=True, | |
| return_tensors="pt", | |
| return_pointing_metadata=True, | |
| **video_kwargs, | |
| ) | |
| metadata = inputs.pop("metadata") | |
| inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()} | |
| # Generate | |
| with torch.inference_mode(): | |
| with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| output = model.generate( | |
| **inputs, | |
| logits_processor=model.build_logit_processor_from_inputs(inputs), | |
| max_new_tokens=int(max_tokens), | |
| temperature=0.0 | |
| ) | |
| generated_tokens = output[0, inputs["input_ids"].size(1):] | |
| generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| # Extract points | |
| points = model.extract_video_points( | |
| generated_text, | |
| metadata["token_pooling"], | |
| metadata["subpatch_mapping"], | |
| metadata["timestamps"], | |
| metadata["video_size"], | |
| ) | |
| tracking = is_tracking_output(generated_text) | |
| annotated_video = None | |
| annotated_frames = [] | |
| points_table = format_points_list(points, is_video=True) | |
| print(f"Output text: {generated_text}") | |
| print("Extracted points:", points_table) | |
| if points: | |
| print(f"Extracted {len(points)} points. Tracking={tracking}") | |
| # Build annotated frames on sampled video frames | |
| if tracking: | |
| group_by_time = defaultdict(list) | |
| for object_id, ts, x, y in points: | |
| group_by_time[ts].append((object_id, x, y)) | |
| group_by_frame = defaultdict(list) | |
| for ts, pts_with_ids in group_by_time.items(): | |
| ix = int(np.argmin(np.abs(metadata["timestamps"] - ts))) | |
| group_by_frame[ix] += pts_with_ids | |
| for ix, pts_with_ids in sorted(group_by_frame.items()): | |
| frame_img = draw_points_colored(videos[0][ix], pts_with_ids) | |
| ts = metadata["timestamps"][ix] | |
| annotated_frames.append((frame_img, f"t={ts:.2f}s")) | |
| else: | |
| group_by_time = defaultdict(list) | |
| for object_id, ts, x, y in points: | |
| group_by_time[ts].append((x, y)) | |
| group_by_frame = defaultdict(list) | |
| for ts, pts in group_by_time.items(): | |
| ix = int(np.argmin(np.abs(metadata["timestamps"] - ts))) | |
| group_by_frame[ix] += pts | |
| for ix, pts in sorted(group_by_frame.items()): | |
| frame_img = draw_points(videos[0][ix], pts) | |
| ts = metadata["timestamps"][ix] | |
| annotated_frames.append((frame_img, f"t={ts:.2f}s")) | |
| # Annotated video with interpolation + trails | |
| annotated_video = create_annotated_video(video_path, points, metadata, tracking) | |
| return generated_text, annotated_video, annotated_frames, points_table | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Read processor defaults for video settings | |
| _default_frame_sample_mode = processor.video_processor.frame_sample_mode | |
| _default_max_frames = processor.video_processor.num_frames | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| #main-title h1 {font-size: 2.3em !important;} | |
| #input_image image { | |
| object-fit: contain !important; | |
| } | |
| #input_video video { | |
| object-fit: contain !important; | |
| } | |
| .gallery-item img { | |
| border: none !important; | |
| outline: none !important; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **MolmoPoint-8B Demo**", elem_id="main-title") | |
| gr.Markdown( | |
| "Image & video pointing and tracking using the " | |
| "[MolmoPoint-8B](https://huggingface.co/allenai/MolmoPoint-8B) pointing model." | |
| ) | |
| with gr.Row(): | |
| # ββ LEFT COLUMN: Inputs ββ | |
| with gr.Column(): | |
| with gr.Tabs() as input_tabs: | |
| with gr.TabItem("Video Tracking", id="video_tracking_tab") as video_tracking_tab: | |
| video_tracking = gr.Video(label="Input Video", elem_id="input_video", height=MAX_VIDEO_HEIGHT) | |
| with gr.TabItem("Video Pointing", id="video_pointing_tab") as video_pointing_tab: | |
| video_pointing = gr.Video(label="Input Video", elem_id="input_video_pointing", height=MAX_VIDEO_HEIGHT) | |
| with gr.TabItem("Image(s) Pointing", id="image_tab") as image_tab: | |
| images_input = gr.Gallery( | |
| label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE, | |
| ) | |
| input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text") | |
| with gr.Row(visible=True) as video_params_row: | |
| frame_sample_mode = gr.Dropdown(choices=[_default_frame_sample_mode, "fps"], value=_default_frame_sample_mode, label="frame_sample_mode") | |
| max_frames = gr.Number(value=_default_max_frames, label="max_frames") | |
| max_fps = gr.Number(value=MAX_FPS, label="max_fps") | |
| max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit", variant="primary", scale=3) | |
| clear_all_button = gr.ClearButton( | |
| components=[video_tracking, video_pointing, images_input, input_text], value="Clear All", scale=1, | |
| ) | |
| # ββ RIGHT COLUMN: Outputs ββ | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Output Text"): | |
| output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10) | |
| with gr.TabItem("Extracted Points"): | |
| output_points = gr.Textbox( | |
| label="Extracted Points ([[id, time/index, x, y]])", lines=15, | |
| ) | |
| output_warning = gr.HTML(visible=False) | |
| with gr.Tabs(visible=True) as video_output_tabs: | |
| with gr.TabItem("Annotated Video"): | |
| output_video = gr.Video(label="Annotated Video", height=MAX_VIDEO_HEIGHT) | |
| with gr.TabItem("Annotated Frames"): | |
| gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*") | |
| output_annotations = gr.Gallery(label="Annotated Frames (Video)", height=MAX_IMAGE_SIZE) | |
| with gr.Group(visible=False) as image_output_group: | |
| gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*") | |
| output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE) | |
| # ββ Examples ββ | |
| with gr.Group(visible=True) as video_tracking_examples_group: | |
| gr.Markdown("### Video Tracking Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["example-videos/us_canada_hockey.mp4", "Track the U.S. hockey players."], | |
| ["example-videos/penguins.mp4", "Track all the penguins."], | |
| ["example-videos/arena_basketball.mp4", "Track the players in yellow uniform in 1 fps."], | |
| ], | |
| inputs=[video_tracking, input_text], | |
| label="Video Tracking Examples", | |
| ) | |
| with gr.Group(visible=False) as video_pointing_examples_group: | |
| gr.Markdown("### Video Pointing Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["example-videos/sports.mp4", "Point to players in white/blue jersey"], | |
| ["example-videos/travel.mp4", "Point to standalone spiky towers"], | |
| ], | |
| inputs=[video_pointing, input_text], | |
| label="Video Pointing Examples", | |
| ) | |
| with gr.Group(visible=False) as image_examples_group: | |
| gr.Markdown("### Image Examples") | |
| gr.Examples( | |
| examples=[ | |
| [["example-images/boat1.jpeg", "example-images/boat2.jpeg"], "Point to the boats."], | |
| [["example-images/messy1.jpg", "example-images/messy2.jpg", "example-images/messy3.jpg", "example-images/messy4.jpg"], "Point to the scissors."], | |
| ], | |
| inputs=[images_input, input_text], | |
| label="Image Pointing Examples", | |
| ) | |
| # ββ Tab switching: toggle visibility + track active tab ββ | |
| active_tab = gr.State("video_tracking") | |
| def _select_video_tracking_tab(): | |
| return ( | |
| "video_tracking", | |
| gr.update(value=10), # max_fps | |
| gr.update(visible=True), # video_tracking_examples_group | |
| gr.update(visible=False), # video_pointing_examples_group | |
| gr.update(visible=False), # image_examples_group | |
| gr.update(visible=True), # video_params_row | |
| gr.update(visible=True), # video_output_tabs | |
| gr.update(visible=False), # image_output_group | |
| ) | |
| def _select_video_pointing_tab(): | |
| return ( | |
| "video_pointing", | |
| gr.update(value=2), # max_fps | |
| gr.update(visible=False), # video_tracking_examples_group | |
| gr.update(visible=True), # video_pointing_examples_group | |
| gr.update(visible=False), # image_examples_group | |
| gr.update(visible=True), # video_params_row | |
| gr.update(visible=True), # video_output_tabs | |
| gr.update(visible=False), # image_output_group | |
| ) | |
| def _select_image_tab(): | |
| return ( | |
| "image", | |
| gr.update(), # max_fps unchanged | |
| gr.update(visible=False), # video_tracking_examples_group | |
| gr.update(visible=False), # video_pointing_examples_group | |
| gr.update(visible=True), # image_examples_group | |
| gr.update(visible=False), # video_params_row | |
| gr.update(visible=False), # video_output_tabs | |
| gr.update(visible=True), # image_output_group | |
| ) | |
| tab_outputs = [ | |
| active_tab, max_fps, | |
| video_tracking_examples_group, video_pointing_examples_group, image_examples_group, | |
| video_params_row, video_output_tabs, image_output_group, | |
| ] | |
| video_tracking_tab.select(fn=_select_video_tracking_tab, outputs=tab_outputs) | |
| video_pointing_tab.select(fn=_select_video_pointing_tab, outputs=tab_outputs) | |
| image_tab.select(fn=_select_image_tab, outputs=tab_outputs) | |
| _WARNING_STYLE = ( | |
| 'style="background:#fef2f2; border:1px solid #fca5a5; border-radius:6px; ' | |
| 'padding:8px 12px; color:#991b1b; font-size:14px;"' | |
| ) | |
| def _fps_warning(generated_text, current_max_fps): | |
| """Return gr.update for the warning HTML block.""" | |
| tracking = "<tracks" in generated_text | |
| pointing = "<point" in generated_text | |
| if pointing and int(current_max_fps) != 2: | |
| html = f'<div {_WARNING_STYLE}>β οΈ For best video pointing results, set <b style="color:#991b1b">max_fps=2</b>.</div>' | |
| return gr.update(value=html, visible=True) | |
| if tracking and int(current_max_fps) != 10: | |
| html = f'<div {_WARNING_STYLE}>β οΈ For best tracking results, set <b style="color:#991b1b">max_fps=10</b>.</div>' | |
| return gr.update(value=html, visible=True) | |
| return gr.update(value="", visible=False) | |
| def dispatch_submit(tab, user_text, video_tracking_path, video_pointing_path, | |
| input_images, fsm, mf, mfps, max_tok): | |
| if tab == "image": | |
| text_out, img_gallery, pts = process_images(user_text, input_images, max_tok) | |
| return text_out, pts, gr.update(value="", visible=False), None, [], img_gallery | |
| else: | |
| video_path = video_tracking_path if tab == "video_tracking" else video_pointing_path | |
| text_out, ann_video, ann_frames, pts = process_video( | |
| user_text, video_path, fsm, mf, mfps, max_tok, | |
| ) | |
| warning = _fps_warning(text_out, mfps) | |
| return text_out, pts, warning, ann_video, ann_frames, [] | |
| submit_button.click( | |
| fn=dispatch_submit, | |
| inputs=[active_tab, input_text, video_tracking, video_pointing, images_input, | |
| frame_sample_mode, max_frames, max_fps, max_tok_slider], | |
| outputs=[output_text, output_points, output_warning, output_video, output_annotations, output_annotations_img], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=css, mcp_server=True, ssr_mode=False, show_error=True, share=True) | |