| from .tgif_qa import TGIFQADataset | |
| from .msvd_qa import MSVDQADataset | |
| from .msrvtt_qa import MSRVTTQADataset | |
| from .activitynet_qa import ActivitynetQADataset | |
| from .mvbench_qa import MVBenchDataset | |
| from .videochatgpt_qa import VideoChatGPTQADataset | |
| # 使用大括号创建字典,并用冒号分隔键和值 | |
| DATASET_DICT = { | |
| "TGIF_QA": TGIFQADataset, | |
| "MSVD_QA": MSVDQADataset, | |
| "MSRVTT_QA": MSRVTTQADataset, | |
| "Activitynet_QA": ActivitynetQADataset, | |
| "MVBench_QA": MVBenchDataset, | |
| "VideoChatGPTQADataset": VideoChatGPTQADataset | |
| } | |
| def load_eval_dataset(dataset_args): | |
| # 检查传入的dataset_name是否存在于字典中,避免KeyError | |
| dataset_name = dataset_args.get('dataset_name') | |
| if dataset_name not in DATASET_DICT: | |
| raise ValueError(f"Dataset {dataset_name} not found in DATASET_DICT") | |
| dataset = DATASET_DICT[dataset_name](dataset_args) | |
| return dataset | |