Spaces:
Runtime error
Runtime error
| import os | |
| from tracemalloc import start | |
| import warnings | |
| import glob | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision | |
| import torch.distributed as dist | |
| from decord import VideoReader | |
| from pcache_fileio import fileio | |
| from pcache_fileio.oss_conf import OssConfigFactory | |
| class SakugaRefDataset(Dataset): | |
| def __init__( | |
| self, | |
| # width=1024, height=576, | |
| video_frames=25, | |
| ref_jump_frames=36, | |
| base_folder='data/samples/', | |
| file_list=None, | |
| temporal_sample=None, | |
| transform=None, | |
| seed=42, | |
| ): | |
| """ | |
| Args: | |
| num_samples (int): Number of samples in the dataset. | |
| channels (int): Number of channels, default is 3 for RGB. | |
| """ | |
| # Define the path to the folder containing video frames | |
| # self.base_folder = 'bdd100k/images/track/mini' | |
| self.base_folder = base_folder | |
| self.file_list = file_list | |
| if file_list is None: | |
| self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4')) | |
| else: | |
| # read from file_list.txt | |
| self.video_lists = [] | |
| with open(file_list, 'r') as f: | |
| for line in f: | |
| video_path = line.strip() | |
| self.video_lists.append(os.path.join(self.base_folder, video_path)) | |
| self.num_samples = len(self.video_lists) | |
| self.channels = 3 | |
| # self.width = width | |
| # self.height = height | |
| self.video_frames = video_frames | |
| self.ref_jump_frames = ref_jump_frames | |
| self.temporal_sample = temporal_sample | |
| self.transform = transform | |
| self.seed = seed | |
| def __len__(self): | |
| return self.num_samples | |
| def get_sample(self, idx): | |
| """ | |
| Args: | |
| idx (int): Index of the sample to return. | |
| Returns: | |
| dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). | |
| """ | |
| # path = random.choice(self.video_lists) | |
| path = self.video_lists[idx] | |
| if self.file_list is not None: # read from pcache | |
| with open(path, 'rb') as f: | |
| vframes = VideoReader(f) | |
| else: | |
| vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') | |
| total_frames = len(vframes) | |
| # Sampling video frames | |
| ref_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
| if not end_frame_ind - ref_frame_ind >= self.video_frames+self.ref_jump_frames: | |
| raise ValueError(f'video {path} does not have enough frames') | |
| start_frame_ind = ref_frame_ind + self.ref_jump_frames | |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.video_frames, dtype=int) | |
| frame_indice = np.insert(frame_indice, 0, ref_frame_ind) | |
| if self.file_list is not None: # read from pcache | |
| video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
| else: | |
| video = vframes[frame_indice] | |
| # (f c h w) | |
| pixel_values = self.transform(video) | |
| return {'pixel_values': pixel_values} # the [0] index for pixel_values is the reference image, the other indexes are the video frames | |
| def __getitem__(self, idx): | |
| # return self.get_sample(idx) | |
| while(True): | |
| try: | |
| # idx = np.random.randint(0, len(self.video_lists) - 1) | |
| # idx = self.rng.integers(0, len(self.video_lists)) | |
| item = self.get_sample(idx) | |
| return item | |
| except: | |
| # warnings.warn(f'loading {idx} failed, retrying...') | |
| idx = np.random.randint(0, len(self.video_lists) - 1) | |
| # item = self.get_sample(idx) | |
| # return item |