RoadQAQ's picture
Upload folder using huggingface_hub
710b71f verified
from .base_dataset import VideoEvalDataset
import os
import json
import numpy as np
class TGIFQADataset(VideoEvalDataset):
"""
TGIFQADataset 继承自 EvalDataset 类,用于处理 TGIF 视频问答数据集。
这个类实现了数据集加载、长度查询、以及根据索引获取单个样本的功能。
"""
# 定义数据集的基本信息
# 包括问题、答案 JSON 文件路径,视频文件路径和相关键值
def __init__(self, kwargs):
"""
初始化数据集对象,加载问题和答案的 JSON 文件并将其存储在 data_list 中。
"""
super().__init__(kwargs)
# for debug
if False:
self.data_list = self.data_list[:100]
print(f"Loaded {len(self.data_list)} entries in TGIFQADataset.")
if __name__ == "__main__":
dataset_dict = dict(
q_json_path="/home/user/students/ml/dataset/TGIF_Zero_Shot_QA/test_q.json",
a_json_path="/home/user/students/ml/dataset/TGIF_Zero_Shot_QA/test_a.json",
video_path="/home/user/students/ml/dataset/TGIF_Zero_Shot_QA/mp4",
data_type="video",
bound="False",
question_key="question",
answer_key="answer",
name_key="video_name",
video_postfix=["mp4"],
num_segments=8
)
dataset = TGIFQADataset(dataset_dict)
print(len(dataset))
# 遍历 DataLoader,打印每个 batch 的内容
for i in range(len(dataset)):
sample = dataset[i]