| | from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder |
| | from ..schedulers import ContinuousODEScheduler |
| | from .base import BasePipeline |
| | import torch |
| | from tqdm import tqdm |
| | from PIL import Image |
| | import numpy as np |
| | from einops import rearrange, repeat |
| |
|
| |
|
| |
|
| | class SVDVideoPipeline(BasePipeline): |
| |
|
| | def __init__(self, device="cuda", torch_dtype=torch.float16): |
| | super().__init__(device=device, torch_dtype=torch_dtype) |
| | self.scheduler = ContinuousODEScheduler() |
| | |
| | self.image_encoder: SVDImageEncoder = None |
| | self.unet: SVDUNet = None |
| | self.vae_encoder: SVDVAEEncoder = None |
| | self.vae_decoder: SVDVAEDecoder = None |
| | |
| |
|
| | def fetch_models(self, model_manager: ModelManager): |
| | self.image_encoder = model_manager.fetch_model("svd_image_encoder") |
| | self.unet = model_manager.fetch_model("svd_unet") |
| | self.vae_encoder = model_manager.fetch_model("svd_vae_encoder") |
| | self.vae_decoder = model_manager.fetch_model("svd_vae_decoder") |
| |
|
| |
|
| | @staticmethod |
| | def from_model_manager(model_manager: ModelManager, **kwargs): |
| | pipe = SVDVideoPipeline( |
| | device=model_manager.device, |
| | torch_dtype=model_manager.torch_dtype |
| | ) |
| | pipe.fetch_models(model_manager) |
| | return pipe |
| | |
| |
|
| | def encode_image_with_clip(self, image): |
| | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) |
| | image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224)) |
| | image = (image + 1.0) / 2.0 |
| | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype) |
| | std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype) |
| | image = (image - mean) / std |
| | image_emb = self.image_encoder(image) |
| | return image_emb |
| | |
| |
|
| | def encode_image_with_vae(self, image, noise_aug_strength, seed=None): |
| | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) |
| | noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype) |
| | image = image + noise_aug_strength * noise |
| | image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor |
| | return image_emb |
| | |
| |
|
| | def encode_video_with_vae(self, video): |
| | video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0) |
| | video = rearrange(video, "T C H W -> 1 C T H W") |
| | video = video.to(device=self.device, dtype=self.torch_dtype) |
| | latents = self.vae_encoder.encode_video(video) |
| | latents = rearrange(latents[0], "C T H W -> T C H W") |
| | return latents |
| | |
| |
|
| | def tensor2video(self, frames): |
| | frames = rearrange(frames, "C T H W -> T H W C") |
| | frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) |
| | frames = [Image.fromarray(frame) for frame in frames] |
| | return frames |
| | |
| |
|
| | def calculate_noise_pred( |
| | self, |
| | latents, |
| | timestep, |
| | add_time_id, |
| | cfg_scales, |
| | image_emb_vae_posi, image_emb_clip_posi, |
| | image_emb_vae_nega, image_emb_clip_nega |
| | ): |
| | |
| | noise_pred_posi = self.unet( |
| | torch.cat([latents, image_emb_vae_posi], dim=1), |
| | timestep, image_emb_clip_posi, add_time_id |
| | ) |
| | |
| | noise_pred_nega = self.unet( |
| | torch.cat([latents, image_emb_vae_nega], dim=1), |
| | timestep, image_emb_clip_nega, add_time_id |
| | ) |
| |
|
| | |
| | noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega) |
| |
|
| | return noise_pred |
| | |
| |
|
| | def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0): |
| | if post_normalize: |
| | mean, std = latents.mean(), latents.std() |
| | latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean |
| | latents = latents * contrast_enhance_scale |
| | return latents |
| |
|
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | input_image=None, |
| | input_video=None, |
| | mask_frames=[], |
| | mask_frame_ids=[], |
| | min_cfg_scale=1.0, |
| | max_cfg_scale=3.0, |
| | denoising_strength=1.0, |
| | num_frames=25, |
| | height=576, |
| | width=1024, |
| | fps=7, |
| | motion_bucket_id=127, |
| | noise_aug_strength=0.02, |
| | num_inference_steps=20, |
| | post_normalize=True, |
| | contrast_enhance_scale=1.2, |
| | seed=None, |
| | progress_bar_cmd=tqdm, |
| | progress_bar_st=None, |
| | ): |
| | height, width = self.check_resize_height_width(height, width) |
| | |
| | |
| | self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) |
| |
|
| | |
| | noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) |
| | if denoising_strength == 1.0: |
| | latents = noise.clone() |
| | else: |
| | latents = self.encode_video_with_vae(input_video) |
| | latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0]) |
| |
|
| | |
| | if len(mask_frames) > 0: |
| | mask_latents = self.encode_video_with_vae(mask_frames) |
| |
|
| | |
| | image_emb_clip_posi = self.encode_image_with_clip(input_image) |
| | image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi) |
| | image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames) |
| | image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi) |
| |
|
| | |
| | cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames) |
| | cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype) |
| | |
| | |
| | add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device) |
| |
|
| | |
| | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): |
| |
|
| | |
| | for frame_id, mask_frame_id in enumerate(mask_frame_ids): |
| | latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep) |
| |
|
| | |
| | noise_pred = self.calculate_noise_pred( |
| | latents, timestep, add_time_id, cfg_scales, |
| | image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega |
| | ) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, timestep, latents) |
| | |
| | |
| | if progress_bar_st is not None: |
| | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) |
| |
|
| | |
| | latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale) |
| | video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd) |
| | video = self.tensor2video(video) |
| |
|
| | return video |
| |
|
| |
|
| |
|
| | class SVDCLIPImageProcessor: |
| | def __init__(self): |
| | pass |
| |
|
| | def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True): |
| | h, w = input.shape[-2:] |
| | factors = (h / size[0], w / size[1]) |
| |
|
| | |
| | |
| | sigmas = ( |
| | max((factors[0] - 1.0) / 2.0, 0.001), |
| | max((factors[1] - 1.0) / 2.0, 0.001), |
| | ) |
| |
|
| | |
| | |
| | |
| | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) |
| |
|
| | |
| | if (ks[0] % 2) == 0: |
| | ks = ks[0] + 1, ks[1] |
| |
|
| | if (ks[1] % 2) == 0: |
| | ks = ks[0], ks[1] + 1 |
| |
|
| | input = self._gaussian_blur2d(input, ks, sigmas) |
| |
|
| | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) |
| | return output |
| |
|
| |
|
| | def _compute_padding(self, kernel_size): |
| | """Compute padding tuple.""" |
| | |
| | |
| | if len(kernel_size) < 2: |
| | raise AssertionError(kernel_size) |
| | computed = [k - 1 for k in kernel_size] |
| |
|
| | |
| | out_padding = 2 * len(kernel_size) * [0] |
| |
|
| | for i in range(len(kernel_size)): |
| | computed_tmp = computed[-(i + 1)] |
| |
|
| | pad_front = computed_tmp // 2 |
| | pad_rear = computed_tmp - pad_front |
| |
|
| | out_padding[2 * i + 0] = pad_front |
| | out_padding[2 * i + 1] = pad_rear |
| |
|
| | return out_padding |
| |
|
| |
|
| | def _filter2d(self, input, kernel): |
| | |
| | b, c, h, w = input.shape |
| | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) |
| |
|
| | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) |
| |
|
| | height, width = tmp_kernel.shape[-2:] |
| |
|
| | padding_shape: list[int] = self._compute_padding([height, width]) |
| | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") |
| |
|
| | |
| | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) |
| | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) |
| |
|
| | |
| | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) |
| |
|
| | out = output.view(b, c, h, w) |
| | return out |
| |
|
| |
|
| | def _gaussian(self, window_size: int, sigma): |
| | if isinstance(sigma, float): |
| | sigma = torch.tensor([[sigma]]) |
| |
|
| | batch_size = sigma.shape[0] |
| |
|
| | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) |
| |
|
| | if window_size % 2 == 0: |
| | x = x + 0.5 |
| |
|
| | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) |
| |
|
| | return gauss / gauss.sum(-1, keepdim=True) |
| |
|
| |
|
| | def _gaussian_blur2d(self, input, kernel_size, sigma): |
| | if isinstance(sigma, tuple): |
| | sigma = torch.tensor([sigma], dtype=input.dtype) |
| | else: |
| | sigma = sigma.to(dtype=input.dtype) |
| |
|
| | ky, kx = int(kernel_size[0]), int(kernel_size[1]) |
| | bs = sigma.shape[0] |
| | kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1)) |
| | kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1)) |
| | out_x = self._filter2d(input, kernel_x[..., None, :]) |
| | out = self._filter2d(out_x, kernel_y[..., None]) |
| |
|
| | return out |
| |
|