jamepark3922's picture
disappeared object appears in fresh trai
0b5792e
import functools
import math
import os
import tempfile
from collections import defaultdict
import cv2
import numpy as np
import PIL
import torch
from PIL import Image, ImageDraw, ImageFile
from transformers import AutoModelForImageTextToText, AutoProcessor
import gradio as gr
import spaces
from molmo_utils import process_vision_info
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
# ── Constants ──────────────────────────────────────────────────────────────────
MODEL_ID = "allenai/MolmoPoint-8B"
MAX_IMAGE_SIZE = 512
MAX_VIDEO_HEIGHT = 512
POINT_SIZE = 0.01
KEYFRAME_HOLD_FRAMES = 3
SHOW_TRAILS = True
MAX_NEW_TOKENS = 2048
MAX_FPS = 10
COLORS = [
"rgb(255, 100, 180)",
"rgb(100, 180, 255)",
"rgb(180, 255, 100)",
"rgb(255, 180, 100)",
"rgb(100, 255, 180)",
"rgb(180, 100, 255)",
"rgb(255, 255, 100)",
"rgb(100, 255, 255)",
"rgb(255, 120, 120)",
"rgb(120, 255, 255)",
"rgb(255, 255, 120)",
"rgb(255, 120, 255)",
]
# ── Model loading ──────────────────────────────────────────────────────────────
print(f"Loading {MODEL_ID}...")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True,
padding_side="left",
)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
trust_remote_code=True,
dtype="bfloat16",
device_map="auto",
)
print("Model loaded successfully.")
# ── Helper functions ───────────────────────────────────────────────────────────
def _parse_rgb(color_str):
"""Parse 'rgb(r, g, b)' to (r, g, b) tuple."""
nums = color_str.replace("rgb(", "").replace(")", "").split(",")
return tuple(int(n.strip()) for n in nums)
COLORS_BGR = [(_parse_rgb(c)[2], _parse_rgb(c)[1], _parse_rgb(c)[0]) for c in COLORS]
def is_tracking_output(generated_text: str) -> bool:
"""Detect tracking from model output by checking for <tracks tag."""
return generated_text.strip().startswith("<tracks")
def cast_float_bf16(t: torch.Tensor):
if torch.is_floating_point(t):
t = t.to(torch.bfloat16)
return t
def draw_points(image, points):
if isinstance(image, np.ndarray):
annotation = PIL.Image.fromarray(image)
else:
annotation = image.copy()
draw = ImageDraw.Draw(annotation)
w, h = annotation.size
size = max(5, int(max(w, h) * POINT_SIZE))
for i, (x, y) in enumerate(points):
color = COLORS[0]
draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None)
return annotation
def draw_points_colored(image, points_with_ids):
"""Draw points with per-instance-ID colors for tracking visualization."""
if isinstance(image, np.ndarray):
annotation = PIL.Image.fromarray(image)
else:
annotation = image.copy()
draw = ImageDraw.Draw(annotation)
w, h = annotation.size
size = max(5, int(max(w, h) * POINT_SIZE))
for object_id, x, y in points_with_ids:
color = COLORS[(object_id - 1) % len(COLORS)]
draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None)
return annotation
def format_points_list(points, is_video=False):
"""Format extracted points as a flat Python list string."""
if not points:
return "[]"
rows = []
if is_video:
for object_id, ts, x, y in points:
rows.append(f"[{int(object_id)}, {float(ts):.2f}, {float(x):.1f}, {float(y):.1f}]")
else:
for object_id, ix, x, y in points:
rows.append(f"[{int(object_id)}, {int(ix)}, {float(x):.1f}, {float(y):.1f}]")
return "[" + ", ".join(rows) + "]"
def _interpolate_keyframes(keyframes, total_frames, max_gap=None):
"""Linearly interpolate positions between keyframes.
keyframes: sorted list of (frame_idx, x, y)
max_gap: if set, skip interpolation (leave gap invisible) when two keyframes
are more than this many frames apart.
Returns dict {frame_idx: (x, y)} for every frame from first to last keyframe.
"""
if not keyframes:
return {}
positions = {}
for i in range(len(keyframes)):
f_idx, x, y = keyframes[i]
positions[f_idx] = (x, y)
if i + 1 < len(keyframes):
nf, nx, ny = keyframes[i + 1]
span = nf - f_idx
if span > 1 and (max_gap is None or span <= max_gap):
for t in range(1, span):
alpha = t / span
positions[f_idx + t] = (x + alpha * (nx - x), y + alpha * (ny - y))
return positions
def create_annotated_video(video_path, points, metadata, tracking):
"""Draw points on the original video with interpolation and fading trails.
Points format: [(object_id, timestamp, x, y), ...]
Coordinates are in the processed frame space (metadata["video_size"]).
"""
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
proc_w, proc_h = metadata["video_size"]
scale_x = vid_w / proc_w
scale_y = vid_h / proc_h
# Build per-object keyframes: {obj_id: [(frame_idx, x, y), ...]}
obj_keyframes = defaultdict(list)
for object_id, ts, x, y in points:
f_idx = int(round(float(ts) * fps))
sx, sy = float(x) * scale_x, float(y) * scale_y
obj_keyframes[int(object_id)].append((f_idx, sx, sy))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
obj_positions = {}
obj_keyframe_set = {}
max_gap_frames = int(fps) # gaps > 1 second: no interpolation, object disappears
for obj_id, kfs in obj_keyframes.items():
kfs.sort(key=lambda k: k[0])
obj_positions[obj_id] = _interpolate_keyframes(kfs, total_frames, max_gap=max_gap_frames)
raw_kf = set(f_idx for f_idx, _, _ in kfs)
obj_keyframe_set[obj_id] = set(
f for kf in raw_kf for f in range(kf - KEYFRAME_HOLD_FRAMES, kf + KEYFRAME_HOLD_FRAMES + 1)
)
out_path = tempfile.mktemp(suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(out_path, fourcc, fps, (vid_w, vid_h))
radius = max(5, int(max(vid_w, vid_h) * POINT_SIZE))
trail_length = int(fps * 2)
obj_history = defaultdict(list)
current_frame = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
for obj_id, positions in obj_positions.items():
if current_frame in positions:
px, py = positions[current_frame]
# Clear trail if the object was absent in the previous frame (gap reappearance)
if current_frame > 0 and (current_frame - 1) not in positions:
obj_history[obj_id] = []
obj_history[obj_id].append((px, py))
if len(obj_history[obj_id]) > trail_length:
obj_history[obj_id] = obj_history[obj_id][-trail_length:]
if tracking:
color = COLORS_BGR[(obj_id - 1) % len(COLORS_BGR)]
else:
color = COLORS_BGR[0]
# Draw fading trail
trail = obj_history[obj_id]
n_trail = len(trail)
if SHOW_TRAILS and n_trail >= 2:
for i in range(n_trail - 1):
alpha = (i + 1) / n_trail
trail_color = tuple(int(c * alpha) for c in color)
thickness = max(1, int(radius * 0.6 * alpha))
pt1 = (int(trail[i][0]), int(trail[i][1]))
pt2 = (int(trail[i + 1][0]), int(trail[i + 1][1]))
cv2.line(frame, pt1, pt2, trail_color, thickness)
# Solid on keyframes, outline-only on interpolated frames
if current_frame in obj_keyframe_set[obj_id]:
cv2.circle(frame, (int(px), int(py)), radius, color, -1)
cv2.circle(frame, (int(px), int(py)), radius + 2, (255, 255, 255), 2)
else:
cv2.circle(frame, (int(px), int(py)), radius, color, 2)
out.write(frame)
current_frame += 1
cap.release()
out.release()
return out_path
# ── Inference functions ────────────────────────────────────────────────────────
@spaces.GPU
def process_images(user_text, input_images, max_tokens):
if not input_images:
return "Please upload at least one image.", [], "[]"
pil_images = []
for img_path in input_images:
if isinstance(img_path, tuple):
img_path = img_path[0]
pil_images.append(Image.open(img_path).convert("RGB"))
# Build messages
content = [dict(type="text", text=user_text)]
for img in pil_images:
content.append(dict(type="image", image=img))
messages = [{"role": "user", "content": content}]
# Process inputs
images, _, _ = process_vision_info(messages)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Prompt: {text}")
inputs = processor(
images=images,
text=text,
padding=True,
return_tensors="pt",
return_pointing_metadata=True,
)
metadata = inputs.pop("metadata")
inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()}
# Generate
with torch.inference_mode():
with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
output = model.generate(
**inputs,
logits_processor=model.build_logit_processor_from_inputs(inputs),
max_new_tokens=int(max_tokens),
temperature=0.0
)
generated_tokens = output[0, inputs["input_ids"].size(1):]
generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Extract points
points = model.extract_image_points(
generated_text,
metadata["token_pooling"],
metadata["subpatch_mapping"],
metadata["image_sizes"],
)
points_table = format_points_list(points, is_video=False)
print(f"Output text: {generated_text}")
print("Extracted points:", points_table)
if points:
group_by_index = defaultdict(list)
for object_id, ix, x, y in points:
group_by_index[ix].append((x, y))
annotated = []
for ix, pts in group_by_index.items():
annotated.append(draw_points(images[ix], pts))
return generated_text, annotated, points_table
return generated_text, pil_images, points_table
@spaces.GPU
def process_video(user_text, video_path, frame_sample_mode, max_frames, max_fps, max_tokens):
if not video_path:
return "Please upload a video.", None, [], "[]"
# Build messages
video_kwargs_msg = {
"num_frames": int(max_frames),
"frame_sample_mode": frame_sample_mode,
}
if max_fps is not None and max_fps > 0:
video_kwargs_msg["max_fps"] = int(max_fps)
messages = [
{
"role": "user",
"content": [
dict(type="text", text=user_text),
dict(type="video", video=video_path, **video_kwargs_msg),
],
}
]
# Process vision info
_, videos, video_kwargs = process_vision_info(messages)
videos, video_metadatas = zip(*videos)
videos, video_metadatas = list(videos), list(video_metadatas)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Prompt: {text}")
inputs = processor(
videos=videos,
video_metadata=video_metadatas,
text=text,
padding=True,
return_tensors="pt",
return_pointing_metadata=True,
**video_kwargs,
)
metadata = inputs.pop("metadata")
inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()}
# Generate
with torch.inference_mode():
with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
output = model.generate(
**inputs,
logits_processor=model.build_logit_processor_from_inputs(inputs),
max_new_tokens=int(max_tokens),
temperature=0.0
)
generated_tokens = output[0, inputs["input_ids"].size(1):]
generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Extract points
points = model.extract_video_points(
generated_text,
metadata["token_pooling"],
metadata["subpatch_mapping"],
metadata["timestamps"],
metadata["video_size"],
)
tracking = is_tracking_output(generated_text)
annotated_video = None
annotated_frames = []
points_table = format_points_list(points, is_video=True)
print(f"Output text: {generated_text}")
print("Extracted points:", points_table)
if points:
print(f"Extracted {len(points)} points. Tracking={tracking}")
# Build annotated frames on sampled video frames
if tracking:
group_by_time = defaultdict(list)
for object_id, ts, x, y in points:
group_by_time[ts].append((object_id, x, y))
group_by_frame = defaultdict(list)
for ts, pts_with_ids in group_by_time.items():
ix = int(np.argmin(np.abs(metadata["timestamps"] - ts)))
group_by_frame[ix] += pts_with_ids
for ix, pts_with_ids in sorted(group_by_frame.items()):
frame_img = draw_points_colored(videos[0][ix], pts_with_ids)
ts = metadata["timestamps"][ix]
annotated_frames.append((frame_img, f"t={ts:.2f}s"))
else:
group_by_time = defaultdict(list)
for object_id, ts, x, y in points:
group_by_time[ts].append((x, y))
group_by_frame = defaultdict(list)
for ts, pts in group_by_time.items():
ix = int(np.argmin(np.abs(metadata["timestamps"] - ts)))
group_by_frame[ix] += pts
for ix, pts in sorted(group_by_frame.items()):
frame_img = draw_points(videos[0][ix], pts)
ts = metadata["timestamps"][ix]
annotated_frames.append((frame_img, f"t={ts:.2f}s"))
# Annotated video with interpolation + trails
annotated_video = create_annotated_video(video_path, points, metadata, tracking)
return generated_text, annotated_video, annotated_frames, points_table
# ── Gradio UI ──────────────────────────────────────────────────────────────────
# Read processor defaults for video settings
_default_frame_sample_mode = processor.video_processor.frame_sample_mode
_default_max_frames = processor.video_processor.num_frames
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
#main-title h1 {font-size: 2.3em !important;}
#input_image image {
object-fit: contain !important;
}
#input_video video {
object-fit: contain !important;
}
.gallery-item img {
border: none !important;
outline: none !important;
}
"""
with gr.Blocks() as demo:
gr.Markdown("# **MolmoPoint-8B Demo**", elem_id="main-title")
gr.Markdown(
"Image & video pointing and tracking using the "
"[MolmoPoint-8B](https://huggingface.co/allenai/MolmoPoint-8B) pointing model."
)
with gr.Row():
# ── LEFT COLUMN: Inputs ──
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.TabItem("Video Tracking", id="video_tracking_tab") as video_tracking_tab:
video_tracking = gr.Video(label="Input Video", elem_id="input_video", height=MAX_VIDEO_HEIGHT)
with gr.TabItem("Video Pointing", id="video_pointing_tab") as video_pointing_tab:
video_pointing = gr.Video(label="Input Video", elem_id="input_video_pointing", height=MAX_VIDEO_HEIGHT)
with gr.TabItem("Image(s) Pointing", id="image_tab") as image_tab:
images_input = gr.Gallery(
label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE,
)
input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text")
with gr.Row(visible=True) as video_params_row:
frame_sample_mode = gr.Dropdown(choices=[_default_frame_sample_mode, "fps"], value=_default_frame_sample_mode, label="frame_sample_mode")
max_frames = gr.Number(value=_default_max_frames, label="max_frames")
max_fps = gr.Number(value=MAX_FPS, label="max_fps")
max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary", scale=3)
clear_all_button = gr.ClearButton(
components=[video_tracking, video_pointing, images_input, input_text], value="Clear All", scale=1,
)
# ── RIGHT COLUMN: Outputs ──
with gr.Column():
with gr.Tabs():
with gr.TabItem("Output Text"):
output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10)
with gr.TabItem("Extracted Points"):
output_points = gr.Textbox(
label="Extracted Points ([[id, time/index, x, y]])", lines=15,
)
output_warning = gr.HTML(visible=False)
with gr.Tabs(visible=True) as video_output_tabs:
with gr.TabItem("Annotated Video"):
output_video = gr.Video(label="Annotated Video", height=MAX_VIDEO_HEIGHT)
with gr.TabItem("Annotated Frames"):
gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*")
output_annotations = gr.Gallery(label="Annotated Frames (Video)", height=MAX_IMAGE_SIZE)
with gr.Group(visible=False) as image_output_group:
gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*")
output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE)
# ── Examples ──
with gr.Group(visible=True) as video_tracking_examples_group:
gr.Markdown("### Video Tracking Examples")
gr.Examples(
examples=[
["example-videos/us_canada_hockey.mp4", "Track the U.S. hockey players."],
["example-videos/penguins.mp4", "Track all the penguins."],
["example-videos/arena_basketball.mp4", "Track the players in yellow uniform in 1 fps."],
],
inputs=[video_tracking, input_text],
label="Video Tracking Examples",
)
with gr.Group(visible=False) as video_pointing_examples_group:
gr.Markdown("### Video Pointing Examples")
gr.Examples(
examples=[
["example-videos/sports.mp4", "Point to players in white/blue jersey"],
["example-videos/travel.mp4", "Point to standalone spiky towers"],
],
inputs=[video_pointing, input_text],
label="Video Pointing Examples",
)
with gr.Group(visible=False) as image_examples_group:
gr.Markdown("### Image Examples")
gr.Examples(
examples=[
[["example-images/boat1.jpeg", "example-images/boat2.jpeg"], "Point to the boats."],
[["example-images/messy1.jpg", "example-images/messy2.jpg", "example-images/messy3.jpg", "example-images/messy4.jpg"], "Point to the scissors."],
],
inputs=[images_input, input_text],
label="Image Pointing Examples",
)
# ── Tab switching: toggle visibility + track active tab ──
active_tab = gr.State("video_tracking")
def _select_video_tracking_tab():
return (
"video_tracking",
gr.update(value=10), # max_fps
gr.update(visible=True), # video_tracking_examples_group
gr.update(visible=False), # video_pointing_examples_group
gr.update(visible=False), # image_examples_group
gr.update(visible=True), # video_params_row
gr.update(visible=True), # video_output_tabs
gr.update(visible=False), # image_output_group
)
def _select_video_pointing_tab():
return (
"video_pointing",
gr.update(value=2), # max_fps
gr.update(visible=False), # video_tracking_examples_group
gr.update(visible=True), # video_pointing_examples_group
gr.update(visible=False), # image_examples_group
gr.update(visible=True), # video_params_row
gr.update(visible=True), # video_output_tabs
gr.update(visible=False), # image_output_group
)
def _select_image_tab():
return (
"image",
gr.update(), # max_fps unchanged
gr.update(visible=False), # video_tracking_examples_group
gr.update(visible=False), # video_pointing_examples_group
gr.update(visible=True), # image_examples_group
gr.update(visible=False), # video_params_row
gr.update(visible=False), # video_output_tabs
gr.update(visible=True), # image_output_group
)
tab_outputs = [
active_tab, max_fps,
video_tracking_examples_group, video_pointing_examples_group, image_examples_group,
video_params_row, video_output_tabs, image_output_group,
]
video_tracking_tab.select(fn=_select_video_tracking_tab, outputs=tab_outputs)
video_pointing_tab.select(fn=_select_video_pointing_tab, outputs=tab_outputs)
image_tab.select(fn=_select_image_tab, outputs=tab_outputs)
_WARNING_STYLE = (
'style="background:#fef2f2; border:1px solid #fca5a5; border-radius:6px; '
'padding:8px 12px; color:#991b1b; font-size:14px;"'
)
def _fps_warning(generated_text, current_max_fps):
"""Return gr.update for the warning HTML block."""
tracking = "<tracks" in generated_text
pointing = "<point" in generated_text
if pointing and int(current_max_fps) != 2:
html = f'<div {_WARNING_STYLE}>⚠️ For best video pointing results, set <b style="color:#991b1b">max_fps=2</b>.</div>'
return gr.update(value=html, visible=True)
if tracking and int(current_max_fps) != 10:
html = f'<div {_WARNING_STYLE}>⚠️ For best tracking results, set <b style="color:#991b1b">max_fps=10</b>.</div>'
return gr.update(value=html, visible=True)
return gr.update(value="", visible=False)
def dispatch_submit(tab, user_text, video_tracking_path, video_pointing_path,
input_images, fsm, mf, mfps, max_tok):
if tab == "image":
text_out, img_gallery, pts = process_images(user_text, input_images, max_tok)
return text_out, pts, gr.update(value="", visible=False), None, [], img_gallery
else:
video_path = video_tracking_path if tab == "video_tracking" else video_pointing_path
text_out, ann_video, ann_frames, pts = process_video(
user_text, video_path, fsm, mf, mfps, max_tok,
)
warning = _fps_warning(text_out, mfps)
return text_out, pts, warning, ann_video, ann_frames, []
submit_button.click(
fn=dispatch_submit,
inputs=[active_tab, input_text, video_tracking, video_pointing, images_input,
frame_sample_mode, max_frames, max_fps, max_tok_slider],
outputs=[output_text, output_points, output_warning, output_video, output_annotations, output_annotations_img],
)
if __name__ == "__main__":
demo.launch(css=css, mcp_server=True, ssr_mode=False, show_error=True, share=True)