| | |
| | |
| | """ |
| | Inference example for Qwen2.5-VL TRASER model. |
| | Usage: |
| | python inference.py \ |
| | --model_path . \ |
| | --video_path /path/to/video.mp4 \ |
| | --mask_path /path/to/mask.json \ |
| | --structured_json_dir /path/to/struct_dir \ |
| | --out_dir ./output |
| | """ |
| |
|
| | import os |
| | import json |
| | import argparse |
| | import random |
| | import torch |
| | import numpy as np |
| | from transformers import AutoProcessor, AutoTokenizer |
| |
|
| | |
| | from modeling_traser import TRASER |
| |
|
| | |
| | from qwen_vl_vsg_utils.src.qwen_vl_utils import process_vision_info |
| | from resampler_utils.token_selection import select_tokens |
| | from resampler_utils.token_arrangement import rearrange_token |
| | from pycocotools import mask as maskUtils |
| | import math |
| | import torch.nn.functional as F |
| |
|
| | def set_seed(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | def load_mask_data(mask_json_path): |
| | with open(mask_json_path, "r") as f: |
| | return json.load(f) |
| |
|
| | def has_any_mask(mask_data, obj_id): |
| | for frame in mask_data: |
| | if not frame or obj_id >= len(frame): continue |
| | if frame[obj_id] and frame[obj_id].get("counts"): return True |
| | return False |
| |
|
| | def build_obj_masks_tensor(mask_data, obj_ids, sampled_idx, H_rz, W_rz, device): |
| | O, N = len(obj_ids), len(sampled_idx) |
| | obj_masks = torch.zeros((O, N, H_rz, W_rz), dtype=torch.float32, device=device) |
| | for o_i, oid in enumerate(obj_ids): |
| | for n_idx, fidx in enumerate(sampled_idx): |
| | if fidx < len(mask_data): |
| | frame_objs = mask_data[fidx] |
| | if frame_objs and oid < len(frame_objs): |
| | rle = frame_objs[oid] |
| | if rle: |
| | m = maskUtils.decode({"size": rle["size"], "counts": rle["counts"]}) |
| | if m.ndim == 3: m = m[:, :, 0] |
| | m_t = torch.from_numpy(m.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float().to(device) |
| | m_rz = F.interpolate(m_t, size=(H_rz, W_rz), mode="nearest")[0, 0] |
| | obj_masks[o_i, n_idx] = (m_rz > 0.5).float() |
| | |
| | keep_idx = (obj_masks.view(O, -1).sum(dim=1) > 0).nonzero(as_tuple=False).squeeze(1).tolist() |
| | if len(keep_idx) < O: obj_masks = obj_masks[keep_idx] |
| | return obj_masks, keep_idx |
| |
|
| | def run_single_video(model, processor, video_path, mask_path, out_dir, device, args): |
| | mask_data = load_mask_data(mask_path) |
| | all_ids = range(min(len(mask_data[0]),args.max_objects)) |
| | eligible = [oid for oid in all_ids if has_any_mask(mask_data, oid)] |
| | |
| | if len(eligible) > args.max_objects: |
| | random.shuffle(eligible) |
| | selected_obj_ids = sorted(eligible[:args.max_objects]) |
| | else: |
| | selected_obj_ids = sorted(eligible) |
| |
|
| | messages = [ |
| | {"role": "system", "content": "You are a helpful assistant."}, |
| | {"role": "user", "content": [ |
| | {"type": "text", "text": "Output the video Scene Graph from the video and object trajectories:\n"}, |
| | {"type": "video", "video": video_path} |
| | ]} |
| | ] |
| | |
| | prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | image_inputs, video_inputs, fps, selected_frame_idx = process_vision_info(messages, return_video_kwargs=True) |
| | |
| | proc_inputs = processor( |
| | text=[prompt_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", fps=1 |
| | ).to(device) |
| |
|
| | video_grid_thw = proc_inputs["video_grid_thw"] |
| | if isinstance(video_grid_thw, list): video_grid_thw = torch.stack([x.to(device) for x in video_grid_thw]) |
| | else: video_grid_thw = video_grid_thw.to(device) |
| | |
| | T_grid = int(video_grid_thw[0, 0].item()) |
| | H_patch, W_patch = int(video_grid_thw[0, 1].item()), int(video_grid_thw[0, 2].item()) |
| | |
| | |
| | patch_size = 14 |
| | H_rz, W_rz = H_patch * patch_size, W_patch * patch_size |
| |
|
| | |
| | sampled_idx = selected_frame_idx[0] |
| | obj_masks, keep_idx = build_obj_masks_tensor(mask_data, selected_obj_ids, sampled_idx, H_rz, W_rz, device) |
| | selected_obj_ids = [selected_obj_ids[i] for i in keep_idx] |
| |
|
| | |
| | per_union_idx, per_obj_idx, _ = select_tokens( |
| | obj_masks=obj_masks, |
| | grid_thw=(T_grid, H_patch, W_patch), |
| | patch_size=patch_size, |
| | device=device |
| | ) |
| |
|
| | |
| | per_obj_idx_batch = [per_obj_idx] |
| | |
| | |
| | text_token_ids_per_sample = [] |
| | label_template = "Object {i}: " |
| | additional_texts = [label_template.format(i=(k + 1)) for k in range(len(per_obj_idx))] |
| | enc = processor.tokenizer(additional_texts, add_special_tokens=False)["input_ids"] |
| | text_token_ids_per_sample.append([torch.tensor(x, dtype=torch.long) for x in enc]) |
| |
|
| | |
| | sec_per_window = torch.arange(0, T_grid) * 2.0 |
| | temporal_window_length = 4.0 |
| | grids_per_window = int(temporal_window_length / 2.0) |
| | |
| | timestamp_token_ids_per_batch = [] |
| | grids_per_window_batch = [] |
| | |
| | temporal_text_list = [] |
| | num_windows = math.ceil(len(sec_per_window) / grids_per_window) |
| | for w_id in range(num_windows): |
| | s, e = w_id * temporal_window_length, (w_id + 1) * temporal_window_length |
| | temporal_text_list.append(f"<{int(s)} - {int(e)} sec>") |
| | |
| | enc_ts = processor.tokenizer(temporal_text_list, add_special_tokens=False)["input_ids"] |
| | timestamp_token_ids_per_batch.append([torch.tensor(x) for x in enc_ts]) |
| | grids_per_window_batch.append(grids_per_window) |
| |
|
| | |
| | with torch.no_grad(): |
| | new_emb, new_pid, new_mask, rope_deltas, cache_pos, _, _ = rearrange_token( |
| | model=model, |
| | input_ids=proc_inputs["input_ids"], |
| | attention_mask=proc_inputs["attention_mask"], |
| | pixel_values_videos=proc_inputs["pixel_values_videos"], |
| | video_grid_thw=video_grid_thw, |
| | image_grid_thw=None, pixel_values=None, second_per_grid_ts=None, |
| | obj_token_indices_per_sample=per_obj_idx_batch, |
| | obj_traj_start_id=args.obj_traj_start_id, |
| | obj_traj_end_id=args.obj_traj_end_id, |
| | text_token_ids_per_sample=text_token_ids_per_sample, |
| | timestamp_token_ids_per_batch=timestamp_token_ids_per_batch, |
| | grids_per_temporal_window_per_batch=grids_per_window_batch, |
| | ) |
| | |
| | gen_out = model.generate( |
| | inputs_embeds=new_emb, |
| | position_ids=new_pid, |
| | attention_mask=new_mask.long(), |
| | rope_deltas=rope_deltas, |
| | max_new_tokens=8192, |
| | do_sample=True, |
| | top_p=0.9, |
| | temperature=1e-6, |
| | repetition_penalty=1.05 |
| | ) |
| |
|
| | decoded = processor.tokenizer.decode(gen_out[0], skip_special_tokens=True) |
| | print(f"Generated Output:\n{decoded}") |
| | |
| | if out_dir: |
| | with open(os.path.join(out_dir, "output.txt"), "w") as f: |
| | f.write(decoded) |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_path", type=str, required=True, help="Path to model or HF repo") |
| | parser.add_argument("--video_path", type=str, required=True) |
| | parser.add_argument("--mask_path", type=str, required=True) |
| | parser.add_argument("--out_dir", type=str, default="./output") |
| | parser.add_argument("--max_objects", type=int, default=40) |
| | parser.add_argument("--obj_traj_start_id", type=int, default=151665) |
| | parser.add_argument("--obj_traj_end_id", type=int, default=151666) |
| | args = parser.parse_args() |
| |
|
| | set_seed(42) |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | if args.out_dir: |
| | os.makedirs(args.out_dir, exist_ok=True) |
| | |
| | |
| | |
| | |
| | model = TRASER.from_pretrained(args.model_path, torch_dtype=torch.bfloat16).to(device) |
| | processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| | processor.tokenizer = tokenizer |
| |
|
| | run_single_video(model, processor, args.video_path, args.mask_path, args.out_dir, device, args) |
| |
|
| | if __name__ == "__main__": |
| | main() |