prithivMLmods's picture
update app [backpropagate]
d6e00ad verified
import colorsys
import gc
import tempfile
import re
import json
import ast
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from typing import Iterator
from gradio.themes import Soft
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID_V = "prithivMLmods/Qwen3-VL-4B-Instruct-Unredacted-MAX"
DTYPE = torch.bfloat16
print(f"Loading {MODEL_ID_V}...")
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
model_v = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_ID_V, attn_implementation="kernels-community/flash-attn3", trust_remote_code=True, torch_dtype=DTYPE
).to(device).eval()
print("Model loaded successfully.")
MAX_SECONDS = 5.0
SYSTEM_PROMPT = """You are a helpful assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled between 0 and 1000. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."""
POINT_SYSTEM_PROMPT = """You are a precise object pointing assistant. When asked to point to an object in an image, you must return ONLY the exact center coordinates of that specific object as [x, y] with values scaled between 0 and 1000 (where 0,0 is the top-left corner and 1000,1000 is the bottom-right corner).
Rules:
1. ONLY point to objects that exactly match the description given.
2. Do NOT point to background, empty areas, or unrelated objects.
3. If there are multiple matching instances, return [[x1, y1], [x2, y2], ...].
4. If no matching object is found, return an empty list [].
5. Return ONLY the coordinate numbers, no explanations or other text.
6. Be extremely precise — place the point at the exact visual center of each matching object."""
POINTS_REGEX = re.compile(r'(?:(\d+)\s*[.:])?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)')
def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
cap = cv2.VideoCapture(video_path_or_url)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
fps_val = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return frames, {"num_frames": len(frames), "fps": float(fps_val) if fps_val > 0 else None}
def parse_bboxes_from_text(text: str) -> list[list[float]]:
text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL)
nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text)
if nested:
try:
all_b = []
for m in nested:
parsed = json.loads(m)
all_b.extend(parsed if isinstance(parsed[0], list) else [parsed])
return all_b
except (json.JSONDecodeError, IndexError):
pass
single = re.findall(
r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text)
if single:
return [[float(v) for v in m] for m in single]
nums = re.findall(r'(\d+(?:\.\d+)?)', text)
return [[float(nums[i]), float(nums[i + 1]), float(nums[i + 2]), float(nums[i + 3])] for i in
range(0, len(nums) - 3, 4)] if len(nums) >= 4 else []
def parse_precise_points(text: str, image_w: int, image_h: int) -> list[tuple[float, float]]:
text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL)
raw_points = []
nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text)
if nested:
try:
for m in nested:
parsed = json.loads(m)
if isinstance(parsed[0], list):
for p in parsed:
if len(p) >= 2:
raw_points.append((float(p[0]), float(p[1])))
elif len(parsed) >= 2:
raw_points.append((float(parsed[0]), float(parsed[1])))
except (json.JSONDecodeError, IndexError):
pass
if not raw_points:
single = re.findall(r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text)
if single:
for m in single:
raw_points.append((float(m[0]), float(m[1])))
if not raw_points:
for match in POINTS_REGEX.finditer(text):
raw_points.append((float(match.group(2)), float(match.group(3))))
validated = []
for sx, sy in raw_points:
if not (0 <= sx <= 1000 and 0 <= sy <= 1000):
continue
px = sx / 1000 * image_w
py = sy / 1000 * image_h
if 0 <= px <= image_w and 0 <= py <= image_h:
validated.append((px, py))
if len(validated) > 1:
deduped = [validated[0]]
for pt in validated[1:]:
if all(((pt[0] - ex[0]) ** 2 + (pt[1] - ex[1]) ** 2) ** 0.5 >= 15 for ex in deduped):
deduped.append(pt)
validated = deduped
return validated
def bbox_to_mask(bbox_scaled: list[float], width: int, height: int) -> np.ndarray:
mask = np.zeros((height, width), dtype=np.float32)
x1 = max(0, min(int(bbox_scaled[0] / 1000 * width), width - 1))
y1 = max(0, min(int(bbox_scaled[1] / 1000 * height), height - 1))
x2 = max(0, min(int(bbox_scaled[2] / 1000 * width), width - 1))
y2 = max(0, min(int(bbox_scaled[3] / 1000 * height), height - 1))
mask[y1:y2, x1:x2] = 1.0
return mask
def bbox_iou(b1, b2):
x1 = max(b1[0], b2[0])
y1 = max(b1[1], b2[1])
x2 = min(b1[2], b2[2])
y2 = min(b1[3], b2[3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
union = (b1[2] - b1[0]) * (b1[3] - b1[1]) + (b2[2] - b2[0]) * (b2[3] - b2[1]) - inter
return inter / union if union > 0 else 0.0
def bbox_center_distance(b1, b2):
c1 = ((b1[0] + b1[2]) / 2, (b1[1] + b1[3]) / 2)
c2 = ((b2[0] + b2[2]) / 2, (b2[1] + b2[3]) / 2)
return ((c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2) ** 0.5
def pixel_point_distance(p1, p2):
return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
def overlay_masks_on_frame(frame: Image.Image, masks: dict, colors_map: dict, alpha=0.5) -> Image.Image:
base = np.array(frame).astype(np.float32) / 255
overlay = base.copy()
for oid, mask in masks.items():
if mask is None:
continue
color = np.array(colors_map.get(oid, (255, 0, 0)), dtype=np.float32) / 255
if mask.ndim == 3:
mask = mask.squeeze()
m = np.clip(mask, 0, 1)[..., None]
overlay = (1 - alpha * m) * overlay + (alpha * m) * color
return Image.fromarray(np.clip(overlay * 255, 0, 255).astype(np.uint8))
def pastel_color_for_prompt(prompt: str):
hue = (sum(ord(c) for c in prompt) * 2654435761 % 360) / 360
r, g, b = colorsys.hsv_to_rgb(hue, 0.5, 0.95)
return int(r * 255), int(g * 255), int(b * 255)
def get_font(image_height: int):
font_size = max(10, int(13 * image_height / 720))
try:
font_paths = [
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/System/Library/Fonts/Helvetica.ttc",
"arial.ttf",
]
for fp in font_paths:
try:
return ImageFont.truetype(fp, font_size), font_size
except OSError:
continue
except Exception:
pass
return ImageFont.load_default(), 13
def detect_objects_in_frame(frame: Image.Image, prompt: str) -> list[list[float]]:
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": frame},
{"type": "text", "text": f"Detect all instances of: {prompt}"}]}
]
text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
generated = out[:, inputs.input_ids.shape[1]:]
txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
return parse_bboxes_from_text(txt)
def detect_precise_points_in_frame(frame: Image.Image, prompt: str) -> list[tuple[float, float]]:
w, h = frame.size
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": frame},
{"type": "text",
"text": f"Detect all instances of: {prompt}. Return only bounding boxes for objects that exactly match this description."}]}
]
text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
generated = out[:, inputs.input_ids.shape[1]:]
txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
bboxes = parse_bboxes_from_text(txt)
if bboxes:
points = []
for b in bboxes:
bw = abs(b[2] - b[0])
bh = abs(b[3] - b[1])
if bw < 5 or bh < 5:
continue
if bw > 950 and bh > 950:
continue
cx = (b[0] + b[2]) / 2 / 1000 * w
cy = (b[1] + b[3]) / 2 / 1000 * h
if 0 <= cx <= w and 0 <= cy <= h:
points.append((cx, cy))
if len(points) > 1:
deduped = [points[0]]
for pt in points[1:]:
if all(pixel_point_distance(pt, ex) >= 20 for ex in deduped):
deduped.append(pt)
points = deduped
if points:
return points
messages2 = [
{"role": "system", "content": [{"type": "text", "text": POINT_SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": frame},
{"type": "text",
"text": f"Point to the exact center of each '{prompt}' in this image. Only point to objects that are clearly '{prompt}', nothing else."}]}
]
text2 = processor_v.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
inputs2 = processor_v(text=[text2], images=[frame], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out2 = model_v.generate(**inputs2, max_new_tokens=512, do_sample=False)
generated2 = out2[:, inputs2.input_ids.shape[1]:]
txt2 = processor_v.batch_decode(generated2, skip_special_tokens=True)[0]
return parse_precise_points(txt2, w, h)
def run_model_inference(image: Image.Image, prompt: str) -> str:
messages = [
{"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
]}
]
text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor_v(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
generated = out[:, inputs.input_ids.shape[1]:]
result = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
result = re.sub(r'<think>.*?</think>', '', result.strip(), flags=re.DOTALL).strip()
return result
def safe_parse_json(text: str):
text = text.strip()
text = re.sub(r"^```(json)?", "", text)
text = re.sub(r"```$", "", text)
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(text)
except Exception:
return {}
def annotate_image_detection(image: Image.Image, result: dict) -> Image.Image:
if not isinstance(image, Image.Image) or not isinstance(result, dict):
return image
image = image.convert("RGB")
original_width, original_height = image.size
draw = ImageDraw.Draw(image)
font, font_size = get_font(original_height)
if "objects" in result and result["objects"]:
colors_list = [
(66, 133, 244), (234, 67, 53), (251, 188, 4), (52, 168, 83),
(255, 109, 0), (171, 71, 188), (0, 172, 193), (255, 82, 82),
(46, 125, 50), (121, 85, 72),
]
for idx, obj in enumerate(result["objects"]):
x_min = int(obj.get("x_min", 0.0) * original_width)
y_min = int(obj.get("y_min", 0.0) * original_height)
x_max = int(obj.get("x_max", 0.0) * original_width)
y_max = int(obj.get("y_max", 0.0) * original_height)
color = colors_list[idx % len(colors_list)]
draw.rectangle([(x_min, y_min), (x_max, y_max)], outline=color, width=3)
label = obj.get("label", f"Object {idx + 1}")
padding = max(2, int(4 * original_height / 720))
label_y = max(0, y_min - int(20 * original_height / 720))
tb = draw.textbbox((x_min, label_y), label, font=font)
draw.rectangle(
[(tb[0] - padding, tb[1] - padding), (tb[2] + padding, tb[3] + padding)],
fill=color
)
draw.text((x_min, label_y), label, fill="white", font=font)
return image
def annotate_image_points(image: Image.Image, result: dict) -> Image.Image:
if not isinstance(image, Image.Image) or not isinstance(result, dict):
return image
image = image.convert("RGB")
original_width, original_height = image.size
draw = ImageDraw.Draw(image)
if "points" in result and result["points"]:
for idx, p in enumerate(result["points"]):
px = int(p["x"] * original_width)
py = int(p["y"] * original_height)
r_outer = max(8, int(10 * original_height / 720))
r_inner = max(5, int(7 * original_height / 720))
r_dot = max(1, int(2 * original_height / 720))
draw.ellipse((px - r_outer, py - r_outer, px + r_outer, py + r_outer), outline="white", width=2)
draw.ellipse((px - r_inner, py - r_inner, px + r_inner, py + r_inner), fill=(255, 40, 40), outline=(255, 40, 40))
draw.ellipse((px - r_dot, py - r_dot, px + r_dot, py + r_dot), fill=(255, 200, 200))
return image
class TrackingState:
def __init__(self):
self.reset()
def reset(self):
self.video_frames: list[Image.Image] = []
self.video_fps: float | None = None
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
self.bboxes_by_frame: dict[int, dict[int, list[float]]] = {}
self.color_by_obj: dict[int, tuple[int, int, int]] = {}
self.color_by_prompt: dict[str, tuple[int, int, int]] = {}
self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {}
self.composited_frames: dict[int, Image.Image] = {}
self.prompts: dict[str, list[int]] = {}
self.next_obj_id: int = 1
self.current_frame_idx: int = 0
@property
def num_frames(self) -> int:
return len(self.video_frames)
class PointTrackingState:
def __init__(self):
self.reset()
def reset(self):
self.video_frames: list[Image.Image] = []
self.video_fps: float | None = None
self.points_by_frame: dict[int, list[tuple[float, float]]] = {}
self.trails: list[list[tuple[int, float, float]]] = []
self.composited_frames: dict[int, Image.Image] = {}
self.prompt_text: str = ""
self.current_frame_idx: int = 0
@property
def num_frames(self) -> int:
return len(self.video_frames)
def compose_tracking_frame(state: TrackingState, frame_idx: int) -> Image.Image:
if state is None or not state.video_frames:
return None
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
frame = state.video_frames[frame_idx].copy()
w, h = frame.size
masks = state.masks_by_frame.get(frame_idx, {})
if masks:
frame = overlay_masks_on_frame(frame, masks, state.color_by_obj, alpha=0.5)
bboxes = state.bboxes_by_frame.get(frame_idx, {})
if bboxes:
draw = ImageDraw.Draw(frame)
font, font_size = get_font(h)
padding = max(2, int(4 * h / 720))
vert_offset = int(20 * h / 720)
for oid, bbox in bboxes.items():
color = state.color_by_obj.get(oid, (255, 255, 255))
x1 = int(bbox[0] / 1000 * w)
y1 = int(bbox[1] / 1000 * h)
x2 = int(bbox[2] / 1000 * w)
y2 = int(bbox[3] / 1000 * h)
draw.rectangle((x1, y1, x2, y2), outline=color, width=3)
prompt = state.text_prompts_by_frame_obj.get(frame_idx, {}).get(oid, "")
if prompt:
label = f"{prompt} - ID{oid}"
label_y = max(0, y1 - vert_offset)
tb = draw.textbbox((x1, label_y), label, font=font)
draw.rectangle(
[(tb[0] - padding, tb[1] - padding), (tb[2] + padding, tb[3] + padding)],
fill=color
)
draw.text((x1, label_y), label, fill="white", font=font)
state.composited_frames[frame_idx] = frame
return frame
def compose_point_frame(pt_state: PointTrackingState, frame_idx: int, trail_length: int = 12) -> Image.Image:
if pt_state is None or not pt_state.video_frames:
return None
frame_idx = int(np.clip(frame_idx, 0, len(pt_state.video_frames) - 1))
frame = pt_state.video_frames[frame_idx].copy()
draw = ImageDraw.Draw(frame)
RED = (255, 40, 40)
DARK_RED = (180, 0, 0)
for trail in pt_state.trails:
trail_pts = [(tx, ty) for fi, tx, ty in trail if fi <= frame_idx and fi > frame_idx - trail_length]
if len(trail_pts) >= 2:
for t_idx in range(len(trail_pts) - 1):
alpha_ratio = (t_idx + 1) / len(trail_pts)
trail_color = (
int(DARK_RED[0] * alpha_ratio),
int(DARK_RED[1] * alpha_ratio),
int(DARK_RED[2] * alpha_ratio)
)
thickness = max(1, int(2 * alpha_ratio))
x1t, y1t = int(trail_pts[t_idx][0]), int(trail_pts[t_idx][1])
x2t, y2t = int(trail_pts[t_idx + 1][0]), int(trail_pts[t_idx + 1][1])
draw.line([(x1t, y1t), (x2t, y2t)], fill=trail_color, width=thickness)
points_f = pt_state.points_by_frame.get(frame_idx, [])
for (px, py) in points_f:
draw.ellipse((px - 10, py - 10, px + 10, py + 10), outline="white", width=2)
draw.ellipse((px - 7, py - 7, px + 7, py + 7), fill=RED, outline=RED)
draw.ellipse((px - 2, py - 2, px + 2, py + 2), fill=(255, 200, 200))
pt_state.composited_frames[frame_idx] = frame
return frame
def update_tracking_display(state: TrackingState, frame_idx: int) -> Image.Image:
if state is None or not state.video_frames:
return None
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
cached = state.composited_frames.get(frame_idx)
if cached is not None:
return cached
return compose_tracking_frame(state, frame_idx)
def update_point_display(pt_state: PointTrackingState, frame_idx: int) -> Image.Image:
if pt_state is None or not pt_state.video_frames:
return None
frame_idx = int(np.clip(frame_idx, 0, len(pt_state.video_frames) - 1))
cached = pt_state.composited_frames.get(frame_idx)
if cached is not None:
return cached
return compose_point_frame(pt_state, frame_idx)
def _get_active_prompts_tracking(state: TrackingState) -> str:
if state is None or not state.prompts:
return "**Active prompts:** None"
prompts_str = ", ".join([f"'{p}' ({len(ids)} obj)" for p, ids in state.prompts.items()])
return f"**Active prompts:** {prompts_str}"
def _get_active_prompts_points(pt_state: PointTrackingState) -> str:
if pt_state is None or not pt_state.prompt_text:
return "**Active prompts:** None"
return f"**Active prompts:** '{pt_state.prompt_text}' ({len(pt_state.trails)} tracked points)"
def init_tracking_video(state: TrackingState, video) -> tuple[TrackingState, int, int, Image.Image, str]:
state.reset()
if isinstance(video, dict):
path = video.get("name") or video.get("path") or video.get("data")
else:
path = video
if not path:
raise gr.Error("Invalid video input.")
frames, info = try_load_video_frames(path)
if not frames:
raise gr.Error("No frames could be loaded from the video.")
trimmed_note = ""
fps_in = info.get("fps")
max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames)
if len(frames) > max_frames_allowed:
frames = frames[:max_frames_allowed]
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
state.video_frames = frames
state.video_fps = float(fps_in) if fps_in else None
first_frame = frames[0]
max_idx = len(frames) - 1
status = f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. Ready for text prompting."
return state, 0, max_idx, first_frame, status
def init_point_video(pt_state: PointTrackingState, video) -> tuple[PointTrackingState, int, int, Image.Image, str]:
pt_state.reset()
if isinstance(video, dict):
path = video.get("name") or video.get("path") or video.get("data")
else:
path = video
if not path:
raise gr.Error("Invalid video input.")
frames, info = try_load_video_frames(path)
if not frames:
raise gr.Error("No frames could be loaded from the video.")
trimmed_note = ""
fps_in = info.get("fps")
max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames)
if len(frames) > max_frames_allowed:
frames = frames[:max_frames_allowed]
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
pt_state.video_frames = frames
pt_state.video_fps = float(fps_in) if fps_in else None
first_frame = frames[0]
max_idx = len(frames) - 1
status = f"Loaded {len(frames)} frames @ {pt_state.video_fps or 'unknown'} fps{trimmed_note}. Ready for point tracking."
return pt_state, 0, max_idx, first_frame, status
@spaces.GPU
def apply_tracking_prompt_on_frame(
state: TrackingState,
frame_idx: int,
text_prompt: str,
) -> tuple[Image.Image, str, str, TrackingState]:
if state is None or not state.video_frames:
return None, "Upload a video first.", "**Active prompts:** None", state
if not text_prompt or not text_prompt.strip():
ap = _get_active_prompts_tracking(state)
return update_tracking_display(state, int(frame_idx)), "Please enter a text prompt.", ap, state
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
frame = state.video_frames[frame_idx]
w, h = frame.size
prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
if not prompt_texts:
ap = _get_active_prompts_tracking(state)
return update_tracking_display(state, frame_idx), "Please enter a valid text prompt.", ap, state
status_parts = [f"Processing on frame {frame_idx}:"]
for prompt in prompt_texts:
bboxes = detect_objects_in_frame(frame, prompt)
if prompt not in state.color_by_prompt:
state.color_by_prompt[prompt] = pastel_color_for_prompt(prompt)
masks_f = state.masks_by_frame.setdefault(frame_idx, {})
bboxes_f = state.bboxes_by_frame.setdefault(frame_idx, {})
texts_f = state.text_prompts_by_frame_obj.setdefault(frame_idx, {})
obj_ids_for_prompt = []
for bbox in bboxes:
oid = state.next_obj_id
state.next_obj_id += 1
state.color_by_obj[oid] = state.color_by_prompt[prompt]
masks_f[oid] = bbox_to_mask(bbox, w, h)
bboxes_f[oid] = bbox
texts_f[oid] = prompt
state.prompts.setdefault(prompt, []).append(oid)
obj_ids_for_prompt.append(oid)
if obj_ids_for_prompt:
ids_str = ", ".join(map(str, obj_ids_for_prompt))
status_parts.append(f" • '{prompt}': {len(obj_ids_for_prompt)} object(s) (IDs: {ids_str})")
else:
status_parts.append(f" • '{prompt}': No objects detected.")
state.composited_frames.pop(frame_idx, None)
status = "\n".join(status_parts)
ap = _get_active_prompts_tracking(state)
return update_tracking_display(state, frame_idx), status, ap, state
@spaces.GPU
def propagate_tracking(state: TrackingState) -> Iterator[tuple[TrackingState, str, dict]]:
if state is None or not state.video_frames:
yield state, "Load a video first.", gr.update()
return
if not state.prompts:
yield state, "No prompts defined. Apply text prompt(s) on a frame first.", gr.update()
return
total = state.num_frames
processed = 0
yield state, f"Propagating: {processed}/{total}", gr.update()
for prompt, obj_ids in list(state.prompts.items()):
seed_frame_idx = None
seed_bboxes_by_oid = {}
for f_idx in sorted(state.bboxes_by_frame.keys()):
for oid in obj_ids:
if oid in state.bboxes_by_frame.get(f_idx, {}):
if seed_frame_idx is None:
seed_frame_idx = f_idx
if f_idx == seed_frame_idx:
seed_bboxes_by_oid[oid] = state.bboxes_by_frame[f_idx][oid]
if seed_frame_idx is None:
continue
# Forward propagation
prev_tracks = [(oid, seed_bboxes_by_oid[oid]) for oid in seed_bboxes_by_oid]
for f_idx in range(seed_frame_idx + 1, total):
frame = state.video_frames[f_idx]
w, h = frame.size
new_bboxes = detect_objects_in_frame(frame, prompt)
masks_f = state.masks_by_frame.setdefault(f_idx, {})
bboxes_f = state.bboxes_by_frame.setdefault(f_idx, {})
texts_f = state.text_prompts_by_frame_obj.setdefault(f_idx, {})
used = set()
matched = {}
scores = [
(bbox_iou(pbbox, nbbox), pi, ni)
for pi, (_, pbbox) in enumerate(prev_tracks)
for ni, nbbox in enumerate(new_bboxes)
]
scores.sort(reverse=True)
for score, pi, ni in scores:
if pi in matched or ni in used or score <= 0.05:
continue
matched[pi] = ni
used.add(ni)
for pi, (_, pbbox) in enumerate(prev_tracks):
if pi in matched:
continue
best = min(
((bbox_center_distance(pbbox, nbbox), ni) for ni, nbbox in enumerate(new_bboxes) if ni not in used),
default=(float('inf'), -1)
)
if best[0] < 300:
matched[pi] = best[1]
used.add(best[1])
new_prev = []
for pi, (oid, _) in enumerate(prev_tracks):
if pi in matched:
nbbox = new_bboxes[matched[pi]]
masks_f[oid] = bbox_to_mask(nbbox, w, h)
bboxes_f[oid] = nbbox
texts_f[oid] = prompt
new_prev.append((oid, nbbox))
for ni, nbbox in enumerate(new_bboxes):
if ni not in used:
oid = state.next_obj_id
state.next_obj_id += 1
state.color_by_obj[oid] = state.color_by_prompt.get(prompt, pastel_color_for_prompt(prompt))
masks_f[oid] = bbox_to_mask(nbbox, w, h)
bboxes_f[oid] = nbbox
texts_f[oid] = prompt
state.prompts.setdefault(prompt, []).append(oid)
new_prev.append((oid, nbbox))
prev_tracks = new_prev
state.composited_frames.pop(f_idx, None)
processed += 1
if processed % 5 == 0 or f_idx == total - 1:
yield state, f"Propagating '{prompt}' (forward): frame {f_idx}/{total}", gr.update(value=f_idx)
# Backward propagation
prev_tracks = [(oid, seed_bboxes_by_oid[oid]) for oid in seed_bboxes_by_oid]
for f_idx in range(seed_frame_idx - 1, -1, -1):
frame = state.video_frames[f_idx]
w, h = frame.size
new_bboxes = detect_objects_in_frame(frame, prompt)
masks_f = state.masks_by_frame.setdefault(f_idx, {})
bboxes_f = state.bboxes_by_frame.setdefault(f_idx, {})
texts_f = state.text_prompts_by_frame_obj.setdefault(f_idx, {})
used = set()
matched = {}
scores = [
(bbox_iou(pbbox, nbbox), pi, ni)
for pi, (_, pbbox) in enumerate(prev_tracks)
for ni, nbbox in enumerate(new_bboxes)
]
scores.sort(reverse=True)
for score, pi, ni in scores:
if pi in matched or ni in used or score <= 0.05:
continue
matched[pi] = ni
used.add(ni)
for pi, (_, pbbox) in enumerate(prev_tracks):
if pi in matched:
continue
best = min(
((bbox_center_distance(pbbox, nbbox), ni) for ni, nbbox in enumerate(new_bboxes) if ni not in used),
default=(float('inf'), -1)
)
if best[0] < 300:
matched[pi] = best[1]
used.add(best[1])
new_prev = []
for pi, (oid, _) in enumerate(prev_tracks):
if pi in matched:
nbbox = new_bboxes[matched[pi]]
masks_f[oid] = bbox_to_mask(nbbox, w, h)
bboxes_f[oid] = nbbox
texts_f[oid] = prompt
new_prev.append((oid, nbbox))
for ni, nbbox in enumerate(new_bboxes):
if ni not in used:
oid = state.next_obj_id
state.next_obj_id += 1
state.color_by_obj[oid] = state.color_by_prompt.get(prompt, pastel_color_for_prompt(prompt))
masks_f[oid] = bbox_to_mask(nbbox, w, h)
bboxes_f[oid] = nbbox
texts_f[oid] = prompt
state.prompts.setdefault(prompt, []).append(oid)
new_prev.append((oid, nbbox))
prev_tracks = new_prev
state.composited_frames.pop(f_idx, None)
processed += 1
if processed % 5 == 0 or f_idx == 0:
yield state, f"Propagating '{prompt}' (backward): frame {f_idx}/{total}", gr.update(value=f_idx)
yield state, f"✅ Propagation complete across {total} frames for {len(state.prompts)} prompt(s).", gr.update(value=0)
@spaces.GPU
def apply_point_prompt_on_frame(
pt_state: PointTrackingState,
frame_idx: int,
text_prompt: str,
) -> tuple[Image.Image, str, str, PointTrackingState]:
if pt_state is None or not pt_state.video_frames:
return None, "Upload a video first.", "**Active prompts:** None", pt_state
if not text_prompt or not text_prompt.strip():
ap = _get_active_prompts_points(pt_state)
return update_point_display(pt_state, int(frame_idx)), "Please enter a text prompt.", ap, pt_state
frame_idx = int(np.clip(frame_idx, 0, len(pt_state.video_frames) - 1))
frame = pt_state.video_frames[frame_idx]
pt_state.prompt_text = text_prompt.strip()
pt_state.points_by_frame.clear()
pt_state.trails.clear()
pt_state.composited_frames.clear()
prompts = [p.strip() for p in text_prompt.split(",") if p.strip()]
all_points = []
status_parts = [f"Point detection on frame {frame_idx}:"]
for prompt in prompts:
points = detect_precise_points_in_frame(frame, prompt)
for pt in points:
all_points.append(pt)
track_idx = len(pt_state.trails)
pt_state.trails.append([(frame_idx, pt[0], pt[1])])
status_parts.append(f" • '{prompt}': {len(points)} point(s)")
pt_state.points_by_frame[frame_idx] = all_points
pt_state.composited_frames.pop(frame_idx, None)
status = "\n".join(status_parts)
ap = _get_active_prompts_points(pt_state)
return update_point_display(pt_state, frame_idx), status, ap, pt_state
@spaces.GPU
def propagate_points(pt_state: PointTrackingState) -> Iterator[tuple[PointTrackingState, str, dict]]:
if pt_state is None or not pt_state.video_frames:
yield pt_state, "Load a video first.", gr.update()
return
if not pt_state.trails:
yield pt_state, "No points defined. Apply point prompt on a frame first.", gr.update()
return
if not pt_state.prompt_text:
yield pt_state, "No prompt text. Apply a text prompt first.", gr.update()
return
total = pt_state.num_frames
prompts = [p.strip() for p in pt_state.prompt_text.split(",") if p.strip()]
seed_frame_idx = None
for f_idx in sorted(pt_state.points_by_frame.keys()):
if pt_state.points_by_frame[f_idx]:
seed_frame_idx = f_idx
break
if seed_frame_idx is None:
yield pt_state, "No seed points found.", gr.update()
return
yield pt_state, f"Propagating points: 0/{total}", gr.update()
seed_tracks = []
for trail_idx, trail in enumerate(pt_state.trails):
for fi, tx, ty in trail:
if fi == seed_frame_idx:
seed_tracks.append((trail_idx, (tx, ty)))
break
prev_tracks = list(seed_tracks)
lost_count = {t[0]: 0 for t in prev_tracks}
for f_idx in range(seed_frame_idx + 1, total):
frame = pt_state.video_frames[f_idx]
w, h = frame.size
all_new_points = []
for prompt in prompts:
pts = detect_precise_points_in_frame(frame, prompt)
all_new_points.extend(pts)
points_f = []
diag = (w ** 2 + h ** 2) ** 0.5
match_threshold = diag * 0.25
if not all_new_points:
new_prev = []
for track_idx, prev_pt in prev_tracks:
lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
if lost_count[track_idx] > 5:
continue
points_f.append(prev_pt)
pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
new_prev.append((track_idx, prev_pt))
prev_tracks = new_prev
else:
used_new = set()
matched = {}
dist_pairs = []
for pi, (_, prev_pt) in enumerate(prev_tracks):
for ni, new_pt in enumerate(all_new_points):
d = pixel_point_distance(prev_pt, new_pt)
dist_pairs.append((d, pi, ni))
dist_pairs.sort()
for d, pi, ni in dist_pairs:
if pi in matched or ni in used_new:
continue
if d < match_threshold:
matched[pi] = ni
used_new.add(ni)
new_prev = []
for pi, (track_idx, prev_pt) in enumerate(prev_tracks):
if pi in matched:
new_pt = all_new_points[matched[pi]]
points_f.append(new_pt)
pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1]))
new_prev.append((track_idx, new_pt))
lost_count[track_idx] = 0
else:
lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
if lost_count[track_idx] <= 5:
points_f.append(prev_pt)
pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
new_prev.append((track_idx, prev_pt))
for ni, new_pt in enumerate(all_new_points):
if ni not in used_new:
too_close = any(pixel_point_distance(new_pt, pp) < diag * 0.08 for _, pp in new_prev)
if not too_close:
track_idx = len(pt_state.trails)
pt_state.trails.append([(f_idx, new_pt[0], new_pt[1])])
points_f.append(new_pt)
new_prev.append((track_idx, new_pt))
lost_count[track_idx] = 0
prev_tracks = new_prev
pt_state.points_by_frame[f_idx] = points_f
pt_state.composited_frames.pop(f_idx, None)
if (f_idx - seed_frame_idx) % 5 == 0 or f_idx == total - 1:
yield pt_state, f"Propagating points (forward): frame {f_idx}/{total}", gr.update(value=f_idx)
prev_tracks = list(seed_tracks)
lost_count = {t[0]: 0 for t in prev_tracks}
for f_idx in range(seed_frame_idx - 1, -1, -1):
frame = pt_state.video_frames[f_idx]
w, h = frame.size
all_new_points = []
for prompt in prompts:
pts = detect_precise_points_in_frame(frame, prompt)
all_new_points.extend(pts)
points_f = []
diag = (w ** 2 + h ** 2) ** 0.5
match_threshold = diag * 0.25
if not all_new_points:
new_prev = []
for track_idx, prev_pt in prev_tracks:
lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
if lost_count[track_idx] > 5:
continue
points_f.append(prev_pt)
pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
new_prev.append((track_idx, prev_pt))
prev_tracks = new_prev
else:
used_new = set()
matched = {}
dist_pairs = []
for pi, (_, prev_pt) in enumerate(prev_tracks):
for ni, new_pt in enumerate(all_new_points):
d = pixel_point_distance(prev_pt, new_pt)
dist_pairs.append((d, pi, ni))
dist_pairs.sort()
for d, pi, ni in dist_pairs:
if pi in matched or ni in used_new:
continue
if d < match_threshold:
matched[pi] = ni
used_new.add(ni)
new_prev = []
for pi, (track_idx, prev_pt) in enumerate(prev_tracks):
if pi in matched:
new_pt = all_new_points[matched[pi]]
points_f.append(new_pt)
pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1]))
new_prev.append((track_idx, new_pt))
lost_count[track_idx] = 0
else:
lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
if lost_count[track_idx] <= 5:
points_f.append(prev_pt)
pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
new_prev.append((track_idx, prev_pt))
for ni, new_pt in enumerate(all_new_points):
if ni not in used_new:
too_close = any(pixel_point_distance(new_pt, pp) < diag * 0.08 for _, pp in new_prev)
if not too_close:
track_idx = len(pt_state.trails)
pt_state.trails.append([(f_idx, new_pt[0], new_pt[1])])
points_f.append(new_pt)
new_prev.append((track_idx, new_pt))
lost_count[track_idx] = 0
prev_tracks = new_prev
pt_state.points_by_frame[f_idx] = points_f
pt_state.composited_frames.pop(f_idx, None)
if (seed_frame_idx - f_idx) % 5 == 0 or f_idx == 0:
yield pt_state, f"Propagating points (backward): frame {f_idx}/{total}", gr.update(value=f_idx)
yield pt_state, f"✅ Point propagation complete across {total} frames. {len(pt_state.trails)} tracks.", gr.update(value=seed_frame_idx)
def reset_tracking_prompts(state: TrackingState) -> tuple[TrackingState, Image.Image, str, str]:
if state is None:
return state, None, "No active session.", "**Active prompts:** None"
state.masks_by_frame.clear()
state.bboxes_by_frame.clear()
state.text_prompts_by_frame_obj.clear()
state.composited_frames.clear()
state.color_by_obj.clear()
state.color_by_prompt.clear()
state.prompts.clear()
state.next_obj_id = 1
current_idx = max(0, min(getattr(state, 'current_frame_idx', 0), state.num_frames - 1))
preview = update_tracking_display(state, current_idx)
return state, preview, "Prompts and outputs reset. Video preserved.", "**Active prompts:** None"
def reset_tracking_session(state: TrackingState) -> tuple[TrackingState, Image.Image, dict, dict, str, str]:
if not state.video_frames:
return state, None, gr.update(minimum=0, maximum=0), gr.update(value=0), "Session reset.", "**Active prompts:** None"
state.masks_by_frame.clear()
state.bboxes_by_frame.clear()
state.text_prompts_by_frame_obj.clear()
state.composited_frames.clear()
state.color_by_obj.clear()
state.color_by_prompt.clear()
state.prompts.clear()
state.next_obj_id = 1
gc.collect()
current_idx = max(0, min(getattr(state, 'current_frame_idx', 0), state.num_frames - 1))
preview = update_tracking_display(state, current_idx)
return (
state, preview,
gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True),
gr.update(value=current_idx),
"Session reset. Prompts cleared; video preserved.",
"**Active prompts:** None"
)
def reset_point_prompts(pt_state: PointTrackingState) -> tuple[PointTrackingState, Image.Image, str, str]:
if pt_state is None:
return pt_state, None, "No active session.", "**Active prompts:** None"
pt_state.points_by_frame.clear()
pt_state.trails.clear()
pt_state.composited_frames.clear()
pt_state.prompt_text = ""
current_idx = max(0, min(getattr(pt_state, 'current_frame_idx', 0), pt_state.num_frames - 1))
preview = update_point_display(pt_state, current_idx)
return pt_state, preview, "Point prompts reset. Video preserved.", "**Active prompts:** None"
def reset_point_session(pt_state: PointTrackingState) -> tuple[PointTrackingState, Image.Image, dict, dict, str, str]:
if not pt_state.video_frames:
return pt_state, None, gr.update(minimum=0, maximum=0), gr.update(value=0), "Session reset.", "**Active prompts:** None"
pt_state.points_by_frame.clear()
pt_state.trails.clear()
pt_state.composited_frames.clear()
pt_state.prompt_text = ""
gc.collect()
current_idx = max(0, min(getattr(pt_state, 'current_frame_idx', 0), pt_state.num_frames - 1))
preview = update_point_display(pt_state, current_idx)
return (
pt_state, preview,
gr.update(minimum=0, maximum=max(pt_state.num_frames - 1, 0), interactive=True),
gr.update(value=current_idx),
"Session reset. Video preserved.",
"**Active prompts:** None"
)
def render_tracking_video(state: TrackingState) -> str:
if state is None or state.num_frames == 0:
raise gr.Error("Load a video first.")
fps = state.video_fps if state.video_fps and state.video_fps > 0 else 12
frames_bgr = []
w, h = state.video_frames[0].size
for idx in range(state.num_frames):
img = state.composited_frames.get(idx)
if img is None:
img = compose_tracking_frame(state, idx)
frames_bgr.append(np.array(img)[:, :, ::-1])
if (idx + 1) % 60 == 0:
gc.collect()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
writer = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
for fr in frames_bgr:
writer.write(fr)
writer.release()
return tmp.name
def render_point_video(pt_state: PointTrackingState) -> str:
if pt_state is None or pt_state.num_frames == 0:
raise gr.Error("Load a video first.")
fps = pt_state.video_fps if pt_state.video_fps and pt_state.video_fps > 0 else 12
frames_bgr = []
w, h = pt_state.video_frames[0].size
for idx in range(pt_state.num_frames):
img = pt_state.composited_frames.get(idx)
if img is None:
img = compose_point_frame(pt_state, idx)
frames_bgr.append(np.array(img)[:, :, ::-1])
if (idx + 1) % 60 == 0:
gc.collect()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
writer = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
for fr in frames_bgr:
writer.write(fr)
writer.release()
return tmp.name
def _on_video_change_tracking(state: TrackingState, video) -> tuple[TrackingState, dict, Image.Image, str, str]:
if video is None:
return state, gr.update(), None, "", "**Active prompts:** None"
state, min_idx, max_idx, first_frame, status = init_tracking_video(state, video)
ap = _get_active_prompts_tracking(state)
return (
state,
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
first_frame,
status,
ap,
)
def _on_video_change_points(pt_state: PointTrackingState, video) -> tuple[PointTrackingState, dict, Image.Image, str, str]:
if video is None:
return pt_state, gr.update(), None, "", "**Active prompts:** None"
pt_state, min_idx, max_idx, first_frame, status = init_point_video(pt_state, video)
ap = _get_active_prompts_points(pt_state)
return (
pt_state,
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
first_frame,
status,
ap,
)
@spaces.GPU
def process_image_detection(image: Image.Image, prompt: str) -> tuple[Image.Image, str]:
if image is None:
raise gr.Error("Please upload an image.")
if not prompt or not prompt.strip():
raise gr.Error("Please provide a detection prompt.")
image = image.convert("RGB")
image.thumbnail((1024, 1024))
original_width, original_height = image.size
full_prompt = f"Provide bounding box coordinates for {prompt}. Report in JSON format."
output_text = run_model_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
objects_result = {"objects": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
xmin, ymin, xmax, ymax = item["bbox_2d"]
label = item.get("label", "object")
objects_result["objects"].append({
"x_min": xmin / 1000.0,
"y_min": ymin / 1000.0,
"x_max": xmax / 1000.0,
"y_max": ymax / 1000.0,
"label": label,
})
elif isinstance(parsed_json, dict):
if "bbox_2d" in parsed_json and len(parsed_json["bbox_2d"]) == 4:
xmin, ymin, xmax, ymax = parsed_json["bbox_2d"]
label = parsed_json.get("label", "object")
objects_result["objects"].append({
"x_min": xmin / 1000.0,
"y_min": ymin / 1000.0,
"x_max": xmax / 1000.0,
"y_max": ymax / 1000.0,
"label": label,
})
if not objects_result["objects"]:
bboxes = parse_bboxes_from_text(output_text)
for idx, bbox in enumerate(bboxes):
objects_result["objects"].append({
"x_min": bbox[0] / 1000.0,
"y_min": bbox[1] / 1000.0,
"x_max": bbox[2] / 1000.0,
"y_max": bbox[3] / 1000.0,
"label": prompt.strip(),
})
annotated = annotate_image_detection(image.copy(), objects_result)
result_text = json.dumps(objects_result, indent=2) if objects_result["objects"] else f"No objects detected for '{prompt}'.\n\nRaw output:\n{output_text}"
return annotated, result_text
@spaces.GPU
def process_image_pointer(image: Image.Image, prompt: str) -> tuple[Image.Image, str]:
if image is None:
raise gr.Error("Please upload an image.")
if not prompt or not prompt.strip():
raise gr.Error("Please provide a pointing prompt.")
image = image.convert("RGB")
image.thumbnail((1024, 1024))
original_width, original_height = image.size
full_prompt = f"Provide 2d point coordinates for {prompt}. Report in JSON format."
output_text = run_model_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
points_result = {"points": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "point_2d" in item and len(item["point_2d"]) == 2:
x, y = item["point_2d"]
points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
elif isinstance(parsed_json, dict):
if "point_2d" in parsed_json and len(parsed_json["point_2d"]) == 2:
x, y = parsed_json["point_2d"]
points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
if not points_result["points"]:
detected_points = parse_precise_points(output_text, original_width, original_height)
for px, py in detected_points:
points_result["points"].append({
"x": px / original_width,
"y": py / original_height,
})
if not points_result["points"]:
bboxes = parse_bboxes_from_text(output_text)
for bbox in bboxes:
cx = (bbox[0] + bbox[2]) / 2 / 1000.0
cy = (bbox[1] + bbox[3]) / 2 / 1000.0
points_result["points"].append({"x": cx, "y": cy})
annotated = annotate_image_points(image.copy(), points_result)
result_text = json.dumps(points_result, indent=2) if points_result["points"] else f"No points detected for '{prompt}'.\n\nRaw output:\n{output_text}"
return annotated, result_text
css = """
#col-container {
margin: 0 auto;
max-width: 900px;
}
#main-title h1 { font-size: 2.6em !important; }
"""
with gr.Blocks() as demo:
gr.Markdown("# **Qwen3-VL-Video-Grounding**", elem_id="main-title")
gr.Markdown(
"""
Perform text-guided object tracking, point tracking, image detection, and image pointing with the Qwen3-VL multimodal model.
**Video tabs:** Upload a video → Select a frame → Apply text prompt(s) → Preview → Propagate → Render MP4.
**Image tabs:** Upload an image → Enter prompt → Get instant detection or pointing results.
Due to compute constraints, this app only supports stop-frame object detection or tracking on propagated frames. For dense-frame full video processing, please visit the [GitHub](https://github.com/PRITHIVSAKTHIUR/Qwen3-VL-Video-Grounding) page.
"""
)
tracking_state = gr.State(TrackingState())
point_state = gr.State(PointTrackingState())
with gr.Tabs() as main_tabs:
with gr.Tab("Video Object Tracking"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Quick start**
- **Load a video**: Upload your own or pick an example below.
- Select a frame and enter text description(s) to detect objects (e.g., "red car", "person"). Multiple prompts separated by commas.
- The text prompt detects all instances on the **selected frame only**.
"""
)
with gr.Column():
gr.Markdown(
"""
**Working with results**
- **Preview**: Use the slider to navigate frames and see the current masks/bboxes.
- **Propagate**: Click "Propagate across video" to track all defined objects through every frame.
- **Export**: Render an MP4 for smooth playback using the original video FPS.
"""
)
with gr.Row():
with gr.Column(scale=1):
video_in_tracking = gr.Video(label="Upload video", sources=["upload", "webcam"])
load_status_tracking = gr.Markdown(visible=True)
reset_btn_tracking = gr.Button("Reset Session", variant="secondary")
with gr.Column(scale=2):
preview_tracking = gr.Image(label="Preview")
with gr.Row():
frame_slider_tracking = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0)
with gr.Column(scale=0):
propagate_btn_tracking = gr.Button("Propagate across video", variant="primary")
propagate_status_tracking = gr.Markdown(visible=True)
with gr.Row():
text_prompt_tracking = gr.Textbox(
label="Text Prompt(s)",
placeholder="Enter text description(s) (e.g., 'person' or 'person, car, dog' for multiple)",
lines=2,
)
with gr.Column(scale=0):
apply_btn_tracking = gr.Button("Apply Text Prompt(s)", variant="primary")
reset_prompts_btn_tracking = gr.Button("Reset Prompts", variant="secondary")
active_prompts_tracking = gr.Markdown("**Active prompts:** None", visible=True)
text_status_tracking = gr.Markdown(visible=True)
with gr.Row():
render_btn_tracking = gr.Button("Render MP4 for smooth playback", variant="primary")
playback_video_tracking = gr.Video(label="Rendered Playback", interactive=False)
gr.Examples(
examples=[
["examples/1.mp4"],
["examples/2.mp4"],
["examples/3.mp4"],
],
inputs=[video_in_tracking],
label="Examples"
)
with gr.Tab("Video Points Tracker"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Quick start**
- **Load a video**: Upload your own or pick an example below.
- Select a frame and enter text description(s) (e.g., `person, ball`).
- The model locates the **center point** of each detected object on the **selected frame only**.
"""
)
with gr.Column():
gr.Markdown(
"""
**Working with results**
- **Preview**: Use the slider to see detected points and motion trails.
- **Propagate**: Click "Propagate across video" to track points through every frame.
- **Export**: Render an MP4 with red dot tracking and trails.
"""
)
with gr.Row():
with gr.Column(scale=1):
video_in_points = gr.Video(label="Upload video", sources=["upload", "webcam"])
load_status_points = gr.Markdown(visible=True)
reset_btn_points = gr.Button("Reset Session", variant="secondary")
with gr.Column(scale=2):
preview_points = gr.Image(label="Preview")
with gr.Row():
frame_slider_points = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0)
with gr.Column(scale=0):
propagate_btn_points = gr.Button("Propagate across video", variant="primary")
propagate_status_points = gr.Markdown(visible=True)
with gr.Row():
text_prompt_points = gr.Textbox(
label="Text Prompt(s)",
placeholder="Enter text description(s) (e.g., 'person' or 'person, ball' for multiple)",
lines=2,
)
with gr.Column(scale=0):
apply_btn_points = gr.Button("Apply Point Prompt(s)", variant="primary")
reset_prompts_btn_points = gr.Button("Reset Prompts", variant="secondary")
active_prompts_points = gr.Markdown("**Active prompts:** None", visible=True)
text_status_points = gr.Markdown(visible=True)
with gr.Row():
render_btn_points = gr.Button("Render MP4 for smooth playback", variant="primary")
playback_video_points = gr.Video(label="Rendered Playback", interactive=False)
gr.Examples(
examples=[
["examples/1.mp4"],
["examples/2.mp4"],
["examples/3.mp4"],
],
inputs=[video_in_points],
label="Examples"
)
with gr.Tab("Image Detection"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Image Object Detection**
- Upload an image and enter what you want to detect.
- The model returns bounding boxes around all matching objects.
- Results are displayed as colored boxes with labels.
"""
)
with gr.Column():
gr.Markdown(
"""
**Tips**
- Be specific: "red car" works better than just "car" if you want a specific one.
- You can detect multiple types: try "person", "headlight", "window", etc.
- The JSON output shows normalized coordinates (0-1 range).
"""
)
with gr.Row():
with gr.Column(scale=1):
img_det_input = gr.Image(type="pil", label="Upload Image", height=400)
img_det_prompt = gr.Textbox(
label="Detection Prompt",
placeholder="e.g., headlight, person, red car, laptop",
lines=2,
)
img_det_btn = gr.Button("Detect Objects", variant="primary")
with gr.Column(scale=1):
img_det_output = gr.Image(label="Detection Result", height=400)
img_det_text = gr.Textbox(label="Detection Output (JSON)", lines=10, interactive=False)
gr.Examples(
examples=[
["examples-images/5.jpg", "children"],
["examples-images/4.jpg", "headlight"],
["examples-images/3.jpg", "gun"],
["examples-images/1.jpg", "boat"],
],
inputs=[img_det_input, img_det_prompt],
label="Examples"
)
with gr.Tab("Image Pointer"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Image Point Detection**
- Upload an image and describe what you want to point to.
- The model returns precise center-point coordinates for each matching object.
- Results are displayed as red dots on the image.
"""
)
with gr.Column():
gr.Markdown(
"""
**Tips**
- Great for locating specific parts: "the gun held by the person", "nose of the dog".
- Multiple instances are supported: all matching objects get a point.
- The JSON output shows normalized coordinates (0-1 range).
"""
)
with gr.Row():
with gr.Column(scale=1):
img_pt_input = gr.Image(type="pil", label="Upload Image", height=400)
img_pt_prompt = gr.Textbox(
label="Pointing Prompt",
placeholder="e.g., the gun held by the person, nose of the dog",
lines=2,
)
img_pt_btn = gr.Button("Point to Objects", variant="primary")
with gr.Column(scale=1):
img_pt_output = gr.Image(label="Pointing Result", height=400)
img_pt_text = gr.Textbox(label="Points Output (JSON)", lines=10, interactive=False)
gr.Examples(
examples=[
["examples-images/5.jpg", "children who are out of focus and wearing a white T-shirt"],
["examples-images/3.jpg", "gun"],
["examples-images/4.jpg", "headlight"],
["examples-images/1.jpg", "boat"],
],
inputs=[img_pt_input, img_pt_prompt],
label="Examples"
)
video_in_tracking.change(
fn=_on_video_change_tracking,
inputs=[tracking_state, video_in_tracking],
outputs=[tracking_state, frame_slider_tracking, preview_tracking, load_status_tracking, active_prompts_tracking],
show_progress=True,
)
def _sync_tracking_frame(state_in: TrackingState, idx: int) -> Image.Image:
if state_in is not None:
state_in.current_frame_idx = int(idx)
return update_tracking_display(state_in, int(idx))
frame_slider_tracking.change(
fn=_sync_tracking_frame,
inputs=[tracking_state, frame_slider_tracking],
outputs=preview_tracking,
)
apply_btn_tracking.click(
fn=apply_tracking_prompt_on_frame,
inputs=[tracking_state, frame_slider_tracking, text_prompt_tracking],
outputs=[preview_tracking, text_status_tracking, active_prompts_tracking, tracking_state],
)
propagate_btn_tracking.click(
fn=propagate_tracking,
inputs=tracking_state,
outputs=[tracking_state, propagate_status_tracking, frame_slider_tracking],
)
reset_prompts_btn_tracking.click(
fn=reset_tracking_prompts,
inputs=tracking_state,
outputs=[tracking_state, preview_tracking, text_status_tracking, active_prompts_tracking],
)
reset_btn_tracking.click(
fn=reset_tracking_session,
inputs=tracking_state,
outputs=[tracking_state, preview_tracking, frame_slider_tracking, frame_slider_tracking, load_status_tracking, active_prompts_tracking],
)
render_btn_tracking.click(
fn=render_tracking_video,
inputs=tracking_state,
outputs=playback_video_tracking,
)
video_in_points.change(
fn=_on_video_change_points,
inputs=[point_state, video_in_points],
outputs=[point_state, frame_slider_points, preview_points, load_status_points, active_prompts_points],
show_progress=True,
)
def _sync_point_frame(state_in: PointTrackingState, idx: int) -> Image.Image:
if state_in is not None:
state_in.current_frame_idx = int(idx)
return update_point_display(state_in, int(idx))
frame_slider_points.change(
fn=_sync_point_frame,
inputs=[point_state, frame_slider_points],
outputs=preview_points,
)
apply_btn_points.click(
fn=apply_point_prompt_on_frame,
inputs=[point_state, frame_slider_points, text_prompt_points],
outputs=[preview_points, text_status_points, active_prompts_points, point_state],
)
propagate_btn_points.click(
fn=propagate_points,
inputs=point_state,
outputs=[point_state, propagate_status_points, frame_slider_points],
)
reset_prompts_btn_points.click(
fn=reset_point_prompts,
inputs=point_state,
outputs=[point_state, preview_points, text_status_points, active_prompts_points],
)
reset_btn_points.click(
fn=reset_point_session,
inputs=point_state,
outputs=[point_state, preview_points, frame_slider_points, frame_slider_points, load_status_points, active_prompts_points],
)
render_btn_points.click(
fn=render_point_video,
inputs=point_state,
outputs=playback_video_points,
)
img_det_btn.click(
fn=process_image_detection,
inputs=[img_det_input, img_det_prompt],
outputs=[img_det_output, img_det_text],
show_progress=True,
)
img_pt_btn.click(
fn=process_image_pointer,
inputs=[img_pt_input, img_pt_prompt],
outputs=[img_pt_output, img_pt_text],
show_progress=True,
)
demo.queue(api_open=False).launch(theme=Soft(primary_hue="orange", secondary_hue="rose"), css=css, ssr_mode=False)