TRASER / inference.py
UWGZQ's picture
Upload folder using huggingface_hub
f72dd03 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Inference example for Qwen2.5-VL TRASER model.
Usage:
python inference.py \
--model_path . \
--video_path /path/to/video.mp4 \
--mask_path /path/to/mask.json \
--structured_json_dir /path/to/struct_dir \
--out_dir ./output
"""
import os
import json
import argparse
import random
import torch
import numpy as np
from transformers import AutoProcessor, AutoTokenizer
# Import Custom Model
from modeling_traser import TRASER
# Import Utils
from qwen_vl_vsg_utils.src.qwen_vl_utils import process_vision_info
from resampler_utils.token_selection import select_tokens
from resampler_utils.token_arrangement import rearrange_token
from pycocotools import mask as maskUtils
import math
import torch.nn.functional as F
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_mask_data(mask_json_path):
with open(mask_json_path, "r") as f:
return json.load(f)
def has_any_mask(mask_data, obj_id):
for frame in mask_data:
if not frame or obj_id >= len(frame): continue
if frame[obj_id] and frame[obj_id].get("counts"): return True
return False
def build_obj_masks_tensor(mask_data, obj_ids, sampled_idx, H_rz, W_rz, device):
O, N = len(obj_ids), len(sampled_idx)
obj_masks = torch.zeros((O, N, H_rz, W_rz), dtype=torch.float32, device=device)
for o_i, oid in enumerate(obj_ids):
for n_idx, fidx in enumerate(sampled_idx):
if fidx < len(mask_data):
frame_objs = mask_data[fidx]
if frame_objs and oid < len(frame_objs):
rle = frame_objs[oid]
if rle:
m = maskUtils.decode({"size": rle["size"], "counts": rle["counts"]})
if m.ndim == 3: m = m[:, :, 0]
m_t = torch.from_numpy(m.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float().to(device)
m_rz = F.interpolate(m_t, size=(H_rz, W_rz), mode="nearest")[0, 0]
obj_masks[o_i, n_idx] = (m_rz > 0.5).float()
keep_idx = (obj_masks.view(O, -1).sum(dim=1) > 0).nonzero(as_tuple=False).squeeze(1).tolist()
if len(keep_idx) < O: obj_masks = obj_masks[keep_idx]
return obj_masks, keep_idx
def run_single_video(model, processor, video_path, mask_path, out_dir, device, args):
mask_data = load_mask_data(mask_path)
all_ids = range(min(len(mask_data[0]),args.max_objects))
eligible = [oid for oid in all_ids if has_any_mask(mask_data, oid)]
if len(eligible) > args.max_objects:
random.shuffle(eligible)
selected_obj_ids = sorted(eligible[:args.max_objects])
else:
selected_obj_ids = sorted(eligible)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "text", "text": "Output the video Scene Graph from the video and object trajectories:\n"},
{"type": "video", "video": video_path}
]}
]
prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, fps, selected_frame_idx = process_vision_info(messages, return_video_kwargs=True)
proc_inputs = processor(
text=[prompt_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", fps=1
).to(device)
video_grid_thw = proc_inputs["video_grid_thw"]
if isinstance(video_grid_thw, list): video_grid_thw = torch.stack([x.to(device) for x in video_grid_thw])
else: video_grid_thw = video_grid_thw.to(device)
T_grid = int(video_grid_thw[0, 0].item())
H_patch, W_patch = int(video_grid_thw[0, 1].item()), int(video_grid_thw[0, 2].item())
# Calculate mask resize dimensions
patch_size = 14
H_rz, W_rz = H_patch * patch_size, W_patch * patch_size
# Build Masks
sampled_idx = selected_frame_idx[0]
obj_masks, keep_idx = build_obj_masks_tensor(mask_data, selected_obj_ids, sampled_idx, H_rz, W_rz, device)
selected_obj_ids = [selected_obj_ids[i] for i in keep_idx]
# Select Tokens
per_union_idx, per_obj_idx, _ = select_tokens(
obj_masks=obj_masks,
grid_thw=(T_grid, H_patch, W_patch),
patch_size=patch_size,
device=device
)
# Prepare Input
per_obj_idx_batch = [per_obj_idx]
# Prepare text labels
text_token_ids_per_sample = []
label_template = "Object {i}: "
additional_texts = [label_template.format(i=(k + 1)) for k in range(len(per_obj_idx))]
enc = processor.tokenizer(additional_texts, add_special_tokens=False)["input_ids"]
text_token_ids_per_sample.append([torch.tensor(x, dtype=torch.long) for x in enc])
# Prepare timestamps
sec_per_window = torch.arange(0, T_grid) * 2.0
temporal_window_length = 4.0
grids_per_window = int(temporal_window_length / 2.0)
timestamp_token_ids_per_batch = []
grids_per_window_batch = []
temporal_text_list = []
num_windows = math.ceil(len(sec_per_window) / grids_per_window)
for w_id in range(num_windows):
s, e = w_id * temporal_window_length, (w_id + 1) * temporal_window_length
temporal_text_list.append(f"<{int(s)} - {int(e)} sec>")
enc_ts = processor.tokenizer(temporal_text_list, add_special_tokens=False)["input_ids"]
timestamp_token_ids_per_batch.append([torch.tensor(x) for x in enc_ts])
grids_per_window_batch.append(grids_per_window)
# Rearrange and Generate
with torch.no_grad():
new_emb, new_pid, new_mask, rope_deltas, cache_pos, _, _ = rearrange_token(
model=model,
input_ids=proc_inputs["input_ids"],
attention_mask=proc_inputs["attention_mask"],
pixel_values_videos=proc_inputs["pixel_values_videos"],
video_grid_thw=video_grid_thw,
image_grid_thw=None, pixel_values=None, second_per_grid_ts=None,
obj_token_indices_per_sample=per_obj_idx_batch,
obj_traj_start_id=args.obj_traj_start_id,
obj_traj_end_id=args.obj_traj_end_id,
text_token_ids_per_sample=text_token_ids_per_sample,
timestamp_token_ids_per_batch=timestamp_token_ids_per_batch,
grids_per_temporal_window_per_batch=grids_per_window_batch,
)
gen_out = model.generate(
inputs_embeds=new_emb,
position_ids=new_pid,
attention_mask=new_mask.long(),
rope_deltas=rope_deltas,
max_new_tokens=8192,
do_sample=True,
top_p=0.9,
temperature=1e-6,
repetition_penalty=1.05
)
decoded = processor.tokenizer.decode(gen_out[0], skip_special_tokens=True)
print(f"Generated Output:\n{decoded}")
if out_dir:
with open(os.path.join(out_dir, "output.txt"), "w") as f:
f.write(decoded)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path to model or HF repo")
parser.add_argument("--video_path", type=str, required=True)
parser.add_argument("--mask_path", type=str, required=True)
parser.add_argument("--out_dir", type=str, default="./output")
parser.add_argument("--max_objects", type=int, default=40)
parser.add_argument("--obj_traj_start_id", type=int, default=151665)
parser.add_argument("--obj_traj_end_id", type=int, default=151666)
args = parser.parse_args()
set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)
# Load Model (Using the separate class)
# Note: If trust_remote_code=True works, you can use AutoModel.
# For this example, we explicit load TRASER to ensure it works with local weights.
model = TRASER.from_pretrained(args.model_path, torch_dtype=torch.bfloat16).to(device)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
processor.tokenizer = tokenizer
run_single_video(model, processor, args.video_path, args.mask_path, args.out_dir, device, args)
if __name__ == "__main__":
main()