File size: 914 Bytes
710b71f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from .tgif_qa import TGIFQADataset
from .msvd_qa import MSVDQADataset
from .msrvtt_qa import MSRVTTQADataset
from .activitynet_qa import ActivitynetQADataset
from .mvbench_qa import MVBenchDataset
from .videochatgpt_qa import VideoChatGPTQADataset

# 使用大括号创建字典,并用冒号分隔键和值
DATASET_DICT = {
    "TGIF_QA": TGIFQADataset,
    "MSVD_QA": MSVDQADataset,
    "MSRVTT_QA": MSRVTTQADataset,
    "Activitynet_QA": ActivitynetQADataset,
    "MVBench_QA": MVBenchDataset,
    "VideoChatGPTQADataset": VideoChatGPTQADataset
}

def load_eval_dataset(dataset_args):
    # 检查传入的dataset_name是否存在于字典中,避免KeyError
    dataset_name = dataset_args.get('dataset_name')
    if dataset_name not in DATASET_DICT:
        raise ValueError(f"Dataset {dataset_name} not found in DATASET_DICT")
    
    dataset = DATASET_DICT[dataset_name](dataset_args)
    return dataset