update [from: test-sand-box] (cleaned) ✅

#3
Files changed (1) hide show
  1. app.py +1133 -0
app.py ADDED
@@ -0,0 +1,1133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import gc
3
+ import tempfile
4
+ import re
5
+ import json
6
+ import uuid
7
+ import cv2
8
+ import gradio as gr
9
+ import numpy as np
10
+ import spaces
11
+ import torch
12
+ from typing import Iterable
13
+ from gradio.themes import Soft
14
+ from gradio.themes.utils import colors, fonts, sizes
15
+ from PIL import Image, ImageDraw, ImageFont
16
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
17
+ from molmo_utils import process_vision_info
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ MODEL_ID_V = "prithivMLmods/Qwen3-VL-4B-Instruct-Unredacted-MAX" # @--- Max model is trained on top of - Qwen/Qwen3-VL-4B-Instruct ---@
21
+ DTYPE = torch.float16
22
+
23
+ print(f"Loading {MODEL_ID_V}...")
24
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
25
+ model_v = Qwen3VLForConditionalGeneration.from_pretrained(
26
+ MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE
27
+ ).to(device).eval()
28
+ print("Model loaded successfully.")
29
+
30
+ MAX_SECONDS = 8.0
31
+ SYSTEM_PROMPT = """You are a helpful assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled between 0 and 1000. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."""
32
+
33
+ POINT_SYSTEM_PROMPT = """You are a precise object pointing assistant. When asked to point to an object in an image, you must return ONLY the exact center coordinates of that specific object as [x, y] with values scaled between 0 and 1000 (where 0,0 is the top-left corner and 1000,1000 is the bottom-right corner).
34
+
35
+ Rules:
36
+ 1. ONLY point to objects that exactly match the description given.
37
+ 2. Do NOT point to background, empty areas, or unrelated objects.
38
+ 3. If there are multiple matching instances, return [[x1, y1], [x2, y2], ...].
39
+ 4. If no matching object is found, return an empty list [].
40
+ 5. Return ONLY the coordinate numbers, no explanations or other text.
41
+ 6. Be extremely precise — place the point at the exact visual center of each matching object."""
42
+
43
+ POINTS_REGEX = re.compile(r'(?:(\d+)\s*[.:])?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)')
44
+ COORD_REGEX = re.compile(r'\[([\s\S]*?)\]')
45
+ FRAME_REGEX = re.compile(r'(\d+(?:\.\d+)?)\s*[,:]\s*([\d\s,\.]+)')
46
+
47
+ class RadioAnimated(gr.HTML):
48
+ def __init__(self, choices, value=None, **kwargs):
49
+ if not choices or len(choices) < 2:
50
+ raise ValueError("RadioAnimated requires at least 2 choices.")
51
+ if value is None:
52
+ value = choices[0]
53
+
54
+ uid = uuid.uuid4().hex[:8]
55
+ group_name = f"ra-{uid}"
56
+
57
+ inputs_html = "\n".join(
58
+ f"""
59
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
60
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
61
+ """
62
+ for i, c in enumerate(choices)
63
+ )
64
+
65
+ html_template = f"""
66
+ <div class="ra-wrap" data-ra="{uid}">
67
+ <div class="ra-inner">
68
+ <div class="ra-highlight"></div>
69
+ {inputs_html}
70
+ </div>
71
+ </div>
72
+ """
73
+
74
+ js_on_load = r"""
75
+ (() => {
76
+ const wrap = element.querySelector('.ra-wrap');
77
+ const inner = element.querySelector('.ra-inner');
78
+ const highlight = element.querySelector('.ra-highlight');
79
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
80
+
81
+ if (!inputs.length) return;
82
+
83
+ const choices = inputs.map(i => i.value);
84
+
85
+ function setHighlightByIndex(idx) {
86
+ const n = choices.length;
87
+ const pct = 100 / n;
88
+ highlight.style.width = `calc(${pct}% - 6px)`;
89
+ highlight.style.transform = `translateX(${idx * 100}%)`;
90
+ }
91
+
92
+ function setCheckedByValue(val, shouldTrigger=false) {
93
+ const idx = Math.max(0, choices.indexOf(val));
94
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
95
+ setHighlightByIndex(idx);
96
+
97
+ props.value = choices[idx];
98
+ if (shouldTrigger) trigger('change', props.value);
99
+ }
100
+
101
+ setCheckedByValue(props.value ?? choices[0], false);
102
+
103
+ inputs.forEach((inp) => {
104
+ inp.addEventListener('change', () => {
105
+ setCheckedByValue(inp.value, true);
106
+ });
107
+ });
108
+ })();
109
+ """
110
+
111
+ super().__init__(
112
+ value=value,
113
+ html_template=html_template,
114
+ js_on_load=js_on_load,
115
+ **kwargs
116
+ )
117
+
118
+
119
+ def apply_gpu_duration(val: str):
120
+ try:
121
+ return int(val)
122
+ except (TypeError, ValueError):
123
+ return 90
124
+
125
+ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
126
+ cap = cv2.VideoCapture(video_path_or_url)
127
+ frames = []
128
+ while cap.isOpened():
129
+ ret, frame = cap.read()
130
+ if not ret:
131
+ break
132
+ frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
133
+ fps_val = cap.get(cv2.CAP_PROP_FPS)
134
+ cap.release()
135
+ return frames, {"num_frames": len(frames), "fps": float(fps_val) if fps_val > 0 else None}
136
+
137
+
138
+ def parse_bboxes_from_text(text: str) -> list[list[float]]:
139
+ text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL)
140
+ nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text)
141
+ if nested:
142
+ try:
143
+ all_b = []
144
+ for m in nested:
145
+ parsed = json.loads(m)
146
+ all_b.extend(parsed if isinstance(parsed[0], list) else [parsed])
147
+ return all_b
148
+ except (json.JSONDecodeError, IndexError):
149
+ pass
150
+ single = re.findall(
151
+ r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text)
152
+ if single:
153
+ return [[float(v) for v in m] for m in single]
154
+ nums = re.findall(r'(\d+(?:\.\d+)?)', text)
155
+ return [[float(nums[i]), float(nums[i + 1]), float(nums[i + 2]), float(nums[i + 3])] for i in
156
+ range(0, len(nums) - 3, 4)] if len(nums) >= 4 else []
157
+
158
+
159
+ def parse_precise_points(text: str, image_w: int, image_h: int) -> list[tuple[float, float]]:
160
+ text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL)
161
+ raw_points = []
162
+
163
+ nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text)
164
+ if nested:
165
+ try:
166
+ for m in nested:
167
+ parsed = json.loads(m)
168
+ if isinstance(parsed[0], list):
169
+ for p in parsed:
170
+ if len(p) >= 2:
171
+ raw_points.append((float(p[0]), float(p[1])))
172
+ elif len(parsed) >= 2:
173
+ raw_points.append((float(parsed[0]), float(parsed[1])))
174
+ except (json.JSONDecodeError, IndexError):
175
+ pass
176
+
177
+ if not raw_points:
178
+ single = re.findall(
179
+ r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text)
180
+ if single:
181
+ for m in single:
182
+ raw_points.append((float(m[0]), float(m[1])))
183
+
184
+ if not raw_points:
185
+ for match in POINTS_REGEX.finditer(text):
186
+ x_val = float(match.group(2))
187
+ y_val = float(match.group(3))
188
+ raw_points.append((x_val, y_val))
189
+
190
+ validated = []
191
+ for sx, sy in raw_points:
192
+ if not (0 <= sx <= 1000 and 0 <= sy <= 1000):
193
+ continue
194
+ px = sx / 1000 * image_w
195
+ py = sy / 1000 * image_h
196
+ if 0 <= px <= image_w and 0 <= py <= image_h:
197
+ validated.append((px, py))
198
+
199
+ if len(validated) > 1:
200
+ deduped = [validated[0]]
201
+ for pt in validated[1:]:
202
+ is_dup = False
203
+ for existing in deduped:
204
+ dist = ((pt[0] - existing[0]) ** 2 + (pt[1] - existing[1]) ** 2) ** 0.5
205
+ if dist < 15:
206
+ is_dup = True
207
+ break
208
+ if not is_dup:
209
+ deduped.append(pt)
210
+ validated = deduped
211
+
212
+ return validated
213
+
214
+
215
+ def bbox_to_mask(bbox_scaled: list[float], width: int, height: int) -> np.ndarray:
216
+ mask = np.zeros((height, width), dtype=np.float32)
217
+ x1 = max(0, min(int(bbox_scaled[0] / 1000 * width), width - 1))
218
+ y1 = max(0, min(int(bbox_scaled[1] / 1000 * height), height - 1))
219
+ x2 = max(0, min(int(bbox_scaled[2] / 1000 * width), width - 1))
220
+ y2 = max(0, min(int(bbox_scaled[3] / 1000 * height), height - 1))
221
+ mask[y1:y2, x1:x2] = 1.0
222
+ return mask
223
+
224
+
225
+ def bbox_iou(b1, b2):
226
+ x1 = max(b1[0], b2[0])
227
+ y1 = max(b1[1], b2[1])
228
+ x2 = min(b1[2], b2[2])
229
+ y2 = min(b1[3], b2[3])
230
+ inter = max(0, x2 - x1) * max(0, y2 - y1)
231
+ union = (b1[2] - b1[0]) * (b1[3] - b1[1]) + (b2[2] - b2[0]) * (b2[3] - b2[1]) - inter
232
+ return inter / union if union > 0 else 0.0
233
+
234
+
235
+ def bbox_center_distance(b1, b2):
236
+ c1 = ((b1[0] + b1[2]) / 2, (b1[1] + b1[3]) / 2)
237
+ c2 = ((b2[0] + b2[2]) / 2, (b2[1] + b2[3]) / 2)
238
+ return ((c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2) ** 0.5
239
+
240
+
241
+ def pixel_point_distance(p1: tuple, p2: tuple) -> float:
242
+ return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
243
+
244
+
245
+ def overlay_masks_on_frame(frame: Image.Image, masks: dict, colors_map: dict, alpha=0.45) -> Image.Image:
246
+ base = np.array(frame).astype(np.float32) / 255
247
+ overlay = base.copy()
248
+ for oid, mask in masks.items():
249
+ if mask is None:
250
+ continue
251
+ color = np.array(colors_map.get(oid, (255, 0, 0)), dtype=np.float32) / 255
252
+ m = np.clip(mask, 0, 1)[..., None]
253
+ overlay = (1 - alpha * m) * overlay + (alpha * m) * color
254
+ return Image.fromarray(np.clip(overlay * 255, 0, 255).astype(np.uint8))
255
+
256
+
257
+ def pastel_color_for_prompt(prompt: str):
258
+ hue = (sum(ord(c) for c in prompt) * 2654435761 % 360) / 360
259
+ r, g, b = colorsys.hsv_to_rgb(hue, 0.5, 0.95)
260
+ return int(r * 255), int(g * 255), int(b * 255)
261
+
262
+ class AppState:
263
+ def __init__(self):
264
+ self.reset()
265
+
266
+ def reset(self):
267
+ self.video_frames: list[Image.Image] = []
268
+ self.video_fps: float | None = None
269
+ self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
270
+ self.bboxes_by_frame: dict[int, dict[int, list[float]]] = {}
271
+ self.color_by_obj: dict[int, tuple[int, int, int]] = {}
272
+ self.color_by_prompt: dict[str, tuple[int, int, int]] = {}
273
+ self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {}
274
+ self.prompts: dict[str, list[int]] = {}
275
+ self.next_obj_id: int = 1
276
+
277
+ @property
278
+ def num_frames(self) -> int:
279
+ return len(self.video_frames)
280
+
281
+
282
+ class PointTrackerState:
283
+ def __init__(self):
284
+ self.reset()
285
+
286
+ def reset(self):
287
+ self.video_frames: list[Image.Image] = []
288
+ self.video_fps: float | None = None
289
+ self.points_by_frame: dict[int, list[tuple[float, float]]] = {}
290
+ self.trails: list[list[tuple[int, float, float]]] = []
291
+
292
+ @property
293
+ def num_frames(self) -> int:
294
+ return len(self.video_frames)
295
+
296
+ def detect_objects_in_frame(frame: Image.Image, prompt: str) -> list[list[float]]:
297
+ messages = [
298
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
299
+ {"role": "user",
300
+ "content": [{"type": "image", "image": frame}, {"type": "text", "text": f"Detect all instances of: {prompt}"}]}
301
+ ]
302
+ text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
303
+ inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device)
304
+ with torch.no_grad():
305
+ out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
306
+ generated = out[:, inputs.input_ids.shape[1]:]
307
+ txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
308
+ return parse_bboxes_from_text(txt)
309
+
310
+
311
+ def detect_precise_points_in_frame(frame: Image.Image, prompt: str) -> list[tuple[float, float]]:
312
+ w, h = frame.size
313
+
314
+ messages = [
315
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
316
+ {"role": "user",
317
+ "content": [{"type": "image", "image": frame},
318
+ {"type": "text",
319
+ "text": f"Detect all instances of: {prompt}. Return only bounding boxes for objects that exactly match this description."}]}
320
+ ]
321
+ text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
322
+ inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device)
323
+ with torch.no_grad():
324
+ out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False)
325
+ generated = out[:, inputs.input_ids.shape[1]:]
326
+ txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0]
327
+
328
+ bboxes = parse_bboxes_from_text(txt)
329
+
330
+ if bboxes:
331
+ points = []
332
+ for b in bboxes:
333
+ bw = abs(b[2] - b[0])
334
+ bh = abs(b[3] - b[1])
335
+ if bw < 5 or bh < 5:
336
+ continue
337
+ if bw > 950 and bh > 950:
338
+ continue
339
+ cx = (b[0] + b[2]) / 2 / 1000 * w
340
+ cy = (b[1] + b[3]) / 2 / 1000 * h
341
+ if 0 <= cx <= w and 0 <= cy <= h:
342
+ points.append((cx, cy))
343
+
344
+ if len(points) > 1:
345
+ deduped = [points[0]]
346
+ for pt in points[1:]:
347
+ is_dup = any(pixel_point_distance(pt, ex) < 20 for ex in deduped)
348
+ if not is_dup:
349
+ deduped.append(pt)
350
+ points = deduped
351
+
352
+ if points:
353
+ return points
354
+
355
+ messages2 = [
356
+ {"role": "system", "content": [{"type": "text", "text": POINT_SYSTEM_PROMPT}]},
357
+ {"role": "user",
358
+ "content": [{"type": "image", "image": frame},
359
+ {"type": "text",
360
+ "text": f"Point to the exact center of each '{prompt}' in this image. Only point to objects that are clearly '{prompt}', nothing else."}]}
361
+ ]
362
+ text2 = processor_v.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
363
+ inputs2 = processor_v(text=[text2], images=[frame], padding=True, return_tensors="pt").to(device)
364
+ with torch.no_grad():
365
+ out2 = model_v.generate(**inputs2, max_new_tokens=512, do_sample=False)
366
+ generated2 = out2[:, inputs2.input_ids.shape[1]:]
367
+ txt2 = processor_v.batch_decode(generated2, skip_special_tokens=True)[0]
368
+
369
+ return parse_precise_points(txt2, w, h)
370
+
371
+ def track_prompt_across_frames(state: AppState, prompt: str):
372
+ total = state.num_frames
373
+ if prompt in state.prompts:
374
+ for oid in state.prompts[prompt]:
375
+ for f in range(total):
376
+ state.masks_by_frame[f].pop(oid, None)
377
+ state.bboxes_by_frame[f].pop(oid, None)
378
+ state.text_prompts_by_frame_obj[f].pop(oid, None)
379
+ del state.prompts[prompt]
380
+
381
+ prev_tracks: list[tuple[int, list[float]]] = []
382
+
383
+ for f_idx in range(total):
384
+ frame = state.video_frames[f_idx]
385
+ w, h = frame.size
386
+ new_bboxes = detect_objects_in_frame(frame, prompt)
387
+
388
+ masks_f = state.masks_by_frame.setdefault(f_idx, {})
389
+ bboxes_f = state.bboxes_by_frame.setdefault(f_idx, {})
390
+ texts_f = state.text_prompts_by_frame_obj.setdefault(f_idx, {})
391
+
392
+ if not prev_tracks:
393
+ for bbox in new_bboxes:
394
+ oid = state.next_obj_id
395
+ state.next_obj_id += 1
396
+ if prompt not in state.color_by_prompt:
397
+ state.color_by_prompt[prompt] = pastel_color_for_prompt(prompt)
398
+ state.color_by_obj[oid] = state.color_by_prompt[prompt]
399
+ masks_f[oid] = bbox_to_mask(bbox, w, h)
400
+ bboxes_f[oid] = bbox
401
+ texts_f[oid] = prompt
402
+ state.prompts.setdefault(prompt, []).append(oid)
403
+ prev_tracks.append((oid, bbox))
404
+ continue
405
+
406
+ used = set()
407
+ matched = {}
408
+ scores = [(bbox_iou(pbbox, nbbox), pi, ni) for pi, (_, pbbox) in enumerate(prev_tracks) for ni, nbbox in
409
+ enumerate(new_bboxes)]
410
+ scores.sort(reverse=True)
411
+ for score, pi, ni in scores:
412
+ if pi in matched or ni in used or score <= 0.05:
413
+ continue
414
+ matched[pi] = ni
415
+ used.add(ni)
416
+
417
+ for pi, (_, pbbox) in enumerate(prev_tracks):
418
+ if pi in matched:
419
+ continue
420
+ best = min(((bbox_center_distance(pbbox, nbbox), ni) for ni, nbbox in enumerate(new_bboxes) if ni not in used),
421
+ default=(float('inf'), -1))
422
+ if best[0] < 300:
423
+ matched[pi] = best[1]
424
+ used.add(best[1])
425
+
426
+ new_prev = []
427
+ for pi, (oid, _) in enumerate(prev_tracks):
428
+ if pi in matched:
429
+ nbbox = new_bboxes[matched[pi]]
430
+ masks_f[oid] = bbox_to_mask(nbbox, w, h)
431
+ bboxes_f[oid] = nbbox
432
+ texts_f[oid] = prompt
433
+ new_prev.append((oid, nbbox))
434
+ for ni, nbbox in enumerate(new_bboxes):
435
+ if ni not in used:
436
+ oid = state.next_obj_id
437
+ state.next_obj_id += 1
438
+ if prompt not in state.color_by_prompt:
439
+ state.color_by_prompt[prompt] = pastel_color_for_prompt(prompt)
440
+ state.color_by_obj[oid] = state.color_by_prompt[prompt]
441
+ masks_f[oid] = bbox_to_mask(nbbox, w, h)
442
+ bboxes_f[oid] = nbbox
443
+ texts_f[oid] = prompt
444
+ state.prompts.setdefault(prompt, []).append(oid)
445
+ new_prev.append((oid, nbbox))
446
+ prev_tracks = new_prev
447
+
448
+
449
+ def track_points_across_frames(pt_state: PointTrackerState, prompt: str):
450
+ total = pt_state.num_frames
451
+ prev_tracks: list[tuple[int, tuple[float, float]]] = []
452
+ lost_count: dict[int, int] = {}
453
+
454
+ for f_idx in range(total):
455
+ frame = pt_state.video_frames[f_idx]
456
+ w, h = frame.size
457
+
458
+ new_points = detect_precise_points_in_frame(frame, prompt)
459
+ points_f = pt_state.points_by_frame.setdefault(f_idx, [])
460
+
461
+ if not prev_tracks:
462
+ for px, py in new_points:
463
+ track_idx = len(pt_state.trails)
464
+ pt_state.trails.append([])
465
+ points_f.append((px, py))
466
+ pt_state.trails[track_idx].append((f_idx, px, py))
467
+ prev_tracks.append((track_idx, (px, py)))
468
+ lost_count[track_idx] = 0
469
+ continue
470
+
471
+ if not new_points:
472
+ new_prev = []
473
+ for track_idx, prev_pt in prev_tracks:
474
+ lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
475
+ if lost_count[track_idx] > 5:
476
+ continue
477
+ points_f.append(prev_pt)
478
+ pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
479
+ new_prev.append((track_idx, prev_pt))
480
+ prev_tracks = new_prev
481
+ continue
482
+
483
+ diag = (w ** 2 + h ** 2) ** 0.5
484
+ match_threshold = diag * 0.25
485
+
486
+ used_new = set()
487
+ matched = {}
488
+
489
+ dist_pairs = []
490
+ for pi, (_, prev_pt) in enumerate(prev_tracks):
491
+ for ni, new_pt in enumerate(new_points):
492
+ d = pixel_point_distance(prev_pt, new_pt)
493
+ dist_pairs.append((d, pi, ni))
494
+ dist_pairs.sort()
495
+
496
+ for d, pi, ni in dist_pairs:
497
+ if pi in matched or ni in used_new:
498
+ continue
499
+ if d < match_threshold:
500
+ matched[pi] = ni
501
+ used_new.add(ni)
502
+
503
+ new_prev = []
504
+ for pi, (track_idx, prev_pt) in enumerate(prev_tracks):
505
+ if pi in matched:
506
+ ni = matched[pi]
507
+ new_pt = new_points[ni]
508
+ points_f.append(new_pt)
509
+ pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1]))
510
+ new_prev.append((track_idx, new_pt))
511
+ lost_count[track_idx] = 0
512
+ else:
513
+ lost_count[track_idx] = lost_count.get(track_idx, 0) + 1
514
+ if lost_count[track_idx] <= 5:
515
+ points_f.append(prev_pt)
516
+ pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1]))
517
+ new_prev.append((track_idx, prev_pt))
518
+
519
+ for ni, new_pt in enumerate(new_points):
520
+ if ni not in used_new:
521
+ too_close = any(
522
+ pixel_point_distance(new_pt, prev_pt) < diag * 0.08
523
+ for _, prev_pt in new_prev
524
+ )
525
+ if not too_close:
526
+ track_idx = len(pt_state.trails)
527
+ pt_state.trails.append([])
528
+ points_f.append(new_pt)
529
+ pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1]))
530
+ new_prev.append((track_idx, new_pt))
531
+ lost_count[track_idx] = 0
532
+
533
+ prev_tracks = new_prev
534
+
535
+
536
+ def render_point_tracker_video(pt_state: PointTrackerState, output_fps: int, trail_length: int = 12) -> str:
537
+ RED = (255, 40, 40)
538
+ DARK_RED = (180, 0, 0)
539
+ frames_bgr = []
540
+
541
+ for i in range(pt_state.num_frames):
542
+ frame = pt_state.video_frames[i].copy()
543
+ draw = ImageDraw.Draw(frame)
544
+
545
+ points_f = pt_state.points_by_frame.get(i, [])
546
+
547
+ for trail in pt_state.trails:
548
+ trail_pts = [(tx, ty) for fi, tx, ty in trail if fi <= i and fi > i - trail_length]
549
+ if len(trail_pts) >= 2:
550
+ for t_idx in range(len(trail_pts) - 1):
551
+ alpha_ratio = (t_idx + 1) / len(trail_pts)
552
+ trail_color = (
553
+ int(DARK_RED[0] * alpha_ratio),
554
+ int(DARK_RED[1] * alpha_ratio),
555
+ int(DARK_RED[2] * alpha_ratio)
556
+ )
557
+ thickness = max(1, int(2 * alpha_ratio))
558
+ x1t, y1t = int(trail_pts[t_idx][0]), int(trail_pts[t_idx][1])
559
+ x2t, y2t = int(trail_pts[t_idx + 1][0]), int(trail_pts[t_idx + 1][1])
560
+ draw.line([(x1t, y1t), (x2t, y2t)], fill=trail_color, width=thickness)
561
+
562
+ for (px, py) in points_f:
563
+ r_outer = 10
564
+ draw.ellipse(
565
+ (px - r_outer, py - r_outer, px + r_outer, py + r_outer),
566
+ outline="white", width=2
567
+ )
568
+ r = 7
569
+ draw.ellipse(
570
+ (px - r, py - r, px + r, py + r),
571
+ fill=RED, outline=RED
572
+ )
573
+ r_inner = 2
574
+ draw.ellipse(
575
+ (px - r_inner, py - r_inner, px + r_inner, py + r_inner),
576
+ fill=(255, 200, 200)
577
+ )
578
+
579
+ frames_bgr.append(np.array(frame)[:, :, ::-1])
580
+ if (i + 1) % 30 == 0:
581
+ gc.collect()
582
+
583
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
584
+ writer = cv2.VideoWriter(
585
+ tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), output_fps,
586
+ (frames_bgr[0].shape[1], frames_bgr[0].shape[0])
587
+ )
588
+ for fr in frames_bgr:
589
+ writer.write(fr)
590
+ writer.release()
591
+ return tmp.name
592
+
593
+
594
+ def render_full_video(state: AppState, output_fps: int) -> str:
595
+ fps = output_fps
596
+ frames_bgr = []
597
+ for i in range(state.num_frames):
598
+ frame = state.video_frames[i].copy()
599
+ masks = state.masks_by_frame.get(i, {})
600
+ if masks:
601
+ frame = overlay_masks_on_frame(frame, masks, state.color_by_obj)
602
+ bboxes = state.bboxes_by_frame.get(i, {})
603
+ if bboxes:
604
+ draw = ImageDraw.Draw(frame)
605
+ w, h = frame.size
606
+ for oid, bbox in bboxes.items():
607
+ color = state.color_by_obj.get(oid, (255, 255, 255))
608
+ x1 = int(bbox[0] / 1000 * w)
609
+ y1 = int(bbox[1] / 1000 * h)
610
+ x2 = int(bbox[2] / 1000 * w)
611
+ y2 = int(bbox[3] / 1000 * h)
612
+ draw.rectangle((x1, y1, x2, y2), outline=color, width=4)
613
+ prompt = state.text_prompts_by_frame_obj.get(i, {}).get(oid, "")
614
+ if prompt:
615
+ label = f"{prompt} - ID{oid}"
616
+ try:
617
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
618
+ except OSError:
619
+ font = ImageFont.load_default()
620
+ tb = draw.textbbox((x1, max(0, y1 - 30)), label, font=font)
621
+ draw.rectangle(tb, fill=color)
622
+ draw.text((x1 + 4, max(0, y1 - 27)), label, fill="white", font=font)
623
+ frames_bgr.append(np.array(frame)[:, :, ::-1])
624
+ if (i + 1) % 30 == 0:
625
+ gc.collect()
626
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
627
+ writer = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), fps,
628
+ (frames_bgr[0].shape[1], frames_bgr[0].shape[0]))
629
+ for fr in frames_bgr:
630
+ writer.write(fr)
631
+ writer.release()
632
+ return tmp.name
633
+
634
+
635
+ def calc_gpu_duration_tracking(state, video, text_prompt, output_fps, gpu_timeout):
636
+ try:
637
+ return int(gpu_timeout)
638
+ except (TypeError, ValueError):
639
+ return 90
640
+
641
+
642
+ def calc_gpu_duration_points(pt_state, video, text_prompt, output_fps, gpu_timeout):
643
+ try:
644
+ return int(gpu_timeout)
645
+ except (TypeError, ValueError):
646
+ return 90
647
+
648
+
649
+ def calc_gpu_duration_qa(video, user_text, max_new_tokens, gpu_timeout):
650
+ try:
651
+ return int(gpu_timeout)
652
+ except (TypeError, ValueError):
653
+ return 90
654
+
655
+
656
+ @spaces.GPU(duration=calc_gpu_duration_tracking)
657
+ def process_and_render(state: AppState, video, text_prompt: str, output_fps: int, gpu_timeout: int):
658
+ if video is None:
659
+ return "❌ Please upload a video", None
660
+ if not text_prompt or not text_prompt.strip():
661
+ return "❌ Please enter at least one text prompt", None
662
+
663
+ state.reset()
664
+ if isinstance(video, dict):
665
+ path = video.get("name") or video.get("path") or video.get("data")
666
+ else:
667
+ path = video
668
+ frames, info = try_load_video_frames(path)
669
+ if not frames:
670
+ return "❌ Could not load video", None
671
+ if info["fps"] and len(frames) > MAX_SECONDS * info["fps"]:
672
+ frames = frames[:int(MAX_SECONDS * info["fps"])]
673
+ state.video_frames = frames
674
+ state.video_fps = info["fps"]
675
+
676
+ prompts = [p.strip() for p in text_prompt.split(",") if p.strip()]
677
+ status = f"✅ Video loaded: {state.num_frames} frames\n"
678
+ status += f"Output FPS: {output_fps}\n"
679
+ status += f"GPU Duration: {gpu_timeout}s\n"
680
+ status += f"Processing {len(prompts)} prompt(s) across ALL frames...\n\n"
681
+
682
+ for p in prompts:
683
+ track_prompt_across_frames(state, p)
684
+ count = len(state.prompts.get(p, []))
685
+ status += f"• '{p}': {count} object(s) tracked\n"
686
+
687
+ status += "\n🎥 Rendering final video with overlays..."
688
+ rendered_path = render_full_video(state, output_fps)
689
+ status += "\n\n✅ Done! Play the video below."
690
+
691
+ return status, rendered_path
692
+
693
+
694
+ @spaces.GPU(duration=calc_gpu_duration_points)
695
+ def process_and_render_points(pt_state: PointTrackerState, video, text_prompt: str, output_fps: int, gpu_timeout: int):
696
+ if video is None:
697
+ return "❌ Please upload a video", None
698
+ if not text_prompt or not text_prompt.strip():
699
+ return "❌ Please enter at least one text prompt", None
700
+
701
+ pt_state.reset()
702
+ if isinstance(video, dict):
703
+ path = video.get("name") or video.get("path") or video.get("data")
704
+ else:
705
+ path = video
706
+ frames, info = try_load_video_frames(path)
707
+ if not frames:
708
+ return "❌ Could not load video", None
709
+ if info["fps"] and len(frames) > MAX_SECONDS * info["fps"]:
710
+ frames = frames[:int(MAX_SECONDS * info["fps"])]
711
+ pt_state.video_frames = frames
712
+ pt_state.video_fps = info["fps"]
713
+
714
+ prompts = [p.strip() for p in text_prompt.split(",") if p.strip()]
715
+ status = f"✅ Video loaded: {pt_state.num_frames} frames\n"
716
+ status += f"Output FPS: {output_fps}\n"
717
+ status += f"GPU Duration: {gpu_timeout}s\n"
718
+ status += f"Processing {len(prompts)} prompt(s) with point tracking...\n\n"
719
+
720
+ for p in prompts:
721
+ track_points_across_frames(pt_state, p)
722
+ status += f"• '{p}': tracked\n"
723
+
724
+ total_tracked = len(pt_state.trails)
725
+ status += f"\n📍 Total tracked points: {total_tracked}\n"
726
+ status += "\n🎥 Rendering video with red dot tracking..."
727
+ rendered_path = render_point_tracker_video(pt_state, output_fps)
728
+ status += "\n\n✅ Done! Play the video below."
729
+
730
+ return status, rendered_path
731
+
732
+
733
+ @spaces.GPU(duration=calc_gpu_duration_qa)
734
+ def process_video_qa(video, user_text, max_new_tokens, gpu_timeout):
735
+ if video is None:
736
+ return "❌ Please upload a video."
737
+
738
+ if not user_text or not user_text.strip():
739
+ user_text = "Describe this video in detail."
740
+
741
+ if isinstance(video, dict):
742
+ video_path = video.get("name") or video.get("path") or video.get("data")
743
+ else:
744
+ video_path = video
745
+
746
+ messages = [
747
+ {
748
+ "role": "user",
749
+ "content": [
750
+ dict(type="text", text=user_text),
751
+ dict(type="video", video=video_path),
752
+ ],
753
+ }
754
+ ]
755
+
756
+ try:
757
+ _, videos, video_kwargs = process_vision_info(messages)
758
+ videos, video_metadatas = zip(*videos)
759
+ videos, video_metadatas = list(videos), list(video_metadatas)
760
+ except Exception as e:
761
+ return f"❌ Error processing video frames: {e}"
762
+
763
+ text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
764
+
765
+ inputs = processor_v(
766
+ videos=videos,
767
+ video_metadata=video_metadatas,
768
+ text=text,
769
+ padding=True,
770
+ return_tensors="pt",
771
+ **video_kwargs,
772
+ )
773
+ inputs = {k: v.to(model_v.device) for k, v in inputs.items()}
774
+
775
+ with torch.inference_mode():
776
+ generated_ids = model_v.generate(
777
+ **inputs,
778
+ max_new_tokens=max_new_tokens
779
+ )
780
+
781
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
782
+ generated_text = processor_v.tokenizer.decode(generated_tokens, skip_special_tokens=True)
783
+
784
+ generated_text = re.sub(r'<think>.*?</think>', '', generated_text.strip(), flags=re.DOTALL).strip()
785
+
786
+ return generated_text
787
+
788
+
789
+ css = """
790
+ #col-container {
791
+ margin: 0 auto;
792
+ max-width: 800px;
793
+ }
794
+ #main-title h1 {font-size: 2.6em !important;}
795
+
796
+ /* RadioAnimated Styles */
797
+ .ra-wrap{ width: fit-content; }
798
+ .ra-inner{
799
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
800
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
801
+ }
802
+ .ra-input{ display: none; }
803
+ .ra-label{
804
+ position: relative; z-index: 2; padding: 8px 16px;
805
+ font-family: inherit; font-size: 14px; font-weight: 600;
806
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
807
+ }
808
+ .ra-highlight{
809
+ position: absolute; z-index: 1; top: 6px; left: 6px;
810
+ height: calc(100% - 12px); border-radius: 9999px;
811
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
812
+ transition: transform 0.2s, width 0.2s;
813
+ }
814
+ .ra-input:checked + .ra-label{ color: black; }
815
+
816
+ /* Dark mode adjustments for RadioAnimated */
817
+ .dark .ra-inner { background: var(--neutral-800); }
818
+ .dark .ra-label { color: var(--neutral-400); }
819
+ .dark .ra-highlight { background: var(--neutral-600); }
820
+ .dark .ra-input:checked + .ra-label { color: white; }
821
+
822
+ #gpu-duration-container {
823
+ padding: 16px;
824
+ border-radius: 12px;
825
+ background: var(--background-fill-secondary);
826
+ border: 2px solid var(--border-color-primary);
827
+ margin-top: 8px;
828
+ }
829
+
830
+ #gpu-info-box {
831
+ padding: 12px;
832
+ border-radius: 8px;
833
+ background: var(--background-fill-primary);
834
+ border: 1px solid var(--border-color-secondary);
835
+ }
836
+ """
837
+
838
+
839
+ with gr.Blocks() as demo:
840
+ gr.Markdown("# **Qwen3-VL-Video-Grounding**", elem_id="main-title")
841
+
842
+ gr.Markdown(
843
+ """
844
+ Perform point tracking, text-guided detection, and video question answering with the Qwen3-VL multimodal model. This demo runs the official implementation using the Hugging Face Transformers, OpenCV, and Molmo libraries.
845
+ """
846
+ )
847
+
848
+ state = gr.State(AppState())
849
+ pt_state = gr.State(PointTrackerState())
850
+ gpu_duration_state = gr.State(value=60)
851
+
852
+ with gr.Tabs():
853
+
854
+ with gr.Tab("Text-guided Object Tracking"):
855
+ with gr.Row():
856
+ with gr.Column():
857
+ gr.Markdown(
858
+ """
859
+ **Getting started**
860
+ - **Upload a video** (max 8 seconds) or record from webcam.
861
+ - Enter **object descriptions** separated by commas (e.g. `person, red car, dog`).
862
+ - Each prompt can detect **multiple instances(classes)** — they'll each get a unique filter **ID's**.
863
+ """
864
+ )
865
+ with gr.Column():
866
+ gr.Markdown(
867
+ """
868
+ **How tracking works**
869
+ - The model detects **bounding boxes** for each object in every frame.
870
+ - Objects are matched across frames using **IoU overlap** and **center-distance** tracking.
871
+ - Output includes colored bounding boxes, semi-transparent mask overlays, and labeled IDs.
872
+ """
873
+ )
874
+
875
+ with gr.Column():
876
+ with gr.Row():
877
+ video_in = gr.Video(label="Upload Video", sources=["upload", "webcam"], height=400)
878
+
879
+ with gr.Row():
880
+ prompt_in = gr.Textbox(
881
+ label="Text Prompts (comma separated)",
882
+ placeholder="person, red car, dog, laptop, traffic light",
883
+ lines=3
884
+ )
885
+ with gr.Row():
886
+ fps_slider = gr.Slider(
887
+ label="Output Video FPS",
888
+ minimum=1,
889
+ maximum=60,
890
+ value=25,
891
+ step=1,
892
+ info="Default: 25 FPS (BEST)"
893
+ )
894
+
895
+ process_btn = gr.Button("Apply Detection and Render Full Video", variant="primary")
896
+
897
+ status_out = gr.Textbox(label="Output Status", lines=3)
898
+ rendered_out = gr.Video(label="Rendered Video with Object Tracking", height=400)
899
+
900
+ gr.Examples(
901
+ examples=[
902
+ ["examples/1.mp4"],
903
+ ["examples/2.mp4"],
904
+ ["examples/3.mp4"],
905
+ ],
906
+ inputs=[video_in],
907
+ label="Examples"
908
+ )
909
+
910
+ with gr.Tab("Points Tracker"):
911
+ with gr.Row():
912
+ with gr.Column():
913
+ gr.Markdown(
914
+ """
915
+ **Getting started**
916
+ - **Upload a video** (max 8 seconds) or record from webcam.
917
+ - Enter **object descriptions** separated by commas (e.g. `person, ball, face`).
918
+ - The model locates the **center point** of each detected object and tracks it with a **red dot**.
919
+ """
920
+ )
921
+ with gr.Column():
922
+ gr.Markdown(
923
+ """
924
+ **How point tracking works**
925
+ - Uses **bounding box detection** converted to precise **center points** for reliability.
926
+ - Points are matched across frames using **adaptive nearest-neighbor** tracking.
927
+ - Lost tracks are kept for up to 5 frames, then dropped to avoid ghost points.
928
+ - Clean visualization with **red dots** and subtle **motion trails**.
929
+ """
930
+ )
931
+
932
+ with gr.Column():
933
+ with gr.Row():
934
+ pt_video_in = gr.Video(label="Upload Video", sources=["upload", "webcam"], height=400)
935
+
936
+ with gr.Row():
937
+ pt_prompt_in = gr.Textbox(
938
+ label="Text Prompts (comma separated)",
939
+ placeholder="person, ball, car, face, hand",
940
+ lines=3
941
+ )
942
+ with gr.Row():
943
+ pt_fps_slider = gr.Slider(
944
+ label="Output Video FPS",
945
+ minimum=1,
946
+ maximum=60,
947
+ value=25,
948
+ step=1,
949
+ info="Default: 25 FPS (BEST)"
950
+ )
951
+
952
+ pt_process_btn = gr.Button("Apply Point Tracking & Render Video", variant="primary")
953
+
954
+ pt_status_out = gr.Textbox(label="Output Status", lines=5)
955
+ pt_rendered_out = gr.Video(label="Rendered Video with Point Tracking", height=400)
956
+
957
+ gr.Examples(
958
+ examples=[
959
+ ["examples/1.mp4"],
960
+ ["examples/2.mp4"],
961
+ ["examples/3.mp4"],
962
+ ],
963
+ inputs=[pt_video_in],
964
+ label="Examples"
965
+ )
966
+
967
+ with gr.Tab("Any Video QA"):
968
+ with gr.Row():
969
+ with gr.Column():
970
+ gr.Markdown(
971
+ """
972
+ **Getting started**
973
+ - **Upload a video** or record from webcam.
974
+ - Enter a **question or prompt** about the video content.
975
+ - The model will analyze the video and provide a **text answer**.
976
+ """
977
+ )
978
+ with gr.Column():
979
+ gr.Markdown(
980
+ """
981
+ **How it works**
982
+ - The video frames are processed by the **Qwen3-VL** vision-language model.
983
+ - You can ask **any question** about the video: describe scenes, identify actions, count objects, etc.
984
+ - If no prompt is provided, the model will **describe the video in detail**.
985
+ """
986
+ )
987
+
988
+ with gr.Column():
989
+ with gr.Row():
990
+ qa_video_in = gr.Video(label="Upload Video", sources=["upload", "webcam"], height=400)
991
+
992
+ with gr.Row():
993
+ qa_prompt_in = gr.Textbox(
994
+ label="Text Prompt / Question",
995
+ placeholder="Describe this video in detail. / What is happening in this video? / How many people are visible?",
996
+ lines=3
997
+ )
998
+ with gr.Row():
999
+ qa_max_tokens = gr.Slider(
1000
+ label="Max New Tokens",
1001
+ minimum=64,
1002
+ maximum=2048,
1003
+ value=1024,
1004
+ step=64,
1005
+ info="Maximum number of tokens in the generated response"
1006
+ )
1007
+
1008
+ qa_process_btn = gr.Button("Analyze Video", variant="primary")
1009
+
1010
+ qa_output = gr.Textbox(label="Model Response", lines=12)
1011
+
1012
+ gr.Examples(
1013
+ examples=[
1014
+ ["examples/1.mp4"],
1015
+ ["examples/2.mp4"],
1016
+ ["examples/3.mp4"],
1017
+ ],
1018
+ inputs=[qa_video_in],
1019
+ label="Examples"
1020
+ )
1021
+
1022
+ with gr.Tab("ZeroGPU Duration"):
1023
+ with gr.Row():
1024
+ with gr.Column():
1025
+ gr.Markdown(
1026
+ """
1027
+ ## ZeroGPU Duration Settings
1028
+
1029
+ Configure the **maximum GPU allocation time** for all processing tasks across every tab.
1030
+ This setting is **shared globally** — changing it here affects:
1031
+
1032
+ - **Text-guided Object Tracking** (Tab 1)
1033
+ - **Points Tracker** (Tab 2)
1034
+ - **Any Video QA** (Tab 3)
1035
+ """
1036
+ )
1037
+ with gr.Column():
1038
+ gr.Markdown(
1039
+ """
1040
+ ## Duration Guide
1041
+
1042
+ | Duration | Best For |
1043
+ |----------|----------|
1044
+ | **60s** | Short videos (1-3s), simple prompts |
1045
+ | **120s** | Medium videos (3-5s), 1-2 prompts |
1046
+ | **180s** | Longer videos (5-8s), multiple prompts |
1047
+ | **240s** | Complex multi-object tracking |
1048
+ | **300s** | Maximum processing time |
1049
+ """
1050
+ )
1051
+
1052
+ with gr.Column():
1053
+ with gr.Row(elem_id="gpu-duration-container"):
1054
+ with gr.Column():
1055
+ gr.Markdown("### Select GPU Duration (seconds)")
1056
+ gr.Markdown(
1057
+ "*Slide to choose how long the GPU will be reserved for each processing request. "
1058
+ "Higher values allow longer/more complex videos but consume more GPU quota.*"
1059
+ )
1060
+ radioanimated_gpu_duration = RadioAnimated(
1061
+ choices=["60", "90", "120", "180", "240", "300", "360"],
1062
+ value="90",
1063
+ elem_id="radioanimated_gpu_duration"
1064
+ )
1065
+
1066
+ with gr.Row():
1067
+ with gr.Column(elem_id="gpu-info-box"):
1068
+ gpu_display = gr.Markdown(
1069
+ value="**Currently selected:** `90 seconds`"
1070
+ )
1071
+
1072
+ with gr.Row():
1073
+ with gr.Column():
1074
+ gr.Markdown(
1075
+ """
1076
+ ### Important Notes
1077
+
1078
+ - **Higher duration = more GPU quota consumed.** Choose the minimum needed for your task.
1079
+ - On **Hugging Face ZeroGPU Spaces**, each user has a daily GPU quota. Be mindful of usage.
1080
+ - If processing **times out**, increase the duration and retry.
1081
+ - The duration is the **maximum allowed time** — if processing finishes early, the GPU is released.
1082
+ - **Default: 90 seconds** is sufficient for most short video tasks.
1083
+ """
1084
+ )
1085
+
1086
+ with gr.Row():
1087
+ with gr.Column():
1088
+ gr.Markdown(
1089
+ """
1090
+ ### 🔧 Troubleshooting
1091
+
1092
+ | Issue | Solution |
1093
+ |-------|----------|
1094
+ | Processing times out | Increase GPU duration to 180s or 240s |
1095
+ | GPU quota exhausted | Wait for quota reset or use shorter durations |
1096
+ | Video too long | Trim to under 8 seconds before uploading |
1097
+ | Multiple prompts slow | Use fewer comma-separated prompts or increase duration |
1098
+ """
1099
+ )
1100
+
1101
+ def update_gpu_display(val: str):
1102
+ duration = apply_gpu_duration(val)
1103
+ return duration, f"**Currently selected:** `{duration} seconds`"
1104
+
1105
+ radioanimated_gpu_duration.change(
1106
+ fn=update_gpu_display,
1107
+ inputs=radioanimated_gpu_duration,
1108
+ outputs=[gpu_duration_state, gpu_display],
1109
+ api_visibility="private"
1110
+ )
1111
+
1112
+ process_btn.click(
1113
+ fn=process_and_render,
1114
+ inputs=[state, video_in, prompt_in, fps_slider, gpu_duration_state],
1115
+ outputs=[status_out, rendered_out],
1116
+ show_progress=True
1117
+ )
1118
+
1119
+ pt_process_btn.click(
1120
+ fn=process_and_render_points,
1121
+ inputs=[pt_state, pt_video_in, pt_prompt_in, pt_fps_slider, gpu_duration_state],
1122
+ outputs=[pt_status_out, pt_rendered_out],
1123
+ show_progress=True
1124
+ )
1125
+
1126
+ qa_process_btn.click(
1127
+ fn=process_video_qa,
1128
+ inputs=[qa_video_in, qa_prompt_in, qa_max_tokens, gpu_duration_state],
1129
+ outputs=[qa_output],
1130
+ show_progress=True
1131
+ )
1132
+
1133
+ demo.queue().launch(css=css, theme=Soft(primary_hue="orange", secondary_hue="rose"), ssr_mode=False, mcp_server=True)