File size: 6,773 Bytes
710b71f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | from torch.utils.data import Dataset
from typing import Any, List
import os
import av
from util.json import load_from_json
from util.read_video import read_video, read_gif, read_frame
class VideoEvalDataset(Dataset):
"""
VideoEvalDataset 是一个继承自 PyTorch 的 Dataset 的基类,用于处理视频、GIF、图像帧数据。
提供了不同的解码方法来读取这些格式的数据并返回处理后的图像组。
"""
def __init__(self, kwargs):
"""
初始化 EvalDataset 类,设置分段数,并初始化支持的视频、GIF、帧读取方法。
参数:
num_segments (int): 将视频或GIF划分为的段数,默认为8。
"""
super().__init__()
self.data_list_info = dict(
q_json_path="Dataset/test_q.json", # 问题 JSON 文件路径
a_json_path="Dataset/test_a.json", # 答案 JSON 文件路径
video_path="Dataset/mp4", # 视频文件存储目录
data_type="video", # 数据类型是视频
bound=False, # 是否有开始和结束时间,默认无
question_key='question', # 问题字段的键
answer_key='answer', # 答案字段的键
name_key='video_name', # 视频名称字段的键
video_postfix=['mp4'], # 视频文件后缀
num_segments=1
)
# 抽帧数
self.num_segments = kwargs["num_segments"] # 定义将视频划分的段数
# 定义不同类型数据的解码方法
self.decord_method = {
'video': read_video,
'gif': read_gif,
'frame': read_frame,
}
# 使用传入的 kwargs 来更新 data_list_info
self.data_list_info.update(kwargs)
self.data_list = [] # 存储加载的数据
# 构造问题 JSON 文件的完整路径
q_json_path = self.data_list_info['q_json_path']
questions_json_data = load_from_json(q_json_path) # 加载问题数据
# 构造答案 JSON 文件的完整路径
a_json_path = self.data_list_info['a_json_path']
answers_json_data = load_from_json(a_json_path) # 加载答案数据
# 将问题和答案整合到 data_list 中,每个问题-答案对作为一个字典存储
for i in range(len(questions_json_data)):
question_data = questions_json_data[i] # 获取第 i 个问题
answer_data = answers_json_data[i] # 获取第 i 个答案
data = {}
data.update(question_data) # 更新字典,加入问题数据
data.update(answer_data) # 更新字典,加入答案数据
self.data_list.append(data) # 将整合后的数据加入列表
def __len__(self):
"""
返回数据集的长度,即总共包含多少个问题-答案对。
"""
return len(self.data_list)
def __getitem__(self, idx):
"""
根据索引 idx 返回数据集中对应的样本,包括视频帧、问题和答案。
参数:
idx (int): 数据集中样本的索引。
返回:
dict: 包含视频帧图像数据、问题文本、视频路径和答案文本的字典。
'video_pils': images_group, # 视频帧图像数据
'question': question, # 问题文本
'video_path': video_path, # 视频路径
'answer': answer # 答案文本
"""
# 获取视频名称的键,并通过键从数据列表中获取对应视频名称
video_name_key = self.data_list_info['name_key']
video_name = self.data_list[idx][video_name_key]
# 获取视频文件的后缀(可能有多个后缀)
video_postfixs = self.data_list_info['video_postfix']
# 查找存在的视频路径,考虑可能有多个后缀的视频文件
video_paths = []
for p in video_postfixs:
video_path = os.path.join(self.data_list_info['video_path'], f"{video_name}{p}")
if os.path.exists(video_path): # 检查视频文件是否存在
video_paths.append(video_path)
# 确保至少找到一个有效的视频文件路径,否则抛出错误
assert len(video_paths) > 0, f"No video named {video_name} found."
# 使用第一个找到的视频文件路径
video_path = video_paths[0]
# 读取视频数据,bound 用于定义是否截取视频的某个范围,这里为 None 表示不限制
bound = None
if self.data_list_info['bound'] == "True":
bound = (
self.data_list[idx]['start'],
self.data_list[idx]['end'],
)
this_decord_method = self.decord_method[self.data_list_info['data_type']]
images_group = this_decord_method(video_path, self.num_segments, bound) # 解码视频获取帧数据
# 从 data_list 中提取问题和答案
question_key = self.data_list_info['question_key']
answer_key = self.data_list_info['answer_key']
question = self.data_list[idx][question_key] # 获取问题文本
answer = self.data_list[idx][answer_key] # 获取答案文本
# 返回包含视频帧、问题、视频路径和答案的字典
return {
'video_pils': images_group, # 视频帧图像数据
'question': question, # 问题文本
'video_path': video_path, # 视频路径
'answer': answer # 答案文本
}
def set_rank_and_world_size(self, rank, world_size):
"""
设置数据集的 rank 和 world_size 以用于分布式推理,并确保数据集按顺序分配。
参数:
rank (int): 当前进程的 rank 值。
world_size (int): 总的进程数。
"""
self.rank = rank
self.world_size = world_size
# 计算每个进程应分配的数据样本数量
total_samples = len(self.data_list)
samples_per_rank = total_samples // world_size
remainder = total_samples % world_size
# 确定当前 rank 的起始和结束索引
if rank < remainder:
start_idx = rank * (samples_per_rank + 1)
end_idx = start_idx + samples_per_rank + 1
else:
start_idx = rank * samples_per_rank + remainder
end_idx = start_idx + samples_per_rank
return start_idx, end_idx
|