Spaces:
Paused
Paused
| import os, io, csv, math, random | |
| import json | |
| import numpy as np | |
| from einops import rearrange | |
| from decord import VideoReader | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.utils.data.dataset import Dataset | |
| from diffusers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class WebVid10M(Dataset): | |
| def __init__( | |
| self, | |
| json_path, video_folder=None, | |
| sample_size=256, sample_stride=4, sample_n_frames=16, | |
| is_image=False, | |
| **kwargs, | |
| ): | |
| logger.info(f"loading annotations from {json_path} ...") | |
| with open(json_path, 'rb') as json_file: | |
| json_list = list(json_file) | |
| self.dataset = [json.loads(json_str) for json_str in json_list] | |
| self.length = len(self.dataset) | |
| logger.info(f"data scale: {self.length}") | |
| self.video_folder = video_folder | |
| self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride) | |
| self.sample_n_frames = sample_n_frames | |
| self.is_image = is_image | |
| sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
| self.pixel_transforms = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.Resize(sample_size[0], antialias=None), | |
| transforms.CenterCrop(sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| video_relative_path, name = video_dict['file'], video_dict['text'] | |
| if self.video_folder is not None: | |
| if video_relative_path[0] == '/': | |
| video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path)) | |
| else: | |
| video_dir = os.path.join(self.video_folder, video_relative_path) | |
| else: | |
| video_dir = video_relative_path | |
| video_reader = VideoReader(video_dir) | |
| video_length = len(video_reader) | |
| if not self.is_image: | |
| if isinstance(self.sample_stride, int): | |
| stride = self.sample_stride | |
| elif isinstance(self.sample_stride, tuple): | |
| stride = random.randint(self.sample_stride[0], self.sample_stride[1]) | |
| clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | |
| else: | |
| frame_difference = random.randint(2, self.sample_n_frames) | |
| clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = [start_idx, start_idx + clip_length - 1] | |
| pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| return pixel_values, name | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| pixel_values, name = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| idx = random.randint(0, self.length-1) | |
| pixel_values = self.pixel_transforms(pixel_values) | |
| sample = dict(pixel_values=pixel_values, text=name) | |
| return sample | |
| class Pexels(Dataset): | |
| def __init__( | |
| self, | |
| json_path, caption_json_path, video_folder=None, | |
| sample_size=256, sample_duration=1, sample_fps=8, | |
| is_image=False, | |
| **kwargs, | |
| ): | |
| logger.info(f"loading captions from {caption_json_path} ...") | |
| with open(caption_json_path, 'rb') as caption_json_file: | |
| caption_json_list = list(caption_json_file) | |
| self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list} | |
| logger.info(f"loading annotations from {json_path} ...") | |
| with open(json_path, 'rb') as json_file: | |
| json_list = list(json_file) | |
| dataset = [json.loads(json_str) for json_str in json_list] | |
| self.dataset = [] | |
| for data in dataset: | |
| data['text'] = self.caption_dict[data['id']] | |
| if data['height'] / data['width'] < 0.625: | |
| self.dataset.append(data) | |
| self.length = len(self.dataset) | |
| logger.info(f"data scale: {self.length}") | |
| self.video_folder = video_folder | |
| self.sample_duration = sample_duration | |
| self.sample_fps = sample_fps | |
| self.sample_n_frames = sample_duration * sample_fps | |
| self.is_image = is_image | |
| sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
| self.pixel_transforms = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.Resize(sample_size[0], antialias=None), | |
| transforms.CenterCrop(sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| video_relative_path, name = video_dict['file'], video_dict['text'] | |
| fps = video_dict['fps'] | |
| if self.video_folder is not None: | |
| if video_relative_path[0] == '/': | |
| video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path)) | |
| else: | |
| video_dir = os.path.join(self.video_folder, video_relative_path) | |
| else: | |
| video_dir = video_relative_path | |
| video_reader = VideoReader(video_dir) | |
| video_length = len(video_reader) | |
| if not self.is_image: | |
| clip_length = min(video_length, math.ceil(fps * self.sample_duration)) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | |
| else: | |
| frame_difference = random.randint(2, self.sample_n_frames) | |
| sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1) | |
| clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = [start_idx, start_idx + clip_length - 1] | |
| pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| return pixel_values, name | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| pixel_values, name = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| idx = random.randint(0, self.length-1) | |
| pixel_values = self.pixel_transforms(pixel_values) | |
| sample = dict(pixel_values=pixel_values, text=name) | |
| return sample | |
| class JointDataset(Dataset): | |
| def __init__( | |
| self, | |
| webvid_config, pexels_config, | |
| sample_size=256, | |
| sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None, | |
| is_image=False, | |
| **kwargs, | |
| ): | |
| assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None" | |
| if sample_duration is not None and sample_fps is not None: | |
| assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None" | |
| if sample_stride is not None: | |
| assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None" | |
| self.dataset = [] | |
| if pexels_config.enable: | |
| logger.info(f"loading pexels dataset") | |
| logger.info(f"loading captions from {pexels_config.caption_json_path} ...") | |
| with open(pexels_config.caption_json_path, 'rb') as caption_json_file: | |
| caption_json_list = list(caption_json_file) | |
| self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list} | |
| logger.info(f"loading annotations from {pexels_config.json_path} ...") | |
| with open(pexels_config.json_path, 'rb') as json_file: | |
| json_list = list(json_file) | |
| dataset = [json.loads(json_str) for json_str in json_list] | |
| for data in dataset: | |
| data['text'] = self.caption_dict[data['id']] | |
| data['dataset'] = 'pexels' | |
| if data['height'] / data['width'] < 0.625: | |
| self.dataset.append(data) | |
| if webvid_config.enable: | |
| logger.info(f"loading webvid dataset") | |
| logger.info(f"loading annotations from {webvid_config.json_path} ...") | |
| with open(webvid_config.json_path, 'rb') as json_file: | |
| json_list = list(json_file) | |
| dataset = [json.loads(json_str) for json_str in json_list] | |
| for data in dataset: | |
| data['dataset'] = 'webvid' | |
| self.dataset.extend(dataset) | |
| self.length = len(self.dataset) | |
| logger.info(f"data scale: {self.length}") | |
| self.pexels_folder = pexels_config.video_folder | |
| self.webvid_folder = webvid_config.video_folder | |
| self.sample_duration = sample_duration | |
| self.sample_fps = sample_fps | |
| self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames | |
| self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride) | |
| self.is_image = is_image | |
| sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
| self.pixel_transforms = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.Resize(sample_size[0], antialias=None), | |
| transforms.CenterCrop(sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| video_relative_path, name = video_dict['file'], video_dict['text'] | |
| if video_dict['dataset'] == 'pexels': | |
| video_folder = self.pexels_folder | |
| elif video_dict['dataset'] == 'webvid': | |
| video_folder = self.webvid_folder | |
| else: | |
| raise NotImplementedError | |
| if video_folder is not None: | |
| if video_relative_path[0] == '/': | |
| video_dir = os.path.join(video_folder, os.path.basename(video_relative_path)) | |
| else: | |
| video_dir = os.path.join(video_folder, video_relative_path) | |
| else: | |
| video_dir = video_relative_path | |
| video_reader = VideoReader(video_dir) | |
| video_length = len(video_reader) | |
| stride = None | |
| if not self.is_image: | |
| if self.sample_duration is not None: | |
| fps = video_dict['fps'] | |
| clip_length = min(video_length, math.ceil(fps * self.sample_duration)) | |
| elif self.sample_stride is not None: | |
| if isinstance(self.sample_stride, int): | |
| stride = self.sample_stride | |
| elif isinstance(self.sample_stride, tuple): | |
| stride = random.randint(self.sample_stride[0], self.sample_stride[1]) | |
| clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | |
| else: | |
| frame_difference = random.randint(2, self.sample_n_frames) | |
| if self.sample_duration is not None: | |
| fps = video_dict['fps'] | |
| sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1) | |
| elif self.sample_stride is not None: | |
| sample_stride = self.sample_stride | |
| clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = [start_idx, start_idx + clip_length - 1] | |
| pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| return pixel_values, name, stride | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| pixel_values, name, stride = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| idx = random.randint(0, self.length-1) | |
| pixel_values = self.pixel_transforms(pixel_values) | |
| sample = dict(pixel_values=pixel_values, text=name, stride=stride) | |
| return sample | |