Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import bisect | |
| import pandas as pd | |
| import omegaconf | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from decord import VideoReader, cpu | |
| import torchvision.transforms._transforms_video as transforms_video | |
| class WebVid(Dataset): | |
| """ | |
| WebVid Dataset. | |
| Assumes webvid data is structured as follows. | |
| Webvid/ | |
| videos/ | |
| 000001_000050/ ($page_dir) | |
| 1.mp4 (videoid.mp4) | |
| ... | |
| 5000.mp4 | |
| ... | |
| """ | |
| def __init__(self, | |
| meta_path, | |
| data_dir, | |
| subsample=None, | |
| video_length=16, | |
| resolution=[256, 512], | |
| frame_stride=1, | |
| spatial_transform=None, | |
| crop_resolution=None, | |
| fps_max=None, | |
| load_raw_resolution=False, | |
| fps_schedule=None, | |
| fs_probs=None, | |
| bs_per_gpu=None, | |
| trigger_word='', | |
| dataname='', | |
| ): | |
| self.meta_path = meta_path | |
| self.data_dir = data_dir | |
| self.subsample = subsample | |
| self.video_length = video_length | |
| self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution | |
| self.frame_stride = frame_stride | |
| self.fps_max = fps_max | |
| self.load_raw_resolution = load_raw_resolution | |
| self.fs_probs = fs_probs | |
| self.trigger_word = trigger_word | |
| self.dataname = dataname | |
| self._load_metadata() | |
| if spatial_transform is not None: | |
| if spatial_transform == "random_crop": | |
| self.spatial_transform = transforms_video.RandomCropVideo(crop_resolution) | |
| elif spatial_transform == "resize_center_crop": | |
| assert(self.resolution[0] == self.resolution[1]) | |
| self.spatial_transform = transforms.Compose([ | |
| transforms.Resize(resolution), | |
| transforms_video.CenterCropVideo(resolution), | |
| ]) | |
| else: | |
| raise NotImplementedError | |
| else: | |
| self.spatial_transform = None | |
| self.fps_schedule = fps_schedule | |
| self.bs_per_gpu = bs_per_gpu | |
| if self.fps_schedule is not None: | |
| assert(self.bs_per_gpu is not None) | |
| self.counter = 0 | |
| self.stage_idx = 0 | |
| def _load_metadata(self): | |
| metadata = pd.read_csv(self.meta_path) | |
| if self.subsample is not None: | |
| metadata = metadata.sample(self.subsample, random_state=0) | |
| metadata['caption'] = metadata['name'] | |
| del metadata['name'] | |
| self.metadata = metadata | |
| self.metadata.dropna(inplace=True) | |
| # self.metadata['caption'] = self.metadata['caption'].str[:350] | |
| def _get_video_path(self, sample): | |
| if self.dataname == "loradata": | |
| rel_video_fp = str(sample['videoid']) + '.mp4' | |
| full_video_fp = os.path.join(self.data_dir, rel_video_fp) | |
| else: | |
| rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') | |
| full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) | |
| return full_video_fp, rel_video_fp | |
| def get_fs_based_on_schedule(self, frame_strides, schedule): | |
| assert(len(frame_strides) == len(schedule) + 1) # nstage=len_fps_schedule + 1 | |
| global_step = self.counter // self.bs_per_gpu # TODO: support resume. | |
| stage_idx = bisect.bisect(schedule, global_step) | |
| frame_stride = frame_strides[stage_idx] | |
| # log stage change | |
| if stage_idx != self.stage_idx: | |
| print(f'fps stage: {stage_idx} start ... new frame stride = {frame_stride}') | |
| self.stage_idx = stage_idx | |
| return frame_stride | |
| def get_fs_based_on_probs(self, frame_strides, probs): | |
| assert(len(frame_strides) == len(probs)) | |
| return random.choices(frame_strides, weights=probs)[0] | |
| def get_fs_randomly(self, frame_strides): | |
| return random.choice(frame_strides) | |
| def __getitem__(self, index): | |
| if isinstance(self.frame_stride, list) or isinstance(self.frame_stride, omegaconf.listconfig.ListConfig): | |
| if self.fps_schedule is not None: | |
| frame_stride = self.get_fs_based_on_schedule(self.frame_stride, self.fps_schedule) | |
| elif self.fs_probs is not None: | |
| frame_stride = self.get_fs_based_on_probs(self.frame_stride, self.fs_probs) | |
| else: | |
| frame_stride = self.get_fs_randomly(self.frame_stride) | |
| else: | |
| frame_stride = self.frame_stride | |
| assert(isinstance(frame_stride, int)), type(frame_stride) | |
| while True: | |
| index = index % len(self.metadata) | |
| sample = self.metadata.iloc[index] | |
| video_path, rel_fp = self._get_video_path(sample) | |
| caption = sample['caption']+self.trigger_word | |
| # make reader | |
| try: | |
| if self.load_raw_resolution: | |
| video_reader = VideoReader(video_path, ctx=cpu(0)) | |
| else: | |
| video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0]) | |
| if len(video_reader) < self.video_length: | |
| print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})") | |
| index += 1 | |
| continue | |
| else: | |
| pass | |
| except: | |
| index += 1 | |
| print(f"Load video failed! path = {video_path}") | |
| continue | |
| # sample strided frames | |
| all_frames = list(range(0, len(video_reader), frame_stride)) | |
| if len(all_frames) < self.video_length: # recal a max fs | |
| frame_stride = len(video_reader) // self.video_length | |
| assert(frame_stride != 0) | |
| all_frames = list(range(0, len(video_reader), frame_stride)) | |
| # select a random clip | |
| rand_idx = random.randint(0, len(all_frames) - self.video_length) | |
| frame_indices = all_frames[rand_idx:rand_idx+self.video_length] | |
| try: | |
| frames = video_reader.get_batch(frame_indices) | |
| break | |
| except: | |
| print(f"Get frames failed! path = {video_path}") | |
| index += 1 | |
| continue | |
| assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' | |
| frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] | |
| if self.spatial_transform is not None: | |
| frames = self.spatial_transform(frames) | |
| if self.resolution is not None: | |
| assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' | |
| frames = (frames / 255 - 0.5) * 2 | |
| fps_ori = video_reader.get_avg_fps() | |
| fps_clip = fps_ori // frame_stride | |
| if self.fps_max is not None and fps_clip > self.fps_max: | |
| fps_clip = self.fps_max | |
| data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride} | |
| if self.fps_schedule is not None: | |
| self.counter += 1 | |
| return data | |
| def __len__(self): | |
| return len(self.metadata) | |