prithivMLmods's picture
update [from: test-sand-box] (cleaned) ✅
1ce5b62 verified
raw
history blame
44.5 kB
import colorsys
import gc
import tempfile
import re
import json
import uuid
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
from molmo_utils import process_vision_info
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID_V = "prithivMLmods/Qwen3-VL-4B-Instruct-Unredacted-MAX" # @--- Max model is trained on top of - Qwen/Qwen3-VL-4B-Instruct ---@
DTYPE = torch.float16
print(f"Loading {MODEL_ID_V}...")
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
model_v = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE
).to(device).eval()
print("Model loaded successfully.")
MAX_SECONDS = 8.0
SYSTEM_PROMPT = """You are a helpful assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled between 0 and 1000. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."""
POINT_SYSTEM_PROMPT = """You are a precise object pointing assistant. When asked to point to an object in an image, you must return ONLY the exact center coordinates of that specific object as [x, y] with values scaled between 0 and 1000 (where 0,0 is the top-left corner and 1000,1000 is the bottom-right corner).
Rules:
1. ONLY point to objects that exactly match the description given.
2. Do NOT point to background, empty areas, or unrelated objects.
3. If there are multiple matching instances, return [[x1, y1], [x2, y2], ...].
4. If no matching object is found, return an empty list [].
5. Return ONLY the coordinate numbers, no explanations or other text.
6. Be extremely precise — place the point at the exact visual center of each matching object."""
POINTS_REGEX = re.compile(r'(?:(\d+)\s*[.:])?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)')
COORD_REGEX = re.compile(r'\[([\s\S]*?)\]')
FRAME_REGEX = re.compile(r'(\d+(?:\.\d+)?)\s*[,:]\s*([\d\s,\.]+)')
class RadioAnimated(gr.HTML):
def __init__(self, choices, value=None, **kwargs):
if not choices or len(choices) < 2:
raise ValueError("RadioAnimated requires at least 2 choices.")
if value is None:
value = choices[0]
uid = uuid.uuid4().hex[:8]
group_name = f"ra-{uid}"
inputs_html = "\n".join(
f"""
<input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
<label class="ra-label" for="{group_name}-{i}">{c}</label>
"""
for i, c in enumerate(choices)
)
html_template = f"""
<div class="ra-wrap" data-ra="{uid}">
<div class="ra-inner">
<div class="ra-highlight"></div>
{inputs_html}
</div>
</div>
"""
js_on_load = r"""
(() => {
const wrap = element.querySelector('.ra-wrap');
const inner = element.querySelector('.ra-inner');
const highlight = element.querySelector('.ra-highlight');
const inputs = Array.from(element.querySelectorAll('.ra-input'));
if (!inputs.length) return;
const choices = inputs.map(i => i.value);
function setHighlightByIndex(idx) {
const n = choices.length;
const pct = 100 / n;
highlight.style.width = `calc(${pct}% - 6px)`;
highlight.style.transform = `translateX(${idx * 100}%)`;
}
function setCheckedByValue(val, shouldTrigger=false) {
const idx = Math.max(0, choices.indexOf(val));
inputs.forEach((inp, i) => { inp.checked = (i === idx); });
setHighlightByIndex(idx);
props.value = choices[idx];
if (shouldTrigger) trigger('change', props.value);
}
setCheckedByValue(props.value ?? choices[0], false);
inputs.forEach((inp) => {
inp.addEventListener('change', () => {
setCheckedByValue(inp.value, true);
});
});
})();
"""
super().__init__(
value=value,
html_template=html_template,
js_on_load=js_on_load,
**kwargs
)
def apply_gpu_duration(val: str):
try:
return int(val)
except (TypeError, ValueError):
return 90
def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
cap = cv2.VideoCapture(video_path_or_url)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
fps_val = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return frames, {"num_frames": len(frames), "fps": float(fps_val) if fps_val > 0 else None}
def parse_bboxes_from_text(text: str) -> list[list[float]]:
text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL)
nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text)
if nested:
try:
all_b = []
for m in nested:
parsed = json.loads(m)
all_b.extend(parsed if isinstance(parsed[0], list) else [parsed])
return all_b
except (json.JSONDecodeError, IndexError):
pass
single = re.findall(
r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text)
if single:
return [[float(v) for v in m] for m in single]
nums = re.findall(r'(\d+(?:\.\d+)?)', text)
return [[float(nums[i]), float(nums[i + 1]), float(nums[i + 2]), float(nums[i + 3])] for i in
range(0, len(nums) - 3, 4)] if len(nums) >= 4 else []
def parse_precise_points(text: str, image_w: int, image_h: int) -> list[tuple[float, float]]:
text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL)
raw_points = []
nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text)
if nested:
try:
for m in nested:
parsed = json.loads(m)
if isinstance(parsed[0], list):
for p in parsed:
if len(p) >= 2:
raw_points.append((float(p[0]), float(p[1])))
elif len(parsed) >= 2:
raw_points.append((float(parsed[0]), float(parsed[1])))
except (json.JSONDecodeError, IndexError):
pass
if not raw_points:
single = re.findall(
r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text)
if single:
for m in single:
raw_points.append((float(m[0]), float(m[1])))
if not raw_points:
for match in POINTS_REGEX.finditer(text):
x_val = float(match.group(2))
y_val = float(match.group(3))
raw_points.append((x_val, y_val))
validated = []
for sx, sy in raw_points:
if not (0 <= sx <= 1000 and 0 <= sy <= 1000):
continue
px = sx / 1000 * image_w
py = sy / 1000 * image_h
if 0 <= px <= image_w and 0 <= py <= image_h:
validated.append((px, py))
if len(validated) > 1:
deduped = [validated[0]]
for pt in validated[1:]:
is_dup = False
for existing in deduped:
dist = ((pt[0] - existing[0]) ** 2 + (pt[1] - existing[1]) ** 2) ** 0.5
if dist < 15:
is_dup = True
break
if not is_dup:
deduped.append(pt)
validated = deduped
return validated
def bbox_to_mask(bbox_scaled: list[float], width: int, height: int) -> np.ndarray:
mask = np.zeros((height, width), dtype=np.float32)
x1 = max(0, min(int(bbox_scaled[0] / 1000 * width), width - 1))
y1 = max(0, min(int(bbox_scaled[1] / 1000 * height), height - 1))
x2 = max(0, min(int(bbox_scaled[2] / 1000 * width), width - 1))
y2 = max(0, min(int(bbox_scaled[3] / 1000 * height), height - 1))
mask[y1:y2, x1:x2] = 1.0
return mask
def bbox_iou(b1, b2):
x1 = max(b1[0], b2[0])
y1 = max(b1[1], b2[1])
x2 = min(b1[2], b2[2])
y2 = min(b1[3], b2[3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
union = (b1[2] - b1[0]) * (b1[3] - b1[1]) + (b2[2] - b2[0]) * (b2[3] - b2[1]) - inter
return inter / union if union > 0 else 0.0
def bbox_center_distance(b1, b2):
c1 = ((b1[0] + b1[2]) / 2, (b1[1] + b1[3]) / 2)
c2 = ((b2[0] + b2[2]) / 2, (b2[1] + b2[3]) / 2)
return ((c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2) ** 0.5
def pixel_point_distance(p1: tuple, p2: tuple) -> float:
return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
def overlay_masks_on_frame(frame: Image.Image, masks: dict, colors_map: dict, alpha=0.45) -> Image.Image:
base = np.array(frame).astype(np.float32) / 255
overlay = base.copy()
for oid, mask in masks.items():
if mask is None:
continue
color = np.array(colors_map.get(oid, (255, 0, 0)), dtype=np.float32) / 255
m = np.clip(mask, 0, 1)[..., None]
overlay = (1 - alpha * m) * overlay + (alpha * m) * color
return Image.fromarray(np.clip(overlay * 255, 0, 255).astype(np.uint8))
def pastel_color_for_prompt(prompt: str):
hue = (sum(ord(c) for c in prompt) * 2654435761 % 360) / 360
r, g, b = colorsys.hsv_to_rgb(hue, 0.5, 0.95)
return int(r * 255), int(g * 255), int(b * 255)
class AppState:
def __init__(self):
self.reset()
def reset(self):
self.video_frames: list[Image.Image] = []
self.video_fps: float | None = None
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
self.bboxes_by_frame: dict[int, dict[int, list[float]]] = {}
self.color_by_obj: dict[int, tuple[int, int, int]] = {}
self.color_by_prompt: dict[str, tuple[int, int, int]] = {}
self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {}
self.prompts: dict[str, list[int]] = {}
self.next_obj_id: int = 1
@property
def num_frames(self) -> int:
return len(self.video_frames)
class PointTrackerState:
def __init__(self):
self.reset()
def reset(self):
self.video_frames: list[Image.Image] = []
self.video_fps: float | None = None
self.points_by_frame: dict[int, list[tuple[float, float]]] = {}
self.trails: list[list[tuple[int, float, float]]] = []
@property
def num_frames(self) -> int:
return len(self.video_frames)
def detect_objects_in_frame(frame: Image.Image, prompt: str) -> list[list[float]]:
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": frame}, {"type": "text", "text": f"Detect all instances of: {prompt}"}]}
]
text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
generated = out[:, inputs.input_ids.shape[1]:]
txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
return parse_bboxes_from_text(txt)
def detect_precise_points_in_frame(frame: Image.Image, prompt: str) -> list[tuple[float, float]]:
w, h = frame.size
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": frame},
{"type": "text",
"text": f"Detect all instances of: {prompt}. Return only bounding boxes for objects that exactly match this description."}]}
]
text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
generated = out[:, inputs.input_ids.shape[1]:]
txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
bboxes = parse_bboxes_from_text(txt)
if bboxes:
points = []
for b in bboxes:
bw = abs(b[2] - b[0])
bh = abs(b[3] - b[1])
if bw < 5 or bh < 5:
continue
if bw > 950 and bh > 950:
continue
cx = (b[0] + b[2]) / 2 / 1000 * w
cy = (b[1] + b[3]) / 2 / 1000 * h
if 0 <= cx <= w and 0 <= cy <= h:
points.append((cx, cy))
if len(points) > 1:
deduped = [points[0]]
for pt in points[1:]:
is_dup = any(pixel_point_distance(pt, ex) < 20 for ex in deduped)
if not is_dup:
deduped.append(pt)
points = deduped
if points:
return points
messages2 = [
{"role": "system", "content": [{"type": "text", "text": POINT_SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": frame},
{"type": "text",
"text": f"Point to the exact center of each '{prompt}' in this image. Only point to objects that are clearly '{prompt}', nothing else."}]}
]
text2 = processor_v.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
inputs2 = processor_v(text=[text2], images=[frame], padding=True, return_tensors="pt").to(device)
with torch.no_grad():
out2 = model_v.generate(**inputs2, max_new_tokens=512, do_sample=False)
generated2 = out2[:, inputs2.input_ids.shape[1]:]
txt2 = processor_v.batch_decode(generated2, skip_special_tokens=True)[0]
return parse_precise_points(txt2, w, h)
def track_prompt_across_frames(state: AppState, prompt: str):
total = state.num_frames
if prompt in state.prompts:
for oid in state.prompts[prompt]:
for f in range(total):
state.masks_by_frame[f].pop(oid, None)
state.bboxes_by_frame[f].pop(oid, None)
state.text_prompts_by_frame_obj[f].pop(oid, None)
del state.prompts[prompt]
prev_tracks: list[tuple[int, list[float]]] = []
for f_idx in range(total):
frame = state.video_frames[f_idx]
w, h = frame.size
new_bboxes = detect_objects_in_frame(frame, prompt)
masks_f = state.masks_by_frame.setdefault(f_idx, {})
bboxes_f = state.bboxes_by_frame.setdefault(f_idx, {})
texts_f = state.text_prompts_by_frame_obj.setdefault(f_idx, {})
if not prev_tracks:
for bbox in new_bboxes:
oid = state.next_obj_id
state.next_obj_id += 1
if prompt not in state.color_by_prompt:
state.color_by_prompt[prompt] = pastel_color_for_prompt(prompt)
state.color_by_obj[oid] = state.color_by_prompt[prompt]
masks_f[oid] = bbox_to_mask(bbox, w, h)
bboxes_f[oid] = bbox
texts_f[oid] = prompt
state.prompts.setdefault(prompt, []).append(oid)
prev_tracks.append((oid, bbox))
continue
used = set()
matched = {}
scores = [(bbox_iou(pbbox, nbbox), pi, ni) for pi, (_, pbbox) in enumerate(prev_tracks) for ni, nbbox in
enumerate(new_bboxes)]
scores.sort(reverse=True)
for score, pi, ni in scores:
if pi in matched or ni in used or score <= 0.05:
continue
matched[pi] = ni
used.add(ni)
for pi, (_, pbbox) in enumerate(prev_tracks):
if pi in matched:
continue
best = min(((bbox_center_distance(pbbox, nbbox), ni) for ni, nbbox in enumerate(new_bboxes) if ni not in used),
default=(float('inf'), -1))
if best[0] < 300:
matched[pi] = best[1]
used.add(best[1])
new_prev = []
for pi, (oid, _) in enumerate(prev_tracks):
if pi in matched:
nbbox = new_bboxes[matched[pi]]
masks_f[oid] = bbox_to_mask(nbbox, w, h)
bboxes_f[oid] = nbbox
texts_f[oid] = prompt
new_prev.append((oid, nbbox))
for ni, nbbox in enumerate(new_bboxes):
if ni not in used:
oid = state.next_obj_id
state.next_obj_id += 1
if prompt not in state.color_by_prompt:
state.color_by_prompt[prompt] = pastel_color_for_prompt(prompt)
state.color_by_obj[oid] = state.color_by_prompt[prompt]
masks_f[oid] = bbox_to_mask(nbbox, w, h)
bboxes_f[oid] = nbbox
texts_f[oid] = prompt
state.prompts.setdefault(prompt, []).append(oid)
new_prev.append((oid, nbbox))
prev_tracks = new_prev
def track_points_across_frames(pt_state: PointTrackerState, prompt: str):
total = pt_state.num_frames
prev_tracks: list[tuple[int, tuple[float, float]]] = []
lost_count: dict[int, int] = {}
for f_idx in range(total):
frame = pt_state.video_frames[f_idx]
w, h = frame.size
new_points = detect_precise_points_in_frame(frame, prompt)
points_f = pt_state.points_by_frame.setdefault(f_idx, [])
if not prev_tracks:
for px, py in new_points:
track_idx = len(pt_state.trails)
pt_state.trails.append([])
points_f.append((px, py))
pt_state.trails[track_idx].append((f_idx, px, py))
prev_tracks.append((track_idx, (px, py)))
lost_count[track_idx] = 0
continue
if not new_points:
new_prev = []
for track_idx, prev_pt in prev_tracks:
lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
if lost_count[track_idx] > 5:
continue
points_f.append(prev_pt)
pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
new_prev.append((track_idx, prev_pt))
prev_tracks = new_prev
continue
diag = (w ** 2 + h ** 2) ** 0.5
match_threshold = diag * 0.25
used_new = set()
matched = {}
dist_pairs = []
for pi, (_, prev_pt) in enumerate(prev_tracks):
for ni, new_pt in enumerate(new_points):
d = pixel_point_distance(prev_pt, new_pt)
dist_pairs.append((d, pi, ni))
dist_pairs.sort()
for d, pi, ni in dist_pairs:
if pi in matched or ni in used_new:
continue
if d < match_threshold:
matched[pi] = ni
used_new.add(ni)
new_prev = []
for pi, (track_idx, prev_pt) in enumerate(prev_tracks):
if pi in matched:
ni = matched[pi]
new_pt = new_points[ni]
points_f.append(new_pt)
pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1]))
new_prev.append((track_idx, new_pt))
lost_count[track_idx] = 0
else:
lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
if lost_count[track_idx] <= 5:
points_f.append(prev_pt)
pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
new_prev.append((track_idx, prev_pt))
for ni, new_pt in enumerate(new_points):
if ni not in used_new:
too_close = any(
pixel_point_distance(new_pt, prev_pt) < diag * 0.08
for _, prev_pt in new_prev
)
if not too_close:
track_idx = len(pt_state.trails)
pt_state.trails.append([])
points_f.append(new_pt)
pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1]))
new_prev.append((track_idx, new_pt))
lost_count[track_idx] = 0
prev_tracks = new_prev
def render_point_tracker_video(pt_state: PointTrackerState, output_fps: int, trail_length: int = 12) -> str:
RED = (255, 40, 40)
DARK_RED = (180, 0, 0)
frames_bgr = []
for i in range(pt_state.num_frames):
frame = pt_state.video_frames[i].copy()
draw = ImageDraw.Draw(frame)
points_f = pt_state.points_by_frame.get(i, [])
for trail in pt_state.trails:
trail_pts = [(tx, ty) for fi, tx, ty in trail if fi <= i and fi > i - trail_length]
if len(trail_pts) >= 2:
for t_idx in range(len(trail_pts) - 1):
alpha_ratio = (t_idx + 1) / len(trail_pts)
trail_color = (
int(DARK_RED[0] * alpha_ratio),
int(DARK_RED[1] * alpha_ratio),
int(DARK_RED[2] * alpha_ratio)
)
thickness = max(1, int(2 * alpha_ratio))
x1t, y1t = int(trail_pts[t_idx][0]), int(trail_pts[t_idx][1])
x2t, y2t = int(trail_pts[t_idx + 1][0]), int(trail_pts[t_idx + 1][1])
draw.line([(x1t, y1t), (x2t, y2t)], fill=trail_color, width=thickness)
for (px, py) in points_f:
r_outer = 10
draw.ellipse(
(px - r_outer, py - r_outer, px + r_outer, py + r_outer),
outline="white", width=2
)
r = 7
draw.ellipse(
(px - r, py - r, px + r, py + r),
fill=RED, outline=RED
)
r_inner = 2
draw.ellipse(
(px - r_inner, py - r_inner, px + r_inner, py + r_inner),
fill=(255, 200, 200)
)
frames_bgr.append(np.array(frame)[:, :, ::-1])
if (i + 1) % 30 == 0:
gc.collect()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
writer = cv2.VideoWriter(
tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), output_fps,
(frames_bgr[0].shape[1], frames_bgr[0].shape[0])
)
for fr in frames_bgr:
writer.write(fr)
writer.release()
return tmp.name
def render_full_video(state: AppState, output_fps: int) -> str:
fps = output_fps
frames_bgr = []
for i in range(state.num_frames):
frame = state.video_frames[i].copy()
masks = state.masks_by_frame.get(i, {})
if masks:
frame = overlay_masks_on_frame(frame, masks, state.color_by_obj)
bboxes = state.bboxes_by_frame.get(i, {})
if bboxes:
draw = ImageDraw.Draw(frame)
w, h = frame.size
for oid, bbox in bboxes.items():
color = state.color_by_obj.get(oid, (255, 255, 255))
x1 = int(bbox[0] / 1000 * w)
y1 = int(bbox[1] / 1000 * h)
x2 = int(bbox[2] / 1000 * w)
y2 = int(bbox[3] / 1000 * h)
draw.rectangle((x1, y1, x2, y2), outline=color, width=4)
prompt = state.text_prompts_by_frame_obj.get(i, {}).get(oid, "")
if prompt:
label = f"{prompt} - ID{oid}"
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
except OSError:
font = ImageFont.load_default()
tb = draw.textbbox((x1, max(0, y1 - 30)), label, font=font)
draw.rectangle(tb, fill=color)
draw.text((x1 + 4, max(0, y1 - 27)), label, fill="white", font=font)
frames_bgr.append(np.array(frame)[:, :, ::-1])
if (i + 1) % 30 == 0:
gc.collect()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
writer = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), fps,
(frames_bgr[0].shape[1], frames_bgr[0].shape[0]))
for fr in frames_bgr:
writer.write(fr)
writer.release()
return tmp.name
def calc_gpu_duration_tracking(state, video, text_prompt, output_fps, gpu_timeout):
try:
return int(gpu_timeout)
except (TypeError, ValueError):
return 90
def calc_gpu_duration_points(pt_state, video, text_prompt, output_fps, gpu_timeout):
try:
return int(gpu_timeout)
except (TypeError, ValueError):
return 90
def calc_gpu_duration_qa(video, user_text, max_new_tokens, gpu_timeout):
try:
return int(gpu_timeout)
except (TypeError, ValueError):
return 90
@spaces.GPU(duration=calc_gpu_duration_tracking)
def process_and_render(state: AppState, video, text_prompt: str, output_fps: int, gpu_timeout: int):
if video is None:
return "❌ Please upload a video", None
if not text_prompt or not text_prompt.strip():
return "❌ Please enter at least one text prompt", None
state.reset()
if isinstance(video, dict):
path = video.get("name") or video.get("path") or video.get("data")
else:
path = video
frames, info = try_load_video_frames(path)
if not frames:
return "❌ Could not load video", None
if info["fps"] and len(frames) > MAX_SECONDS * info["fps"]:
frames = frames[:int(MAX_SECONDS * info["fps"])]
state.video_frames = frames
state.video_fps = info["fps"]
prompts = [p.strip() for p in text_prompt.split(",") if p.strip()]
status = f"✅ Video loaded: {state.num_frames} frames\n"
status += f"Output FPS: {output_fps}\n"
status += f"GPU Duration: {gpu_timeout}s\n"
status += f"Processing {len(prompts)} prompt(s) across ALL frames...\n\n"
for p in prompts:
track_prompt_across_frames(state, p)
count = len(state.prompts.get(p, []))
status += f"• '{p}': {count} object(s) tracked\n"
status += "\n🎥 Rendering final video with overlays..."
rendered_path = render_full_video(state, output_fps)
status += "\n\n✅ Done! Play the video below."
return status, rendered_path
@spaces.GPU(duration=calc_gpu_duration_points)
def process_and_render_points(pt_state: PointTrackerState, video, text_prompt: str, output_fps: int, gpu_timeout: int):
if video is None:
return "❌ Please upload a video", None
if not text_prompt or not text_prompt.strip():
return "❌ Please enter at least one text prompt", None
pt_state.reset()
if isinstance(video, dict):
path = video.get("name") or video.get("path") or video.get("data")
else:
path = video
frames, info = try_load_video_frames(path)
if not frames:
return "❌ Could not load video", None
if info["fps"] and len(frames) > MAX_SECONDS * info["fps"]:
frames = frames[:int(MAX_SECONDS * info["fps"])]
pt_state.video_frames = frames
pt_state.video_fps = info["fps"]
prompts = [p.strip() for p in text_prompt.split(",") if p.strip()]
status = f"✅ Video loaded: {pt_state.num_frames} frames\n"
status += f"Output FPS: {output_fps}\n"
status += f"GPU Duration: {gpu_timeout}s\n"
status += f"Processing {len(prompts)} prompt(s) with point tracking...\n\n"
for p in prompts:
track_points_across_frames(pt_state, p)
status += f"• '{p}': tracked\n"
total_tracked = len(pt_state.trails)
status += f"\n📍 Total tracked points: {total_tracked}\n"
status += "\n🎥 Rendering video with red dot tracking..."
rendered_path = render_point_tracker_video(pt_state, output_fps)
status += "\n\n✅ Done! Play the video below."
return status, rendered_path
@spaces.GPU(duration=calc_gpu_duration_qa)
def process_video_qa(video, user_text, max_new_tokens, gpu_timeout):
if video is None:
return "❌ Please upload a video."
if not user_text or not user_text.strip():
user_text = "Describe this video in detail."
if isinstance(video, dict):
video_path = video.get("name") or video.get("path") or video.get("data")
else:
video_path = video
messages = [
{
"role": "user",
"content": [
dict(type="text", text=user_text),
dict(type="video", video=video_path),
],
}
]
try:
_, videos, video_kwargs = process_vision_info(messages)
videos, video_metadatas = zip(*videos)
videos, video_metadatas = list(videos), list(video_metadatas)
except Exception as e:
return f"❌ Error processing video frames: {e}"
text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor_v(
videos=videos,
video_metadata=video_metadatas,
text=text,
padding=True,
return_tensors="pt",
**video_kwargs,
)
inputs = {k: v.to(model_v.device) for k, v in inputs.items()}
with torch.inference_mode():
generated_ids = model_v.generate(
**inputs,
max_new_tokens=max_new_tokens
)
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
generated_text = processor_v.tokenizer.decode(generated_tokens, skip_special_tokens=True)
generated_text = re.sub(r'<think>.*?</think>', '', generated_text.strip(), flags=re.DOTALL).strip()
return generated_text
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
#main-title h1 {font-size: 2.6em !important;}
/* RadioAnimated Styles */
.ra-wrap{ width: fit-content; }
.ra-inner{
position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
}
.ra-input{ display: none; }
.ra-label{
position: relative; z-index: 2; padding: 8px 16px;
font-family: inherit; font-size: 14px; font-weight: 600;
color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
}
.ra-highlight{
position: absolute; z-index: 1; top: 6px; left: 6px;
height: calc(100% - 12px); border-radius: 9999px;
background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
transition: transform 0.2s, width 0.2s;
}
.ra-input:checked + .ra-label{ color: black; }
/* Dark mode adjustments for RadioAnimated */
.dark .ra-inner { background: var(--neutral-800); }
.dark .ra-label { color: var(--neutral-400); }
.dark .ra-highlight { background: var(--neutral-600); }
.dark .ra-input:checked + .ra-label { color: white; }
#gpu-duration-container {
padding: 16px;
border-radius: 12px;
background: var(--background-fill-secondary);
border: 2px solid var(--border-color-primary);
margin-top: 8px;
}
#gpu-info-box {
padding: 12px;
border-radius: 8px;
background: var(--background-fill-primary);
border: 1px solid var(--border-color-secondary);
}
"""
with gr.Blocks() as demo:
gr.Markdown("# **Qwen3-VL-Video-Grounding**", elem_id="main-title")
gr.Markdown(
"""
Perform point tracking, text-guided detection, and video question answering with the Qwen3-VL multimodal model. This demo runs the official implementation using the Hugging Face Transformers, OpenCV, and Molmo libraries.
"""
)
state = gr.State(AppState())
pt_state = gr.State(PointTrackerState())
gpu_duration_state = gr.State(value=60)
with gr.Tabs():
with gr.Tab("Text-guided Object Tracking"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Getting started**
- **Upload a video** (max 8 seconds) or record from webcam.
- Enter **object descriptions** separated by commas (e.g. `person, red car, dog`).
- Each prompt can detect **multiple instances(classes)** — they'll each get a unique filter **ID's**.
"""
)
with gr.Column():
gr.Markdown(
"""
**How tracking works**
- The model detects **bounding boxes** for each object in every frame.
- Objects are matched across frames using **IoU overlap** and **center-distance** tracking.
- Output includes colored bounding boxes, semi-transparent mask overlays, and labeled IDs.
"""
)
with gr.Column():
with gr.Row():
video_in = gr.Video(label="Upload Video", sources=["upload", "webcam"], height=400)
with gr.Row():
prompt_in = gr.Textbox(
label="Text Prompts (comma separated)",
placeholder="person, red car, dog, laptop, traffic light",
lines=3
)
with gr.Row():
fps_slider = gr.Slider(
label="Output Video FPS",
minimum=1,
maximum=60,
value=25,
step=1,
info="Default: 25 FPS (BEST)"
)
process_btn = gr.Button("Apply Detection and Render Full Video", variant="primary")
status_out = gr.Textbox(label="Output Status", lines=3)
rendered_out = gr.Video(label="Rendered Video with Object Tracking", height=400)
gr.Examples(
examples=[
["examples/1.mp4"],
["examples/2.mp4"],
["examples/3.mp4"],
],
inputs=[video_in],
label="Examples"
)
with gr.Tab("Points Tracker"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Getting started**
- **Upload a video** (max 8 seconds) or record from webcam.
- Enter **object descriptions** separated by commas (e.g. `person, ball, face`).
- The model locates the **center point** of each detected object and tracks it with a **red dot**.
"""
)
with gr.Column():
gr.Markdown(
"""
**How point tracking works**
- Uses **bounding box detection** converted to precise **center points** for reliability.
- Points are matched across frames using **adaptive nearest-neighbor** tracking.
- Lost tracks are kept for up to 5 frames, then dropped to avoid ghost points.
- Clean visualization with **red dots** and subtle **motion trails**.
"""
)
with gr.Column():
with gr.Row():
pt_video_in = gr.Video(label="Upload Video", sources=["upload", "webcam"], height=400)
with gr.Row():
pt_prompt_in = gr.Textbox(
label="Text Prompts (comma separated)",
placeholder="person, ball, car, face, hand",
lines=3
)
with gr.Row():
pt_fps_slider = gr.Slider(
label="Output Video FPS",
minimum=1,
maximum=60,
value=25,
step=1,
info="Default: 25 FPS (BEST)"
)
pt_process_btn = gr.Button("Apply Point Tracking & Render Video", variant="primary")
pt_status_out = gr.Textbox(label="Output Status", lines=5)
pt_rendered_out = gr.Video(label="Rendered Video with Point Tracking", height=400)
gr.Examples(
examples=[
["examples/1.mp4"],
["examples/2.mp4"],
["examples/3.mp4"],
],
inputs=[pt_video_in],
label="Examples"
)
with gr.Tab("Any Video QA"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**Getting started**
- **Upload a video** or record from webcam.
- Enter a **question or prompt** about the video content.
- The model will analyze the video and provide a **text answer**.
"""
)
with gr.Column():
gr.Markdown(
"""
**How it works**
- The video frames are processed by the **Qwen3-VL** vision-language model.
- You can ask **any question** about the video: describe scenes, identify actions, count objects, etc.
- If no prompt is provided, the model will **describe the video in detail**.
"""
)
with gr.Column():
with gr.Row():
qa_video_in = gr.Video(label="Upload Video", sources=["upload", "webcam"], height=400)
with gr.Row():
qa_prompt_in = gr.Textbox(
label="Text Prompt / Question",
placeholder="Describe this video in detail. / What is happening in this video? / How many people are visible?",
lines=3
)
with gr.Row():
qa_max_tokens = gr.Slider(
label="Max New Tokens",
minimum=64,
maximum=2048,
value=1024,
step=64,
info="Maximum number of tokens in the generated response"
)
qa_process_btn = gr.Button("Analyze Video", variant="primary")
qa_output = gr.Textbox(label="Model Response", lines=12)
gr.Examples(
examples=[
["examples/1.mp4"],
["examples/2.mp4"],
["examples/3.mp4"],
],
inputs=[qa_video_in],
label="Examples"
)
with gr.Tab("ZeroGPU Duration"):
with gr.Row():
with gr.Column():
gr.Markdown(
"""
## ZeroGPU Duration Settings
Configure the **maximum GPU allocation time** for all processing tasks across every tab.
This setting is **shared globally** — changing it here affects:
- **Text-guided Object Tracking** (Tab 1)
- **Points Tracker** (Tab 2)
- **Any Video QA** (Tab 3)
"""
)
with gr.Column():
gr.Markdown(
"""
## Duration Guide
| Duration | Best For |
|----------|----------|
| **60s** | Short videos (1-3s), simple prompts |
| **120s** | Medium videos (3-5s), 1-2 prompts |
| **180s** | Longer videos (5-8s), multiple prompts |
| **240s** | Complex multi-object tracking |
| **300s** | Maximum processing time |
"""
)
with gr.Column():
with gr.Row(elem_id="gpu-duration-container"):
with gr.Column():
gr.Markdown("### Select GPU Duration (seconds)")
gr.Markdown(
"*Slide to choose how long the GPU will be reserved for each processing request. "
"Higher values allow longer/more complex videos but consume more GPU quota.*"
)
radioanimated_gpu_duration = RadioAnimated(
choices=["60", "90", "120", "180", "240", "300", "360"],
value="90",
elem_id="radioanimated_gpu_duration"
)
with gr.Row():
with gr.Column(elem_id="gpu-info-box"):
gpu_display = gr.Markdown(
value="**Currently selected:** `90 seconds`"
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
### Important Notes
- **Higher duration = more GPU quota consumed.** Choose the minimum needed for your task.
- On **Hugging Face ZeroGPU Spaces**, each user has a daily GPU quota. Be mindful of usage.
- If processing **times out**, increase the duration and retry.
- The duration is the **maximum allowed time** — if processing finishes early, the GPU is released.
- **Default: 90 seconds** is sufficient for most short video tasks.
"""
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
### 🔧 Troubleshooting
| Issue | Solution |
|-------|----------|
| Processing times out | Increase GPU duration to 180s or 240s |
| GPU quota exhausted | Wait for quota reset or use shorter durations |
| Video too long | Trim to under 8 seconds before uploading |
| Multiple prompts slow | Use fewer comma-separated prompts or increase duration |
"""
)
def update_gpu_display(val: str):
duration = apply_gpu_duration(val)
return duration, f"**Currently selected:** `{duration} seconds`"
radioanimated_gpu_duration.change(
fn=update_gpu_display,
inputs=radioanimated_gpu_duration,
outputs=[gpu_duration_state, gpu_display],
api_visibility="private"
)
process_btn.click(
fn=process_and_render,
inputs=[state, video_in, prompt_in, fps_slider, gpu_duration_state],
outputs=[status_out, rendered_out],
show_progress=True
)
pt_process_btn.click(
fn=process_and_render_points,
inputs=[pt_state, pt_video_in, pt_prompt_in, pt_fps_slider, gpu_duration_state],
outputs=[pt_status_out, pt_rendered_out],
show_progress=True
)
qa_process_btn.click(
fn=process_video_qa,
inputs=[qa_video_in, qa_prompt_in, qa_max_tokens, gpu_duration_state],
outputs=[qa_output],
show_progress=True
)
demo.queue().launch(css=css, theme=Soft(primary_hue="orange", secondary_hue="rose"), ssr_mode=False, mcp_server=True)