Spaces:
Running on Zero
Running on Zero
| import colorsys | |
| import gc | |
| import tempfile | |
| import re | |
| import json | |
| import ast | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from typing import Iterator | |
| from gradio.themes import Soft | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoProcessor, Qwen3VLForConditionalGeneration | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID_V = "prithivMLmods/Qwen3-VL-4B-Instruct-Unredacted-MAX" | |
| DTYPE = torch.bfloat16 | |
| print(f"Loading {MODEL_ID_V}...") | |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_V, attn_implementation="kernels-community/flash-attn3", trust_remote_code=True, torch_dtype=DTYPE | |
| ).to(device).eval() | |
| print("Model loaded successfully.") | |
| MAX_SECONDS = 5.0 | |
| SYSTEM_PROMPT = """You are a helpful assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled between 0 and 1000. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...].""" | |
| POINT_SYSTEM_PROMPT = """You are a precise object pointing assistant. When asked to point to an object in an image, you must return ONLY the exact center coordinates of that specific object as [x, y] with values scaled between 0 and 1000 (where 0,0 is the top-left corner and 1000,1000 is the bottom-right corner). | |
| Rules: | |
| 1. ONLY point to objects that exactly match the description given. | |
| 2. Do NOT point to background, empty areas, or unrelated objects. | |
| 3. If there are multiple matching instances, return [[x1, y1], [x2, y2], ...]. | |
| 4. If no matching object is found, return an empty list []. | |
| 5. Return ONLY the coordinate numbers, no explanations or other text. | |
| 6. Be extremely precise — place the point at the exact visual center of each matching object.""" | |
| POINTS_REGEX = re.compile(r'(?:(\d+)\s*[.:])?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)') | |
| def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]: | |
| cap = cv2.VideoCapture(video_path_or_url) | |
| frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
| fps_val = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| return frames, {"num_frames": len(frames), "fps": float(fps_val) if fps_val > 0 else None} | |
| def parse_bboxes_from_text(text: str) -> list[list[float]]: | |
| text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL) | |
| nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text) | |
| if nested: | |
| try: | |
| all_b = [] | |
| for m in nested: | |
| parsed = json.loads(m) | |
| all_b.extend(parsed if isinstance(parsed[0], list) else [parsed]) | |
| return all_b | |
| except (json.JSONDecodeError, IndexError): | |
| pass | |
| single = re.findall( | |
| r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text) | |
| if single: | |
| return [[float(v) for v in m] for m in single] | |
| nums = re.findall(r'(\d+(?:\.\d+)?)', text) | |
| return [[float(nums[i]), float(nums[i + 1]), float(nums[i + 2]), float(nums[i + 3])] for i in | |
| range(0, len(nums) - 3, 4)] if len(nums) >= 4 else [] | |
| def parse_precise_points(text: str, image_w: int, image_h: int) -> list[tuple[float, float]]: | |
| text = re.sub(r'<think>.*?</think>', '', text.strip(), flags=re.DOTALL) | |
| raw_points = [] | |
| nested = re.findall(r'\[\s*\[[\d\s,\.]+\](?:\s*,\s*\[[\d\s,\.]+\])*\s*\]', text) | |
| if nested: | |
| try: | |
| for m in nested: | |
| parsed = json.loads(m) | |
| if isinstance(parsed[0], list): | |
| for p in parsed: | |
| if len(p) >= 2: | |
| raw_points.append((float(p[0]), float(p[1]))) | |
| elif len(parsed) >= 2: | |
| raw_points.append((float(parsed[0]), float(parsed[1]))) | |
| except (json.JSONDecodeError, IndexError): | |
| pass | |
| if not raw_points: | |
| single = re.findall(r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', text) | |
| if single: | |
| for m in single: | |
| raw_points.append((float(m[0]), float(m[1]))) | |
| if not raw_points: | |
| for match in POINTS_REGEX.finditer(text): | |
| raw_points.append((float(match.group(2)), float(match.group(3)))) | |
| validated = [] | |
| for sx, sy in raw_points: | |
| if not (0 <= sx <= 1000 and 0 <= sy <= 1000): | |
| continue | |
| px = sx / 1000 * image_w | |
| py = sy / 1000 * image_h | |
| if 0 <= px <= image_w and 0 <= py <= image_h: | |
| validated.append((px, py)) | |
| if len(validated) > 1: | |
| deduped = [validated[0]] | |
| for pt in validated[1:]: | |
| if all(((pt[0] - ex[0]) ** 2 + (pt[1] - ex[1]) ** 2) ** 0.5 >= 15 for ex in deduped): | |
| deduped.append(pt) | |
| validated = deduped | |
| return validated | |
| def bbox_to_mask(bbox_scaled: list[float], width: int, height: int) -> np.ndarray: | |
| mask = np.zeros((height, width), dtype=np.float32) | |
| x1 = max(0, min(int(bbox_scaled[0] / 1000 * width), width - 1)) | |
| y1 = max(0, min(int(bbox_scaled[1] / 1000 * height), height - 1)) | |
| x2 = max(0, min(int(bbox_scaled[2] / 1000 * width), width - 1)) | |
| y2 = max(0, min(int(bbox_scaled[3] / 1000 * height), height - 1)) | |
| mask[y1:y2, x1:x2] = 1.0 | |
| return mask | |
| def bbox_iou(b1, b2): | |
| x1 = max(b1[0], b2[0]) | |
| y1 = max(b1[1], b2[1]) | |
| x2 = min(b1[2], b2[2]) | |
| y2 = min(b1[3], b2[3]) | |
| inter = max(0, x2 - x1) * max(0, y2 - y1) | |
| union = (b1[2] - b1[0]) * (b1[3] - b1[1]) + (b2[2] - b2[0]) * (b2[3] - b2[1]) - inter | |
| return inter / union if union > 0 else 0.0 | |
| def bbox_center_distance(b1, b2): | |
| c1 = ((b1[0] + b1[2]) / 2, (b1[1] + b1[3]) / 2) | |
| c2 = ((b2[0] + b2[2]) / 2, (b2[1] + b2[3]) / 2) | |
| return ((c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2) ** 0.5 | |
| def pixel_point_distance(p1, p2): | |
| return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5 | |
| def overlay_masks_on_frame(frame: Image.Image, masks: dict, colors_map: dict, alpha=0.5) -> Image.Image: | |
| base = np.array(frame).astype(np.float32) / 255 | |
| overlay = base.copy() | |
| for oid, mask in masks.items(): | |
| if mask is None: | |
| continue | |
| color = np.array(colors_map.get(oid, (255, 0, 0)), dtype=np.float32) / 255 | |
| if mask.ndim == 3: | |
| mask = mask.squeeze() | |
| m = np.clip(mask, 0, 1)[..., None] | |
| overlay = (1 - alpha * m) * overlay + (alpha * m) * color | |
| return Image.fromarray(np.clip(overlay * 255, 0, 255).astype(np.uint8)) | |
| def pastel_color_for_prompt(prompt: str): | |
| hue = (sum(ord(c) for c in prompt) * 2654435761 % 360) / 360 | |
| r, g, b = colorsys.hsv_to_rgb(hue, 0.5, 0.95) | |
| return int(r * 255), int(g * 255), int(b * 255) | |
| def get_font(image_height: int): | |
| font_size = max(10, int(13 * image_height / 720)) | |
| try: | |
| font_paths = [ | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", | |
| "/System/Library/Fonts/Helvetica.ttc", | |
| "arial.ttf", | |
| ] | |
| for fp in font_paths: | |
| try: | |
| return ImageFont.truetype(fp, font_size), font_size | |
| except OSError: | |
| continue | |
| except Exception: | |
| pass | |
| return ImageFont.load_default(), 13 | |
| def detect_objects_in_frame(frame: Image.Image, prompt: str) -> list[list[float]]: | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, | |
| {"role": "user", | |
| "content": [{"type": "image", "image": frame}, | |
| {"type": "text", "text": f"Detect all instances of: {prompt}"}]} | |
| ] | |
| text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False) | |
| generated = out[:, inputs.input_ids.shape[1]:] | |
| txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0] | |
| return parse_bboxes_from_text(txt) | |
| def detect_precise_points_in_frame(frame: Image.Image, prompt: str) -> list[tuple[float, float]]: | |
| w, h = frame.size | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, | |
| {"role": "user", | |
| "content": [{"type": "image", "image": frame}, | |
| {"type": "text", | |
| "text": f"Detect all instances of: {prompt}. Return only bounding boxes for objects that exactly match this description."}]} | |
| ] | |
| text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor_v(text=[text], images=[frame], padding=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False) | |
| generated = out[:, inputs.input_ids.shape[1]:] | |
| txt = processor_v.batch_decode(generated, skip_special_tokens=True)[0] | |
| bboxes = parse_bboxes_from_text(txt) | |
| if bboxes: | |
| points = [] | |
| for b in bboxes: | |
| bw = abs(b[2] - b[0]) | |
| bh = abs(b[3] - b[1]) | |
| if bw < 5 or bh < 5: | |
| continue | |
| if bw > 950 and bh > 950: | |
| continue | |
| cx = (b[0] + b[2]) / 2 / 1000 * w | |
| cy = (b[1] + b[3]) / 2 / 1000 * h | |
| if 0 <= cx <= w and 0 <= cy <= h: | |
| points.append((cx, cy)) | |
| if len(points) > 1: | |
| deduped = [points[0]] | |
| for pt in points[1:]: | |
| if all(pixel_point_distance(pt, ex) >= 20 for ex in deduped): | |
| deduped.append(pt) | |
| points = deduped | |
| if points: | |
| return points | |
| messages2 = [ | |
| {"role": "system", "content": [{"type": "text", "text": POINT_SYSTEM_PROMPT}]}, | |
| {"role": "user", | |
| "content": [{"type": "image", "image": frame}, | |
| {"type": "text", | |
| "text": f"Point to the exact center of each '{prompt}' in this image. Only point to objects that are clearly '{prompt}', nothing else."}]} | |
| ] | |
| text2 = processor_v.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) | |
| inputs2 = processor_v(text=[text2], images=[frame], padding=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out2 = model_v.generate(**inputs2, max_new_tokens=512, do_sample=False) | |
| generated2 = out2[:, inputs2.input_ids.shape[1]:] | |
| txt2 = processor_v.batch_decode(generated2, skip_special_tokens=True)[0] | |
| return parse_precise_points(txt2, w, h) | |
| def run_model_inference(image: Image.Image, prompt: str) -> str: | |
| messages = [ | |
| {"role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ]} | |
| ] | |
| text = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor_v(text=[text], images=[image], padding=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model_v.generate(**inputs, max_new_tokens=512, do_sample=False) | |
| generated = out[:, inputs.input_ids.shape[1]:] | |
| result = processor_v.batch_decode(generated, skip_special_tokens=True)[0] | |
| result = re.sub(r'<think>.*?</think>', '', result.strip(), flags=re.DOTALL).strip() | |
| return result | |
| def safe_parse_json(text: str): | |
| text = text.strip() | |
| text = re.sub(r"^```(json)?", "", text) | |
| text = re.sub(r"```$", "", text) | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| return ast.literal_eval(text) | |
| except Exception: | |
| return {} | |
| def annotate_image_detection(image: Image.Image, result: dict) -> Image.Image: | |
| if not isinstance(image, Image.Image) or not isinstance(result, dict): | |
| return image | |
| image = image.convert("RGB") | |
| original_width, original_height = image.size | |
| draw = ImageDraw.Draw(image) | |
| font, font_size = get_font(original_height) | |
| if "objects" in result and result["objects"]: | |
| colors_list = [ | |
| (66, 133, 244), (234, 67, 53), (251, 188, 4), (52, 168, 83), | |
| (255, 109, 0), (171, 71, 188), (0, 172, 193), (255, 82, 82), | |
| (46, 125, 50), (121, 85, 72), | |
| ] | |
| for idx, obj in enumerate(result["objects"]): | |
| x_min = int(obj.get("x_min", 0.0) * original_width) | |
| y_min = int(obj.get("y_min", 0.0) * original_height) | |
| x_max = int(obj.get("x_max", 0.0) * original_width) | |
| y_max = int(obj.get("y_max", 0.0) * original_height) | |
| color = colors_list[idx % len(colors_list)] | |
| draw.rectangle([(x_min, y_min), (x_max, y_max)], outline=color, width=3) | |
| label = obj.get("label", f"Object {idx + 1}") | |
| padding = max(2, int(4 * original_height / 720)) | |
| label_y = max(0, y_min - int(20 * original_height / 720)) | |
| tb = draw.textbbox((x_min, label_y), label, font=font) | |
| draw.rectangle( | |
| [(tb[0] - padding, tb[1] - padding), (tb[2] + padding, tb[3] + padding)], | |
| fill=color | |
| ) | |
| draw.text((x_min, label_y), label, fill="white", font=font) | |
| return image | |
| def annotate_image_points(image: Image.Image, result: dict) -> Image.Image: | |
| if not isinstance(image, Image.Image) or not isinstance(result, dict): | |
| return image | |
| image = image.convert("RGB") | |
| original_width, original_height = image.size | |
| draw = ImageDraw.Draw(image) | |
| if "points" in result and result["points"]: | |
| for idx, p in enumerate(result["points"]): | |
| px = int(p["x"] * original_width) | |
| py = int(p["y"] * original_height) | |
| r_outer = max(8, int(10 * original_height / 720)) | |
| r_inner = max(5, int(7 * original_height / 720)) | |
| r_dot = max(1, int(2 * original_height / 720)) | |
| draw.ellipse((px - r_outer, py - r_outer, px + r_outer, py + r_outer), outline="white", width=2) | |
| draw.ellipse((px - r_inner, py - r_inner, px + r_inner, py + r_inner), fill=(255, 40, 40), outline=(255, 40, 40)) | |
| draw.ellipse((px - r_dot, py - r_dot, px + r_dot, py + r_dot), fill=(255, 200, 200)) | |
| return image | |
| class TrackingState: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.video_frames: list[Image.Image] = [] | |
| self.video_fps: float | None = None | |
| self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {} | |
| self.bboxes_by_frame: dict[int, dict[int, list[float]]] = {} | |
| self.color_by_obj: dict[int, tuple[int, int, int]] = {} | |
| self.color_by_prompt: dict[str, tuple[int, int, int]] = {} | |
| self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {} | |
| self.composited_frames: dict[int, Image.Image] = {} | |
| self.prompts: dict[str, list[int]] = {} | |
| self.next_obj_id: int = 1 | |
| self.current_frame_idx: int = 0 | |
| def num_frames(self) -> int: | |
| return len(self.video_frames) | |
| class PointTrackingState: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.video_frames: list[Image.Image] = [] | |
| self.video_fps: float | None = None | |
| self.points_by_frame: dict[int, list[tuple[float, float]]] = {} | |
| self.trails: list[list[tuple[int, float, float]]] = [] | |
| self.composited_frames: dict[int, Image.Image] = {} | |
| self.prompt_text: str = "" | |
| self.current_frame_idx: int = 0 | |
| def num_frames(self) -> int: | |
| return len(self.video_frames) | |
| def compose_tracking_frame(state: TrackingState, frame_idx: int) -> Image.Image: | |
| if state is None or not state.video_frames: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| frame = state.video_frames[frame_idx].copy() | |
| w, h = frame.size | |
| masks = state.masks_by_frame.get(frame_idx, {}) | |
| if masks: | |
| frame = overlay_masks_on_frame(frame, masks, state.color_by_obj, alpha=0.5) | |
| bboxes = state.bboxes_by_frame.get(frame_idx, {}) | |
| if bboxes: | |
| draw = ImageDraw.Draw(frame) | |
| font, font_size = get_font(h) | |
| padding = max(2, int(4 * h / 720)) | |
| vert_offset = int(20 * h / 720) | |
| for oid, bbox in bboxes.items(): | |
| color = state.color_by_obj.get(oid, (255, 255, 255)) | |
| x1 = int(bbox[0] / 1000 * w) | |
| y1 = int(bbox[1] / 1000 * h) | |
| x2 = int(bbox[2] / 1000 * w) | |
| y2 = int(bbox[3] / 1000 * h) | |
| draw.rectangle((x1, y1, x2, y2), outline=color, width=3) | |
| prompt = state.text_prompts_by_frame_obj.get(frame_idx, {}).get(oid, "") | |
| if prompt: | |
| label = f"{prompt} - ID{oid}" | |
| label_y = max(0, y1 - vert_offset) | |
| tb = draw.textbbox((x1, label_y), label, font=font) | |
| draw.rectangle( | |
| [(tb[0] - padding, tb[1] - padding), (tb[2] + padding, tb[3] + padding)], | |
| fill=color | |
| ) | |
| draw.text((x1, label_y), label, fill="white", font=font) | |
| state.composited_frames[frame_idx] = frame | |
| return frame | |
| def compose_point_frame(pt_state: PointTrackingState, frame_idx: int, trail_length: int = 12) -> Image.Image: | |
| if pt_state is None or not pt_state.video_frames: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(pt_state.video_frames) - 1)) | |
| frame = pt_state.video_frames[frame_idx].copy() | |
| draw = ImageDraw.Draw(frame) | |
| RED = (255, 40, 40) | |
| DARK_RED = (180, 0, 0) | |
| for trail in pt_state.trails: | |
| trail_pts = [(tx, ty) for fi, tx, ty in trail if fi <= frame_idx and fi > frame_idx - trail_length] | |
| if len(trail_pts) >= 2: | |
| for t_idx in range(len(trail_pts) - 1): | |
| alpha_ratio = (t_idx + 1) / len(trail_pts) | |
| trail_color = ( | |
| int(DARK_RED[0] * alpha_ratio), | |
| int(DARK_RED[1] * alpha_ratio), | |
| int(DARK_RED[2] * alpha_ratio) | |
| ) | |
| thickness = max(1, int(2 * alpha_ratio)) | |
| x1t, y1t = int(trail_pts[t_idx][0]), int(trail_pts[t_idx][1]) | |
| x2t, y2t = int(trail_pts[t_idx + 1][0]), int(trail_pts[t_idx + 1][1]) | |
| draw.line([(x1t, y1t), (x2t, y2t)], fill=trail_color, width=thickness) | |
| points_f = pt_state.points_by_frame.get(frame_idx, []) | |
| for (px, py) in points_f: | |
| draw.ellipse((px - 10, py - 10, px + 10, py + 10), outline="white", width=2) | |
| draw.ellipse((px - 7, py - 7, px + 7, py + 7), fill=RED, outline=RED) | |
| draw.ellipse((px - 2, py - 2, px + 2, py + 2), fill=(255, 200, 200)) | |
| pt_state.composited_frames[frame_idx] = frame | |
| return frame | |
| def update_tracking_display(state: TrackingState, frame_idx: int) -> Image.Image: | |
| if state is None or not state.video_frames: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| cached = state.composited_frames.get(frame_idx) | |
| if cached is not None: | |
| return cached | |
| return compose_tracking_frame(state, frame_idx) | |
| def update_point_display(pt_state: PointTrackingState, frame_idx: int) -> Image.Image: | |
| if pt_state is None or not pt_state.video_frames: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(pt_state.video_frames) - 1)) | |
| cached = pt_state.composited_frames.get(frame_idx) | |
| if cached is not None: | |
| return cached | |
| return compose_point_frame(pt_state, frame_idx) | |
| def _get_active_prompts_tracking(state: TrackingState) -> str: | |
| if state is None or not state.prompts: | |
| return "**Active prompts:** None" | |
| prompts_str = ", ".join([f"'{p}' ({len(ids)} obj)" for p, ids in state.prompts.items()]) | |
| return f"**Active prompts:** {prompts_str}" | |
| def _get_active_prompts_points(pt_state: PointTrackingState) -> str: | |
| if pt_state is None or not pt_state.prompt_text: | |
| return "**Active prompts:** None" | |
| return f"**Active prompts:** '{pt_state.prompt_text}' ({len(pt_state.trails)} tracked points)" | |
| def init_tracking_video(state: TrackingState, video) -> tuple[TrackingState, int, int, Image.Image, str]: | |
| state.reset() | |
| if isinstance(video, dict): | |
| path = video.get("name") or video.get("path") or video.get("data") | |
| else: | |
| path = video | |
| if not path: | |
| raise gr.Error("Invalid video input.") | |
| frames, info = try_load_video_frames(path) | |
| if not frames: | |
| raise gr.Error("No frames could be loaded from the video.") | |
| trimmed_note = "" | |
| fps_in = info.get("fps") | |
| max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames) | |
| if len(frames) > max_frames_allowed: | |
| frames = frames[:max_frames_allowed] | |
| trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" | |
| state.video_frames = frames | |
| state.video_fps = float(fps_in) if fps_in else None | |
| first_frame = frames[0] | |
| max_idx = len(frames) - 1 | |
| status = f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. Ready for text prompting." | |
| return state, 0, max_idx, first_frame, status | |
| def init_point_video(pt_state: PointTrackingState, video) -> tuple[PointTrackingState, int, int, Image.Image, str]: | |
| pt_state.reset() | |
| if isinstance(video, dict): | |
| path = video.get("name") or video.get("path") or video.get("data") | |
| else: | |
| path = video | |
| if not path: | |
| raise gr.Error("Invalid video input.") | |
| frames, info = try_load_video_frames(path) | |
| if not frames: | |
| raise gr.Error("No frames could be loaded from the video.") | |
| trimmed_note = "" | |
| fps_in = info.get("fps") | |
| max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames) | |
| if len(frames) > max_frames_allowed: | |
| frames = frames[:max_frames_allowed] | |
| trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" | |
| pt_state.video_frames = frames | |
| pt_state.video_fps = float(fps_in) if fps_in else None | |
| first_frame = frames[0] | |
| max_idx = len(frames) - 1 | |
| status = f"Loaded {len(frames)} frames @ {pt_state.video_fps or 'unknown'} fps{trimmed_note}. Ready for point tracking." | |
| return pt_state, 0, max_idx, first_frame, status | |
| def apply_tracking_prompt_on_frame( | |
| state: TrackingState, | |
| frame_idx: int, | |
| text_prompt: str, | |
| ) -> tuple[Image.Image, str, str, TrackingState]: | |
| if state is None or not state.video_frames: | |
| return None, "Upload a video first.", "**Active prompts:** None", state | |
| if not text_prompt or not text_prompt.strip(): | |
| ap = _get_active_prompts_tracking(state) | |
| return update_tracking_display(state, int(frame_idx)), "Please enter a text prompt.", ap, state | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| frame = state.video_frames[frame_idx] | |
| w, h = frame.size | |
| prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()] | |
| if not prompt_texts: | |
| ap = _get_active_prompts_tracking(state) | |
| return update_tracking_display(state, frame_idx), "Please enter a valid text prompt.", ap, state | |
| status_parts = [f"Processing on frame {frame_idx}:"] | |
| for prompt in prompt_texts: | |
| bboxes = detect_objects_in_frame(frame, prompt) | |
| if prompt not in state.color_by_prompt: | |
| state.color_by_prompt[prompt] = pastel_color_for_prompt(prompt) | |
| masks_f = state.masks_by_frame.setdefault(frame_idx, {}) | |
| bboxes_f = state.bboxes_by_frame.setdefault(frame_idx, {}) | |
| texts_f = state.text_prompts_by_frame_obj.setdefault(frame_idx, {}) | |
| obj_ids_for_prompt = [] | |
| for bbox in bboxes: | |
| oid = state.next_obj_id | |
| state.next_obj_id += 1 | |
| state.color_by_obj[oid] = state.color_by_prompt[prompt] | |
| masks_f[oid] = bbox_to_mask(bbox, w, h) | |
| bboxes_f[oid] = bbox | |
| texts_f[oid] = prompt | |
| state.prompts.setdefault(prompt, []).append(oid) | |
| obj_ids_for_prompt.append(oid) | |
| if obj_ids_for_prompt: | |
| ids_str = ", ".join(map(str, obj_ids_for_prompt)) | |
| status_parts.append(f" • '{prompt}': {len(obj_ids_for_prompt)} object(s) (IDs: {ids_str})") | |
| else: | |
| status_parts.append(f" • '{prompt}': No objects detected.") | |
| state.composited_frames.pop(frame_idx, None) | |
| status = "\n".join(status_parts) | |
| ap = _get_active_prompts_tracking(state) | |
| return update_tracking_display(state, frame_idx), status, ap, state | |
| def propagate_tracking(state: TrackingState) -> Iterator[tuple[TrackingState, str, dict]]: | |
| if state is None or not state.video_frames: | |
| yield state, "Load a video first.", gr.update() | |
| return | |
| if not state.prompts: | |
| yield state, "No prompts defined. Apply text prompt(s) on a frame first.", gr.update() | |
| return | |
| total = state.num_frames | |
| processed = 0 | |
| yield state, f"Propagating: {processed}/{total}", gr.update() | |
| for prompt, obj_ids in list(state.prompts.items()): | |
| seed_frame_idx = None | |
| seed_bboxes_by_oid = {} | |
| for f_idx in sorted(state.bboxes_by_frame.keys()): | |
| for oid in obj_ids: | |
| if oid in state.bboxes_by_frame.get(f_idx, {}): | |
| if seed_frame_idx is None: | |
| seed_frame_idx = f_idx | |
| if f_idx == seed_frame_idx: | |
| seed_bboxes_by_oid[oid] = state.bboxes_by_frame[f_idx][oid] | |
| if seed_frame_idx is None: | |
| continue | |
| # Forward propagation | |
| prev_tracks = [(oid, seed_bboxes_by_oid[oid]) for oid in seed_bboxes_by_oid] | |
| for f_idx in range(seed_frame_idx + 1, total): | |
| frame = state.video_frames[f_idx] | |
| w, h = frame.size | |
| new_bboxes = detect_objects_in_frame(frame, prompt) | |
| masks_f = state.masks_by_frame.setdefault(f_idx, {}) | |
| bboxes_f = state.bboxes_by_frame.setdefault(f_idx, {}) | |
| texts_f = state.text_prompts_by_frame_obj.setdefault(f_idx, {}) | |
| used = set() | |
| matched = {} | |
| scores = [ | |
| (bbox_iou(pbbox, nbbox), pi, ni) | |
| for pi, (_, pbbox) in enumerate(prev_tracks) | |
| for ni, nbbox in enumerate(new_bboxes) | |
| ] | |
| scores.sort(reverse=True) | |
| for score, pi, ni in scores: | |
| if pi in matched or ni in used or score <= 0.05: | |
| continue | |
| matched[pi] = ni | |
| used.add(ni) | |
| for pi, (_, pbbox) in enumerate(prev_tracks): | |
| if pi in matched: | |
| continue | |
| best = min( | |
| ((bbox_center_distance(pbbox, nbbox), ni) for ni, nbbox in enumerate(new_bboxes) if ni not in used), | |
| default=(float('inf'), -1) | |
| ) | |
| if best[0] < 300: | |
| matched[pi] = best[1] | |
| used.add(best[1]) | |
| new_prev = [] | |
| for pi, (oid, _) in enumerate(prev_tracks): | |
| if pi in matched: | |
| nbbox = new_bboxes[matched[pi]] | |
| masks_f[oid] = bbox_to_mask(nbbox, w, h) | |
| bboxes_f[oid] = nbbox | |
| texts_f[oid] = prompt | |
| new_prev.append((oid, nbbox)) | |
| for ni, nbbox in enumerate(new_bboxes): | |
| if ni not in used: | |
| oid = state.next_obj_id | |
| state.next_obj_id += 1 | |
| state.color_by_obj[oid] = state.color_by_prompt.get(prompt, pastel_color_for_prompt(prompt)) | |
| masks_f[oid] = bbox_to_mask(nbbox, w, h) | |
| bboxes_f[oid] = nbbox | |
| texts_f[oid] = prompt | |
| state.prompts.setdefault(prompt, []).append(oid) | |
| new_prev.append((oid, nbbox)) | |
| prev_tracks = new_prev | |
| state.composited_frames.pop(f_idx, None) | |
| processed += 1 | |
| if processed % 5 == 0 or f_idx == total - 1: | |
| yield state, f"Propagating '{prompt}' (forward): frame {f_idx}/{total}", gr.update(value=f_idx) | |
| # Backward propagation | |
| prev_tracks = [(oid, seed_bboxes_by_oid[oid]) for oid in seed_bboxes_by_oid] | |
| for f_idx in range(seed_frame_idx - 1, -1, -1): | |
| frame = state.video_frames[f_idx] | |
| w, h = frame.size | |
| new_bboxes = detect_objects_in_frame(frame, prompt) | |
| masks_f = state.masks_by_frame.setdefault(f_idx, {}) | |
| bboxes_f = state.bboxes_by_frame.setdefault(f_idx, {}) | |
| texts_f = state.text_prompts_by_frame_obj.setdefault(f_idx, {}) | |
| used = set() | |
| matched = {} | |
| scores = [ | |
| (bbox_iou(pbbox, nbbox), pi, ni) | |
| for pi, (_, pbbox) in enumerate(prev_tracks) | |
| for ni, nbbox in enumerate(new_bboxes) | |
| ] | |
| scores.sort(reverse=True) | |
| for score, pi, ni in scores: | |
| if pi in matched or ni in used or score <= 0.05: | |
| continue | |
| matched[pi] = ni | |
| used.add(ni) | |
| for pi, (_, pbbox) in enumerate(prev_tracks): | |
| if pi in matched: | |
| continue | |
| best = min( | |
| ((bbox_center_distance(pbbox, nbbox), ni) for ni, nbbox in enumerate(new_bboxes) if ni not in used), | |
| default=(float('inf'), -1) | |
| ) | |
| if best[0] < 300: | |
| matched[pi] = best[1] | |
| used.add(best[1]) | |
| new_prev = [] | |
| for pi, (oid, _) in enumerate(prev_tracks): | |
| if pi in matched: | |
| nbbox = new_bboxes[matched[pi]] | |
| masks_f[oid] = bbox_to_mask(nbbox, w, h) | |
| bboxes_f[oid] = nbbox | |
| texts_f[oid] = prompt | |
| new_prev.append((oid, nbbox)) | |
| for ni, nbbox in enumerate(new_bboxes): | |
| if ni not in used: | |
| oid = state.next_obj_id | |
| state.next_obj_id += 1 | |
| state.color_by_obj[oid] = state.color_by_prompt.get(prompt, pastel_color_for_prompt(prompt)) | |
| masks_f[oid] = bbox_to_mask(nbbox, w, h) | |
| bboxes_f[oid] = nbbox | |
| texts_f[oid] = prompt | |
| state.prompts.setdefault(prompt, []).append(oid) | |
| new_prev.append((oid, nbbox)) | |
| prev_tracks = new_prev | |
| state.composited_frames.pop(f_idx, None) | |
| processed += 1 | |
| if processed % 5 == 0 or f_idx == 0: | |
| yield state, f"Propagating '{prompt}' (backward): frame {f_idx}/{total}", gr.update(value=f_idx) | |
| yield state, f"✅ Propagation complete across {total} frames for {len(state.prompts)} prompt(s).", gr.update(value=0) | |
| def apply_point_prompt_on_frame( | |
| pt_state: PointTrackingState, | |
| frame_idx: int, | |
| text_prompt: str, | |
| ) -> tuple[Image.Image, str, str, PointTrackingState]: | |
| if pt_state is None or not pt_state.video_frames: | |
| return None, "Upload a video first.", "**Active prompts:** None", pt_state | |
| if not text_prompt or not text_prompt.strip(): | |
| ap = _get_active_prompts_points(pt_state) | |
| return update_point_display(pt_state, int(frame_idx)), "Please enter a text prompt.", ap, pt_state | |
| frame_idx = int(np.clip(frame_idx, 0, len(pt_state.video_frames) - 1)) | |
| frame = pt_state.video_frames[frame_idx] | |
| pt_state.prompt_text = text_prompt.strip() | |
| pt_state.points_by_frame.clear() | |
| pt_state.trails.clear() | |
| pt_state.composited_frames.clear() | |
| prompts = [p.strip() for p in text_prompt.split(",") if p.strip()] | |
| all_points = [] | |
| status_parts = [f"Point detection on frame {frame_idx}:"] | |
| for prompt in prompts: | |
| points = detect_precise_points_in_frame(frame, prompt) | |
| for pt in points: | |
| all_points.append(pt) | |
| track_idx = len(pt_state.trails) | |
| pt_state.trails.append([(frame_idx, pt[0], pt[1])]) | |
| status_parts.append(f" • '{prompt}': {len(points)} point(s)") | |
| pt_state.points_by_frame[frame_idx] = all_points | |
| pt_state.composited_frames.pop(frame_idx, None) | |
| status = "\n".join(status_parts) | |
| ap = _get_active_prompts_points(pt_state) | |
| return update_point_display(pt_state, frame_idx), status, ap, pt_state | |
| def propagate_points(pt_state: PointTrackingState) -> Iterator[tuple[PointTrackingState, str, dict]]: | |
| if pt_state is None or not pt_state.video_frames: | |
| yield pt_state, "Load a video first.", gr.update() | |
| return | |
| if not pt_state.trails: | |
| yield pt_state, "No points defined. Apply point prompt on a frame first.", gr.update() | |
| return | |
| if not pt_state.prompt_text: | |
| yield pt_state, "No prompt text. Apply a text prompt first.", gr.update() | |
| return | |
| total = pt_state.num_frames | |
| prompts = [p.strip() for p in pt_state.prompt_text.split(",") if p.strip()] | |
| seed_frame_idx = None | |
| for f_idx in sorted(pt_state.points_by_frame.keys()): | |
| if pt_state.points_by_frame[f_idx]: | |
| seed_frame_idx = f_idx | |
| break | |
| if seed_frame_idx is None: | |
| yield pt_state, "No seed points found.", gr.update() | |
| return | |
| yield pt_state, f"Propagating points: 0/{total}", gr.update() | |
| seed_tracks = [] | |
| for trail_idx, trail in enumerate(pt_state.trails): | |
| for fi, tx, ty in trail: | |
| if fi == seed_frame_idx: | |
| seed_tracks.append((trail_idx, (tx, ty))) | |
| break | |
| prev_tracks = list(seed_tracks) | |
| lost_count = {t[0]: 0 for t in prev_tracks} | |
| for f_idx in range(seed_frame_idx + 1, total): | |
| frame = pt_state.video_frames[f_idx] | |
| w, h = frame.size | |
| all_new_points = [] | |
| for prompt in prompts: | |
| pts = detect_precise_points_in_frame(frame, prompt) | |
| all_new_points.extend(pts) | |
| points_f = [] | |
| diag = (w ** 2 + h ** 2) ** 0.5 | |
| match_threshold = diag * 0.25 | |
| if not all_new_points: | |
| new_prev = [] | |
| for track_idx, prev_pt in prev_tracks: | |
| lost_count[track_idx] = lost_count.get(track_idx, 0) + 1 | |
| if lost_count[track_idx] > 5: | |
| continue | |
| points_f.append(prev_pt) | |
| pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1])) | |
| new_prev.append((track_idx, prev_pt)) | |
| prev_tracks = new_prev | |
| else: | |
| used_new = set() | |
| matched = {} | |
| dist_pairs = [] | |
| for pi, (_, prev_pt) in enumerate(prev_tracks): | |
| for ni, new_pt in enumerate(all_new_points): | |
| d = pixel_point_distance(prev_pt, new_pt) | |
| dist_pairs.append((d, pi, ni)) | |
| dist_pairs.sort() | |
| for d, pi, ni in dist_pairs: | |
| if pi in matched or ni in used_new: | |
| continue | |
| if d < match_threshold: | |
| matched[pi] = ni | |
| used_new.add(ni) | |
| new_prev = [] | |
| for pi, (track_idx, prev_pt) in enumerate(prev_tracks): | |
| if pi in matched: | |
| new_pt = all_new_points[matched[pi]] | |
| points_f.append(new_pt) | |
| pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1])) | |
| new_prev.append((track_idx, new_pt)) | |
| lost_count[track_idx] = 0 | |
| else: | |
| lost_count[track_idx] = lost_count.get(track_idx, 0) + 1 | |
| if lost_count[track_idx] <= 5: | |
| points_f.append(prev_pt) | |
| pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1])) | |
| new_prev.append((track_idx, prev_pt)) | |
| for ni, new_pt in enumerate(all_new_points): | |
| if ni not in used_new: | |
| too_close = any(pixel_point_distance(new_pt, pp) < diag * 0.08 for _, pp in new_prev) | |
| if not too_close: | |
| track_idx = len(pt_state.trails) | |
| pt_state.trails.append([(f_idx, new_pt[0], new_pt[1])]) | |
| points_f.append(new_pt) | |
| new_prev.append((track_idx, new_pt)) | |
| lost_count[track_idx] = 0 | |
| prev_tracks = new_prev | |
| pt_state.points_by_frame[f_idx] = points_f | |
| pt_state.composited_frames.pop(f_idx, None) | |
| if (f_idx - seed_frame_idx) % 5 == 0 or f_idx == total - 1: | |
| yield pt_state, f"Propagating points (forward): frame {f_idx}/{total}", gr.update(value=f_idx) | |
| prev_tracks = list(seed_tracks) | |
| lost_count = {t[0]: 0 for t in prev_tracks} | |
| for f_idx in range(seed_frame_idx - 1, -1, -1): | |
| frame = pt_state.video_frames[f_idx] | |
| w, h = frame.size | |
| all_new_points = [] | |
| for prompt in prompts: | |
| pts = detect_precise_points_in_frame(frame, prompt) | |
| all_new_points.extend(pts) | |
| points_f = [] | |
| diag = (w ** 2 + h ** 2) ** 0.5 | |
| match_threshold = diag * 0.25 | |
| if not all_new_points: | |
| new_prev = [] | |
| for track_idx, prev_pt in prev_tracks: | |
| lost_count[track_idx] = lost_count.get(track_idx, 0) + 1 | |
| if lost_count[track_idx] > 5: | |
| continue | |
| points_f.append(prev_pt) | |
| pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1])) | |
| new_prev.append((track_idx, prev_pt)) | |
| prev_tracks = new_prev | |
| else: | |
| used_new = set() | |
| matched = {} | |
| dist_pairs = [] | |
| for pi, (_, prev_pt) in enumerate(prev_tracks): | |
| for ni, new_pt in enumerate(all_new_points): | |
| d = pixel_point_distance(prev_pt, new_pt) | |
| dist_pairs.append((d, pi, ni)) | |
| dist_pairs.sort() | |
| for d, pi, ni in dist_pairs: | |
| if pi in matched or ni in used_new: | |
| continue | |
| if d < match_threshold: | |
| matched[pi] = ni | |
| used_new.add(ni) | |
| new_prev = [] | |
| for pi, (track_idx, prev_pt) in enumerate(prev_tracks): | |
| if pi in matched: | |
| new_pt = all_new_points[matched[pi]] | |
| points_f.append(new_pt) | |
| pt_state.trails[track_idx].append((f_idx, new_pt[0], new_pt[1])) | |
| new_prev.append((track_idx, new_pt)) | |
| lost_count[track_idx] = 0 | |
| else: | |
| lost_count[track_idx] = lost_count.get(track_idx, 0) + 1 | |
| if lost_count[track_idx] <= 5: | |
| points_f.append(prev_pt) | |
| pt_state.trails[track_idx].append((f_idx, prev_pt[0], prev_pt[1])) | |
| new_prev.append((track_idx, prev_pt)) | |
| for ni, new_pt in enumerate(all_new_points): | |
| if ni not in used_new: | |
| too_close = any(pixel_point_distance(new_pt, pp) < diag * 0.08 for _, pp in new_prev) | |
| if not too_close: | |
| track_idx = len(pt_state.trails) | |
| pt_state.trails.append([(f_idx, new_pt[0], new_pt[1])]) | |
| points_f.append(new_pt) | |
| new_prev.append((track_idx, new_pt)) | |
| lost_count[track_idx] = 0 | |
| prev_tracks = new_prev | |
| pt_state.points_by_frame[f_idx] = points_f | |
| pt_state.composited_frames.pop(f_idx, None) | |
| if (seed_frame_idx - f_idx) % 5 == 0 or f_idx == 0: | |
| yield pt_state, f"Propagating points (backward): frame {f_idx}/{total}", gr.update(value=f_idx) | |
| yield pt_state, f"✅ Point propagation complete across {total} frames. {len(pt_state.trails)} tracks.", gr.update(value=seed_frame_idx) | |
| def reset_tracking_prompts(state: TrackingState) -> tuple[TrackingState, Image.Image, str, str]: | |
| if state is None: | |
| return state, None, "No active session.", "**Active prompts:** None" | |
| state.masks_by_frame.clear() | |
| state.bboxes_by_frame.clear() | |
| state.text_prompts_by_frame_obj.clear() | |
| state.composited_frames.clear() | |
| state.color_by_obj.clear() | |
| state.color_by_prompt.clear() | |
| state.prompts.clear() | |
| state.next_obj_id = 1 | |
| current_idx = max(0, min(getattr(state, 'current_frame_idx', 0), state.num_frames - 1)) | |
| preview = update_tracking_display(state, current_idx) | |
| return state, preview, "Prompts and outputs reset. Video preserved.", "**Active prompts:** None" | |
| def reset_tracking_session(state: TrackingState) -> tuple[TrackingState, Image.Image, dict, dict, str, str]: | |
| if not state.video_frames: | |
| return state, None, gr.update(minimum=0, maximum=0), gr.update(value=0), "Session reset.", "**Active prompts:** None" | |
| state.masks_by_frame.clear() | |
| state.bboxes_by_frame.clear() | |
| state.text_prompts_by_frame_obj.clear() | |
| state.composited_frames.clear() | |
| state.color_by_obj.clear() | |
| state.color_by_prompt.clear() | |
| state.prompts.clear() | |
| state.next_obj_id = 1 | |
| gc.collect() | |
| current_idx = max(0, min(getattr(state, 'current_frame_idx', 0), state.num_frames - 1)) | |
| preview = update_tracking_display(state, current_idx) | |
| return ( | |
| state, preview, | |
| gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True), | |
| gr.update(value=current_idx), | |
| "Session reset. Prompts cleared; video preserved.", | |
| "**Active prompts:** None" | |
| ) | |
| def reset_point_prompts(pt_state: PointTrackingState) -> tuple[PointTrackingState, Image.Image, str, str]: | |
| if pt_state is None: | |
| return pt_state, None, "No active session.", "**Active prompts:** None" | |
| pt_state.points_by_frame.clear() | |
| pt_state.trails.clear() | |
| pt_state.composited_frames.clear() | |
| pt_state.prompt_text = "" | |
| current_idx = max(0, min(getattr(pt_state, 'current_frame_idx', 0), pt_state.num_frames - 1)) | |
| preview = update_point_display(pt_state, current_idx) | |
| return pt_state, preview, "Point prompts reset. Video preserved.", "**Active prompts:** None" | |
| def reset_point_session(pt_state: PointTrackingState) -> tuple[PointTrackingState, Image.Image, dict, dict, str, str]: | |
| if not pt_state.video_frames: | |
| return pt_state, None, gr.update(minimum=0, maximum=0), gr.update(value=0), "Session reset.", "**Active prompts:** None" | |
| pt_state.points_by_frame.clear() | |
| pt_state.trails.clear() | |
| pt_state.composited_frames.clear() | |
| pt_state.prompt_text = "" | |
| gc.collect() | |
| current_idx = max(0, min(getattr(pt_state, 'current_frame_idx', 0), pt_state.num_frames - 1)) | |
| preview = update_point_display(pt_state, current_idx) | |
| return ( | |
| pt_state, preview, | |
| gr.update(minimum=0, maximum=max(pt_state.num_frames - 1, 0), interactive=True), | |
| gr.update(value=current_idx), | |
| "Session reset. Video preserved.", | |
| "**Active prompts:** None" | |
| ) | |
| def render_tracking_video(state: TrackingState) -> str: | |
| if state is None or state.num_frames == 0: | |
| raise gr.Error("Load a video first.") | |
| fps = state.video_fps if state.video_fps and state.video_fps > 0 else 12 | |
| frames_bgr = [] | |
| w, h = state.video_frames[0].size | |
| for idx in range(state.num_frames): | |
| img = state.composited_frames.get(idx) | |
| if img is None: | |
| img = compose_tracking_frame(state, idx) | |
| frames_bgr.append(np.array(img)[:, :, ::-1]) | |
| if (idx + 1) % 60 == 0: | |
| gc.collect() | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| writer = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) | |
| for fr in frames_bgr: | |
| writer.write(fr) | |
| writer.release() | |
| return tmp.name | |
| def render_point_video(pt_state: PointTrackingState) -> str: | |
| if pt_state is None or pt_state.num_frames == 0: | |
| raise gr.Error("Load a video first.") | |
| fps = pt_state.video_fps if pt_state.video_fps and pt_state.video_fps > 0 else 12 | |
| frames_bgr = [] | |
| w, h = pt_state.video_frames[0].size | |
| for idx in range(pt_state.num_frames): | |
| img = pt_state.composited_frames.get(idx) | |
| if img is None: | |
| img = compose_point_frame(pt_state, idx) | |
| frames_bgr.append(np.array(img)[:, :, ::-1]) | |
| if (idx + 1) % 60 == 0: | |
| gc.collect() | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| writer = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) | |
| for fr in frames_bgr: | |
| writer.write(fr) | |
| writer.release() | |
| return tmp.name | |
| def _on_video_change_tracking(state: TrackingState, video) -> tuple[TrackingState, dict, Image.Image, str, str]: | |
| if video is None: | |
| return state, gr.update(), None, "", "**Active prompts:** None" | |
| state, min_idx, max_idx, first_frame, status = init_tracking_video(state, video) | |
| ap = _get_active_prompts_tracking(state) | |
| return ( | |
| state, | |
| gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), | |
| first_frame, | |
| status, | |
| ap, | |
| ) | |
| def _on_video_change_points(pt_state: PointTrackingState, video) -> tuple[PointTrackingState, dict, Image.Image, str, str]: | |
| if video is None: | |
| return pt_state, gr.update(), None, "", "**Active prompts:** None" | |
| pt_state, min_idx, max_idx, first_frame, status = init_point_video(pt_state, video) | |
| ap = _get_active_prompts_points(pt_state) | |
| return ( | |
| pt_state, | |
| gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), | |
| first_frame, | |
| status, | |
| ap, | |
| ) | |
| def process_image_detection(image: Image.Image, prompt: str) -> tuple[Image.Image, str]: | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please provide a detection prompt.") | |
| image = image.convert("RGB") | |
| image.thumbnail((1024, 1024)) | |
| original_width, original_height = image.size | |
| full_prompt = f"Provide bounding box coordinates for {prompt}. Report in JSON format." | |
| output_text = run_model_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| objects_result = {"objects": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "bbox_2d" in item and len(item["bbox_2d"]) == 4: | |
| xmin, ymin, xmax, ymax = item["bbox_2d"] | |
| label = item.get("label", "object") | |
| objects_result["objects"].append({ | |
| "x_min": xmin / 1000.0, | |
| "y_min": ymin / 1000.0, | |
| "x_max": xmax / 1000.0, | |
| "y_max": ymax / 1000.0, | |
| "label": label, | |
| }) | |
| elif isinstance(parsed_json, dict): | |
| if "bbox_2d" in parsed_json and len(parsed_json["bbox_2d"]) == 4: | |
| xmin, ymin, xmax, ymax = parsed_json["bbox_2d"] | |
| label = parsed_json.get("label", "object") | |
| objects_result["objects"].append({ | |
| "x_min": xmin / 1000.0, | |
| "y_min": ymin / 1000.0, | |
| "x_max": xmax / 1000.0, | |
| "y_max": ymax / 1000.0, | |
| "label": label, | |
| }) | |
| if not objects_result["objects"]: | |
| bboxes = parse_bboxes_from_text(output_text) | |
| for idx, bbox in enumerate(bboxes): | |
| objects_result["objects"].append({ | |
| "x_min": bbox[0] / 1000.0, | |
| "y_min": bbox[1] / 1000.0, | |
| "x_max": bbox[2] / 1000.0, | |
| "y_max": bbox[3] / 1000.0, | |
| "label": prompt.strip(), | |
| }) | |
| annotated = annotate_image_detection(image.copy(), objects_result) | |
| result_text = json.dumps(objects_result, indent=2) if objects_result["objects"] else f"No objects detected for '{prompt}'.\n\nRaw output:\n{output_text}" | |
| return annotated, result_text | |
| def process_image_pointer(image: Image.Image, prompt: str) -> tuple[Image.Image, str]: | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please provide a pointing prompt.") | |
| image = image.convert("RGB") | |
| image.thumbnail((1024, 1024)) | |
| original_width, original_height = image.size | |
| full_prompt = f"Provide 2d point coordinates for {prompt}. Report in JSON format." | |
| output_text = run_model_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| points_result = {"points": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "point_2d" in item and len(item["point_2d"]) == 2: | |
| x, y = item["point_2d"] | |
| points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) | |
| elif isinstance(parsed_json, dict): | |
| if "point_2d" in parsed_json and len(parsed_json["point_2d"]) == 2: | |
| x, y = parsed_json["point_2d"] | |
| points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) | |
| if not points_result["points"]: | |
| detected_points = parse_precise_points(output_text, original_width, original_height) | |
| for px, py in detected_points: | |
| points_result["points"].append({ | |
| "x": px / original_width, | |
| "y": py / original_height, | |
| }) | |
| if not points_result["points"]: | |
| bboxes = parse_bboxes_from_text(output_text) | |
| for bbox in bboxes: | |
| cx = (bbox[0] + bbox[2]) / 2 / 1000.0 | |
| cy = (bbox[1] + bbox[3]) / 2 / 1000.0 | |
| points_result["points"].append({"x": cx, "y": cy}) | |
| annotated = annotate_image_points(image.copy(), points_result) | |
| result_text = json.dumps(points_result, indent=2) if points_result["points"] else f"No points detected for '{prompt}'.\n\nRaw output:\n{output_text}" | |
| return annotated, result_text | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 900px; | |
| } | |
| #main-title h1 { font-size: 2.6em !important; } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **Qwen3-VL-Video-Grounding**", elem_id="main-title") | |
| gr.Markdown( | |
| """ | |
| Perform text-guided object tracking, point tracking, image detection, and image pointing with the Qwen3-VL multimodal model. | |
| **Video tabs:** Upload a video → Select a frame → Apply text prompt(s) → Preview → Propagate → Render MP4. | |
| **Image tabs:** Upload an image → Enter prompt → Get instant detection or pointing results. | |
| Due to compute constraints, this app only supports stop-frame object detection or tracking on propagated frames. For dense-frame full video processing, please visit the [GitHub](https://github.com/PRITHIVSAKTHIUR/Qwen3-VL-Video-Grounding) page. | |
| """ | |
| ) | |
| tracking_state = gr.State(TrackingState()) | |
| point_state = gr.State(PointTrackingState()) | |
| with gr.Tabs() as main_tabs: | |
| with gr.Tab("Video Object Tracking"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Quick start** | |
| - **Load a video**: Upload your own or pick an example below. | |
| - Select a frame and enter text description(s) to detect objects (e.g., "red car", "person"). Multiple prompts separated by commas. | |
| - The text prompt detects all instances on the **selected frame only**. | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Working with results** | |
| - **Preview**: Use the slider to navigate frames and see the current masks/bboxes. | |
| - **Propagate**: Click "Propagate across video" to track all defined objects through every frame. | |
| - **Export**: Render an MP4 for smooth playback using the original video FPS. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_in_tracking = gr.Video(label="Upload video", sources=["upload", "webcam"]) | |
| load_status_tracking = gr.Markdown(visible=True) | |
| reset_btn_tracking = gr.Button("Reset Session", variant="secondary") | |
| with gr.Column(scale=2): | |
| preview_tracking = gr.Image(label="Preview") | |
| with gr.Row(): | |
| frame_slider_tracking = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0) | |
| with gr.Column(scale=0): | |
| propagate_btn_tracking = gr.Button("Propagate across video", variant="primary") | |
| propagate_status_tracking = gr.Markdown(visible=True) | |
| with gr.Row(): | |
| text_prompt_tracking = gr.Textbox( | |
| label="Text Prompt(s)", | |
| placeholder="Enter text description(s) (e.g., 'person' or 'person, car, dog' for multiple)", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=0): | |
| apply_btn_tracking = gr.Button("Apply Text Prompt(s)", variant="primary") | |
| reset_prompts_btn_tracking = gr.Button("Reset Prompts", variant="secondary") | |
| active_prompts_tracking = gr.Markdown("**Active prompts:** None", visible=True) | |
| text_status_tracking = gr.Markdown(visible=True) | |
| with gr.Row(): | |
| render_btn_tracking = gr.Button("Render MP4 for smooth playback", variant="primary") | |
| playback_video_tracking = gr.Video(label="Rendered Playback", interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/1.mp4"], | |
| ["examples/2.mp4"], | |
| ["examples/3.mp4"], | |
| ], | |
| inputs=[video_in_tracking], | |
| label="Examples" | |
| ) | |
| with gr.Tab("Video Points Tracker"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Quick start** | |
| - **Load a video**: Upload your own or pick an example below. | |
| - Select a frame and enter text description(s) (e.g., `person, ball`). | |
| - The model locates the **center point** of each detected object on the **selected frame only**. | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Working with results** | |
| - **Preview**: Use the slider to see detected points and motion trails. | |
| - **Propagate**: Click "Propagate across video" to track points through every frame. | |
| - **Export**: Render an MP4 with red dot tracking and trails. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_in_points = gr.Video(label="Upload video", sources=["upload", "webcam"]) | |
| load_status_points = gr.Markdown(visible=True) | |
| reset_btn_points = gr.Button("Reset Session", variant="secondary") | |
| with gr.Column(scale=2): | |
| preview_points = gr.Image(label="Preview") | |
| with gr.Row(): | |
| frame_slider_points = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0) | |
| with gr.Column(scale=0): | |
| propagate_btn_points = gr.Button("Propagate across video", variant="primary") | |
| propagate_status_points = gr.Markdown(visible=True) | |
| with gr.Row(): | |
| text_prompt_points = gr.Textbox( | |
| label="Text Prompt(s)", | |
| placeholder="Enter text description(s) (e.g., 'person' or 'person, ball' for multiple)", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=0): | |
| apply_btn_points = gr.Button("Apply Point Prompt(s)", variant="primary") | |
| reset_prompts_btn_points = gr.Button("Reset Prompts", variant="secondary") | |
| active_prompts_points = gr.Markdown("**Active prompts:** None", visible=True) | |
| text_status_points = gr.Markdown(visible=True) | |
| with gr.Row(): | |
| render_btn_points = gr.Button("Render MP4 for smooth playback", variant="primary") | |
| playback_video_points = gr.Video(label="Rendered Playback", interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/1.mp4"], | |
| ["examples/2.mp4"], | |
| ["examples/3.mp4"], | |
| ], | |
| inputs=[video_in_points], | |
| label="Examples" | |
| ) | |
| with gr.Tab("Image Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Image Object Detection** | |
| - Upload an image and enter what you want to detect. | |
| - The model returns bounding boxes around all matching objects. | |
| - Results are displayed as colored boxes with labels. | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Tips** | |
| - Be specific: "red car" works better than just "car" if you want a specific one. | |
| - You can detect multiple types: try "person", "headlight", "window", etc. | |
| - The JSON output shows normalized coordinates (0-1 range). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_det_input = gr.Image(type="pil", label="Upload Image", height=400) | |
| img_det_prompt = gr.Textbox( | |
| label="Detection Prompt", | |
| placeholder="e.g., headlight, person, red car, laptop", | |
| lines=2, | |
| ) | |
| img_det_btn = gr.Button("Detect Objects", variant="primary") | |
| with gr.Column(scale=1): | |
| img_det_output = gr.Image(label="Detection Result", height=400) | |
| img_det_text = gr.Textbox(label="Detection Output (JSON)", lines=10, interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["examples-images/5.jpg", "children"], | |
| ["examples-images/4.jpg", "headlight"], | |
| ["examples-images/3.jpg", "gun"], | |
| ["examples-images/1.jpg", "boat"], | |
| ], | |
| inputs=[img_det_input, img_det_prompt], | |
| label="Examples" | |
| ) | |
| with gr.Tab("Image Pointer"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Image Point Detection** | |
| - Upload an image and describe what you want to point to. | |
| - The model returns precise center-point coordinates for each matching object. | |
| - Results are displayed as red dots on the image. | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Tips** | |
| - Great for locating specific parts: "the gun held by the person", "nose of the dog". | |
| - Multiple instances are supported: all matching objects get a point. | |
| - The JSON output shows normalized coordinates (0-1 range). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_pt_input = gr.Image(type="pil", label="Upload Image", height=400) | |
| img_pt_prompt = gr.Textbox( | |
| label="Pointing Prompt", | |
| placeholder="e.g., the gun held by the person, nose of the dog", | |
| lines=2, | |
| ) | |
| img_pt_btn = gr.Button("Point to Objects", variant="primary") | |
| with gr.Column(scale=1): | |
| img_pt_output = gr.Image(label="Pointing Result", height=400) | |
| img_pt_text = gr.Textbox(label="Points Output (JSON)", lines=10, interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["examples-images/5.jpg", "children who are out of focus and wearing a white T-shirt"], | |
| ["examples-images/3.jpg", "gun"], | |
| ["examples-images/4.jpg", "headlight"], | |
| ["examples-images/1.jpg", "boat"], | |
| ], | |
| inputs=[img_pt_input, img_pt_prompt], | |
| label="Examples" | |
| ) | |
| video_in_tracking.change( | |
| fn=_on_video_change_tracking, | |
| inputs=[tracking_state, video_in_tracking], | |
| outputs=[tracking_state, frame_slider_tracking, preview_tracking, load_status_tracking, active_prompts_tracking], | |
| show_progress=True, | |
| ) | |
| def _sync_tracking_frame(state_in: TrackingState, idx: int) -> Image.Image: | |
| if state_in is not None: | |
| state_in.current_frame_idx = int(idx) | |
| return update_tracking_display(state_in, int(idx)) | |
| frame_slider_tracking.change( | |
| fn=_sync_tracking_frame, | |
| inputs=[tracking_state, frame_slider_tracking], | |
| outputs=preview_tracking, | |
| ) | |
| apply_btn_tracking.click( | |
| fn=apply_tracking_prompt_on_frame, | |
| inputs=[tracking_state, frame_slider_tracking, text_prompt_tracking], | |
| outputs=[preview_tracking, text_status_tracking, active_prompts_tracking, tracking_state], | |
| ) | |
| propagate_btn_tracking.click( | |
| fn=propagate_tracking, | |
| inputs=tracking_state, | |
| outputs=[tracking_state, propagate_status_tracking, frame_slider_tracking], | |
| ) | |
| reset_prompts_btn_tracking.click( | |
| fn=reset_tracking_prompts, | |
| inputs=tracking_state, | |
| outputs=[tracking_state, preview_tracking, text_status_tracking, active_prompts_tracking], | |
| ) | |
| reset_btn_tracking.click( | |
| fn=reset_tracking_session, | |
| inputs=tracking_state, | |
| outputs=[tracking_state, preview_tracking, frame_slider_tracking, frame_slider_tracking, load_status_tracking, active_prompts_tracking], | |
| ) | |
| render_btn_tracking.click( | |
| fn=render_tracking_video, | |
| inputs=tracking_state, | |
| outputs=playback_video_tracking, | |
| ) | |
| video_in_points.change( | |
| fn=_on_video_change_points, | |
| inputs=[point_state, video_in_points], | |
| outputs=[point_state, frame_slider_points, preview_points, load_status_points, active_prompts_points], | |
| show_progress=True, | |
| ) | |
| def _sync_point_frame(state_in: PointTrackingState, idx: int) -> Image.Image: | |
| if state_in is not None: | |
| state_in.current_frame_idx = int(idx) | |
| return update_point_display(state_in, int(idx)) | |
| frame_slider_points.change( | |
| fn=_sync_point_frame, | |
| inputs=[point_state, frame_slider_points], | |
| outputs=preview_points, | |
| ) | |
| apply_btn_points.click( | |
| fn=apply_point_prompt_on_frame, | |
| inputs=[point_state, frame_slider_points, text_prompt_points], | |
| outputs=[preview_points, text_status_points, active_prompts_points, point_state], | |
| ) | |
| propagate_btn_points.click( | |
| fn=propagate_points, | |
| inputs=point_state, | |
| outputs=[point_state, propagate_status_points, frame_slider_points], | |
| ) | |
| reset_prompts_btn_points.click( | |
| fn=reset_point_prompts, | |
| inputs=point_state, | |
| outputs=[point_state, preview_points, text_status_points, active_prompts_points], | |
| ) | |
| reset_btn_points.click( | |
| fn=reset_point_session, | |
| inputs=point_state, | |
| outputs=[point_state, preview_points, frame_slider_points, frame_slider_points, load_status_points, active_prompts_points], | |
| ) | |
| render_btn_points.click( | |
| fn=render_point_video, | |
| inputs=point_state, | |
| outputs=playback_video_points, | |
| ) | |
| img_det_btn.click( | |
| fn=process_image_detection, | |
| inputs=[img_det_input, img_det_prompt], | |
| outputs=[img_det_output, img_det_text], | |
| show_progress=True, | |
| ) | |
| img_pt_btn.click( | |
| fn=process_image_pointer, | |
| inputs=[img_pt_input, img_pt_prompt], | |
| outputs=[img_pt_output, img_pt_text], | |
| show_progress=True, | |
| ) | |
| demo.queue(api_open=False).launch(theme=Soft(primary_hue="orange", secondary_hue="rose"), css=css, ssr_mode=False) |