|
|
import functools |
|
|
import itertools |
|
|
import multiprocessing as mp |
|
|
import torch |
|
|
from dataset.eval import load_eval_dataset |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
from decord import VideoReader, cpu |
|
|
from util.conversation import ConversationTemplates |
|
|
from util.json import load_from_json |
|
|
from argparse import ArgumentParser |
|
|
from tqdm import tqdm |
|
|
import copy |
|
|
|
|
|
class PromptDataset(Dataset): |
|
|
""" |
|
|
Custom dataset class that loads data and generates the corresponding prompt. |
|
|
""" |
|
|
message_template = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "video", |
|
|
"video": None, |
|
|
}, |
|
|
{"type": "text", "text": None}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
def __init__(self, dataset, processor, prefix = ""): |
|
|
""" |
|
|
Initializes the dataset. |
|
|
|
|
|
Args: |
|
|
dataset (Dataset): The original dataset (e.g., TGIFQADataset). |
|
|
conv_mode (str): The conversation template mode for generating prompts. |
|
|
""" |
|
|
self.dataset = dataset |
|
|
self.processor = processor |
|
|
self.prefix = prefix |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
Returns the length of the dataset. |
|
|
""" |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Retrieves a single sample, generates a prompt, and processes it into the appropriate format for model input. |
|
|
|
|
|
Args: |
|
|
idx (int): Index of the sample. |
|
|
|
|
|
Returns: |
|
|
dict: Contains video frames, prompt, question, and answer. |
|
|
""" |
|
|
sample = self.dataset[idx] |
|
|
question = sample['question'] |
|
|
|
|
|
|
|
|
question = self.prefix + question |
|
|
|
|
|
|
|
|
now_message = copy.deepcopy(self.message_template) |
|
|
now_message[0]["content"][1]["text"] = question |
|
|
prompt = self.processor.apply_chat_template( |
|
|
now_message, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
return { |
|
|
'video_pils': sample['video_pils'], |
|
|
'prompt': prompt, |
|
|
'question': question, |
|
|
'answer': sample['answer'] |
|
|
} |
|
|
|
|
|
def load_qwen2_vl_model_and_processor(model_path, device='cpu'): |
|
|
""" |
|
|
Loads LlavaNextVideo model and processor from the specified path. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the pre-trained model. |
|
|
device (str): Device to load the model (default: 'cpu'). |
|
|
|
|
|
Returns: |
|
|
tuple: LlavaNextVideo model and processor. |
|
|
""" |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2" |
|
|
).to(device) |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
return model, processor |
|
|
|
|
|
def load_config(): |
|
|
""" |
|
|
Parses user arguments and loads configuration from a JSON file. |
|
|
|
|
|
Returns: |
|
|
dict: The configuration arguments. |
|
|
""" |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument('config_path', type=str, help='Path to the configuration file') |
|
|
|
|
|
|
|
|
input_args = parser.parse_args() |
|
|
config_path = input_args.config_path |
|
|
|
|
|
|
|
|
json_args = load_from_json(config_path) |
|
|
input_args = vars(input_args) |
|
|
input_args.pop('config_path') |
|
|
|
|
|
|
|
|
json_args.update(input_args) |
|
|
print(json_args) |
|
|
return json_args |
|
|
|
|
|
def load_prompt_dataset(dataset_args, processor, prefix = ""): |
|
|
""" |
|
|
Loads a custom TGIF QA dataset with prompts. |
|
|
|
|
|
Args: |
|
|
dataset_path (str): Path to the dataset. |
|
|
conv_mode (str): Conversation mode for generating prompts. |
|
|
|
|
|
Returns: |
|
|
PromptDataset: A dataset object with prompts. |
|
|
""" |
|
|
original_dataset = load_eval_dataset(dataset_args) |
|
|
return PromptDataset(original_dataset, processor, prefix) |
|
|
|
|
|
def qwen2_vl_answer(model, processor, batch, device): |
|
|
""" |
|
|
Performs batch inference using LlavaNextVideo model. |
|
|
|
|
|
Args: |
|
|
model (LlavaNextVideoForConditionalGeneration): The video model. |
|
|
processor (LlavaNextVideoProcessor): The processor for formatting inputs. |
|
|
batch (list): A batch of samples from the dataset. |
|
|
device (str): The device to run inference on. |
|
|
|
|
|
Returns: |
|
|
list: The decoded answers from the model's output. |
|
|
""" |
|
|
prompts = [item['prompt'] for item in batch] |
|
|
videos = [item['video_pils'] for item in batch] |
|
|
|
|
|
inputs = processor( |
|
|
text=prompts, |
|
|
images=None, |
|
|
videos=videos, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = inputs.to(device) |
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False).cpu() |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
|
|
|
def batch_run(rank, args, world_size, dataset): |
|
|
""" |
|
|
Executes batch inference on a specified GPU device. |
|
|
|
|
|
Args: |
|
|
rank (int): The GPU rank for multi-process execution. |
|
|
args (dict): The configuration arguments. |
|
|
world_size (int): Total number of GPUs. |
|
|
dataset (Dataset): The dataset for inference. |
|
|
|
|
|
Returns: |
|
|
list: Inference results for the current batch. |
|
|
""" |
|
|
start_idx, end_idx = dataset.dataset.set_rank_and_world_size(rank, world_size) |
|
|
|
|
|
splited_dataset = [] |
|
|
for i in range(start_idx, end_idx): |
|
|
try: |
|
|
sample = dataset[i] |
|
|
splited_dataset.append(sample) |
|
|
except AssertionError as e: |
|
|
print(str(e)) |
|
|
if "No video named" in str(e): |
|
|
continue |
|
|
print(f"{rank} dataset loading finally") |
|
|
|
|
|
|
|
|
model, processor = load_qwen2_vl_model_and_processor(args["model"]["model_path"], rank) |
|
|
|
|
|
|
|
|
batch_responses = [] |
|
|
batch_size = args['batch_size'] |
|
|
for i in range(0, len(splited_dataset), batch_size): |
|
|
batch_samples = splited_dataset[i:i+batch_size] |
|
|
try: |
|
|
llm_messages = qwen2_vl_answer(model, processor, batch_samples, device=rank) |
|
|
except RuntimeError as e: |
|
|
print(f"An unexpected error occurred: {e}") |
|
|
continue |
|
|
|
|
|
for idx, message in enumerate(llm_messages): |
|
|
response = message |
|
|
batch_responses.append({ |
|
|
'question': batch_samples[idx]['question'], |
|
|
'answer': batch_samples[idx]['answer'], |
|
|
'pred': response |
|
|
}) |
|
|
|
|
|
print(f"{rank} process has processed {i}/{len(splited_dataset)}") |
|
|
|
|
|
return batch_responses |
|
|
|
|
|
def main(args): |
|
|
""" |
|
|
Main function for multi-process batch inference on qa bench. |
|
|
|
|
|
Args: |
|
|
args (dict): Configuration arguments. |
|
|
""" |
|
|
mp.set_start_method('spawn') |
|
|
processor = AutoProcessor.from_pretrained(args["model"]["model_path"]) |
|
|
prefix = args.get("prefix","") |
|
|
dataset = load_prompt_dataset(args['dataset'], processor, prefix) |
|
|
del processor |
|
|
print("--------------------------------------loading dataset finally--------------------------------------") |
|
|
|
|
|
n_gpus = torch.cuda.device_count() |
|
|
world_size = n_gpus |
|
|
|
|
|
|
|
|
with mp.Pool(world_size) as pool: |
|
|
func = functools.partial(batch_run, args=args, world_size=world_size, dataset=dataset) |
|
|
result_lists = pool.map(func, range(world_size)) |
|
|
pool.close() |
|
|
pool.join() |
|
|
|
|
|
result_list = list(itertools.chain(*result_lists)) |
|
|
|
|
|
from task.eval.util import save_results |
|
|
import json |
|
|
description = args['description'] + '\n\n\nArgs:' + json.dumps(args) |
|
|
save_results(result_list, args['save_path'], args['experiment_name'], description=description) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = load_config() |
|
|
main(args) |
|
|
|