| import torch |
| import torch.distributed as dist |
|
|
| from vlmeval.config import supported_VLM |
| from vlmeval.dataset import build_dataset |
| from vlmeval.inference import infer_data_job |
| from vlmeval.inference_video import infer_data_job_video |
| from vlmeval.inference_mt import infer_data_job_mt |
| from vlmeval.smp import * |
| from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument('--data', type=str, nargs='+', required=True) |
| parser.add_argument('--model', type=str, nargs='+', required=True) |
| |
| parser.add_argument('--nframe', type=int, default=8) |
| parser.add_argument('--pack', action='store_true') |
| parser.add_argument('--use-subtitle', action='store_true') |
| |
| parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory') |
| |
| parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer']) |
| |
| parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling') |
| parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs') |
| |
| parser.add_argument('--judge', type=str, default=None) |
| |
| parser.add_argument('--verbose', action='store_true') |
| |
| |
| parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ') |
| |
| parser.add_argument('--rerun', action='store_true') |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def main(): |
| logger = get_logger('RUN') |
|
|
| args = parse_args() |
| assert len(args.data), '--data should be a list of data files' |
|
|
| if args.retry is not None: |
| for k, v in supported_VLM.items(): |
| if hasattr(v, 'keywords') and 'retry' in v.keywords: |
| v.keywords['retry'] = args.retry |
| supported_VLM[k] = v |
| if hasattr(v, 'keywords') and 'verbose' in v.keywords: |
| v.keywords['verbose'] = args.verbose |
| supported_VLM[k] = v |
|
|
| rank, world_size = get_rank_and_world_size() |
| if world_size > 1: |
| local_rank = os.environ.get('LOCAL_RANK', 0) |
| torch.cuda.set_device(int(local_rank)) |
| dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=10800)) |
|
|
| for _, model_name in enumerate(args.model): |
| model = None |
|
|
| pred_root = osp.join(args.work_dir, model_name) |
| os.makedirs(pred_root, exist_ok=True) |
|
|
| for _, dataset_name in enumerate(args.data): |
| dataset_kwargs = {} |
| if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']: |
| dataset_kwargs['model'] = model_name |
| if dataset_name == 'MMBench-Video': |
| dataset_kwargs['pack'] = args.pack |
| if dataset_name == 'Video-MME': |
| dataset_kwargs['use_subtitle'] = args.use_subtitle |
|
|
| |
| if world_size > 1: |
| dataset = build_dataset(dataset_name, **dataset_kwargs) if rank == 0 else None |
| dist.barrier() |
| dataset_list = [dataset] |
| dist.broadcast_object_list(dataset_list, src=0) |
| dataset = dataset_list[0] |
| else: |
| dataset = build_dataset(dataset_name, **dataset_kwargs) |
| if dataset is None: |
| logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ') |
| continue |
|
|
| result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx' |
| if dataset_name in ['MMBench-Video']: |
| packstr = 'pack' if args.pack else 'nopack' |
| result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' |
| elif dataset.MODALITY == 'VIDEO': |
| if args.pack: |
| logger.info(f'{dataset_name} not support Pack Mode, directly change to unpack') |
| args.pack = False |
| packstr = 'pack' if args.pack else 'nopack' |
| result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx' |
| if dataset_name in ['Video-MME']: |
| subtitlestr = 'subs' if args.use_subtitle else 'nosubs' |
| result_file = result_file.replace('.xlsx', f'_{subtitlestr}.xlsx') |
|
|
| if dataset.TYPE == 'MT': |
| result_file = result_file.replace('.xlsx', '.tsv') |
|
|
| if osp.exists(result_file) and args.rerun: |
| for keyword in ['openai', 'gpt', 'auxmatch']: |
| os.system(f'rm {pred_root}/{model_name}_{dataset_name}_{keyword}*') |
|
|
| if model is None: |
| model = model_name |
|
|
| |
| if dataset.MODALITY == 'VIDEO': |
| model = infer_data_job_video( |
| model, |
| work_dir=pred_root, |
| model_name=model_name, |
| dataset=dataset, |
| nframe=args.nframe, |
| pack=args.pack, |
| verbose=args.verbose, |
| subtitle=args.use_subtitle, |
| api_nproc=args.nproc) |
| elif dataset.TYPE == 'MT': |
| model = infer_data_job_mt( |
| model, |
| work_dir=pred_root, |
| model_name=model_name, |
| dataset=dataset, |
| verbose=args.verbose, |
| api_nproc=args.nproc, |
| ignore_failed=args.ignore) |
| else: |
| model = infer_data_job( |
| model, |
| work_dir=pred_root, |
| model_name=model_name, |
| dataset=dataset, |
| verbose=args.verbose, |
| api_nproc=args.nproc, |
| ignore_failed=args.ignore) |
|
|
| |
| judge_kwargs = { |
| 'nproc': args.nproc, |
| 'verbose': args.verbose, |
| } |
| if args.retry is not None: |
| judge_kwargs['retry'] = args.retry |
| if args.judge is not None: |
| judge_kwargs['model'] = args.judge |
| else: |
| if dataset.TYPE in ['MCQ', 'Y/N']: |
| judge_kwargs['model'] = 'chatgpt-0125' |
| elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'], dataset_name): |
| judge_kwargs['model'] = 'gpt-4-turbo' |
| elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI'], dataset_name): |
| judge_kwargs['model'] = 'gpt-4o' |
| if 'OPENAI_API_KEY_JUDGE' in os.environ and len(os.environ['OPENAI_API_KEY_JUDGE']): |
| judge_kwargs['key'] = os.environ['OPENAI_API_KEY_JUDGE'] |
| if 'OPENAI_API_BASE_JUDGE' in os.environ and len(os.environ['OPENAI_API_BASE_JUDGE']): |
| judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE'] |
|
|
| if rank == 0: |
| if dataset_name in ['MMMU_TEST']: |
| result_json = MMMU_result_transfer(result_file) |
| logger.info(f'Transfer MMMU_TEST result to json for official evaluation, ' |
| f'json file saved in {result_json}') |
| continue |
| elif 'MMT-Bench_ALL' in dataset_name: |
| submission_file = MMTBench_result_transfer(result_file, **judge_kwargs) |
| logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation ' |
| f'(https://eval.ai/web/challenges/challenge-page/2328/overview), ' |
| f'submission file saved in {submission_file}') |
| continue |
| elif 'MLLMGuard_DS' in dataset_name: |
| logger.info('The evaluation of MLLMGuard_DS is not supported yet. ') |
| continue |
| elif 'AesBench_TEST' == dataset_name: |
| logger.info(f'The results are saved in {result_file}. ' |
| f'Please send it to the AesBench Team via huangyipo@hotmail.com.') |
| continue |
|
|
| if dataset_name in [ |
| 'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN', |
| 'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11' |
| ]: |
| if not MMBenchOfficialServer(dataset_name): |
| logger.error( |
| f'Can not evaluate {dataset_name} on non-official servers, ' |
| 'will skip the evaluation. ' |
| ) |
| continue |
|
|
| eval_proxy = os.environ.get('EVAL_PROXY', None) |
| old_proxy = os.environ.get('HTTP_PROXY', '') |
|
|
| if rank == 0 and args.mode == 'all': |
| if eval_proxy is not None: |
| proxy_set(eval_proxy) |
|
|
| eval_results = dataset.evaluate(result_file, **judge_kwargs) |
| if eval_results is not None: |
| assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame) |
| logger.info(f'The evaluation of model {model_name} x dataset {dataset_name} has finished! ') |
| logger.info('Evaluation Results:') |
| if isinstance(eval_results, dict): |
| logger.info('\n' + json.dumps(eval_results, indent=4)) |
| elif isinstance(eval_results, pd.DataFrame): |
| if len(eval_results) < len(eval_results.columns): |
| eval_results = eval_results.T |
| logger.info('\n' + tabulate(eval_results)) |
|
|
| if eval_proxy is not None: |
| proxy_set(old_proxy) |
|
|
|
|
| if __name__ == '__main__': |
| load_env() |
| main() |
|
|