|
|
from torch.utils.data import Dataset |
|
|
from typing import Any, List |
|
|
import os |
|
|
import av |
|
|
from util.json import load_from_json |
|
|
from util.read_video import read_video, read_gif, read_frame |
|
|
|
|
|
class VideoEvalDataset(Dataset): |
|
|
""" |
|
|
VideoEvalDataset 是一个继承自 PyTorch 的 Dataset 的基类,用于处理视频、GIF、图像帧数据。 |
|
|
提供了不同的解码方法来读取这些格式的数据并返回处理后的图像组。 |
|
|
""" |
|
|
|
|
|
def __init__(self, kwargs): |
|
|
""" |
|
|
初始化 EvalDataset 类,设置分段数,并初始化支持的视频、GIF、帧读取方法。 |
|
|
|
|
|
参数: |
|
|
num_segments (int): 将视频或GIF划分为的段数,默认为8。 |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.data_list_info = dict( |
|
|
q_json_path="Dataset/test_q.json", |
|
|
a_json_path="Dataset/test_a.json", |
|
|
video_path="Dataset/mp4", |
|
|
data_type="video", |
|
|
bound=False, |
|
|
question_key='question', |
|
|
answer_key='answer', |
|
|
name_key='video_name', |
|
|
video_postfix=['mp4'], |
|
|
num_segments=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.num_segments = kwargs["num_segments"] |
|
|
|
|
|
|
|
|
self.decord_method = { |
|
|
'video': read_video, |
|
|
'gif': read_gif, |
|
|
'frame': read_frame, |
|
|
} |
|
|
|
|
|
|
|
|
self.data_list_info.update(kwargs) |
|
|
|
|
|
self.data_list = [] |
|
|
|
|
|
|
|
|
q_json_path = self.data_list_info['q_json_path'] |
|
|
questions_json_data = load_from_json(q_json_path) |
|
|
|
|
|
|
|
|
a_json_path = self.data_list_info['a_json_path'] |
|
|
answers_json_data = load_from_json(a_json_path) |
|
|
|
|
|
|
|
|
for i in range(len(questions_json_data)): |
|
|
question_data = questions_json_data[i] |
|
|
answer_data = answers_json_data[i] |
|
|
data = {} |
|
|
data.update(question_data) |
|
|
data.update(answer_data) |
|
|
self.data_list.append(data) |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
返回数据集的长度,即总共包含多少个问题-答案对。 |
|
|
""" |
|
|
return len(self.data_list) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
根据索引 idx 返回数据集中对应的样本,包括视频帧、问题和答案。 |
|
|
|
|
|
参数: |
|
|
idx (int): 数据集中样本的索引。 |
|
|
|
|
|
返回: |
|
|
dict: 包含视频帧图像数据、问题文本、视频路径和答案文本的字典。 |
|
|
'video_pils': images_group, # 视频帧图像数据 |
|
|
'question': question, # 问题文本 |
|
|
'video_path': video_path, # 视频路径 |
|
|
'answer': answer # 答案文本 |
|
|
""" |
|
|
|
|
|
video_name_key = self.data_list_info['name_key'] |
|
|
video_name = self.data_list[idx][video_name_key] |
|
|
|
|
|
|
|
|
video_postfixs = self.data_list_info['video_postfix'] |
|
|
|
|
|
|
|
|
video_paths = [] |
|
|
for p in video_postfixs: |
|
|
video_path = os.path.join(self.data_list_info['video_path'], f"{video_name}{p}") |
|
|
if os.path.exists(video_path): |
|
|
video_paths.append(video_path) |
|
|
|
|
|
|
|
|
assert len(video_paths) > 0, f"No video named {video_name} found." |
|
|
|
|
|
|
|
|
video_path = video_paths[0] |
|
|
|
|
|
|
|
|
bound = None |
|
|
if self.data_list_info['bound'] == "True": |
|
|
bound = ( |
|
|
self.data_list[idx]['start'], |
|
|
self.data_list[idx]['end'], |
|
|
) |
|
|
this_decord_method = self.decord_method[self.data_list_info['data_type']] |
|
|
images_group = this_decord_method(video_path, self.num_segments, bound) |
|
|
|
|
|
|
|
|
question_key = self.data_list_info['question_key'] |
|
|
answer_key = self.data_list_info['answer_key'] |
|
|
question = self.data_list[idx][question_key] |
|
|
answer = self.data_list[idx][answer_key] |
|
|
|
|
|
|
|
|
return { |
|
|
'video_pils': images_group, |
|
|
'question': question, |
|
|
'video_path': video_path, |
|
|
'answer': answer |
|
|
} |
|
|
|
|
|
def set_rank_and_world_size(self, rank, world_size): |
|
|
""" |
|
|
设置数据集的 rank 和 world_size 以用于分布式推理,并确保数据集按顺序分配。 |
|
|
|
|
|
参数: |
|
|
rank (int): 当前进程的 rank 值。 |
|
|
world_size (int): 总的进程数。 |
|
|
""" |
|
|
self.rank = rank |
|
|
self.world_size = world_size |
|
|
|
|
|
|
|
|
total_samples = len(self.data_list) |
|
|
samples_per_rank = total_samples // world_size |
|
|
remainder = total_samples % world_size |
|
|
|
|
|
|
|
|
if rank < remainder: |
|
|
start_idx = rank * (samples_per_rank + 1) |
|
|
end_idx = start_idx + samples_per_rank + 1 |
|
|
else: |
|
|
start_idx = rank * samples_per_rank + remainder |
|
|
end_idx = start_idx + samples_per_rank |
|
|
|
|
|
return start_idx, end_idx |
|
|
|
|
|
|