Spaces:
Build error
Build error
| import torchvision.io | |
| from einops import rearrange, repeat | |
| import numpy as np | |
| import inspect | |
| from typing import List, Optional, Union, Tuple | |
| import os | |
| import PIL | |
| import torch | |
| import torchaudio | |
| import torchvision.io | |
| import torchvision.transforms as transforms | |
| from transformers import ImageProcessingMixin | |
| from diffusers.loaders import TextualInversionLoaderMixin | |
| from diffusers.models import AutoencoderKL | |
| from diffusers.schedulers import KarrasDiffusionSchedulers, PNDMScheduler | |
| from diffusers.utils import logging | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.image_processor import VaeImageProcessor | |
| from unet import AudioUNet3DConditionModel | |
| from audio_encoder import ImageBindSegmaskAudioEncoder | |
| from imagebind.data import waveform2melspec | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def waveform_to_melspectrogram( | |
| waveform: Union[np.ndarray, torch.Tensor], | |
| num_mel_bins=128, | |
| target_length=204, | |
| sample_rate=16000, | |
| clip_duration=2., | |
| mean=-4.268, | |
| std=9.138 | |
| ): | |
| if isinstance(waveform, np.ndarray): | |
| waveform = torch.from_numpy(waveform) | |
| audio_length = waveform.shape[1] | |
| audio_target_length = int(clip_duration * sample_rate) | |
| audio_start_idx = 0 | |
| if audio_length > audio_target_length: | |
| audio_start_idx = (audio_length - audio_target_length) // 2 | |
| audio_end_idx = audio_start_idx + audio_target_length | |
| waveform_clip = waveform[:, audio_start_idx:audio_end_idx] | |
| waveform_melspec = waveform2melspec( | |
| waveform_clip, sample_rate, num_mel_bins, target_length | |
| ) # (1, n_mel, n_frame) | |
| normalize = transforms.Normalize(mean=mean, std=std) | |
| audio_clip = normalize(waveform_melspec) | |
| return audio_clip # (1, freq, time) | |
| class AudioMelspectrogramExtractor(ImageProcessingMixin): | |
| def __init__( | |
| self, | |
| num_mel_bins=128, | |
| target_length=204, | |
| sample_rate=16000, | |
| clip_duration=2, | |
| mean=-4.268, | |
| std=9.138 | |
| ): | |
| super().__init__() | |
| self.num_mel_bins = num_mel_bins | |
| self.target_length = target_length | |
| self.sample_rate = sample_rate | |
| self.clip_duration = clip_duration | |
| self.mean = mean | |
| self.std = std | |
| def max_length_s(self) -> int: | |
| return self.clip_duration | |
| def sampling_rate(self) -> int: | |
| return self.sample_rate | |
| def __call__( | |
| self, | |
| waveforms: Union[ | |
| np.ndarray, | |
| torch.Tensor, | |
| List[np.ndarray], | |
| List[torch.Tensor] | |
| ] | |
| ): | |
| if isinstance(waveforms, (np.ndarray, torch.Tensor)) and waveforms.ndim == 2: | |
| waveforms = [waveforms, ] | |
| features = [] | |
| for waveform in waveforms: | |
| feature = waveform_to_melspectrogram( | |
| waveform=waveform, | |
| num_mel_bins=self.num_mel_bins, | |
| target_length=self.target_length, | |
| sample_rate=self.sample_rate, | |
| clip_duration=self.clip_duration, | |
| mean=self.mean, | |
| std=self.std | |
| ) | |
| features.append(feature) | |
| features = torch.stack(features, dim=0) | |
| return features # (b c n t) | |
| class AudioCondAnimationPipeline(DiffusionPipeline, TextualInversionLoaderMixin): | |
| """ | |
| Pipeline for text-guided image to image generation using stable unCLIP. | |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
| library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
| Args: | |
| feature_extractor ([`CLIPImageProcessor`]): | |
| Feature extractor for image pre-processing before being encoded. | |
| unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | |
| scheduler ([`KarrasDiffusionSchedulers`]): | |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. | |
| vae ([`AutoencoderKL`]): | |
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
| """ | |
| unet: AudioUNet3DConditionModel | |
| scheduler: KarrasDiffusionSchedulers | |
| vae: AutoencoderKL | |
| audio_encoder: ImageBindSegmaskAudioEncoder | |
| def __init__( | |
| self, | |
| unet: AudioUNet3DConditionModel, | |
| scheduler: KarrasDiffusionSchedulers, | |
| vae: AutoencoderKL, | |
| audio_encoder: ImageBindSegmaskAudioEncoder, | |
| null_text_encodings_path: str = "" | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| unet=unet, | |
| scheduler=scheduler, | |
| vae=vae, | |
| audio_encoder=audio_encoder | |
| ) | |
| if null_text_encodings_path: | |
| self.null_text_encoding = torch.load(null_text_encodings_path).view(1, 77, 768) | |
| self.melspectrogram_shape = (128, 204) | |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | |
| self.audio_processor = AudioMelspectrogramExtractor() | |
| def encode_text( | |
| self, | |
| text_encodings, | |
| device, | |
| dtype, | |
| do_text_classifier_free_guidance, | |
| do_audio_classifier_free_guidance, | |
| ): | |
| if isinstance(text_encodings, (List, Tuple)): | |
| text_encodings = torch.cat(text_encodings) | |
| text_encodings = text_encodings.to(dtype=dtype, device=device) | |
| batch_size = len(text_encodings) | |
| # get unconditional embeddings for classifier free guidance | |
| if do_text_classifier_free_guidance: | |
| if not hasattr(self, "null_text_encoding"): | |
| uncond_token = "" | |
| max_length = text_encodings.shape[1] | |
| uncond_input = self.tokenizer( | |
| uncond_token, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| if hasattr(self.text_encoder.config, | |
| "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
| attention_mask = uncond_input.attention_mask.to(device) | |
| else: | |
| attention_mask = None | |
| uncond_text_encodings = self.text_encoder( | |
| uncond_input.input_ids.to(device), | |
| attention_mask=attention_mask, | |
| ) | |
| uncond_text_encodings = uncond_text_encodings[0] | |
| else: | |
| uncond_text_encodings = self.null_text_encoding | |
| uncond_text_encodings = repeat(uncond_text_encodings, "1 n d -> b n d", b=batch_size).contiguous() | |
| uncond_text_encodings = uncond_text_encodings.to(dtype=dtype, device=device) | |
| if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg | |
| text_encodings = torch.cat([uncond_text_encodings, text_encodings, text_encodings]) | |
| elif do_text_classifier_free_guidance: # only text cfg | |
| text_encodings = torch.cat([uncond_text_encodings, text_encodings]) | |
| elif do_audio_classifier_free_guidance: # only audio cfg | |
| text_encodings = torch.cat([text_encodings, text_encodings]) | |
| return text_encodings | |
| def encode_audio( | |
| self, | |
| audios: Union[List[np.ndarray], List[torch.Tensor]], | |
| video_length: int = 12, | |
| do_text_classifier_free_guidance: bool = False, | |
| do_audio_classifier_free_guidance: bool = False, | |
| device: torch.device = torch.device("cuda:0"), | |
| dtype: torch.dtype = torch.float32 | |
| ): | |
| batch_size = len(audios) | |
| melspectrograms = self.audio_processor(audios).to(device=device, dtype=dtype) # (b c n t) | |
| # audio_encodings: (b, n, c) | |
| # audio_masks: (b, s, n) | |
| _, audio_encodings, audio_masks = self.audio_encoder( | |
| melspectrograms, normalize=False, return_dict=False | |
| ) | |
| audio_encodings = repeat(audio_encodings, "b n c -> b f n c", f=video_length) | |
| if do_audio_classifier_free_guidance: | |
| null_melspectrograms = torch.zeros(1, 1, *self.melspectrogram_shape).to(device=device, dtype=dtype) | |
| _, null_audio_encodings, null_audio_masks = self.audio_encoder( | |
| null_melspectrograms, normalize=False, return_dict=False | |
| ) | |
| null_audio_encodings = repeat(null_audio_encodings, "1 n c -> b f n c", b=batch_size, f=video_length) | |
| if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg | |
| audio_encodings = torch.cat([null_audio_encodings, null_audio_encodings, audio_encodings]) | |
| audio_masks = torch.cat([null_audio_masks, null_audio_masks, audio_masks]) | |
| elif do_text_classifier_free_guidance: # only text cfg | |
| audio_encodings = torch.cat([audio_encodings, audio_encodings]) | |
| audio_masks = torch.cat([audio_masks, audio_masks]) | |
| elif do_audio_classifier_free_guidance: # only audio cfg | |
| audio_encodings = torch.cat([null_audio_encodings, audio_encodings]) | |
| audio_masks = torch.cat([null_audio_masks, audio_masks]) | |
| return audio_encodings, audio_masks | |
| def encode_latents(self, image: torch.Tensor): | |
| dtype = self.vae.dtype | |
| image = image.to(device=self.device, dtype=dtype) | |
| image_latents = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor | |
| return image_latents | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents | |
| def decode_latents(self, latents): | |
| dtype = next(self.vae.parameters()).dtype | |
| latents = latents.to(dtype=dtype) | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| image = self.vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1).cpu().float() # ((b t) c h w) | |
| return image | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | |
| def prepare_extra_step_kwargs(self, generator, eta): | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
| # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
| # and should be between [0, 1] | |
| accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
| extra_step_kwargs = {} | |
| if accepts_eta: | |
| extra_step_kwargs["eta"] = eta | |
| # check if the scheduler accepts generator | |
| accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
| if accepts_generator: | |
| extra_step_kwargs["generator"] = generator | |
| return extra_step_kwargs | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | |
| def prepare_video_latents( | |
| self, | |
| image_latents: torch.Tensor, | |
| num_channels_latents: int, | |
| video_length: int = 12, | |
| height: int = 256, | |
| width: int = 256, | |
| device: torch.device = torch.device("cuda"), | |
| dtype: torch.dtype = torch.float32, | |
| generator: Optional[torch.Generator] = None, | |
| ): | |
| batch_size = len(image_latents) | |
| shape = ( | |
| batch_size, | |
| num_channels_latents, | |
| video_length - 1, | |
| height // self.vae_scale_factor, | |
| width // self.vae_scale_factor | |
| ) | |
| image_latents = image_latents.unsqueeze(2) # (b c 1 h w) | |
| rand_noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| noise_latents = torch.cat([image_latents, rand_noise], dim=2) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| noise_latents = noise_latents * self.scheduler.init_noise_sigma | |
| return noise_latents | |
| def __call__( | |
| self, | |
| images: List[PIL.Image.Image], | |
| audios: Union[List[np.ndarray], List[torch.Tensor]], | |
| text_encodings: List[torch.Tensor], | |
| video_length: int = 12, | |
| height: int = 256, | |
| width: int = 256, | |
| num_inference_steps: int = 20, | |
| audio_guidance_scale: float = 4.0, | |
| text_guidance_scale: float = 1.0, | |
| generator: Optional[torch.Generator] = None, | |
| return_dict: bool = True | |
| ): | |
| # 0. Default height and width to unet | |
| device = self.device | |
| dtype = self.dtype | |
| batch_size = len(images) | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| do_text_classifier_free_guidance = (text_guidance_scale > 1.0) | |
| do_audio_classifier_free_guidance = (audio_guidance_scale > 1.0) | |
| # 1. Encoder text into ((k b) f n d) | |
| text_encodings = self.encode_text( | |
| text_encodings=text_encodings, | |
| device=device, | |
| dtype=dtype, | |
| do_text_classifier_free_guidance=do_text_classifier_free_guidance, | |
| do_audio_classifier_free_guidance=do_audio_classifier_free_guidance | |
| ) # ((k b), n, d) | |
| text_encodings = repeat(text_encodings, "b n d -> b t n d", t=video_length).to(device=device, dtype=dtype) | |
| # 2. Encode audio | |
| # audio_encodings: ((k b), n, d) | |
| # audio_masks: ((k b), s, n) | |
| audio_encodings, audio_masks = self.encode_audio( | |
| audios, video_length, do_text_classifier_free_guidance, do_audio_classifier_free_guidance, device, dtype | |
| ) | |
| # 3. Prepare image latent | |
| image = self.image_processor.preprocess(images) | |
| image_latents = self.encode_latents(image).to(device=device, dtype=dtype) # (b c h w) | |
| # 4. Prepare unet noising video latents | |
| video_latents = self.prepare_video_latents( | |
| image_latents=image_latents, | |
| num_channels_latents=self.unet.config.in_channels, | |
| video_length=video_length, | |
| height=height, | |
| width=width, | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| ) # (b c f h w) | |
| # 5. Prepare timesteps and extra step kwargs | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta=0.0) | |
| # 7. Denoising loop | |
| for i, t in enumerate(self.progress_bar(timesteps)): | |
| latent_model_input = [video_latents] | |
| if do_text_classifier_free_guidance: | |
| latent_model_input.append(video_latents) | |
| if do_audio_classifier_free_guidance: | |
| latent_model_input.append(video_latents) | |
| latent_model_input = torch.cat(latent_model_input) | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=text_encodings, | |
| audio_encoder_hidden_states=audio_encodings, | |
| audio_attention_mask=audio_masks | |
| ).sample | |
| # perform guidance | |
| if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg | |
| noise_pred_uncond, noise_pred_text, noise_pred_text_audio = noise_pred.chunk(3) | |
| noise_pred = noise_pred_uncond + \ | |
| text_guidance_scale * (noise_pred_text - noise_pred_uncond) + \ | |
| audio_guidance_scale * (noise_pred_text_audio - noise_pred_text) | |
| elif do_text_classifier_free_guidance: # only text cfg | |
| noise_pred_audio, noise_pred_text_audio = noise_pred.chunk(2) | |
| noise_pred = noise_pred_audio + \ | |
| text_guidance_scale * (noise_pred_text_audio - noise_pred_audio) | |
| elif do_audio_classifier_free_guidance: # only audio cfg | |
| noise_pred_text, noise_pred_text_audio = noise_pred.chunk(2) | |
| noise_pred = noise_pred_text + \ | |
| audio_guidance_scale * (noise_pred_text_audio - noise_pred_text) | |
| # First frame latent will always server as unchanged condition | |
| video_latents[:, :, 1:, :, :] = self.scheduler.step(noise_pred[:, :, 1:, :, :], t, | |
| video_latents[:, :, 1:, :, :], | |
| **extra_step_kwargs).prev_sample | |
| video_latents = video_latents.contiguous() | |
| # 8. Post-processing | |
| video_latents = rearrange(video_latents, "b c f h w -> (b f) c h w") | |
| videos = self.decode_latents(video_latents).detach().cpu() | |
| videos = rearrange(videos, "(b f) c h w -> b f c h w", f=video_length) # value range [0, 1] | |
| if not return_dict: | |
| return videos | |
| return {"videos": videos} | |
| def load_and_transform_images_stable_diffusion( | |
| images: Union[List[np.ndarray], torch.Tensor, np.ndarray], | |
| size=512, | |
| flip=False, | |
| randcrop=False, | |
| normalize=True | |
| ): | |
| """ | |
| @images: (List of) np.uint8 images of shape (h, w, 3) | |
| or tensor of shape (b, c, h, w) in [0., 1.0] | |
| """ | |
| assert isinstance(images, (List, torch.Tensor, np.ndarray)), type(images) | |
| if isinstance(images, List): | |
| assert isinstance(images[0], np.ndarray) | |
| assert images[0].dtype == np.uint8 | |
| assert images[0].shape[2] == 3 | |
| # convert np images into torch float tensor | |
| images = torch.from_numpy( | |
| rearrange(np.stack(images, axis=0), "f h w c -> f c h w") | |
| ).float() / 255. | |
| elif isinstance(images, np.ndarray): | |
| assert isinstance(images, np.ndarray) | |
| assert images.dtype == np.uint8 | |
| assert images.shape[3] == 3 | |
| # convert np images into torch float tensor | |
| images = torch.from_numpy( | |
| rearrange(images, "f h w c -> f c h w") | |
| ).float() / 255. | |
| assert images.shape[1] == 3 | |
| assert torch.all(images <= 1.0) and torch.all(images >= 0.0) | |
| h, w = images.shape[-2:] | |
| if isinstance(size, int): | |
| target_h, target_w = size, size | |
| else: | |
| target_h, target_w = size | |
| # first crop the image | |
| target_aspect_ratio = float(target_h) / target_w | |
| curr_aspect_ratio = float(h) / w | |
| if target_aspect_ratio >= curr_aspect_ratio: # trim w | |
| trimmed_w = int(h / target_aspect_ratio) | |
| images = images[:, :, :, (w - trimmed_w) // 2: (w - trimmed_w) // 2 + trimmed_w] | |
| else: # trim h | |
| trimmed_h = int(w * target_aspect_ratio) | |
| images = images[:, :, (h - trimmed_h) // 2: (h - trimmed_h) // 2 + trimmed_h] | |
| transform_list = [ | |
| transforms.Resize( | |
| size, | |
| interpolation=transforms.InterpolationMode.BILINEAR, | |
| antialias=True | |
| ), | |
| ] | |
| # assert not randcrop | |
| if randcrop: | |
| transform_list.append(transforms.RandomCrop(size)) | |
| else: | |
| transform_list.append(transforms.CenterCrop(size)) | |
| if flip: | |
| transform_list.append(transforms.RandomHorizontalFlip(p=1.0)) | |
| if normalize: | |
| transform_list.append(transforms.Normalize([0.5], [0.5])) | |
| data_transform = transforms.Compose(transform_list) | |
| images = data_transform(images) | |
| return images | |
| def load_image(image_path): | |
| image = PIL.Image.open(image_path).convert('RGB') | |
| width, height = image.size | |
| if width < height: | |
| new_width = 256 | |
| new_height = int((256 / width) * height) | |
| else: | |
| new_height = 256 | |
| new_width = int((256 / height) * width) | |
| # Rescale the image | |
| image = image.resize((new_width, new_height), PIL.Image.LANCZOS) | |
| # Crop a 256x256 square from the center | |
| left = (new_width - 256) / 2 | |
| top = (new_height - 256) / 2 | |
| right = (new_width + 256) / 2 | |
| bottom = (new_height + 256) / 2 | |
| image = image.crop((left, top, right, bottom)) | |
| return image | |
| def load_audio(audio_path): | |
| audio, audio_sr = torchaudio.load(audio_path) | |
| if audio.ndim == 1: audio = audio.unsqueeze(0) | |
| else: | |
| audio = audio.mean(dim=0).unsqueeze(0) | |
| audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) | |
| audio = audio[:, :32000].contiguous().float() | |
| if audio.shape[1] < 32000: | |
| audio = torch.cat([audio, torch.ones(1, 32000-audio.shape[1]).float()], dim=1) | |
| return audio.contiguous() | |
| def generate_videos( | |
| pipeline, | |
| image_path: str = '', | |
| audio_path: str = '', | |
| category_text_encoding: Optional[torch.Tensor] = None, | |
| image_size: Tuple[int, int] = (256, 256), | |
| video_fps: int = 6, | |
| video_num_frame: int = 12, | |
| audio_guidance_scale: float = 4.0, | |
| denoising_step: int = 20, | |
| text_guidance_scale: float = 1.0, | |
| seed: int = 0, | |
| save_path: str = "", | |
| device: torch.device = torch.device("cuda"), | |
| ): | |
| image = load_image(image_path) | |
| audio = load_audio(audio_path) | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(seed) | |
| generated_video = pipeline( | |
| images=[image], | |
| audios=[audio], | |
| text_encodings=[category_text_encoding], | |
| video_length=video_num_frame, | |
| height=image_size[0], | |
| width=image_size[1], | |
| num_inference_steps=denoising_step, | |
| audio_guidance_scale=audio_guidance_scale, | |
| text_guidance_scale=text_guidance_scale, | |
| generator=generator, | |
| return_dict=False | |
| )[0] # (f c h w) in range [0, 1] | |
| generated_video = (generated_video.permute(0, 2, 3, 1).contiguous() * 255).byte() | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| torchvision.io.write_video( | |
| filename=save_path, | |
| video_array=generated_video, | |
| fps=video_fps, | |
| audio_array=audio, | |
| audio_fps=16000, | |
| audio_codec="aac" | |
| ) | |
| return | |