File size: 6,773 Bytes
710b71f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from torch.utils.data import Dataset
from typing import Any, List
import os
import av
from util.json import load_from_json
from util.read_video import read_video, read_gif, read_frame

class VideoEvalDataset(Dataset):
    """
    VideoEvalDataset 是一个继承自 PyTorch 的 Dataset 的基类,用于处理视频、GIF、图像帧数据。
    提供了不同的解码方法来读取这些格式的数据并返回处理后的图像组。
    """

    def __init__(self, kwargs):
        """
        初始化 EvalDataset 类,设置分段数,并初始化支持的视频、GIF、帧读取方法。

        参数:
        num_segments (int): 将视频或GIF划分为的段数,默认为8。
        """
        super().__init__()
        
        self.data_list_info = dict(
            q_json_path="Dataset/test_q.json",  # 问题 JSON 文件路径
            a_json_path="Dataset/test_a.json",  # 答案 JSON 文件路径
            video_path="Dataset/mp4",          # 视频文件存储目录
            data_type="video",                           # 数据类型是视频
            bound=False,                                 # 是否有开始和结束时间,默认无
            question_key='question',                     # 问题字段的键
            answer_key='answer',                         # 答案字段的键
            name_key='video_name',                       # 视频名称字段的键
            video_postfix=['mp4'],                      # 视频文件后缀
            num_segments=1
        )

        # 抽帧数
        self.num_segments = kwargs["num_segments"]  # 定义将视频划分的段数
        
        # 定义不同类型数据的解码方法
        self.decord_method = {
            'video': read_video,
            'gif': read_gif,
            'frame': read_frame,
        }

        # 使用传入的 kwargs 来更新 data_list_info
        self.data_list_info.update(kwargs)

        self.data_list = []  # 存储加载的数据
        
        # 构造问题 JSON 文件的完整路径
        q_json_path = self.data_list_info['q_json_path']
        questions_json_data = load_from_json(q_json_path)  # 加载问题数据

        # 构造答案 JSON 文件的完整路径
        a_json_path = self.data_list_info['a_json_path']
        answers_json_data = load_from_json(a_json_path)  # 加载答案数据

        # 将问题和答案整合到 data_list 中,每个问题-答案对作为一个字典存储
        for i in range(len(questions_json_data)):
            question_data = questions_json_data[i]  # 获取第 i 个问题
            answer_data = answers_json_data[i]      # 获取第 i 个答案
            data = {}
            data.update(question_data)              # 更新字典,加入问题数据
            data.update(answer_data)                # 更新字典,加入答案数据
            self.data_list.append(data)             # 将整合后的数据加入列表
        
    def __len__(self):
        """
        返回数据集的长度,即总共包含多少个问题-答案对。
        """
        return len(self.data_list)

    def __getitem__(self, idx):
        """
        根据索引 idx 返回数据集中对应的样本,包括视频帧、问题和答案。
        
        参数:
        idx (int): 数据集中样本的索引。

        返回:
        dict: 包含视频帧图像数据、问题文本、视频路径和答案文本的字典。
            'video_pils': images_group,  # 视频帧图像数据
            'question': question,        # 问题文本
            'video_path': video_path,    # 视频路径
            'answer': answer             # 答案文本
        """
        # 获取视频名称的键,并通过键从数据列表中获取对应视频名称
        video_name_key = self.data_list_info['name_key']
        video_name = self.data_list[idx][video_name_key]

        # 获取视频文件的后缀(可能有多个后缀)
        video_postfixs = self.data_list_info['video_postfix']

        # 查找存在的视频路径,考虑可能有多个后缀的视频文件
        video_paths = []
        for p in video_postfixs:
            video_path = os.path.join(self.data_list_info['video_path'], f"{video_name}{p}")
            if os.path.exists(video_path):  # 检查视频文件是否存在
                video_paths.append(video_path)
        
        # 确保至少找到一个有效的视频文件路径,否则抛出错误
        assert len(video_paths) > 0, f"No video named {video_name} found."

        # 使用第一个找到的视频文件路径
        video_path = video_paths[0]
        
        # 读取视频数据,bound 用于定义是否截取视频的某个范围,这里为 None 表示不限制
        bound = None
        if self.data_list_info['bound'] == "True":
            bound = (
                self.data_list[idx]['start'],
                self.data_list[idx]['end'],
            )
        this_decord_method = self.decord_method[self.data_list_info['data_type']]
        images_group = this_decord_method(video_path, self.num_segments, bound)  # 解码视频获取帧数据
    
        # 从 data_list 中提取问题和答案
        question_key = self.data_list_info['question_key']
        answer_key = self.data_list_info['answer_key']
        question = self.data_list[idx][question_key]  # 获取问题文本
        answer = self.data_list[idx][answer_key]      # 获取答案文本
        
        # 返回包含视频帧、问题、视频路径和答案的字典
        return {
            'video_pils': images_group,  # 视频帧图像数据
            'question': question,        # 问题文本
            'video_path': video_path,    # 视频路径
            'answer': answer             # 答案文本
        }

    def set_rank_and_world_size(self, rank, world_size):
        """
        设置数据集的 rank 和 world_size 以用于分布式推理,并确保数据集按顺序分配。

        参数:
        rank (int): 当前进程的 rank 值。
        world_size (int): 总的进程数。
        """
        self.rank = rank
        self.world_size = world_size

        # 计算每个进程应分配的数据样本数量
        total_samples = len(self.data_list)
        samples_per_rank = total_samples // world_size
        remainder = total_samples % world_size

        # 确定当前 rank 的起始和结束索引
        if rank < remainder:
            start_idx = rank * (samples_per_rank + 1)
            end_idx = start_idx + samples_per_rank + 1
        else:
            start_idx = rank * samples_per_rank + remainder
            end_idx = start_idx + samples_per_rank

        return start_idx, end_idx