Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Any, Union, Optional | |
| import torch | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| import datasets | |
| from diffusers import DDPMScheduler | |
| from functools import partial | |
| import random | |
| import numpy as np | |
| def collate_fn( | |
| batch: List[Dict[str, Any]], | |
| noise_scheduler: DDPMScheduler, | |
| num_frames: int, | |
| hint_spacing: Optional[int] = None, | |
| as_numpy: bool = True | |
| ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: | |
| if hint_spacing is None or hint_spacing < 1: | |
| hint_spacing = num_frames | |
| if as_numpy: | |
| dtype = np.float32 | |
| else: | |
| dtype = torch.float32 | |
| prompts = [] | |
| videos = [] | |
| for s in batch: | |
| # prompt | |
| prompts.append(torch.tensor(s['prompt']).to(dtype = torch.float32)) | |
| # frames | |
| frames = torch.tensor(s['video']).to(dtype = torch.float32) | |
| max_frames = len(frames) | |
| assert max_frames >= num_frames | |
| video_slice = random.randint(0, max_frames - num_frames) | |
| frames = frames[video_slice:video_slice + num_frames] | |
| frames = frames.permute(1, 0, 2, 3) # f, c, h, w -> c, f, h, w | |
| videos.append(frames) | |
| encoder_hidden_states = torch.cat(prompts) # b, 77, 768 | |
| latents = torch.stack(videos) # b, c, f, h, w | |
| latents = latents * 0.18215 | |
| hint_latents = latents[:, :, ::hint_spacing, :, :] | |
| hint_latents = hint_latents.repeat_interleave(hint_spacing, 2) | |
| #hint_latents = hint_latents[:, :, :num_frames-1, :, :] | |
| #input_latents = latents[:, :, 1:, :, :] | |
| input_latents = latents | |
| noise = torch.randn_like(input_latents) | |
| bsz = input_latents.shape[0] | |
| timesteps = torch.randint( | |
| 0, | |
| noise_scheduler.config.num_train_timesteps, | |
| (bsz,), | |
| dtype = torch.int64 | |
| ) | |
| noisy_latents = noise_scheduler.add_noise(input_latents, noise, timesteps) | |
| mask = torch.zeros([ | |
| noisy_latents.shape[0], | |
| 1, | |
| noisy_latents.shape[2], | |
| noisy_latents.shape[3], | |
| noisy_latents.shape[4] | |
| ]) | |
| latent_model_input = torch.cat([noisy_latents, mask, hint_latents], dim = 1) | |
| latent_model_input = latent_model_input.to(memory_format = torch.contiguous_format) | |
| encoder_hidden_states = encoder_hidden_states.to(memory_format = torch.contiguous_format) | |
| timesteps = timesteps.to(memory_format = torch.contiguous_format) | |
| noise = noise.to(memory_format = torch.contiguous_format) | |
| if as_numpy: | |
| latent_model_input = latent_model_input.numpy().astype(dtype) | |
| encoder_hidden_states = encoder_hidden_states.numpy().astype(dtype) | |
| timesteps = timesteps.numpy().astype(np.int32) | |
| noise = noise.numpy().astype(dtype) | |
| else: | |
| latent_model_input = latent_model_input.to(dtype = dtype) | |
| encoder_hidden_states = encoder_hidden_states.to(dtype = dtype) | |
| noise = noise.to(dtype = dtype) | |
| return { | |
| 'latent_model_input': latent_model_input, | |
| 'encoder_hidden_states': encoder_hidden_states, | |
| 'timesteps': timesteps, | |
| 'noise': noise | |
| } | |
| def worker_init_fn(worker_id: int): | |
| wseed = torch.initial_seed() % 4294967294 # max val for random 2**32 - 1 | |
| random.seed(wseed) | |
| np.random.seed(wseed) | |
| def load_dataset( | |
| dataset_path: str, | |
| model_path: str, | |
| cache_dir: Optional[str] = None, | |
| batch_size: int = 1, | |
| num_frames: int = 24, | |
| hint_spacing: Optional[int] = None, | |
| num_workers: int = 0, | |
| shuffle: bool = False, | |
| as_numpy: bool = True, | |
| pin_memory: bool = False, | |
| pin_memory_device: str = '' | |
| ) -> DataLoader: | |
| noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( | |
| model_path, | |
| subfolder = 'scheduler' | |
| ) | |
| dataset = datasets.load_dataset( | |
| dataset_path, | |
| streaming = False, | |
| cache_dir = cache_dir | |
| ) | |
| merged_dataset = ConcatDataset([ dataset[s] for s in dataset ]) | |
| dataloader = DataLoader( | |
| merged_dataset, | |
| batch_size = batch_size, | |
| num_workers = num_workers, | |
| persistent_workers = num_workers > 0, | |
| drop_last = True, | |
| shuffle = shuffle, | |
| worker_init_fn = worker_init_fn, | |
| collate_fn = partial(collate_fn, | |
| noise_scheduler = noise_scheduler, | |
| num_frames = num_frames, | |
| hint_spacing = hint_spacing, | |
| as_numpy = as_numpy | |
| ), | |
| pin_memory = pin_memory, | |
| pin_memory_device = pin_memory_device | |
| ) | |
| return dataloader | |
| def validate_dataset( | |
| dataset_path: str | |
| ) -> List[int]: | |
| import os | |
| import json | |
| data_path = os.path.join(dataset_path, 'data') | |
| meta = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'metadata'))) | |
| prompts = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'prompts'))) | |
| videos = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'videos'))) | |
| ok = meta.intersection(prompts).intersection(videos) | |
| all_of_em = meta.union(prompts).union(videos) | |
| not_ok = [] | |
| for a in all_of_em: | |
| if a not in ok: | |
| not_ok.append(a) | |
| ok = list(ok) | |
| ok.sort() | |
| with open(os.path.join(data_path, 'id_list.json'), 'w') as f: | |
| json.dump(ok, f) | |