#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Inference example for Qwen2.5-VL TRASER model. Usage: python inference.py \ --model_path . \ --video_path /path/to/video.mp4 \ --mask_path /path/to/mask.json \ --structured_json_dir /path/to/struct_dir \ --out_dir ./output """ import os import json import argparse import random import torch import numpy as np from transformers import AutoProcessor, AutoTokenizer # Import Custom Model from modeling_traser import TRASER # Import Utils from qwen_vl_vsg_utils.src.qwen_vl_utils import process_vision_info from resampler_utils.token_selection import select_tokens from resampler_utils.token_arrangement import rearrange_token from pycocotools import mask as maskUtils import math import torch.nn.functional as F def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def load_mask_data(mask_json_path): with open(mask_json_path, "r") as f: return json.load(f) def has_any_mask(mask_data, obj_id): for frame in mask_data: if not frame or obj_id >= len(frame): continue if frame[obj_id] and frame[obj_id].get("counts"): return True return False def build_obj_masks_tensor(mask_data, obj_ids, sampled_idx, H_rz, W_rz, device): O, N = len(obj_ids), len(sampled_idx) obj_masks = torch.zeros((O, N, H_rz, W_rz), dtype=torch.float32, device=device) for o_i, oid in enumerate(obj_ids): for n_idx, fidx in enumerate(sampled_idx): if fidx < len(mask_data): frame_objs = mask_data[fidx] if frame_objs and oid < len(frame_objs): rle = frame_objs[oid] if rle: m = maskUtils.decode({"size": rle["size"], "counts": rle["counts"]}) if m.ndim == 3: m = m[:, :, 0] m_t = torch.from_numpy(m.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float().to(device) m_rz = F.interpolate(m_t, size=(H_rz, W_rz), mode="nearest")[0, 0] obj_masks[o_i, n_idx] = (m_rz > 0.5).float() keep_idx = (obj_masks.view(O, -1).sum(dim=1) > 0).nonzero(as_tuple=False).squeeze(1).tolist() if len(keep_idx) < O: obj_masks = obj_masks[keep_idx] return obj_masks, keep_idx def run_single_video(model, processor, video_path, mask_path, out_dir, device, args): mask_data = load_mask_data(mask_path) all_ids = range(min(len(mask_data[0]),args.max_objects)) eligible = [oid for oid in all_ids if has_any_mask(mask_data, oid)] if len(eligible) > args.max_objects: random.shuffle(eligible) selected_obj_ids = sorted(eligible[:args.max_objects]) else: selected_obj_ids = sorted(eligible) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [ {"type": "text", "text": "Output the video Scene Graph from the video and object trajectories:\n"}, {"type": "video", "video": video_path} ]} ] prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs, fps, selected_frame_idx = process_vision_info(messages, return_video_kwargs=True) proc_inputs = processor( text=[prompt_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", fps=1 ).to(device) video_grid_thw = proc_inputs["video_grid_thw"] if isinstance(video_grid_thw, list): video_grid_thw = torch.stack([x.to(device) for x in video_grid_thw]) else: video_grid_thw = video_grid_thw.to(device) T_grid = int(video_grid_thw[0, 0].item()) H_patch, W_patch = int(video_grid_thw[0, 1].item()), int(video_grid_thw[0, 2].item()) # Calculate mask resize dimensions patch_size = 14 H_rz, W_rz = H_patch * patch_size, W_patch * patch_size # Build Masks sampled_idx = selected_frame_idx[0] obj_masks, keep_idx = build_obj_masks_tensor(mask_data, selected_obj_ids, sampled_idx, H_rz, W_rz, device) selected_obj_ids = [selected_obj_ids[i] for i in keep_idx] # Select Tokens per_union_idx, per_obj_idx, _ = select_tokens( obj_masks=obj_masks, grid_thw=(T_grid, H_patch, W_patch), patch_size=patch_size, device=device ) # Prepare Input per_obj_idx_batch = [per_obj_idx] # Prepare text labels text_token_ids_per_sample = [] label_template = "Object {i}: " additional_texts = [label_template.format(i=(k + 1)) for k in range(len(per_obj_idx))] enc = processor.tokenizer(additional_texts, add_special_tokens=False)["input_ids"] text_token_ids_per_sample.append([torch.tensor(x, dtype=torch.long) for x in enc]) # Prepare timestamps sec_per_window = torch.arange(0, T_grid) * 2.0 temporal_window_length = 4.0 grids_per_window = int(temporal_window_length / 2.0) timestamp_token_ids_per_batch = [] grids_per_window_batch = [] temporal_text_list = [] num_windows = math.ceil(len(sec_per_window) / grids_per_window) for w_id in range(num_windows): s, e = w_id * temporal_window_length, (w_id + 1) * temporal_window_length temporal_text_list.append(f"<{int(s)} - {int(e)} sec>") enc_ts = processor.tokenizer(temporal_text_list, add_special_tokens=False)["input_ids"] timestamp_token_ids_per_batch.append([torch.tensor(x) for x in enc_ts]) grids_per_window_batch.append(grids_per_window) # Rearrange and Generate with torch.no_grad(): new_emb, new_pid, new_mask, rope_deltas, cache_pos, _, _ = rearrange_token( model=model, input_ids=proc_inputs["input_ids"], attention_mask=proc_inputs["attention_mask"], pixel_values_videos=proc_inputs["pixel_values_videos"], video_grid_thw=video_grid_thw, image_grid_thw=None, pixel_values=None, second_per_grid_ts=None, obj_token_indices_per_sample=per_obj_idx_batch, obj_traj_start_id=args.obj_traj_start_id, obj_traj_end_id=args.obj_traj_end_id, text_token_ids_per_sample=text_token_ids_per_sample, timestamp_token_ids_per_batch=timestamp_token_ids_per_batch, grids_per_temporal_window_per_batch=grids_per_window_batch, ) gen_out = model.generate( inputs_embeds=new_emb, position_ids=new_pid, attention_mask=new_mask.long(), rope_deltas=rope_deltas, max_new_tokens=8192, do_sample=True, top_p=0.9, temperature=1e-6, repetition_penalty=1.05 ) decoded = processor.tokenizer.decode(gen_out[0], skip_special_tokens=True) print(f"Generated Output:\n{decoded}") if out_dir: with open(os.path.join(out_dir, "output.txt"), "w") as f: f.write(decoded) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, required=True, help="Path to model or HF repo") parser.add_argument("--video_path", type=str, required=True) parser.add_argument("--mask_path", type=str, required=True) parser.add_argument("--out_dir", type=str, default="./output") parser.add_argument("--max_objects", type=int, default=40) parser.add_argument("--obj_traj_start_id", type=int, default=151665) parser.add_argument("--obj_traj_end_id", type=int, default=151666) args = parser.parse_args() set_seed(42) device = "cuda" if torch.cuda.is_available() else "cpu" if args.out_dir: os.makedirs(args.out_dir, exist_ok=True) # Load Model (Using the separate class) # Note: If trust_remote_code=True works, you can use AutoModel. # For this example, we explicit load TRASER to ensure it works with local weights. model = TRASER.from_pretrained(args.model_path, torch_dtype=torch.bfloat16).to(device) processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") tokenizer = AutoTokenizer.from_pretrained(args.model_path) processor.tokenizer = tokenizer run_single_video(model, processor, args.video_path, args.mask_path, args.out_dir, device, args) if __name__ == "__main__": main()