video_llm_template / task /eval /qwen2_vl_qa_eval.py
RoadQAQ's picture
Upload folder using huggingface_hub
710b71f verified
import functools
import itertools
import multiprocessing as mp
import torch
from dataset.eval import load_eval_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from decord import VideoReader, cpu
from util.conversation import ConversationTemplates
from util.json import load_from_json
from argparse import ArgumentParser
from tqdm import tqdm
import copy
class PromptDataset(Dataset):
"""
Custom dataset class that loads data and generates the corresponding prompt.
"""
message_template = [
{
"role": "user",
"content": [
{
"type": "video",
"video": None,
},
{"type": "text", "text": None},
],
}
]
def __init__(self, dataset, processor, prefix = ""):
"""
Initializes the dataset.
Args:
dataset (Dataset): The original dataset (e.g., TGIFQADataset).
conv_mode (str): The conversation template mode for generating prompts.
"""
self.dataset = dataset
self.processor = processor
self.prefix = prefix
def __len__(self):
"""
Returns the length of the dataset.
"""
return len(self.dataset)
def __getitem__(self, idx):
"""
Retrieves a single sample, generates a prompt, and processes it into the appropriate format for model input.
Args:
idx (int): Index of the sample.
Returns:
dict: Contains video frames, prompt, question, and answer.
"""
sample = self.dataset[idx]
question = sample['question']
# add prefix
question = self.prefix + question
# Preparation for inference
now_message = copy.deepcopy(self.message_template)
now_message[0]["content"][1]["text"] = question
prompt = self.processor.apply_chat_template(
now_message, tokenize=False, add_generation_prompt=True
)
return {
'video_pils': sample['video_pils'], # Video frames
'prompt': prompt,
'question': question,
'answer': sample['answer']
}
def load_qwen2_vl_model_and_processor(model_path, device='cpu'):
"""
Loads LlavaNextVideo model and processor from the specified path.
Args:
model_path (str): Path to the pre-trained model.
device (str): Device to load the model (default: 'cpu').
Returns:
tuple: LlavaNextVideo model and processor.
"""
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).to(device)
# default processer
processor = AutoProcessor.from_pretrained(model_path)
return model, processor
def load_config():
"""
Parses user arguments and loads configuration from a JSON file.
Returns:
dict: The configuration arguments.
"""
parser = ArgumentParser()
parser.add_argument('config_path', type=str, help='Path to the configuration file')
# Parse command-line arguments
input_args = parser.parse_args()
config_path = input_args.config_path
# Load JSON configuration
json_args = load_from_json(config_path)
input_args = vars(input_args)
input_args.pop('config_path')
# Update JSON args with command-line arguments
json_args.update(input_args)
print(json_args)
return json_args
def load_prompt_dataset(dataset_args, processor, prefix = ""):
"""
Loads a custom TGIF QA dataset with prompts.
Args:
dataset_path (str): Path to the dataset.
conv_mode (str): Conversation mode for generating prompts.
Returns:
PromptDataset: A dataset object with prompts.
"""
original_dataset = load_eval_dataset(dataset_args)
return PromptDataset(original_dataset, processor, prefix)
def qwen2_vl_answer(model, processor, batch, device):
"""
Performs batch inference using LlavaNextVideo model.
Args:
model (LlavaNextVideoForConditionalGeneration): The video model.
processor (LlavaNextVideoProcessor): The processor for formatting inputs.
batch (list): A batch of samples from the dataset.
device (str): The device to run inference on.
Returns:
list: The decoded answers from the model's output.
"""
prompts = [item['prompt'] for item in batch]
videos = [item['video_pils'] for item in batch]
inputs = processor(
text=prompts,
images=None,
videos=videos,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False).cpu()
# only obtain generated text
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
def batch_run(rank, args, world_size, dataset):
"""
Executes batch inference on a specified GPU device.
Args:
rank (int): The GPU rank for multi-process execution.
args (dict): The configuration arguments.
world_size (int): Total number of GPUs.
dataset (Dataset): The dataset for inference.
Returns:
list: Inference results for the current batch.
"""
start_idx, end_idx = dataset.dataset.set_rank_and_world_size(rank, world_size)
# Split dataset for this GPU rank
splited_dataset = []
for i in range(start_idx, end_idx):
try:
sample = dataset[i]
splited_dataset.append(sample)
except AssertionError as e:
print(str(e))
if "No video named" in str(e):
continue
print(f"{rank} dataset loading finally")
# Load model and processor
model, processor = load_qwen2_vl_model_and_processor(args["model"]["model_path"], rank)
# Batch inference
batch_responses = []
batch_size = args['batch_size']
for i in range(0, len(splited_dataset), batch_size):
batch_samples = splited_dataset[i:i+batch_size]
try:
llm_messages = qwen2_vl_answer(model, processor, batch_samples, device=rank)
except RuntimeError as e:
print(f"An unexpected error occurred: {e}")
continue
# Process inference results
for idx, message in enumerate(llm_messages):
response = message
batch_responses.append({
'question': batch_samples[idx]['question'],
'answer': batch_samples[idx]['answer'],
'pred': response
})
print(f"{rank} process has processed {i}/{len(splited_dataset)}")
return batch_responses
def main(args):
"""
Main function for multi-process batch inference on qa bench.
Args:
args (dict): Configuration arguments.
"""
mp.set_start_method('spawn') # Set multi-process start method
processor = AutoProcessor.from_pretrained(args["model"]["model_path"])
prefix = args.get("prefix","")
dataset = load_prompt_dataset(args['dataset'], processor, prefix) # Load dataset with prompts
del processor
print("--------------------------------------loading dataset finally--------------------------------------")
n_gpus = torch.cuda.device_count()
world_size = n_gpus
# Multi-process inference
with mp.Pool(world_size) as pool:
func = functools.partial(batch_run, args=args, world_size=world_size, dataset=dataset)
result_lists = pool.map(func, range(world_size))
pool.close()
pool.join()
result_list = list(itertools.chain(*result_lists)) # Combine results from all processes
from task.eval.util import save_results
import json
description = args['description'] + '\n\n\nArgs:' + json.dumps(args)
save_results(result_list, args['save_path'], args['experiment_name'], description=description) # Save results
if __name__ == "__main__":
args = load_config()
main(args)