Reward-Forcing / videoalign /inference.py
fffiloni's picture
Migrated from GitHub
34c9227 verified
import ast
import json
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pdb
from collections.abc import Mapping
import pandas as pd
import torch
from videoalign.vision_process import process_vision_info, process_wanvideo_tensor
from videoalign.data import DataConfig
from videoalign.utils import ModelConfig, PEFTLoraConfig, TrainingConfig
from videoalign.utils import load_model_from_checkpoint
from videoalign.train_reward import create_model_and_processor
from videoalign.prompt_template import build_prompt
import numpy as np
from PIL import Image
import imageio
def load_configs_from_json(config_path):
with open(config_path, "r") as f:
config_dict = json.load(f)
# del config_dict["training_args"]["_n_gpu"]
del config_dict["data_config"]["meta_data"]
del config_dict["data_config"]["data_dir"]
return config_dict["data_config"], None, config_dict["model_config"], config_dict["peft_lora_config"], \
config_dict["inference_config"] if "inference_config" in config_dict else None
class VideoVLMRewardInference():
def __init__(self, load_from_pretrained, load_from_pretrained_step=-1, device='cuda', dtype=torch.bfloat16):
config_path = os.path.join(load_from_pretrained, "model_config.json")
data_config, _, model_config, peft_lora_config, inference_config = load_configs_from_json(config_path)
data_config = DataConfig(**data_config)
model_config = ModelConfig(**model_config)
peft_lora_config = PEFTLoraConfig(**peft_lora_config)
training_args = TrainingConfig(
load_from_pretrained=load_from_pretrained,
load_from_pretrained_step=load_from_pretrained_step,
gradient_checkpointing=False,
disable_flash_attn2=False,
bf16=True if dtype == torch.bfloat16 else False,
fp16=True if dtype == torch.float16 else False,
output_dir="",
)
model, processor, peft_config = create_model_and_processor(
model_config=model_config,
peft_lora_config=peft_lora_config,
training_args=training_args,
)
self.device = device
model, checkpoint_step = load_model_from_checkpoint(model, load_from_pretrained, load_from_pretrained_step)
# model.eval()
self.model = model
self.processor = processor
self.model.to(self.device)
self.data_config = data_config
self.inference_config = inference_config
def _norm(self, reward):
if self.inference_config is None:
return reward
else:
reward['VQ'] = (reward['VQ'] - self.inference_config['VQ_mean']) / self.inference_config['VQ_std']
reward['MQ'] = (reward['MQ'] - self.inference_config['MQ_mean']) / self.inference_config['MQ_std']
reward['TA'] = (reward['TA'] - self.inference_config['TA_mean']) / self.inference_config['TA_std']
return reward
def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'):
"""
Pad the sequences to the maximum length.
"""
assert padding_side in ['right', 'left']
if sequences.shape[1] >= max_len:
return sequences, attention_mask
pad_len = max_len - sequences.shape[1]
padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0)
sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id)
attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0)
return sequences_padded, attention_mask_padded
def _prepare_input(self, data):
"""
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.device}
## TODO: Maybe need to add dtype
# if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
# # NLP models inputs are int/uint and those get adjusted to the right dtype of the
# # embedding. Other models such as wav2vec2's inputs are already float and thus
# # may need special handling to match the dtypes of the model
# kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
return data.to(**kwargs)
return data
def _prepare_inputs(self, inputs):
"""
Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
inputs = self._prepare_input(inputs)
if len(inputs) == 0:
raise ValueError
return inputs
def prepare_batch_from_frames(self, video_tensors, prompts):
"""
直接从帧张量准备batch(跳过视频文件加载)
Args:
video_tensors: List[torch.Tensor], 每个张量形状为 [T, C, H, W]
prompts: List[str], 提示文本列表
"""
# 构建包含视频占位符的文本模板
chat_data = [
[
{
"role": "user",
"content": [
{"type": "video", "video": "file://dummy_path"}, # 占位符
{"type": "text", "text": build_prompt(prompt, self.data_config.eval_dim, self.data_config.prompt_template_type)},
],
}
] for prompt in prompts
]
processed_videos = []
for tensor in video_tensors:
# 对每个视频单独处理
processed = process_wanvideo_tensor(tensor)
processed_videos.append(processed)
video_tensors = processed_videos # 保持为列表
# print(video_tensors[0].shape)
# print(video_tensors[1].shape)
# print(video_tensors[2].shape)
batch = self.processor(
text=self.processor.apply_chat_template(chat_data, tokenize=False, add_generation_prompt=True),
images=None, # 没有图像输入
videos=video_tensors, # 直接传入帧张量
padding=True,
return_tensors="pt",
videos_kwargs={"do_rescale": True}, # 确保进行归一化
)
return self._prepare_inputs(batch)
def reward_from_frames(self, video_tensors, prompts, use_norm=True):
batch = self.prepare_batch_from_frames(video_tensors, prompts)
rewards = self.model(**batch, return_dict=True)["logits"] # [B, 3]
# 直接操作张量,不要转换为Python标量
if use_norm:
vq = (rewards[:, 0] - self.inference_config['VQ_mean']) / self.inference_config['VQ_std']
mq = (rewards[:, 1] - self.inference_config['MQ_mean']) / self.inference_config['MQ_std']
ta = (rewards[:, 2] - self.inference_config['TA_mean']) / self.inference_config['TA_std']
else:
vq = rewards[:, 0]
mq = rewards[:, 1]
ta = rewards[:, 2]
overall = vq + mq + ta
# 返回张量而不是Python字典
return {
'VQ': vq,
'MQ': mq,
'TA': ta,
'Overall': overall
}
def prepare_batch(self, video_paths, prompts, fps=None, num_frames=None, max_pixels=None,):
fps = self.data_config.fps if fps is None else fps
num_frames = self.data_config.num_frames if num_frames is None else num_frames
max_pixels = self.data_config.max_frame_pixels if max_pixels is None else max_pixels
if num_frames is None:
chat_data = [
[
{
"role": "user",
"content": [
{
"type": "video",
"video": f"file://{video_path}",
"max_pixels": max_pixels,
"fps": fps,
"sample_type": self.data_config.sample_type,
},
{"type": "text", "text": build_prompt(prompt, self.data_config.eval_dim, self.data_config.prompt_template_type)},
],
},
] for video_path, prompt in zip(video_paths, prompts)
]
else:
chat_data = [
[
{
"role": "user",
"content": [
{
"type": "video",
"video": f"file://{video_path}",
"max_pixels": max_pixels,
"nframes": num_frames,
"sample_type": self.data_config.sample_type,
},
{"type": "text", "text": build_prompt(prompt, self.data_config.eval_dim, self.data_config.prompt_template_type)},
],
},
] for video_path, prompt in zip(video_paths, prompts)
]
image_inputs, video_inputs = process_vision_info(chat_data)
batch = self.processor(
text=self.processor.apply_chat_template(chat_data, tokenize=False, add_generation_prompt=True),
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
videos_kwargs={"do_rescale": True},
)
batch = self._prepare_inputs(batch)
return batch
def reward(self, video_paths, prompts, fps=None, num_frames=None, max_pixels=None, use_norm=True):
"""
Inputs:
video_paths: List[str], B paths of the videos.
prompts: List[str], B prompts for the videos.
eval_dims: List[str], N evaluation dimensions.
fps: float, sample rate of the videos. If None, use the default value in the config.
num_frames: int, number of frames of the videos. If None, use the default value in the config.
max_pixels: int, maximum pixels of the videos. If None, use the default value in the config.
use_norm: bool, whether to rescale the output rewards
Outputs:
Rewards: List[dict], N + 1 rewards of the B videos.
"""
assert fps is None or num_frames is None, "fps and num_frames cannot be set at the same time."
batch = self.prepare_batch(video_paths, prompts, fps, num_frames, max_pixels)
rewards = self.model(
return_dict=True,
**batch
)["logits"]
rewards = [{'VQ': reward[0].item(), 'MQ': reward[1].item(), 'TA': reward[2].item()} for reward in rewards]
for i in range(len(rewards)):
if use_norm:
rewards[i] = self._norm(rewards[i])
rewards[i]['Overall'] = rewards[i]['VQ'] + rewards[i]['MQ'] + rewards[i]['TA']
return rewards