File size: 8,411 Bytes
5dbdc31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Inference example for Qwen2.5-VL TRASER model.
Usage:
python inference.py \
--model_path . \
--video_path /path/to/video.mp4 \
--mask_path /path/to/mask.json \
--structured_json_dir /path/to/struct_dir \
--out_dir ./output
"""
import os
import json
import argparse
import random
import torch
import numpy as np
from transformers import AutoProcessor, AutoTokenizer
# Import Custom Model
from modeling_traser import TRASER
# Import Utils
from qwen_vl_vsg_utils.src.qwen_vl_utils import process_vision_info
from resampler_utils.token_selection import select_tokens
from resampler_utils.token_arrangement import rearrange_token
from pycocotools import mask as maskUtils
import math
import torch.nn.functional as F
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_mask_data(mask_json_path):
with open(mask_json_path, "r") as f:
return json.load(f)
def has_any_mask(mask_data, obj_id):
for frame in mask_data:
if not frame or obj_id >= len(frame): continue
if frame[obj_id] and frame[obj_id].get("counts"): return True
return False
def build_obj_masks_tensor(mask_data, obj_ids, sampled_idx, H_rz, W_rz, device):
O, N = len(obj_ids), len(sampled_idx)
obj_masks = torch.zeros((O, N, H_rz, W_rz), dtype=torch.float32, device=device)
for o_i, oid in enumerate(obj_ids):
for n_idx, fidx in enumerate(sampled_idx):
if fidx < len(mask_data):
frame_objs = mask_data[fidx]
if frame_objs and oid < len(frame_objs):
rle = frame_objs[oid]
if rle:
m = maskUtils.decode({"size": rle["size"], "counts": rle["counts"]})
if m.ndim == 3: m = m[:, :, 0]
m_t = torch.from_numpy(m.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float().to(device)
m_rz = F.interpolate(m_t, size=(H_rz, W_rz), mode="nearest")[0, 0]
obj_masks[o_i, n_idx] = (m_rz > 0.5).float()
keep_idx = (obj_masks.view(O, -1).sum(dim=1) > 0).nonzero(as_tuple=False).squeeze(1).tolist()
if len(keep_idx) < O: obj_masks = obj_masks[keep_idx]
return obj_masks, keep_idx
def run_single_video(model, processor, video_path, mask_path, out_dir, device, args):
mask_data = load_mask_data(mask_path)
all_ids = range(min(len(mask_data[0]),args.max_objects))
eligible = [oid for oid in all_ids if has_any_mask(mask_data, oid)]
if len(eligible) > args.max_objects:
random.shuffle(eligible)
selected_obj_ids = sorted(eligible[:args.max_objects])
else:
selected_obj_ids = sorted(eligible)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "text", "text": "Output the video Scene Graph from the video and object trajectories:\n"},
{"type": "video", "video": video_path}
]}
]
prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, fps, selected_frame_idx = process_vision_info(messages, return_video_kwargs=True)
proc_inputs = processor(
text=[prompt_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", fps=1
).to(device)
video_grid_thw = proc_inputs["video_grid_thw"]
if isinstance(video_grid_thw, list): video_grid_thw = torch.stack([x.to(device) for x in video_grid_thw])
else: video_grid_thw = video_grid_thw.to(device)
T_grid = int(video_grid_thw[0, 0].item())
H_patch, W_patch = int(video_grid_thw[0, 1].item()), int(video_grid_thw[0, 2].item())
# Calculate mask resize dimensions
patch_size = 14
H_rz, W_rz = H_patch * patch_size, W_patch * patch_size
# Build Masks
sampled_idx = selected_frame_idx[0]
obj_masks, keep_idx = build_obj_masks_tensor(mask_data, selected_obj_ids, sampled_idx, H_rz, W_rz, device)
selected_obj_ids = [selected_obj_ids[i] for i in keep_idx]
# Select Tokens
per_union_idx, per_obj_idx, _ = select_tokens(
obj_masks=obj_masks,
grid_thw=(T_grid, H_patch, W_patch),
patch_size=patch_size,
device=device
)
# Prepare Input
per_obj_idx_batch = [per_obj_idx]
# Prepare text labels
text_token_ids_per_sample = []
label_template = "Object {i}: "
additional_texts = [label_template.format(i=(k + 1)) for k in range(len(per_obj_idx))]
enc = processor.tokenizer(additional_texts, add_special_tokens=False)["input_ids"]
text_token_ids_per_sample.append([torch.tensor(x, dtype=torch.long) for x in enc])
# Prepare timestamps
sec_per_window = torch.arange(0, T_grid) * 2.0
temporal_window_length = 4.0
grids_per_window = int(temporal_window_length / 2.0)
timestamp_token_ids_per_batch = []
grids_per_window_batch = []
temporal_text_list = []
num_windows = math.ceil(len(sec_per_window) / grids_per_window)
for w_id in range(num_windows):
s, e = w_id * temporal_window_length, (w_id + 1) * temporal_window_length
temporal_text_list.append(f"<{int(s)} - {int(e)} sec>")
enc_ts = processor.tokenizer(temporal_text_list, add_special_tokens=False)["input_ids"]
timestamp_token_ids_per_batch.append([torch.tensor(x) for x in enc_ts])
grids_per_window_batch.append(grids_per_window)
# Rearrange and Generate
with torch.no_grad():
new_emb, new_pid, new_mask, rope_deltas, cache_pos, _, _ = rearrange_token(
model=model,
input_ids=proc_inputs["input_ids"],
attention_mask=proc_inputs["attention_mask"],
pixel_values_videos=proc_inputs["pixel_values_videos"],
video_grid_thw=video_grid_thw,
image_grid_thw=None, pixel_values=None, second_per_grid_ts=None,
obj_token_indices_per_sample=per_obj_idx_batch,
obj_traj_start_id=args.obj_traj_start_id,
obj_traj_end_id=args.obj_traj_end_id,
text_token_ids_per_sample=text_token_ids_per_sample,
timestamp_token_ids_per_batch=timestamp_token_ids_per_batch,
grids_per_temporal_window_per_batch=grids_per_window_batch,
)
gen_out = model.generate(
inputs_embeds=new_emb,
position_ids=new_pid,
attention_mask=new_mask.long(),
rope_deltas=rope_deltas,
max_new_tokens=8192,
do_sample=True,
top_p=0.9,
temperature=1e-6,
repetition_penalty=1.05
)
decoded = processor.tokenizer.decode(gen_out[0], skip_special_tokens=True)
print(f"Generated Output:\n{decoded}")
if out_dir:
with open(os.path.join(out_dir, "output.txt"), "w") as f:
f.write(decoded)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path to model or HF repo")
parser.add_argument("--video_path", type=str, required=True)
parser.add_argument("--mask_path", type=str, required=True)
parser.add_argument("--out_dir", type=str, default="./output")
parser.add_argument("--max_objects", type=int, default=40)
parser.add_argument("--obj_traj_start_id", type=int, default=151665)
parser.add_argument("--obj_traj_end_id", type=int, default=151666)
args = parser.parse_args()
set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)
# Load Model (Using the separate class)
# Note: If trust_remote_code=True works, you can use AutoModel.
# For this example, we explicit load TRASER to ensure it works with local weights.
model = TRASER.from_pretrained(args.model_path, torch_dtype=torch.bfloat16).to(device)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
processor.tokenizer = tokenizer
run_single_video(model, processor, args.video_path, args.mask_path, args.out_dir, device, args)
if __name__ == "__main__":
main() |