RoadQAQ's picture
Upload folder using huggingface_hub
710b71f verified
from .tgif_qa import TGIFQADataset
from .msvd_qa import MSVDQADataset
from .msrvtt_qa import MSRVTTQADataset
from .activitynet_qa import ActivitynetQADataset
from .mvbench_qa import MVBenchDataset
from .videochatgpt_qa import VideoChatGPTQADataset
# 使用大括号创建字典,并用冒号分隔键和值
DATASET_DICT = {
"TGIF_QA": TGIFQADataset,
"MSVD_QA": MSVDQADataset,
"MSRVTT_QA": MSRVTTQADataset,
"Activitynet_QA": ActivitynetQADataset,
"MVBench_QA": MVBenchDataset,
"VideoChatGPTQADataset": VideoChatGPTQADataset
}
def load_eval_dataset(dataset_args):
# 检查传入的dataset_name是否存在于字典中,避免KeyError
dataset_name = dataset_args.get('dataset_name')
if dataset_name not in DATASET_DICT:
raise ValueError(f"Dataset {dataset_name} not found in DATASET_DICT")
dataset = DATASET_DICT[dataset_name](dataset_args)
return dataset