Spaces:
Runtime error
Runtime error
| from typing import Union, List | |
| import PIL | |
| from PIL import Image | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| import torch | |
| import torchvision | |
| from torchvision.transforms import ToPILImage | |
| from einops import repeat | |
| from diffusers import AutoencoderKLCogVideoX | |
| from diffusers import CogVideoXDDIMScheduler | |
| from .model.dit import DiffusionTransformer3D | |
| from .model.text_embedders import T5TextEmbedder | |
| def predict_x_0(noise_scheduler, model_output, timesteps, sample, device): | |
| init_alpha_device = noise_scheduler.alphas_cumprod.device | |
| alphas = noise_scheduler.alphas_cumprod.to(device) | |
| alpha_prod_t = alphas[timesteps][:, None, None, None] | |
| beta_prod_t = 1 - alpha_prod_t | |
| pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output | |
| noise_scheduler.alphas_cumprod.to(init_alpha_device) | |
| return pred_original_sample | |
| def get_velocity( | |
| model, x, t, text_embed, visual_cu_seqlens, text_cu_seqlens, | |
| num_goups=(1, 1, 1), scale_factor=(1., 1., 1.) | |
| ): | |
| pred_velocity = model(x, text_embed, t, visual_cu_seqlens, text_cu_seqlens, num_goups, scale_factor) | |
| return pred_velocity | |
| def diffusion_generate_renoise( | |
| model, noise_scheduler, shape, device, num_steps, text_embed, visual_cu_seqlens, text_cu_seqlens, | |
| num_goups=(1, 1, 1), scale_factor=(1., 1., 1.), progress=False, seed=6554 | |
| ): | |
| generator = torch.Generator() | |
| if seed is not None: | |
| generator.manual_seed(seed) | |
| img = torch.randn(*shape, generator=generator).to(torch.bfloat16).to(device) | |
| noise_scheduler.set_timesteps(num_steps, device=device) | |
| timesteps = noise_scheduler.timesteps | |
| if progress: | |
| timesteps = tqdm(timesteps) | |
| for time in timesteps: | |
| model_time = time.unsqueeze(0).repeat(visual_cu_seqlens.shape[0] - 1) | |
| noise = torch.randn(img.shape, generator=generator).to(torch.bfloat16).to(device) | |
| img = noise_scheduler.add_noise(img, noise, time) | |
| pred_velocity = get_velocity( | |
| model, img.to(torch.bfloat16), model_time, | |
| text_embed.to(torch.bfloat16), visual_cu_seqlens, | |
| text_cu_seqlens, num_goups, scale_factor | |
| ) | |
| img = predict_x_0(noise_scheduler=noise_scheduler, model_output=pred_velocity.to(device), timesteps=model_time.to(device), sample=img.to(device), device=device) | |
| return img | |
| class Kandinsky4T2VPipeline: | |
| def __init__( | |
| self, | |
| device_map: Union[str, torch.device, dict], # {"dit": cuda:0, "vae": cuda:1, "text_embedder": cuda:1 } | |
| dit: DiffusionTransformer3D, | |
| text_embedder: T5TextEmbedder, | |
| vae: AutoencoderKLCogVideoX, | |
| noise_scheduler: CogVideoXDDIMScheduler, # TODO base class | |
| resolution: int = 512, | |
| local_dit_rank=0, | |
| world_size=1, | |
| ): | |
| if resolution not in [512]: | |
| raise ValueError("Resolution can be only 512") | |
| self.dit = dit | |
| self.noise_scheduler = noise_scheduler | |
| self.text_embedder = text_embedder | |
| self.vae = vae | |
| self.resolution = resolution | |
| self.device_map = device_map | |
| self.local_dit_rank = local_dit_rank | |
| self.world_size = world_size | |
| self.RESOLUTIONS = { | |
| 512: [(512, 512), (352, 736), (736, 352), (384, 672), (672, 384), (480, 544), (544, 480)], | |
| } | |
| def __call__( | |
| self, | |
| text: str, | |
| save_path: str = "./test.mp4", | |
| bs: int = 1, | |
| time_length: int = 12, # time in seconds 0 if you want generate image | |
| width: int = 512, | |
| height: int = 512, | |
| seed: int = None, | |
| return_frames: bool = False | |
| ): | |
| num_steps = 4 | |
| # SEED | |
| if seed is None: | |
| if self.local_dit_rank == 0: | |
| seed = torch.randint(2 ** 63 - 1, (1,)).to(self.local_dit_rank) | |
| else: | |
| seed = torch.empty((1,), dtype=torch.int64).to(self.local_dit_rank) | |
| if self.world_size > 1: | |
| torch.distributed.broadcast(seed, 0) | |
| seed = seed.item() | |
| assert bs == 1 | |
| if self.resolution != 512: | |
| raise NotImplementedError(f"Only 512 resolution is available for now") | |
| if (height, width) not in self.RESOLUTIONS[self.resolution]: | |
| raise ValueError(f"Wrong height, width pair. Available (height, width) are: {self.RESOLUTIONS[self.resolution]}") | |
| if num_steps != 4: | |
| raise NotImplementedError(f"In the distilled version number of steps have to be strictly equal to 4") | |
| # PREPARATION | |
| num_frames = 1 if time_length == 0 else time_length * 8 // 4 + 1 | |
| num_groups = (1, 1, 1) if self.resolution == 512 else (1, 2, 2) | |
| scale_factor = (1., 1., 1.) if self.resolution == 512 else (1., 2., 2.) | |
| # TEXT EMBEDDER | |
| if self.local_dit_rank == 0: | |
| with torch.no_grad(): | |
| text_embed = self.text_embedder(text).squeeze(0).to(self.local_dit_rank, dtype=torch.bfloat16) | |
| else: | |
| text_embed = torch.empty(224, 4096, dtype=torch.bfloat16).to(self.local_dit_rank) | |
| if self.world_size > 1: | |
| torch.distributed.broadcast(text_embed, 0) | |
| torch.cuda.empty_cache() | |
| visual_cu_seqlens = num_frames * torch.arange(bs + 1, dtype=torch.int32, device=self.device_map["dit"]) | |
| text_cu_seqlens = text_embed.shape[0] * torch.arange(bs + 1, dtype=torch.int32, device=self.device_map["dit"]) | |
| bs_text_embed = text_embed.repeat(bs, 1).to(self.device_map["dit"]) | |
| shape = (bs * num_frames, height // 8, width // 8, 16) | |
| # DIT | |
| with torch.no_grad(): | |
| with torch.autocast(device_type='cuda', dtype=torch.bfloat16): | |
| images = diffusion_generate_renoise( | |
| self.dit, self.noise_scheduler, shape, self.device_map["dit"], | |
| num_steps, bs_text_embed, visual_cu_seqlens, text_cu_seqlens, | |
| num_groups, scale_factor, progress=True, seed=seed, | |
| ) | |
| torch.cuda.empty_cache() | |
| # VAE | |
| if self.local_dit_rank == 0: | |
| self.vae.num_latent_frames_batch_size = 1 if time_length == 0 else 2 | |
| with torch.no_grad(): | |
| images = 1 / self.vae.config.scaling_factor * images.to(device=self.device_map["vae"], dtype=torch.bfloat16) | |
| images = images.permute(0, 3, 1, 2) if time_length == 0 else images.permute(3, 0, 1, 2) | |
| images = self.vae.decode(images.unsqueeze(2 if time_length == 0 else 0)).sample.float() | |
| images = torch.clip((images + 1.) / 2., 0., 1.) | |
| torch.cuda.empty_cache() | |
| if self.local_dit_rank == 0: | |
| # RESULTS | |
| if time_length == 0: | |
| return_images = [] | |
| for i, image in enumerate(images.squeeze(2).cpu()): | |
| return_images.append(ToPILImage()(image)) | |
| return return_images | |
| else: | |
| if return_frames: | |
| return_images = [] | |
| for i, image in enumerate(images.squeeze(0).float().permute(1, 0, 2, 3).cpu()): | |
| return_images.append(ToPILImage()(image)) | |
| return return_images | |
| else: | |
| torchvision.io.write_video(save_path, 255. * images.squeeze(0).float().permute(1, 2, 3, 0).cpu().numpy(), fps=8, options = {"crf": "5"}) | |