from .tgif_qa import TGIFQADataset from .msvd_qa import MSVDQADataset from .msrvtt_qa import MSRVTTQADataset from .activitynet_qa import ActivitynetQADataset from .mvbench_qa import MVBenchDataset from .videochatgpt_qa import VideoChatGPTQADataset # 使用大括号创建字典,并用冒号分隔键和值 DATASET_DICT = { "TGIF_QA": TGIFQADataset, "MSVD_QA": MSVDQADataset, "MSRVTT_QA": MSRVTTQADataset, "Activitynet_QA": ActivitynetQADataset, "MVBench_QA": MVBenchDataset, "VideoChatGPTQADataset": VideoChatGPTQADataset } def load_eval_dataset(dataset_args): # 检查传入的dataset_name是否存在于字典中,避免KeyError dataset_name = dataset_args.get('dataset_name') if dataset_name not in DATASET_DICT: raise ValueError(f"Dataset {dataset_name} not found in DATASET_DICT") dataset = DATASET_DICT[dataset_name](dataset_args) return dataset