Spaces:
Build error
Build error
| import functools | |
| import itertools | |
| import logging | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from multiprocessing import Pool | |
| import multiprocessing as mp | |
| from argparse import ArgumentParser | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from decord import VideoReader, cpu | |
| import transformers | |
| from tasks.eval.model_utils import load_pllava, pllava_answer | |
| from tasks.eval.eval_utils import conv_templates | |
| from tasks.eval.mvbench import ( | |
| MVBenchDataset, | |
| check_ans, | |
| save_results, | |
| load_results, | |
| ) | |
| logging.basicConfig() | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| RESOLUTION = 672 # | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| required=True, | |
| default='llava-hf/llava-1.5-7b-hf' | |
| ) | |
| parser.add_argument( | |
| "--save_path", | |
| type=str, | |
| required=True, | |
| default='"./test_results/test_llava_mvbench"' | |
| ) | |
| parser.add_argument( | |
| "--num_frames", | |
| type=int, | |
| required=True, | |
| default=4, | |
| ) | |
| parser.add_argument( | |
| "--use_lora", | |
| action='store_true' | |
| ) | |
| parser.add_argument( | |
| "--lora_alpha", | |
| type=int, | |
| required=False, | |
| default=32, | |
| ) | |
| parser.add_argument( | |
| "--weight_dir", | |
| type=str, | |
| required=False, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--conv_mode", | |
| type=str, | |
| required=False, | |
| default='eval_mvbench', | |
| ) | |
| parser.add_argument( | |
| "--pooling_shape", | |
| type=str, | |
| required=False, | |
| default=None, | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, pooling_shape=(16,12,12)): | |
| # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes. | |
| model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, weight_dir=weight_dir, lora_alpha=lora_alpha, pooling_shape=pooling_shape) | |
| logger.info('done loading llava') | |
| # position embedding | |
| model = model.to(torch.device(rank)) | |
| model = model.eval() | |
| dataset = MVBenchDataset(num_segments=num_frames) | |
| dataset.set_rank_and_world_size(rank, world_size) | |
| return model, processor, dataset | |
| def infer_mvbench( | |
| model, | |
| processor, | |
| data_sample, | |
| conv_mode, | |
| pre_query_prompt=None, # add in the head of question | |
| post_query_prompt=None, # add in the end of question | |
| answer_prompt=None, # add in the begining of answer | |
| return_prompt=None, # add in the begining of return message | |
| print_res=False, | |
| ): | |
| video_list = data_sample["video_pils"] | |
| conv = conv_templates[conv_mode].copy() | |
| conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True) | |
| if answer_prompt is not None: | |
| conv.assistant_response(answer_prompt) | |
| llm_message, conv = pllava_answer( | |
| conv=conv, | |
| model=model, | |
| processor=processor, | |
| img_list=video_list, | |
| max_new_tokens=100, | |
| do_sample=False, | |
| print_res=print_res | |
| ) | |
| if answer_prompt is not None: | |
| llm_message = ''.join(llm_message.split(answer_prompt)[1:]) | |
| if return_prompt is not None: | |
| llm_message = return_prompt + llm_message | |
| return llm_message | |
| def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): | |
| def get_index(num_frames, num_segments): | |
| seg_size = float(num_frames - 1) / num_segments | |
| start = int(seg_size / 2) | |
| offsets = np.array([ | |
| start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
| ]) | |
| return offsets | |
| def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): | |
| transforms = torchvision.transforms.Resize(size=resolution) | |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
| num_frames = len(vr) | |
| frame_indices = get_index(num_frames, num_segments) | |
| images_group = list() | |
| for frame_index in frame_indices: | |
| img = Image.fromarray(vr[frame_index].asnumpy()) | |
| images_group.append(transforms(img)) | |
| if return_msg: | |
| fps = float(vr.get_avg_fps()) | |
| sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
| # " " should be added in the start and end | |
| msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
| return images_group, msg | |
| else: | |
| return images_group | |
| if num_frames != 0: | |
| vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) | |
| else: | |
| vid, msg = None, 'num_frames is 0, not inputing image' | |
| img_list = vid | |
| conv = conv_templates[conv_mode].copy() | |
| conv.user_query("Describe the video in details.", is_mm=True) | |
| llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) | |
| def run(rank, args, world_size): | |
| if rank != 0: | |
| transformers.utils.logging.set_verbosity_error() | |
| logger.setLevel(transformers.logging.ERROR) | |
| print_res = False | |
| conv_mode= args.conv_mode | |
| pre_query_prompt = None | |
| post_query_prompt = "\nOnly give the best option." | |
| if args.pooling_shape is not None: | |
| pooling_shape=tuple([int(x) for x in args.pooling_shape.split("-")]) | |
| logger.info(f'loading model and constructing dataset to gpu {rank}...') | |
| model, processor, dataset = load_model_and_dataset(rank, | |
| world_size, | |
| pretrained_model_name_or_path=args.pretrained_model_name_or_path, | |
| num_frames=args.num_frames, | |
| use_lora=args.use_lora, | |
| lora_alpha=args.lora_alpha, | |
| weight_dir=args.weight_dir, | |
| pooling_shape=pooling_shape) | |
| logger.info(f'done model and dataset...') | |
| logger.info('constructing dataset...') | |
| logger.info('single test...') | |
| vid_path = "./example/yoga.mp4" | |
| # vid_path = "./example/jesse_dance.mp4" | |
| if rank == 0: | |
| single_test(model, | |
| processor, | |
| vid_path, | |
| num_frames=args.num_frames, | |
| conv_mode=args.conv_mode) | |
| logger.info('single test done...') | |
| tbar = tqdm(total=len(dataset)) | |
| correct = 0 | |
| total = 0 | |
| result_list = [] | |
| acc_dict = {} | |
| done_count = 0 | |
| for example in dataset: | |
| task_type = example['task_type'] | |
| if task_type not in acc_dict: | |
| acc_dict[task_type] = [0, 0] # correct, total | |
| acc_dict[task_type][1] += 1 | |
| total += 1 | |
| pred = infer_mvbench( | |
| model, | |
| processor, | |
| example, | |
| conv_mode=conv_mode, | |
| pre_query_prompt=pre_query_prompt, | |
| post_query_prompt=post_query_prompt, | |
| answer_prompt="Best option:(", | |
| return_prompt='(', | |
| print_res=print_res, | |
| ) | |
| gt = example['answer'] | |
| result_list.append({ | |
| 'pred': pred, | |
| 'gt': gt, | |
| 'task_type': task_type, | |
| 'video_path': example['video_path'], | |
| 'question': example['question'], | |
| }) | |
| if check_ans(pred=pred, gt=gt): | |
| acc_dict[task_type][0] += 1 | |
| correct += 1 | |
| if rank == 0: | |
| tbar.update(len(result_list) - done_count, ) | |
| tbar.set_description_str( | |
| f"One Chunk--Task Type: {task_type}, Chunk Part Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%;" | |
| f" Chunk Total Acc: {correct / total * 100 :.2f}%" | |
| ) | |
| done_count = len(result_list) | |
| return result_list | |
| def main(): | |
| multiprocess=True | |
| mp.set_start_method('spawn') | |
| args = parse_args() | |
| save_path = args.save_path | |
| json_data = load_results(save_path) | |
| if json_data is None: | |
| if multiprocess: | |
| logger.info(f'started benchmarking, saving to: {save_path}') | |
| n_gpus = torch.cuda.device_count() | |
| # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" | |
| world_size = n_gpus | |
| with Pool(world_size) as pool: | |
| func = functools.partial(run, args=args, world_size=world_size) | |
| result_lists = pool.map(func, range(world_size)) | |
| logger.info('finished running') | |
| result_list = [ res for res in itertools.chain(*result_lists)] | |
| else: | |
| result_list = run(0, world_size=1, args=args) # debug | |
| else: | |
| logger.info(f'loaded results from {save_path}') | |
| result_list = json_data | |
| save_results(result_list, save_path) | |
| if __name__ == "__main__": | |
| main() |