| from .base_dataset import VideoEvalDataset | |
| import os | |
| import json | |
| import numpy as np | |
| class TGIFQADataset(VideoEvalDataset): | |
| """ | |
| TGIFQADataset 继承自 EvalDataset 类,用于处理 TGIF 视频问答数据集。 | |
| 这个类实现了数据集加载、长度查询、以及根据索引获取单个样本的功能。 | |
| """ | |
| # 定义数据集的基本信息 | |
| # 包括问题、答案 JSON 文件路径,视频文件路径和相关键值 | |
| def __init__(self, kwargs): | |
| """ | |
| 初始化数据集对象,加载问题和答案的 JSON 文件并将其存储在 data_list 中。 | |
| """ | |
| super().__init__(kwargs) | |
| # for debug | |
| if False: | |
| self.data_list = self.data_list[:100] | |
| print(f"Loaded {len(self.data_list)} entries in TGIFQADataset.") | |
| if __name__ == "__main__": | |
| dataset_dict = dict( | |
| q_json_path="/home/user/students/ml/dataset/TGIF_Zero_Shot_QA/test_q.json", | |
| a_json_path="/home/user/students/ml/dataset/TGIF_Zero_Shot_QA/test_a.json", | |
| video_path="/home/user/students/ml/dataset/TGIF_Zero_Shot_QA/mp4", | |
| data_type="video", | |
| bound="False", | |
| question_key="question", | |
| answer_key="answer", | |
| name_key="video_name", | |
| video_postfix=["mp4"], | |
| num_segments=8 | |
| ) | |
| dataset = TGIFQADataset(dataset_dict) | |
| print(len(dataset)) | |
| # 遍历 DataLoader,打印每个 batch 的内容 | |
| for i in range(len(dataset)): | |
| sample = dataset[i] |