diff --git a/app.py b/app.py index 399c014e7f96f89c97f024b3d80c6f61145b4235..510cca474d028a879dd54a17168cc5139684132e 100644 --- a/app.py +++ b/app.py @@ -1,26 +1,118 @@ import gradio as gr +from pathlib import Path + +import soundfile as sf +import torch +import torchaudio +import hydra +from omegaconf import OmegaConf +import diffusers.schedulers as noise_schedulers + +from utils.config import register_omegaconf_resolvers +from models.common import LoadPretrainedBase + +from huggingface_hub import hf_hub_download +import fairseq + +register_omegaconf_resolvers() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +config = OmegaConf.load("configs/infer.yaml") + +ckpt_path = hf_hub_download( + repo_id="assasinatee/STAR", + filename="model.safetensors", + repo_type="model", + force_download=False +) + +exp_config = OmegaConf.load("configs/config.yaml") +if "pretrained_ckpt" in exp_config["model"]: + exp_config["model"]["pretrained_ckpt"] = ckpt_path +model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"]) + +model = model.to(device) + +ckpt_path = hf_hub_download( + repo_id="assasinatee/STAR", + filename="hubert_large_ll60k.pt", + repo_type="model", + force_download=False +) +hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) +hubert_model = hubert_models[0].eval().to(device) + +scheduler = getattr( + noise_schedulers, + config["noise_scheduler"]["type"], +).from_pretrained( + config["noise_scheduler"]["name"], + subfolder="scheduler", +) + +@torch.no_grad() +def infer(audio_path: str) -> str: + waveform_tts, sample_rate = torchaudio.load(audio_path) + if sample_rate != 16000: + waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts) + if waveform_tts.shape[0] > 1: + waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True) + with torch.no_grad(): + features, _ = hubert_model.extract_features(waveform_tts.to(device)) + + kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True) + kwargs['content'] = [features] + kwargs['condition'] = None + kwargs['task'] = ["speech_to_audio"] + + model.eval() + waveform = model.inference( + scheduler=scheduler, + **kwargs, + ) + + output_file = "output_audio.wav" + sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"]) + + return output_file with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo: gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning") gr.Markdown("""
+ ## 🗣️ Input A brief input speech utterance for the overall audio scene. -> Example:A cat meowing and young female speaking +> Example:A cat meowing and young female speaking + +### 🎙️ Input Speech Example +""") + + speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath") + + gr.Markdown(""" +
-#### 🎙️ Input Speech Example +### 🎧️ Output Audio Example +""") + + audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath") -#### 🎧️ Output Audio Example + gr.Markdown("""
---
-""") +""") -if __name__ == "__main__": - demo.launch() + with gr.Column(): + input_audio = gr.Audio(label="Speech Input", type="filepath") + btn = gr.Button("🎵Generate Audio!", variant="primary") + output_audio = gr.Audio(label="Generated Audio", type="filepath") + btn.click(fn=infer, inputs=input_audio, outputs=output_audio) + +demo.launch() \ No newline at end of file diff --git a/ckpts/1m.pt b/ckpts/1m.pt new file mode 100644 index 0000000000000000000000000000000000000000..895b620b0af0420ebf153d4118e5f1e22adb6bbf --- /dev/null +++ b/ckpts/1m.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cb13e2699fa922ce6a2b3b4f53c270ec64156e0cc3f3e3645e10cdf98b740dc +size 183037614 diff --git a/ckpts/exp0_best.pt b/ckpts/exp0_best.pt new file mode 100644 index 0000000000000000000000000000000000000000..6e25ad0e1da3a2bc155c51215d8b1da35475859e --- /dev/null +++ b/ckpts/exp0_best.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e2dc436e6d47cb02e954a0087a3a1b4aa1d5d3e1ded4fdafb6274966264d5a7 +size 73171895 diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a036c765fa906d496e5b97af0ecc9b24e292ddb --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,84 @@ +model: + autoencoder: + _target_: models.autoencoder.waveform.stable_vae.StableVAE + encoder: + _target_: models.autoencoder.waveform.stable_vae.OobleckEncoder + in_channels: 1 + channels: 128 + c_mults: + - 1 + - 2 + - 4 + - 8 + strides: + - 2 + - 4 + - 6 + - 10 + latent_dim: 256 + use_snake: true + decoder: + _target_: models.autoencoder.waveform.stable_vae.OobleckDecoder + out_channels: 1 + channels: 128 + c_mults: + - 1 + - 2 + - 4 + - 8 + strides: + - 2 + - 4 + - 6 + - 10 + latent_dim: 128 + use_snake: true + final_tanh: false + io_channels: 1 + latent_dim: 128 + downsampling_ratio: 480 + sample_rate: 24000 + pretrained_ckpt: /hpc_stor03/sjtu_home/xuenan.xu/workspace/text_to_audio_generation/ezaudio/ckpts/vae/1m.pt + bottleneck: + _target_: models.autoencoder.waveform.stable_vae.VAEBottleneck + backbone: + _target_: models.dit.mask_dit.UDiT + img_size: 500 + patch_size: 1 + in_chans: 128 + out_chans: 128 + input_type: 1d + embed_dim: 1024 + depth: 24 + num_heads: 16 + mlp_ratio: 4.0 + qkv_bias: false + qk_scale: null + qk_norm: layernorm + norm_layer: layernorm + act_layer: geglu + context_norm: true + use_checkpoint: true + time_fusion: ada_sola_bias + ada_sola_rank: 32 + ada_sola_alpha: 32 + cls_dim: null + context_dim: 1024 + context_fusion: cross + context_max_length: null + context_pe_method: none + pe_method: none + rope_mode: shared + use_conv: true + skip: true + skip_norm: true + cfg_drop_ratio: 0.2 + _target_: models.flow_matching.SingleTaskCrossAttentionAudioFlowMatching + content_encoder: + _target_: models.content_encoder.content_encoder.ContentEncoder + embed_dim: 1024 + text_encoder: None + speech_encoder: + _target_: models.content_encoder.star_encoder.star_encoder.QformerBridgeNet + load_from_pretrained: /hpc_stor03/sjtu_home/zeyu.xie/workspace/speech2audio/hear/output/qformer_caption_tts_hubert/exp0_best.pt + pretrained_ckpt: /hpc_stor03/sjtu_home/zeyu.xie/workspace/speech2audio/x2audio/x_to_audio_generation/experiments/audiocaps_fm/checkpoints/epoch_100/model.safetensors \ No newline at end of file diff --git a/configs/infer.yaml b/configs/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ca815e4d0c2106979a32a7c565a1588f291cef2 --- /dev/null +++ b/configs/infer.yaml @@ -0,0 +1,16 @@ +defaults: + - basic + - _self_ + +wav_dir: inference_delay + +noise_scheduler: + type: DDIMScheduler + name: stabilityai/stable-diffusion-2-1 + +infer_args: + num_steps: 50 + guidance_scale: 3.5 + guidance_rescale: 0.5 + use_gt_duration: false + latent_shape: [128, 500] \ No newline at end of file diff --git a/models/__pycache__/common.cpython-310.pyc b/models/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c26bf242810cd0f5a5781f1002167c5b097e0971 Binary files /dev/null and b/models/__pycache__/common.cpython-310.pyc differ diff --git a/models/__pycache__/content_adapter.cpython-310.pyc b/models/__pycache__/content_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c0b7bf7aba309822162289ac0ef6f43e591c24 Binary files /dev/null and b/models/__pycache__/content_adapter.cpython-310.pyc differ diff --git a/models/__pycache__/diffusion.cpython-310.pyc b/models/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24a2fbbd3e433632f343efc0d7465097e288f2d Binary files /dev/null and b/models/__pycache__/diffusion.cpython-310.pyc differ diff --git a/models/__pycache__/flow_matching.cpython-310.pyc b/models/__pycache__/flow_matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d1126e2d787c7e6493b3a377bdee61756bfa5ae Binary files /dev/null and b/models/__pycache__/flow_matching.cpython-310.pyc differ diff --git a/models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc b/models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0974bf1cb6f1d8847f573c39982c74a98e7b5bea Binary files /dev/null and b/models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc differ diff --git a/models/autoencoder/autoencoder_base.py b/models/autoencoder/autoencoder_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2852ad185b48e9595e116735baed689fa09cc0d3 --- /dev/null +++ b/models/autoencoder/autoencoder_base.py @@ -0,0 +1,22 @@ +from abc import abstractmethod, ABC +from typing import Sequence +import torch +import torch.nn as nn + + +class AutoEncoderBase(ABC): + def __init__( + self, downsampling_ratio: int, sample_rate: int, + latent_shape: Sequence[int | None] + ): + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + self.latent_token_rate = sample_rate // downsampling_ratio + self.latent_shape = latent_shape + self.time_dim = latent_shape.index(None) + 1 # the first dim is batch + + @abstractmethod + def encode( + self, waveform: torch.Tensor, waveform_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + ... diff --git a/models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc b/models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4019e321981d753c20a140335e93b16d957476 Binary files /dev/null and b/models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc differ diff --git a/models/autoencoder/waveform/dac.py b/models/autoencoder/waveform/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/autoencoder/waveform/stable_vae.py b/models/autoencoder/waveform/stable_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..cbebfce624995379da5b3957e9d74d1186dce6b5 --- /dev/null +++ b/models/autoencoder/waveform/stable_vae.py @@ -0,0 +1,559 @@ +from typing import Any, Literal, Callable +import math +from pathlib import Path + +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +import torchaudio +from alias_free_torch import Activation1d + +from models.common import LoadPretrainedBase +from models.autoencoder.autoencoder_base import AutoEncoderBase +from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length + + +# jit script make it 1.4x faster and save GPU memory +@torch.jit.script +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + + +class SnakeBeta(nn.Module): + def __init__( + self, + in_features, + alpha=1.0, + alpha_trainable=True, + alpha_logscale=True + ): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: + # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + # self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +def get_activation( + activation: Literal["elu", "snake", "none"], + antialias=False, + channels=None +) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + + +class ResidualUnit(nn.Module): + def __init__( + self, + in_channels, + out_channels, + dilation, + use_snake=False, + antialias_activation=False + ): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7 - 1)) // 2 + + self.layers = nn.Sequential( + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=out_channels + ), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=7, + dilation=dilation, + padding=padding + ), + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=out_channels + ), + WNConv1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1 + ) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + + +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride, + use_snake=False, + antialias_activation=False + ): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit( + in_channels=in_channels, + out_channels=in_channels, + dilation=1, + use_snake=use_snake + ), + ResidualUnit( + in_channels=in_channels, + out_channels=in_channels, + dilation=3, + use_snake=use_snake + ), + ResidualUnit( + in_channels=in_channels, + out_channels=in_channels, + dilation=9, + use_snake=use_snake + ), + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=in_channels + ), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2) + ), + ) + + def forward(self, x): + return self.layers(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + stride, + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False + ): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=1, + bias=False, + padding='same' + ) + ) + else: + upsample_layer = WNConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2) + ) + + self.layers = nn.Sequential( + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=in_channels + ), + upsample_layer, + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + dilation=1, + use_snake=use_snake + ), + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + dilation=3, + use_snake=use_snake + ), + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + dilation=9, + use_snake=use_snake + ), + ) + + def forward(self, x): + return self.layers(x) + + +class OobleckEncoder(nn.Module): + def __init__( + self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults=[1, 2, 4, 8], + strides=[2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d( + in_channels=in_channels, + out_channels=c_mults[0] * channels, + kernel_size=7, + padding=3 + ) + ] + + for i in range(self.depth - 1): + layers += [ + EncoderBlock( + in_channels=c_mults[i] * channels, + out_channels=c_mults[i + 1] * channels, + stride=strides[i], + use_snake=use_snake + ) + ] + + layers += [ + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=c_mults[-1] * channels + ), + WNConv1d( + in_channels=c_mults[-1] * channels, + out_channels=latent_dim, + kernel_size=3, + padding=1 + ) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__( + self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults=[1, 2, 4, 8], + strides=[2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d( + in_channels=latent_dim, + out_channels=c_mults[-1] * channels, + kernel_size=7, + padding=3 + ), + ] + + for i in range(self.depth - 1, 0, -1): + layers += [ + DecoderBlock( + in_channels=c_mults[i] * channels, + out_channels=c_mults[i - 1] * channels, + stride=strides[i - 1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation( + "snake" if use_snake else "elu", + antialias=antialias_activation, + channels=c_mults[0] * channels + ), + WNConv1d( + in_channels=c_mults[0] * channels, + out_channels=out_channels, + kernel_size=7, + padding=3, + bias=False + ), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + + +@torch.jit.script +def vae_sample(mean, scale) -> dict[str, torch.Tensor]: + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + return {"latents": latents, "kl": kl} + + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, + x, + return_info=False, + **kwargs) -> dict[str, torch.Tensor] | torch.Tensor: + mean, scale = x.chunk(2, dim=1) + sampled = vae_sample(mean, scale) + + if return_info: + return sampled["latents"], {"kl": sampled["kl"]} + else: + return sampled["latents"] + + def decode(self, x): + return x + + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + + +class StableVAE(LoadPretrainedBase, AutoEncoderBase): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels=None, + out_channels=None, + soft_clip=False, + pretrained_ckpt: str | Path = None + ): + LoadPretrainedBase.__init__(self) + AutoEncoderBase.__init__( + self, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + latent_shape=(latent_dim, None) + ) + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + self.encoder = encoder + self.decoder = decoder + self.pretransform = pretransform + self.soft_clip = soft_clip + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory( + "autoencoder." + ) + if pretrained_ckpt is not None: + self.load_pretrained(pretrained_ckpt) + + def process_state_dict(self, model_dict, state_dict): + state_dict = state_dict["state_dict"] + state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict) + return state_dict + + def encode( + self, waveform: torch.Tensor, waveform_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + z = self.encoder(waveform) + z = self.bottleneck.encode(z) + z_length = waveform_lengths // self.downsampling_ratio + z_mask = create_mask_from_length(z_length) + return z, z_mask + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + waveform = self.decoder(latents) + return waveform + + +class StableVAEProjectorWrapper(nn.Module): + def __init__( + self, + vae_dim: int, + embed_dim: int, + model: StableVAE | None = None, + ): + super().__init__() + self.model = model + self.proj = nn.Linear(vae_dim, embed_dim) + + def forward( + self, waveform: torch.Tensor, waveform_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + self.model.eval() + with torch.no_grad(): + z, z_mask = self.model.encode(waveform, waveform_lengths) + z = self.proj(z.transpose(1, 2)) + return {"output": z, "mask": z_mask} + + +if __name__ == '__main__': + import hydra + from utils.config import generate_config_from_command_line_overrides + model_config = generate_config_from_command_line_overrides( + "configs/model/autoencoder/stable_vae.yaml" + ) + autoencoder: StableVAE = hydra.utils.instantiate(model_config) + autoencoder.eval() + + waveform, sr = torchaudio.load( + "/hpc_stor03/sjtu_home/xuenan.xu/data/m4singer/Tenor-1#童话/0006.wav" + ) + waveform = waveform.mean(0, keepdim=True) + waveform = torchaudio.functional.resample( + waveform, sr, model_config["sample_rate"] + ) + print("waveform: ", waveform.shape) + with torch.no_grad(): + latent, latent_length = autoencoder.encode( + waveform, torch.as_tensor([waveform.shape[-1]]) + ) + print("latent: ", latent.shape) + reconstructed = autoencoder.decode(latent) + print("reconstructed: ", reconstructed.shape) + import soundfile as sf + sf.write( + "./reconstructed.wav", + reconstructed[0, 0].numpy(), + samplerate=model_config["sample_rate"] + ) diff --git a/models/common.py b/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..79afb0be423dab991b5e7f14a74ba5503ce0c426 --- /dev/null +++ b/models/common.py @@ -0,0 +1,67 @@ +from pathlib import Path +import torch +import torch.nn as nn +from utils.torch_utilities import load_pretrained_model, merge_matched_keys + + +class LoadPretrainedBase(nn.Module): + def process_state_dict( + self, model_dict: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor] + ): + """ + Custom processing functions of each model that transforms `state_dict` loaded from + checkpoints to the state that can be used in `load_state_dict`. + Use `merge_mathced_keys` to update parameters with matched names and shapes by + default. + + Args + model_dict: + The state dict of the current model, which is going to load pretrained parameters + state_dict: + A dictionary of parameters from a pre-trained model. + + Returns: + dict[str, torch.Tensor]: + The updated state dict, where parameters with matched keys and shape are + updated with values in `state_dict`. + """ + state_dict = merge_matched_keys(model_dict, state_dict) + return state_dict + + def load_pretrained(self, ckpt_path: str | Path): + load_pretrained_model( + self, ckpt_path, state_dict_process_fn=self.process_state_dict + ) + + +class CountParamsBase(nn.Module): + def count_params(self): + num_params = 0 + trainable_params = 0 + for param in self.parameters(): + num_params += param.numel() + if param.requires_grad: + trainable_params += param.numel() + return num_params, trainable_params + + +class SaveTrainableParamsBase(nn.Module): + @property + def param_names_to_save(self): + names = [] + for name, param in self.named_parameters(): + if param.requires_grad: + names.append(name) + for name, _ in self.named_buffers(): + names.append(name) + return names + + def load_state_dict(self, state_dict, strict=True): + for key in self.param_names_to_save: + if key not in state_dict: + raise Exception( + f"{key} not found in either pre-trained models (e.g. BERT)" + " or resumed checkpoints (e.g. epoch_40/model.pt)" + ) + return super().load_state_dict(state_dict, strict) diff --git a/models/content_adapter.py b/models/content_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..083b2c9b639de703b8e63748a611509c379e513c --- /dev/null +++ b/models/content_adapter.py @@ -0,0 +1,381 @@ +import math +import torch +import torch.nn as nn + +from utils.torch_utilities import concat_non_padding, restore_from_concat + + +###################### +# fastspeech modules +###################### +class LayerNorm(nn.LayerNorm): + """Layer normalization module. + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, + self).forward(x.transpose(1, -1)).transpose(1, -1) + + +class DurationPredictor(nn.Module): + def __init__( + self, + in_channels: int, + filter_channels: int, + n_layers: int = 2, + kernel_size: int = 3, + p_dropout: float = 0.1, + padding: str = "SAME" + ): + super(DurationPredictor, self).__init__() + self.conv = nn.ModuleList() + self.kernel_size = kernel_size + self.padding = padding + for idx in range(n_layers): + in_chans = in_channels if idx == 0 else filter_channels + self.conv += [ + nn.Sequential( + nn.ConstantPad1d(((kernel_size - 1) // 2, + (kernel_size - 1) // + 2) if padding == 'SAME' else + (kernel_size - 1, 0), 0), + nn.Conv1d( + in_chans, + filter_channels, + kernel_size, + stride=1, + padding=0 + ), nn.ReLU(), LayerNorm(filter_channels, dim=1), + nn.Dropout(p_dropout) + ) + ] + self.linear = nn.Linear(filter_channels, 1) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor): + # x: [B, T, E] + x = x.transpose(1, -1) + x_mask = x_mask.unsqueeze(1).to(x.device) + for f in self.conv: + x = f(x) + x = x * x_mask.float() + + x = self.linear(x.transpose(1, -1) + ) * x_mask.transpose(1, -1).float() # [B, T, 1] + return x + + +###################### +# adapter modules +###################### + + +class ContentAdapterBase(nn.Module): + def __init__(self, d_out): + super().__init__() + self.d_out = d_out + + +class SinusoidalPositionalEmbedding(nn.Module): + def __init__(self, d_model, dropout, max_len=1000): + super().__init__() + self.dropout = nn.Dropout(dropout) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * + (-math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.size(1), :] + return self.dropout(x) + + +class ContentAdapter(ContentAdapterBase): + def __init__( + self, + d_model: int, + d_out: int, + num_layers: int, + num_heads: int, + duration_predictor: DurationPredictor, + dropout: float = 0.1, + norm_first: bool = False, + activation: str = "gelu", + duration_grad_scale: float = 0.0, + ): + super().__init__(d_out) + self.duration_grad_scale = duration_grad_scale + self.cls_embed = nn.Parameter(torch.randn(d_model)) + if hasattr(torch, "npu") and torch.npu.is_available(): + enable_nested_tensor = False + else: + enable_nested_tensor = True + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_heads, + dim_feedforward=4 * d_model, + dropout=dropout, + activation=activation, + norm_first=norm_first, + batch_first=True + ) + self.encoder_layers = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_layers, + enable_nested_tensor=enable_nested_tensor + ) + self.duration_predictor = duration_predictor + self.content_proj = nn.Conv1d(d_model, d_out, 1) + + def forward(self, x, x_mask): + batch_size = x.size(0) + cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1) + cls_embed = cls_embed.to(x.device).unsqueeze(1) + x = torch.cat([cls_embed, x], dim=1) + + cls_mask = torch.ones(batch_size, 1).to(x_mask.device) + x_mask = torch.cat([cls_mask, x_mask], dim=1) + x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool()) + x_grad_rescaled = x * self.duration_grad_scale + x.detach( + ) * (1 - self.duration_grad_scale) + duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1) + content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) + return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:] + + +class PrefixAdapter(ContentAdapterBase): + def __init__( + self, + content_dim: int, + d_model: int, + d_out: int, + prefix_dim: int, + num_layers: int, + num_heads: int, + duration_predictor: DurationPredictor, + dropout: float = 0.1, + norm_first: bool = False, + use_last_norm: bool = True, + activation: str = "gelu", + duration_grad_scale: float = 0.1, + ): + super().__init__(d_out) + self.duration_grad_scale = duration_grad_scale + self.prefix_mlp = nn.Sequential( + nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout), + nn.Linear(d_model, d_model) + ) + self.content_mlp = nn.Sequential( + nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout), + nn.Linear(d_model, d_model) + ) + layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_heads, + dim_feedforward=4 * d_model, + dropout=dropout, + activation=activation, + batch_first=True, + norm_first=norm_first + ) + if hasattr(torch, "npu") and torch.npu.is_available(): + enable_nested_tensor = False + else: + enable_nested_tensor = True + self.cls_embed = nn.Parameter(torch.randn(d_model)) + # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout) + self.layers = nn.TransformerEncoder( + encoder_layer=layer, + num_layers=num_layers, + enable_nested_tensor=enable_nested_tensor + ) + self.use_last_norm = use_last_norm + if self.use_last_norm: + self.last_norm = nn.LayerNorm(d_model) + self.duration_predictor = duration_predictor + self.content_proj = nn.Conv1d(d_model, d_out, 1) + nn.init.normal_(self.cls_embed, 0., 0.02) + nn.init.xavier_uniform_(self.content_proj.weight) + nn.init.constant_(self.content_proj.bias, 0.) + + def forward(self, content, content_mask, instruction, instruction_mask): + batch_size = content.size(0) + cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1) + cls_embed = cls_embed.to(content.device).unsqueeze(1) + content = self.content_mlp(content) + x = torch.cat([cls_embed, content], dim=1) + cls_mask = torch.ones(batch_size, 1, + dtype=bool).to(content_mask.device) + x_mask = torch.cat([cls_mask, content_mask], dim=1) + + prefix = self.prefix_mlp(instruction) + seq, seq_mask, perm = concat_non_padding( + prefix, instruction_mask, x, x_mask + ) + # seq = self.pos_embed(seq) + x = self.layers(seq, src_key_padding_mask=~seq_mask.bool()) + if self.use_last_norm: + x = self.last_norm(x) + _, x = restore_from_concat(x, instruction_mask, x_mask, perm) + + x_grad_rescaled = x * self.duration_grad_scale + x.detach( + ) * (1 - self.duration_grad_scale) + duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1) + content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) + return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:] + + +class CrossAttentionAdapter(ContentAdapterBase): + def __init__( + self, + d_out: int, + content_dim: int, + prefix_dim: int, + num_heads: int, + duration_predictor: DurationPredictor, + dropout: float = 0.1, + duration_grad_scale: float = 0.1, + ): + super().__init__(d_out) + self.attn = nn.MultiheadAttention( + embed_dim=content_dim, + num_heads=num_heads, + dropout=dropout, + kdim=prefix_dim, + vdim=prefix_dim, + batch_first=True, + ) + self.duration_grad_scale = duration_grad_scale + self.duration_predictor = duration_predictor + self.global_duration_mlp = nn.Sequential( + nn.Linear(content_dim, content_dim), nn.ReLU(), + nn.Dropout(dropout), nn.Linear(content_dim, 1) + ) + self.norm = nn.LayerNorm(content_dim) + self.content_proj = nn.Conv1d(content_dim, d_out, 1) + + def forward(self, content, content_mask, prefix, prefix_mask): + attn_output, attn_output_weights = self.attn( + query=content, + key=prefix, + value=prefix, + key_padding_mask=~prefix_mask.bool() + ) + attn_output = attn_output * content_mask.unsqueeze(-1).float() + x = self.norm(attn_output + content) + x_grad_rescaled = x * self.duration_grad_scale + x.detach( + ) * (1 - self.duration_grad_scale) + x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float() + ).sum(dim=1) / content_mask.sum(dim=1, + keepdim=True).float() + global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1) + local_duration = self.duration_predictor( + x_grad_rescaled, content_mask + ).squeeze(-1) + content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) + return content, content_mask, global_duration, local_duration + + +class ExperimentalCrossAttentionAdapter(ContentAdapterBase): + def __init__( + self, + d_out: int, + content_dim: int, + prefix_dim: int, + num_heads: int, + duration_predictor: DurationPredictor, + dropout: float = 0.1, + duration_grad_scale: float = 0.1, + ): + super().__init__(d_out) + self.content_mlp = nn.Sequential( + nn.Linear(content_dim, content_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(content_dim, content_dim), + ) + self.content_norm = nn.LayerNorm(content_dim) + self.prefix_mlp = nn.Sequential( + nn.Linear(prefix_dim, prefix_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(prefix_dim, prefix_dim), + ) + self.prefix_norm = nn.LayerNorm(content_dim) + self.attn = nn.MultiheadAttention( + embed_dim=content_dim, + num_heads=num_heads, + dropout=dropout, + kdim=prefix_dim, + vdim=prefix_dim, + batch_first=True, + ) + self.duration_grad_scale = duration_grad_scale + self.duration_predictor = duration_predictor + self.global_duration_mlp = nn.Sequential( + nn.Linear(content_dim, content_dim), nn.ReLU(), + nn.Dropout(dropout), nn.Linear(content_dim, 1) + ) + self.content_proj = nn.Sequential( + nn.Linear(content_dim, d_out), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_out, d_out), + ) + self.norm1 = nn.LayerNorm(content_dim) + self.norm2 = nn.LayerNorm(d_out) + self.init_weights() + + def init_weights(self): + def _init_weights(module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0.) + + self.apply(_init_weights) + + def forward(self, content, content_mask, prefix, prefix_mask): + content = self.content_mlp(content) + content = self.content_norm(content) + prefix = self.prefix_mlp(prefix) + prefix = self.prefix_norm(prefix) + attn_output, attn_weights = self.attn( + query=content, + key=prefix, + value=prefix, + key_padding_mask=~prefix_mask.bool(), + ) + attn_output = attn_output * content_mask.unsqueeze(-1).float() + x = attn_output + content + x = self.norm1(x) + x_grad_rescaled = x * self.duration_grad_scale + x.detach( + ) * (1 - self.duration_grad_scale) + x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float() + ).sum(dim=1) / content_mask.sum(dim=1, + keepdim=True).float() + global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1) + local_duration = self.duration_predictor( + x_grad_rescaled, content_mask + ).squeeze(-1) + content = self.content_proj(x) + content = self.norm2(content) + return content, content_mask, global_duration, local_duration diff --git a/models/content_encoder/__pycache__/content_encoder.cpython-310.pyc b/models/content_encoder/__pycache__/content_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ace4e43d9626c100e9e1fcf068cd262f828629d2 Binary files /dev/null and b/models/content_encoder/__pycache__/content_encoder.cpython-310.pyc differ diff --git a/models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc b/models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac8dbd6065266f00836ba293f607690a9755600 Binary files /dev/null and b/models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc differ diff --git a/models/content_encoder/__pycache__/text_encoder.cpython-310.pyc b/models/content_encoder/__pycache__/text_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..991aec058fcc39d780b17a5372490355eb49e67f Binary files /dev/null and b/models/content_encoder/__pycache__/text_encoder.cpython-310.pyc differ diff --git a/models/content_encoder/content_encoder.py b/models/content_encoder/content_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fd881d8012a9c84e8c7d4f82fa10b37ebec0a588 --- /dev/null +++ b/models/content_encoder/content_encoder.py @@ -0,0 +1,280 @@ +from typing import Any +import torch +import torch.nn as nn + + +class ContentEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + text_encoder: nn.Module = None, + video_encoder: nn.Module = None, + midi_encoder: nn.Module = None, + phoneme_encoder: nn.Module = None, + pitch_encoder: nn.Module = None, + audio_encoder: nn.Module = None, + speech_encoder: nn.Module = None, + sketch_encoder: nn.Module = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.text_encoder = text_encoder + self.midi_encoder = midi_encoder + self.phoneme_encoder = phoneme_encoder + self.pitch_encoder = pitch_encoder + self.audio_encoder = audio_encoder + self.video_encoder = video_encoder + self.speech_encoder = speech_encoder + self.sketch_encoder = sketch_encoder + + def encode_content( + self, batch_content: list[Any], batch_task: list[str], + device: str | torch.device + ): + batch_content_output = [] + batch_content_mask = [] + batch_la_content_output = [] + + zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) + + for content, task in zip(batch_content, batch_task): + if task == "audio_super_resolution" or task == "speech_enhancement": + content_dict = { + "waveform": torch.as_tensor(content).float(), + "waveform_lengths": torch.as_tensor(content.shape[0]), + } + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + content_output_dict = self.audio_encoder(**content_dict) + la_content_output_dict = { + "output": zero_la_content, + } + elif task == "text_to_audio" or task == "text_to_music": + content_output_dict = self.text_encoder([content]) + la_content_output_dict = { + "output": zero_la_content, + } + elif task == "speech_to_audio": + input_dict = { + "embed": content, + "embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device), + } + content_output_dict = self.speech_encoder(input_dict) + la_content_output_dict = { + "output": zero_la_content, + } + elif task == "direct_speech_to_audio": + # content shape [1, L/T 133, dim] mask [1, L/T 133] in hubert + if len(content.shape) < 3: + content = content.unsqueeze(0) + mask = torch.ones(content.shape[:2]) + mask = (mask == 1).to(content.device) + content_output_dict = { + "output": content, + "mask": mask, + } + la_content_output_dict = { + "output": zero_la_content, + } + elif task == "sketch_to_audio": + content_output_dict = self.sketch_encoder([content["caption"]]) + content_dict = { + "f0": torch.as_tensor(content["f0"]), + "energy": torch.as_tensor(content["energy"]), + } + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + la_content_output_dict = self.sketch_encoder.encode_sketch( + **content_dict + ) + elif task == "video_to_audio": + content_dict = { + "frames": torch.as_tensor(content).float(), + "frame_nums": torch.as_tensor(content.shape[0]), + } + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + content_output_dict = self.video_encoder(**content_dict) + la_content_output_dict = { + "output": zero_la_content, + } + elif task == "singing_voice_synthesis": + content_dict = { + "phoneme": + torch.as_tensor(content["phoneme"]).long(), + "midi": + torch.as_tensor(content["midi"]).long(), + "midi_duration": + torch.as_tensor(content["midi_duration"]).float(), + "is_slur": + torch.as_tensor(content["is_slur"]).long() + } + if "spk" in content: + if self.midi_encoder.spk_config.encoding_format == "id": + content_dict["spk"] = torch.as_tensor(content["spk"] + ).long() + elif self.midi_encoder.spk_config.encoding_format == "embedding": + content_dict["spk"] = torch.as_tensor(content["spk"] + ).float() + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + content_dict["lengths"] = torch.as_tensor([ + len(content["phoneme"]) + ]) + content_output_dict = self.midi_encoder(**content_dict) + la_content_output_dict = {"output": zero_la_content} + elif task == "text_to_speech": + content_dict = { + "phoneme": torch.as_tensor(content["phoneme"]).long(), + } + if "spk" in content: + if self.phoneme_encoder.spk_config.encoding_format == "id": + content_dict["spk"] = torch.as_tensor(content["spk"] + ).long() + elif self.phoneme_encoder.spk_config.encoding_format == "embedding": + content_dict["spk"] = torch.as_tensor(content["spk"] + ).float() + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + content_dict["lengths"] = torch.as_tensor([ + len(content["phoneme"]) + ]) + content_output_dict = self.phoneme_encoder(**content_dict) + la_content_output_dict = {"output": zero_la_content} + elif task == "singing_acoustic_modeling": + content_dict = { + "phoneme": torch.as_tensor(content["phoneme"]).long(), + } + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + content_dict["lengths"] = torch.as_tensor([ + len(content["phoneme"]) + ]) + content_output_dict = self.pitch_encoder(**content_dict) + + content_dict = { + "f0": torch.as_tensor(content["f0"]), + "uv": torch.as_tensor(content["uv"]), + } + for key in list(content_dict.keys()): + content_dict[key] = content_dict[key].unsqueeze(0).to( + device + ) + la_content_output_dict = self.pitch_encoder.encode_pitch( + **content_dict + ) + + batch_content_output.append(content_output_dict["output"][0]) + batch_content_mask.append(content_output_dict["mask"][0]) + batch_la_content_output.append(la_content_output_dict["output"][0]) + + batch_content_output = nn.utils.rnn.pad_sequence( + batch_content_output, batch_first=True, padding_value=0 + ) + batch_content_mask = nn.utils.rnn.pad_sequence( + batch_content_mask, batch_first=True, padding_value=False + ) + batch_la_content_output = nn.utils.rnn.pad_sequence( + batch_la_content_output, batch_first=True, padding_value=0 + ) + return { + "content": batch_content_output, + "content_mask": batch_content_mask, + "length_aligned_content": batch_la_content_output, + } + + +class BatchedContentEncoder(ContentEncoder): + def encode_content( + self, batch_content: list | dict, batch_task: list[str], + device: str | torch.device + ): + task = batch_task[0] + zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) + if task == "audio_super_resolution" or task == "speech_enhancement": + content_dict = { + "waveform": + batch_content["content"].unsqueeze(1).float().to(device), + "waveform_lengths": + batch_content["content_lengths"].long().to(device), + } + content_output = self.audio_encoder(**content_dict) + la_content_output = zero_la_content + elif task == "text_to_audio": + content_output = self.text_encoder(batch_content) + la_content_output = zero_la_content + elif task == "video_to_audio": + content_dict = { + "frames": + batch_content["content"].float().to(device), + "frame_nums": + batch_content["content_lengths"].long().to(device), + } + content_output = self.video_encoder(**content_dict) + la_content_output = zero_la_content + elif task == "singing_voice_synthesis": + content_dict = { + "phoneme": + batch_content["phoneme"].long().to(device), + "midi": + batch_content["midi"].long().to(device), + "midi_duration": + batch_content["midi_duration"].float().to(device), + "is_slur": + batch_content["is_slur"].long().to(device), + "lengths": + batch_content["phoneme_lengths"].long().cpu(), + } + if "spk" in batch_content: + if self.midi_encoder.spk_config.encoding_format == "id": + content_dict["spk"] = batch_content["spk"].long( + ).to(device) + elif self.midi_encoder.spk_config.encoding_format == "embedding": + content_dict["spk"] = batch_content["spk"].float( + ).to(device) + content_output = self.midi_encoder(**content_dict) + la_content_output = zero_la_content + elif task == "text_to_speech": + content_dict = { + "phoneme": batch_content["phoneme"].long().to(device), + "lengths": batch_content["phoneme_lengths"].long().cpu(), + } + if "spk" in batch_content: + if self.phoneme_encoder.spk_config.encoding_format == "id": + content_dict["spk"] = batch_content["spk"].long( + ).to(device) + elif self.phoneme_encoder.spk_config.encoding_format == "embedding": + content_dict["spk"] = batch_content["spk"].float( + ).to(device) + content_output = self.phoneme_encoder(**content_dict) + la_content_output = zero_la_content + elif task == "singing_acoustic_modeling": + content_dict = { + "phoneme": batch_content["phoneme"].long().to(device), + "lengths": batch_content["phoneme_lengths"].long().to(device), + } + content_output = self.pitch_encoder(**content_dict) + + content_dict = { + "f0": batch_content["f0"].float().to(device), + "uv": batch_content["uv"].float().to(device), + } + la_content_output = self.pitch_encoder.encode_pitch(**content_dict) + + return { + "content": content_output["output"], + "content_mask": content_output["mask"], + "length_aligned_content": la_content_output, + } diff --git a/models/content_encoder/midi_encoder.py b/models/content_encoder/midi_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fda372db829831f44dd41a2ee50197e68f4aa618 --- /dev/null +++ b/models/content_encoder/midi_encoder.py @@ -0,0 +1,1046 @@ +from typing import Sequence +from dataclasses import dataclass +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from utils.torch_utilities import create_mask_from_length +from utils.diffsinger_utilities import denorm_f0, f0_to_coarse + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return (torch.cumsum(mask, dim=1).type_as(mask) * + mask).long() + padding_idx + + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def LayerNorm( + normalized_shape, eps=1e-5, elementwise_affine=True, export=False +): + if not export and torch.cuda.is_available(): + try: + from apex.normalization import FusedLayerNorm + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +class BatchNorm1dTBC(nn.Module): + def __init__(self, c): + super(BatchNorm1dTBC, self).__init__() + self.bn = nn.BatchNorm1d(c) + + def forward(self, x): + """ + + :param x: [T, B, C] + :return: [T, B, C] + """ + x = x.permute(1, 2, 0) # [B, C, T] + x = self.bn(x) # [B, C, T] + x = x.permute(2, 0, 1) # [T, B, C] + return x + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. + """ + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + def __init__(self, d_model, padding_idx, init_size=2048): + super().__init__() + self.d_model = d_model + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + d_model, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, d_model, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = d_model // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, + dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], + dim=1).view(num_embeddings, -1) + if d_model % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward( + self, + x, + lengths, + incremental_state=None, + timestep=None, + positions=None, + **kwargs + ): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = x.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.d_model, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = create_mask_from_length( + lengths, max_length=x.shape[1] + ) * (torch.arange(x.shape[1]) + 1).unsqueeze(0).expand(x.shape[0], -1) + positions = positions.to(self.weights.device) + pos_emb = self.weights.index_select(0, positions.view(-1)).view( + bsz, seq_len, -1 + ).detach() + return x + pos_emb + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e5) # an arbitrary large number + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, x, lengths): + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, :x.size(1)] + return self.dropout(x) + self.dropout(pos_emb) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ + 'value to be of the same size' + + if self.qkv_same_dim: + self.in_proj_weight = Parameter( + torch.Tensor(3 * embed_dim, embed_dim) + ) + else: + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + + if bias: + self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.enable_torch_version = False + if hasattr(F, "multi_head_attention_forward"): + self.enable_torch_version = True + else: + self.enable_torch_version = False + self.last_attn_probs = None + + def reset_parameters(self): + if self.qkv_same_dim: + nn.init.xavier_uniform_(self.in_proj_weight) + else: + nn.init.xavier_uniform_(self.k_proj_weight) + nn.init.xavier_uniform_(self.v_proj_weight) + nn.init.xavier_uniform_(self.q_proj_weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None: + if self.qkv_same_dim: + return F.multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, self.bias_k, + self.bias_v, self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, self.training, + key_padding_mask, need_weights, attn_mask + ) + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight + ) + + if incremental_state is not None: + print('Not implemented error.') + exit() + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, + attn_mask.new_zeros(attn_mask.size(0), 1)], + dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros( + key_padding_mask.size(0), 1 + ) + ], + dim=1 + ) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + + if saved_state is not None: + print('Not implemented error.') + exit() + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size( + [] + ): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat( + [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1 + ) + v = torch.cat( + [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1 + ) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, + attn_mask.new_zeros(attn_mask.size(0), 1)], + dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), + 1).type_as(key_padding_mask) + ], + dim=1 + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask( + attn_weights, tgt_len, src_len, bsz + ) + + assert list(attn_weights.size()) == [ + bsz * self.num_heads, tgt_len, src_len + ] + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat( + [1, self.num_heads, 1, 1] + ).reshape(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights + attn_mask + + if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv + attn_weights = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + enc_dec_attn_constraint_mask.unsqueeze(2).bool(), + -1e9, + ) + attn_weights = attn_weights.view( + bsz * self.num_heads, tgt_len, src_len + ) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -1e9, + ) + attn_weights = attn_weights.view( + bsz * self.num_heads, tgt_len, src_len + ) + + attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout( + attn_weights_float.type_as(attn_weights), + p=self.dropout, + training=self.training + ) + + if reset_attn_weight is not None: + if reset_attn_weight: + self.last_attn_probs = attn_probs.detach() + else: + assert self.last_attn_probs is not None + attn_probs = self.last_attn_probs + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [ + bsz * self.num_heads, tgt_len, self.head_dim + ] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + else: + attn_weights = None + + return attn, (attn_weights, attn_logits) + + def in_proj_qkv(self, query): + return self._in_proj(query).chunk(3, dim=-1) + + def in_proj_q(self, query): + if self.qkv_same_dim: + return self._in_proj(query, end=self.embed_dim) + else: + bias = self.in_proj_bias + if bias is not None: + bias = bias[:self.embed_dim] + return F.linear(query, self.q_proj_weight, bias) + + def in_proj_k(self, key): + if self.qkv_same_dim: + return self._in_proj( + key, start=self.embed_dim, end=2 * self.embed_dim + ) + else: + weight = self.k_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[self.embed_dim:2 * self.embed_dim] + return F.linear(key, weight, bias) + + def in_proj_v(self, value): + if self.qkv_same_dim: + return self._in_proj(value, start=2 * self.embed_dim) + else: + weight = self.v_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[2 * self.embed_dim:] + return F.linear(value, weight, bias) + + def _in_proj(self, input, start=0, end=None): + weight = self.in_proj_weight + bias = self.in_proj_bias + weight = weight[start:end, :] + if bias is not None: + bias = bias[start:end] + return F.linear(input, weight, bias) + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + return attn_weights + + +class TransformerFFNLayer(nn.Module): + def __init__( + self, + hidden_size, + filter_size, + padding="SAME", + kernel_size=1, + dropout=0., + act='gelu' + ): + super().__init__() + self.kernel_size = kernel_size + self.dropout = dropout + self.act = act + if padding == 'SAME': + self.ffn_1 = nn.Conv1d( + hidden_size, + filter_size, + kernel_size, + padding=kernel_size // 2 + ) + elif padding == 'LEFT': + self.ffn_1 = nn.Sequential( + nn.ConstantPad1d((kernel_size - 1, 0), 0.0), + nn.Conv1d(hidden_size, filter_size, kernel_size) + ) + self.ffn_2 = nn.Linear(filter_size, hidden_size) + + def forward( + self, + x, + ): + # x: T x B x C + x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1) + x = x * self.kernel_size**-0.5 + + if self.act == 'gelu': + x = F.gelu(x) + if self.act == 'relu': + x = F.relu(x) + if self.act == 'swish': + x = F.silu(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.ffn_2(x) + return x + + +class EncoderSelfAttentionLayer(nn.Module): + def __init__( + self, + c, + num_heads, + dropout, + attention_dropout=0.1, + relu_dropout=0.1, + kernel_size=9, + padding='SAME', + norm='ln', + act='gelu', + padding_set_zero=True + ): + super().__init__() + self.c = c + self.dropout = dropout + self.num_heads = num_heads + self.padding_set_zero = padding_set_zero + if num_heads > 0: + if norm == 'ln': + self.layer_norm1 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm1 = BatchNorm1dTBC(c) + self.self_attn = MultiheadAttention( + self.c, + num_heads=num_heads, + self_attention=True, + dropout=attention_dropout, + bias=False, + ) + if norm == 'ln': + self.layer_norm2 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm2 = BatchNorm1dTBC(c) + self.ffn = TransformerFFNLayer( + c, + 4 * c, + kernel_size=kernel_size, + dropout=relu_dropout, + padding=padding, + act=act + ) + + def forward(self, x, encoder_padding_mask=None, **kwargs): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + if self.num_heads > 0: + residual = x + x = self.layer_norm1(x) + x, _, = self.self_attn( + query=x, key=x, value=x, key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.padding_set_zero: + x = x * (1 - encoder_padding_mask.float()).transpose(0, + 1)[..., + None] + + residual = x + x = self.layer_norm2(x) + x = self.ffn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + if self.padding_set_zero: + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., + None] + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + hidden_size, + dropout, + kernel_size, + num_heads=2, + norm='ln', + padding_set_zero=True, + ): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + self.op = EncoderSelfAttentionLayer( + hidden_size, + num_heads, + dropout=dropout, + attention_dropout=0.0, + relu_dropout=dropout, + kernel_size=kernel_size, + padding="SAME", + norm=norm, + act="gelu", + padding_set_zero=padding_set_zero + ) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + +class FFTBlocks(nn.Module): + def __init__( + self, + hidden_size, + num_layers, + ffn_kernel_size=9, + dropout=0.1, + num_heads=2, + use_last_norm=True, + padding_set_zero=True, + ): + super().__init__() + self.num_layers = num_layers + embed_dim = self.hidden_size = hidden_size + self.dropout = dropout + self.use_last_norm = use_last_norm + self.padding_set_zero = padding_set_zero + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerEncoderLayer( + self.hidden_size, + self.dropout, + kernel_size=ffn_kernel_size, + num_heads=num_heads, + padding_set_zero=padding_set_zero, + ) for _ in range(self.num_layers) + ] + ) + if self.use_last_norm: + self.layer_norm = nn.LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, x, padding_mask=None, attn_mask=None): + """ + :param x: [B, T, C] + :param padding_mask: [B, T] + :return: [B, T, C] or [L, B, T, C] + """ + if padding_mask is None: + padding_mask = torch.zeros(x.size(0), x.size(1)).to(x.device) + nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float( + )[:, :, None] # [T, B, 1] + # B x T x C -> T x B x C + x = x.transpose(0, 1) + if self.padding_set_zero: + x = x * nonpadding_mask_TB + for layer in self.layers: + x = layer( + x, encoder_padding_mask=padding_mask, attn_mask=attn_mask + ) + if self.padding_set_zero: + x = x * nonpadding_mask_TB + if self.use_last_norm: + x = self.layer_norm(x) + if self.padding_set_zero: + x = x * nonpadding_mask_TB + + x = x.transpose(0, 1) # [B, T, C] + return x + + +class FastSpeech2EncoderBase(nn.Module): + def __init__( + self, + d_model: int, + num_layers: int, + num_heads: int, + ffn_kernel_size: int, + d_out: int, + dropout: float = 0.1, + rel_pos: bool = True, + padding_set_zero: bool = True + ): + super().__init__() + self.rel_pos = rel_pos + + if self.rel_pos: + self.pos_encoding = RelPositionalEncoding( + d_model, dropout_rate=0.0 + ) + else: + self.pos_encoding = SinusoidalPositionalEmbedding( + d_model, padding_idx=0 + ) + self.dropout = dropout + self.embed_scale = math.sqrt(d_model) + + self.layers = FFTBlocks( + hidden_size=d_model, + num_layers=num_layers, + ffn_kernel_size=ffn_kernel_size, + dropout=dropout, + num_heads=num_heads, + use_last_norm=True, + padding_set_zero=padding_set_zero + ) + + self.out_proj = nn.Linear(d_model, d_out) + self.apply(self.init_weights) + + def init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0, std=m.embedding_dim**-0.5) + + +@dataclass +class SpkConfig: + encoding_format: str + num_spk: int | None = None + spk_embed_dim: int | None = None + + def __post_init__(self): + allowed_formats = {"id", "embedding"} + assert self.encoding_format in allowed_formats, f"mode must be one of {allowed_formats}, got '{self.encoding_format}'" + if self.encoding_format == "id": + assert self.num_spk is not None + if self.encoding_format == "embedding": + assert self.spk_embed_dim is not None + + +class FastSpeech2PhonemeEncoder(FastSpeech2EncoderBase): + def __init__( + self, + phone_vocab_size, + d_model, + num_layers, + num_heads, + ffn_kernel_size, + d_out, + dropout=0.1, + rel_pos=False, + spk_config: SpkConfig | None = None, + padding_set_zero: bool = True + ): + super().__init__( + d_model=d_model, + num_layers=num_layers, + num_heads=num_heads, + ffn_kernel_size=ffn_kernel_size, + d_out=d_out, + dropout=dropout, + rel_pos=rel_pos, + padding_set_zero=padding_set_zero + ) + self.phone_embed = Embedding(phone_vocab_size, d_model) + self.spk_config = spk_config + if spk_config is not None: + if spk_config.encoding_format == "id": + self.spk_embed_proj = Embedding( + spk_config.num_spk + 1, d_model + ) + elif spk_config.encoding_format == "embedding": + self.spk_embed_proj = Linear(spk_config.spk_embed_dim, d_model) + + def forward( + self, phoneme: torch.Tensor, lengths: Sequence[int], spk: torch.Tensor + ): + x = self.embed_scale * self.phone_embed(phoneme) + x = self.pos_encoding(x, lengths) + x = F.dropout(x, p=self.dropout, training=self.training) + + padding_mask = ~create_mask_from_length(lengths).to(phoneme.device) + x = self.layers(x, padding_mask=padding_mask) + + if self.spk_config is not None: + spk_embed = self.spk_embed_proj(spk).unsqueeze(1) + x = x + spk_embed + + x = self.out_proj(x) + + return {"output": x, "mask": ~padding_mask} + + +class FastSpeech2MIDIEncoder(FastSpeech2PhonemeEncoder): + def __init__( + self, + phone_vocab_size: int, + midi_vocab_size: int, + slur_vocab_size: int, + spk_config: SpkConfig | None, + d_model: int, + num_layers: int, + num_heads: int, + ffn_kernel_size: int, + d_out: int, + dropout: float = 0.1, + rel_pos: bool = True, + padding_set_zero: bool = True + ): + super().__init__( + phone_vocab_size=phone_vocab_size, + d_model=d_model, + num_layers=num_layers, + num_heads=num_heads, + ffn_kernel_size=ffn_kernel_size, + d_out=d_out, + dropout=dropout, + rel_pos=rel_pos, + spk_config=spk_config, + padding_set_zero=padding_set_zero + ) + self.midi_embed = Embedding(midi_vocab_size, d_model, padding_idx=0) + self.midi_dur_embed = Linear(1, d_model) + self.is_slur_embed = Embedding(slur_vocab_size, d_model) + + def forward( + self, + phoneme: torch.Tensor, + midi: torch.Tensor, + midi_duration: torch.Tensor, + is_slur: torch.Tensor, + lengths: Sequence[int], + spk: torch.Tensor | None = None, + ): + x = self.embed_scale * self.phone_embed(phoneme) + midi_embedding = self.midi_embed(midi) + midi_dur_embedding = self.midi_dur_embed(midi_duration[:, :, None]) + slur_embedding = self.is_slur_embed(is_slur) + + x = x + midi_embedding + midi_dur_embedding + slur_embedding + x = self.pos_encoding(x, lengths) + x = F.dropout(x, p=self.dropout, training=self.training) + + padding_mask = ~create_mask_from_length(lengths).to(phoneme.device) + x = self.layers(x, padding_mask=padding_mask) + + if self.spk_config is not None: + spk_embed = self.spk_embed_proj(spk).unsqueeze(1) + x = x + spk_embed + + x = self.out_proj(x) + + return {"output": x, "mask": ~padding_mask} + + +class FastSpeech2PitchEncoder(FastSpeech2EncoderBase): + def __init__( + self, + phone_vocab_size, + d_model, + num_layers, + num_heads, + ffn_kernel_size, + d_out, + dropout=0.1, + rel_pos=False, + padding_set_zero=True + ): + super().__init__( + d_model=d_model, + num_layers=num_layers, + num_heads=num_heads, + ffn_kernel_size=ffn_kernel_size, + d_out=d_out, + dropout=dropout, + rel_pos=rel_pos, + padding_set_zero=padding_set_zero + ) + self.phone_embed = Embedding(phone_vocab_size, d_model) + self.pitch_embed = Embedding(300, d_model) + + def forward(self, phoneme: torch.Tensor, lengths: Sequence[int]): + x = self.embed_scale * self.phone_embed(phoneme) + x = self.pos_encoding(x, lengths) + x = F.dropout(x, p=self.dropout, training=self.training) + + padding_mask = ~create_mask_from_length(lengths).to(phoneme.device) + x = self.layers(x, padding_mask=padding_mask) + + x = self.out_proj(x) + + return {"output": x, "mask": ~padding_mask} + + def encode_pitch(self, f0, uv): + + f0_denorm = denorm_f0(f0, uv) + pitch = f0_to_coarse(f0_denorm) + pitch_embed = self.pitch_embed(pitch) + return {"output": pitch_embed} diff --git a/models/content_encoder/sketch_encoder.py b/models/content_encoder/sketch_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..73d76bbf9de9073d97b1b5a7a244f7d90c0c6683 --- /dev/null +++ b/models/content_encoder/sketch_encoder.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + + +try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + DEVICE_TYPE = "npu" +except ModuleNotFoundError: + DEVICE_TYPE = "cuda" + +from .text_encoder import T5TextEncoder + +class SketchT5TextEncoder(T5TextEncoder): + def __init__( + self, f0_dim: int , energy_dim: int, latent_dim: int, + embed_dim: int, model_name: str = "google/flan-t5-large", + ): + super().__init__( + embed_dim = embed_dim, + model_name = model_name, + ) + self.f0_proj = nn.Linear(f0_dim, latent_dim) + self.f0_norm = nn.LayerNorm(f0_dim) + self.energy_proj = nn.Linear(energy_dim, latent_dim) + + def encode( + self, + text: list[str], + ): + with torch.no_grad(), torch.amp.autocast( + device_type=DEVICE_TYPE, enabled=False + ): + return super().encode(text) + + def encode_sketch( + self, + f0, + energy, + ): + f0_embed = self.f0_proj(self.f0_norm(f0)).unsqueeze(-1) + energy_embed = self.energy_proj(energy).unsqueeze(-1) + sketch_embed = torch.cat([f0_embed, energy_embed], dim=-1) + return {"output": sketch_embed} + + +if __name__ == "__main__": + text_encoder = T5TextEncoder(embed_dim=512) + text = ["a man is speaking", "a woman is singing while a dog is barking"] + + output = text_encoder(text) diff --git a/models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc b/models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26cc67fbc20d872513cc1172de628e5af9b13732 Binary files /dev/null and b/models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc differ diff --git a/models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc b/models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6849f1815b014e48523e415059510f90272db373 Binary files /dev/null and b/models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc differ diff --git a/models/content_encoder/star_encoder/star_encoder.py b/models/content_encoder/star_encoder/star_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..59fff9482841d1366cddeec5e6ce7d7b2220fcb4 --- /dev/null +++ b/models/content_encoder/star_encoder/star_encoder.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn + +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from Qformer import BertConfig, BertLMHeadModel + + +try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + DEVICE_TYPE = "npu" +except ModuleNotFoundError: + DEVICE_TYPE = "cuda" + +def generate_length_mask(lens, max_length=None): + lens = torch.as_tensor(lens) + N = lens.size(0) + if max_length is None: + max_length = max(lens) + idxs = torch.arange(max_length).repeat(N).view(N, max_length) + idxs = idxs.to(lens.device) + mask = (idxs < lens.view(-1, 1)).int() + return mask + +class QformerBridgeNet(torch.nn.Module): + def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32, + hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True, + load_from_pretrained: str = None): + super().__init__() + + self.Qformer_model_name = Qformer_model_name + self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token, speech_width=speech_width) + self.audio_Qformer.cls = None + self.audio_Qformer.bert.embeddings.word_embeddings = None + self.audio_Qformer.bert.embeddings.position_embeddings = None + for layer in self.audio_Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.freeze_QFormer = freeze_QFormer + if freeze_QFormer: + for name, param in self.audio_Qformer.named_parameters(): + param.requires_grad = False + self.audio_Qformer.eval() + self.audio_query_tokens.requires_grad = False + + self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size) + #torch.nn.init.xavier_uniform_(self.hiddin_projection.weight, gain=torch.nn.init.calculate_gain("relu")) + + if load_from_pretrained: + state_dict = torch.load(load_from_pretrained) + del_key = ["projection.weight", "projection.bias"] + del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key} + self.load_state_dict(del_state_dict) + print("Load adaptor_model_pt from", load_from_pretrained) + + + def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained(self.Qformer_model_name) + encoder_config.num_hidden_layers = num_hidden_layers + encoder_config.encoder_width = speech_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens, encoder_config + + def hidden(self, batch,): + audio_feature, lens = batch['embed'], batch['embed_len'] + frame_atts = generate_length_mask(lens).to(audio_feature.device) + audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1) + #frame_atts = torch.ones(audio_feature.size()[:-1], dtype=torch.long).to(audio_feature.device) + + #print(audio_query_tokens.shape, audio_feature.shape, frame_atts.shape) + audio_query_output=self.audio_Qformer.bert( + query_embeds=audio_query_tokens, #[32,768] + encoder_hidden_states=audio_feature, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + audio_hidden = audio_query_output.last_hidden_state + return audio_hidden + + def forward(self, batch) -> torch.Tensor: + with torch.no_grad(), torch.amp.autocast( + device_type=DEVICE_TYPE, enabled=False + ): + x = self.hidden(batch) + x = self.hiddin_projection(x) + + mask = torch.ones(x.shape[:2]) + mask = (mask == 1).to(x.device) + return {"output": x, "mask": mask} + + +if __name__ == '__main__': + text_encoder = T5TextEncoder() + text = ["a man is speaking", "a woman is singing while a dog is barking"] + text_encoder.eval() + with torch.no_grad(): + output = text_encoder(text) diff --git a/models/content_encoder/text_encoder.py b/models/content_encoder/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c64bbdd60deca4e7c170cb5b02636cb28f49e922 --- /dev/null +++ b/models/content_encoder/text_encoder.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel +from transformers.modeling_outputs import BaseModelOutput + +try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + DEVICE_TYPE = "npu" +except ModuleNotFoundError: + DEVICE_TYPE = "cuda" + + +class TransformersTextEncoderBase(nn.Module): + def __init__(self, model_name: str, embed_dim: int): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name) + self.proj = nn.Linear(self.model.config.hidden_size, embed_dim) + + def forward( + self, + text: list[str], + ): + output, mask = self.encode(text) + output = self.projection(output) + return {"output": output, "mask": mask} + + def encode(self, text: list[str]): + device = self.model.device + batch = self.tokenizer( + text, + max_length=self.tokenizer.model_max_length, + padding=True, + truncation=True, + return_tensors="pt", + ) + input_ids = batch.input_ids.to(device) + attention_mask = batch.attention_mask.to(device) + output: BaseModelOutput = self.model( + input_ids=input_ids, attention_mask=attention_mask + ) + output = output.last_hidden_state + mask = (attention_mask == 1).to(device) + return output, mask + + def projection(self, x): + return self.proj(x) + + +class T5TextEncoder(TransformersTextEncoderBase): + def __init__( + self, embed_dim: int, model_name: str = "google/flan-t5-large" + ): + nn.Module.__init__(self) + self.tokenizer = T5Tokenizer.from_pretrained(model_name) + self.model = T5EncoderModel.from_pretrained(model_name) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + self.proj = nn.Linear(self.model.config.hidden_size, embed_dim) + + def encode( + self, + text: list[str], + ): + with torch.no_grad(), torch.amp.autocast( + device_type=DEVICE_TYPE, enabled=False + ): + return super().encode(text) + + +if __name__ == "__main__": + text_encoder = T5TextEncoder(embed_dim=512) + text = ["a man is speaking", "a woman is singing while a dog is barking"] + + output = text_encoder(text) diff --git a/models/content_encoder/vision_encoder.py b/models/content_encoder/vision_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe5666cba2b78fc74681c2ce5cfff6c5e3730c6 --- /dev/null +++ b/models/content_encoder/vision_encoder.py @@ -0,0 +1,34 @@ +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.torch_utilities import create_mask_from_length + + +class MlpVideoEncoder(nn.Module): + def __init__( + self, + video_feat_dim: int, + embed_dim: int, + ): + super().__init__() + self.mlp = nn.Linear(video_feat_dim, embed_dim) + self.init_weights() + + def init_weights(self): + def _init_weights(module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0.) + + self.apply(_init_weights) + + def forward(self, frames: torch.Tensor, frame_nums: Sequence[int]): + device = frames.device + x = F.normalize(frames, p=2, dim=-1) + x = self.mlp(x) + mask = create_mask_from_length(frame_nums).to(device) + return {"output": x, "mask": mask} diff --git a/models/diffsinger_net.py b/models/diffsinger_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f802b0e684b532b7fb141348aabadba3d105152d --- /dev/null +++ b/models/diffsinger_net.py @@ -0,0 +1,119 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim-1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.dilated_conv = nn.Conv1d( + residual_channels, + 2 * residual_channels, + 3, + padding=dilation, + dilation=dilation + ) + self.diffusion_projection = nn.Linear( + residual_channels, residual_channels + ) + self.conditioner_projection = nn.Conv1d( + encoder_hidden, 2 * residual_channels, 1 + ) + self.output_projection = nn.Conv1d( + residual_channels, 2 * residual_channels, 1 + ) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step + ).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x+residual) / math.sqrt(2.0), skip + + +class DiffSingerNet(nn.Module): + def __init__( + self, + in_dims=128, + residual_channels=256, + encoder_hidden=256, + dilation_cycle_length=4, + residual_layers=20, + ): + super().__init__() + + # self.pe_scale = pe_scale + + self.input_projection = nn.Conv1d(in_dims, residual_channels, 1) + self.time_pos_emb = SinusoidalPosEmb(residual_channels) + dim = residual_channels + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock( + encoder_hidden, residual_channels, + 2**(i % dilation_cycle_length) + ) for i in range(residual_layers) + ]) + self.skip_projection = nn.Conv1d( + residual_channels, residual_channels, 1 + ) + self.output_projection = nn.Conv1d(residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, x, timesteps, context, x_mask=None, context_mask=None): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0] + ).to(x.device, dtype=torch.long) + + x = self.input_projection(x) # x [B, residual_channel, T] + + x = F.relu(x) + + t = self.time_pos_emb(timesteps) + t = self.mlp(t) + + cond = context + + skip = [] + for layer_id, layer in enumerate(self.residual_layers): + x, skip_connection = layer(x, cond, t) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), + dim=0) / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, M, T] + return x * x_mask.unsqueeze(1) diff --git a/models/diffusion.py b/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4f8f05c35c048a0c9f334985870e6ea8cd7fa7 --- /dev/null +++ b/models/diffusion.py @@ -0,0 +1,1261 @@ +from typing import Sequence +import random +from typing import Any +from pathlib import Path + +from tqdm import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import diffusers.schedulers as noise_schedulers +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor + +from models.autoencoder.autoencoder_base import AutoEncoderBase +from models.content_encoder.content_encoder import ContentEncoder +from models.content_adapter import ContentAdapterBase +from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase +from utils.torch_utilities import ( + create_alignment_path, create_mask_from_length, loss_with_mask, + trim_or_pad_length +) +from safetensors.torch import load_file + +class DiffusionMixin: + def __init__( + self, + noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", + snr_gamma: float = None, + cfg_drop_ratio: float = 0.2 + ) -> None: + self.noise_scheduler_name = noise_scheduler_name + self.snr_gamma = snr_gamma + self.classifier_free_guidance = cfg_drop_ratio > 0.0 + self.cfg_drop_ratio = cfg_drop_ratio + self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained( + self.noise_scheduler_name, subfolder="scheduler" + ) + + def compute_snr(self, timesteps) -> torch.Tensor: + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device + )[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., + None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma)**2 + return snr + + def get_timesteps( + self, + batch_size: int, + device: torch.device, + training: bool = True + ) -> torch.Tensor: + if training: + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (batch_size, ), + device=device + ) + else: + # validation on half of the total timesteps + timesteps = (self.noise_scheduler.config.num_train_timesteps // + 2) * torch.ones((batch_size, ), + dtype=torch.int64, + device=device) + + timesteps = timesteps.long() + return timesteps + + def get_target( + self, latent: torch.Tensor, noise: torch.Tensor, + timesteps: torch.Tensor + ) -> torch.Tensor: + """ + Get the target for loss depending on the prediction type + """ + if self.noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif self.noise_scheduler.config.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity( + latent, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" + ) + return target + + def loss_with_snr( + self, pred: torch.Tensor, target: torch.Tensor, + timesteps: torch.Tensor, mask: torch.Tensor, + loss_reduce: bool = True, + ) -> torch.Tensor: + if self.snr_gamma is None: + loss = F.mse_loss(pred.float(), target.float(), reduction="none") + loss = loss_with_mask(loss, mask, reduce=loss_reduce) + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006 + snr = self.compute_snr(timesteps) + mse_loss_weights = torch.stack( + [ + snr, + self.snr_gamma * torch.ones_like(timesteps), + ], + dim=1, + ).min(dim=1)[0] + # division by (snr + 1) does not work well, not clear about the reason + mse_loss_weights = mse_loss_weights / snr + loss = F.mse_loss(pred.float(), target.float(), reduction="none") + loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights + if loss_reduce: + loss = loss.mean() + return loss + + def rescale_cfg( + self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor, + guidance_rescale: float + ): + """ + Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_cond = pred_cond.std( + dim=list(range(1, pred_cond.ndim)), keepdim=True + ) + std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True) + + pred_rescaled = pred_cfg * (std_cond / std_cfg) + pred_cfg = guidance_rescale * pred_rescaled + ( + 1 - guidance_rescale + ) * pred_cfg + return pred_cfg + +class CrossAttentionAudioDiffusion( + LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, + DiffusionMixin +): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + content_adapter: ContentAdapterBase, + backbone: nn.Module, + duration_offset: float = 1.0, + noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", + snr_gamma: float = None, + cfg_drop_ratio: float = 0.2, + ): + nn.Module.__init__(self) + DiffusionMixin.__init__( + self, noise_scheduler_name, snr_gamma, cfg_drop_ratio + ) + + self.autoencoder = autoencoder + for param in self.autoencoder.parameters(): + param.requires_grad = False + + self.content_encoder = content_encoder + self.content_encoder.audio_encoder.model = self.autoencoder + self.content_adapter = content_adapter + self.backbone = backbone + self.duration_offset = duration_offset + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward( + self, content: list[Any], task: list[str], waveform: torch.Tensor, + waveform_lengths: torch.Tensor, instruction: torch.Tensor, + instruction_lengths: Sequence[int], **kwargs + ): + device = self.dummy_param.device + num_train_timesteps = self.noise_scheduler.config.num_train_timesteps + self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + content, content_mask = content_output["content"], content_output[ + "content_mask"] + instruction_mask = create_mask_from_length(instruction_lengths) + content, content_mask, global_duration_pred, _ = \ + self.content_adapter(content, content_mask, instruction, instruction_mask) + global_duration_target = torch.log( + latent_mask.sum(1) / self.autoencoder.latent_token_rate + + self.duration_offset + ) + global_duration_loss = F.mse_loss( + global_duration_target, global_duration_pred + ) + + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + content[mask_indices] = 0 + + batch_size = latent.shape[0] + timesteps = self.get_timesteps(batch_size, device, self.training) + noise = torch.randn_like(latent) + noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) + target = self.get_target(latent, noise, timesteps) + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + timesteps=timesteps, + context=content, + x_mask=latent_mask, + context_mask=content_mask + ) + + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) + + return { + "diff_loss": diff_loss, + "global_duration_loss": global_duration_loss, + } + + @torch.no_grad() + def inference( + self, + content: list[Any], + condition: list[Any], + task: list[str], + instruction: torch.Tensor, + instruction_lengths: Sequence[int], + scheduler: SchedulerMixin, + num_steps: int = 20, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + disable_progress: bool = True, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + content, content_mask = content_output["content"], content_output[ + "content_mask"] + + instruction_mask = create_mask_from_length(instruction_lengths) + content, content_mask, global_duration_pred, _ = \ + self.content_adapter(content, content_mask, instruction, instruction_mask) + batch_size = content.size(0) + + if classifier_free_guidance: + uncond_content = torch.zeros_like(content) + uncond_content_mask = content_mask.detach().clone() + content = torch.cat([uncond_content, content]) + content_mask = torch.cat([uncond_content_mask, content_mask]) + + scheduler.set_timesteps(num_steps, device=device) + timesteps = scheduler.timesteps + + global_duration_pred = torch.exp( + global_duration_pred + ) - self.duration_offset + global_duration_pred *= self.autoencoder.latent_token_rate + global_duration_pred = torch.round(global_duration_pred) + + latent_shape = tuple( + int(global_duration_pred.max().item()) if dim is None else dim + for dim in self.autoencoder.latent_shape + ) + latent = self.prepare_latent( + batch_size, scheduler, latent_shape, content.dtype, device + ) + latent_mask = create_mask_from_length(global_duration_pred).to( + content_mask.device + ) + if classifier_free_guidance: + latent_mask = torch.cat([latent_mask, latent_mask]) + + num_warmup_steps = len(timesteps) - num_steps * scheduler.order + progress_bar = tqdm(range(num_steps), disable=disable_progress) + + for i, timestep in enumerate(timesteps): + # expand the latent if we are doing classifier free guidance + latent_input = torch.cat([latent, latent] + ) if classifier_free_guidance else latent + latent_input = scheduler.scale_model_input(latent_input, timestep) + + noise_pred = self.backbone( + x=latent_input, + x_mask=latent_mask, + timesteps=timestep, + context=content, + context_mask=content_mask, + ) + + # perform guidance + if classifier_free_guidance: + noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_content - noise_pred_uncond + ) + if guidance_rescale != 0.0: + noise_pred = self.rescale_cfg( + noise_pred_content, noise_pred, guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latent = scheduler.step(noise_pred, timestep, latent).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and + (i + 1) % scheduler.order == 0): + progress_bar.update(1) + + waveform = self.autoencoder.decode(latent) + + return waveform + + def prepare_latent( + self, batch_size: int, scheduler: SchedulerMixin, + latent_shape: Sequence[int], dtype: torch.dtype, device: str + ): + shape = (batch_size, *latent_shape) + latent = randn_tensor( + shape, generator=None, device=device, dtype=dtype + ) + # scale the initial noise by the standard deviation required by the scheduler + latent = latent * scheduler.init_noise_sigma + return latent + +class SingleTaskCrossAttentionAudioDiffusion(CrossAttentionAudioDiffusion +): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + backbone: nn.Module, + pretrained_ckpt: str | Path = None, + noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", + snr_gamma: float = None, + cfg_drop_ratio: float = 0.2, + ): + nn.Module.__init__(self) + DiffusionMixin.__init__( + self, noise_scheduler_name, snr_gamma, cfg_drop_ratio + ) + + self.autoencoder = autoencoder + for param in self.autoencoder.parameters(): + param.requires_grad = False + + self.backbone = backbone + if pretrained_ckpt is not None: + pretrained_state_dict = load_file(pretrained_ckpt) + self.load_pretrained(pretrained_state_dict) + + self.content_encoder = content_encoder + #self.content_encoder.audio_encoder.model = self.autoencoder + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward( + self, content: list[Any], condition: list[Any], task: list[str], waveform: torch.Tensor, + waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs + ): + loss_reduce = self.training or (loss_reduce and not self.training) + device = self.dummy_param.device + num_train_timesteps = self.noise_scheduler.config.num_train_timesteps + self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + content, content_mask = content_output["content"], content_output[ + "content_mask"] + + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + content[mask_indices] = 0 + + batch_size = latent.shape[0] + timesteps = self.get_timesteps(batch_size, device, self.training) + noise = torch.randn_like(latent) + noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) + target = self.get_target(latent, noise, timesteps) + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + timesteps=timesteps, + context=content, + x_mask=latent_mask, + context_mask=content_mask + ) + + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask, loss_reduce=loss_reduce) + + return { + "diff_loss": diff_loss, + } + + @torch.no_grad() + def inference( + self, + content: list[Any], + condition: list[Any], + task: list[str], + scheduler: SchedulerMixin, + latent_shape: Sequence[int], + num_steps: int = 20, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + disable_progress: bool = True, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + content, content_mask = content_output["content"], content_output[ + "content_mask"] + batch_size = content.size(0) + + if classifier_free_guidance: + uncond_content = torch.zeros_like(content) + uncond_content_mask = content_mask.detach().clone() + content = torch.cat([uncond_content, content]) + content_mask = torch.cat([uncond_content_mask, content_mask]) + + scheduler.set_timesteps(num_steps, device=device) + timesteps = scheduler.timesteps + + latent = self.prepare_latent( + batch_size, scheduler, latent_shape, content.dtype, device + ) + + num_warmup_steps = len(timesteps) - num_steps * scheduler.order + progress_bar = tqdm(range(num_steps), disable=disable_progress) + + for i, timestep in enumerate(timesteps): + # expand the latent if we are doing classifier free guidance + latent_input = torch.cat([latent, latent] + ) if classifier_free_guidance else latent + latent_input = scheduler.scale_model_input(latent_input, timestep) + + noise_pred = self.backbone( + x=latent_input, + timesteps=timestep, + context=content, + context_mask=content_mask, + ) + + # perform guidance + if classifier_free_guidance: + noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_content - noise_pred_uncond + ) + if guidance_rescale != 0.0: + noise_pred = self.rescale_cfg( + noise_pred_content, noise_pred, guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latent = scheduler.step(noise_pred, timestep, latent).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and + (i + 1) % scheduler.order == 0): + progress_bar.update(1) + + waveform = self.autoencoder.decode(latent) + + return waveform + + +class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + content_adapter: ContentAdapterBase, + backbone: nn.Module, + content_dim: int, + frame_resolution: float, + duration_offset: float = 1.0, + noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", + snr_gamma: float = None, + cfg_drop_ratio: float = 0.2, + ): + """ + Args: + autoencoder: + Pretrained audio autoencoder that encodes raw waveforms into latent + space and decodes latents back to waveforms. + content_encoder: + Module that produces content embeddings (e.g., from text, MIDI, or + other modalities) used to guide the diffusion. + content_adapter (ContentAdapterBase): + Adapter module that fuses task instruction embeddings and content embeddings, + and performs duration prediction for time-aligned tasks. + backbone: + U‑Net or Transformer backbone that performs the core denoising + operations in latent space. + content_dim: + Dimension of the content embeddings produced by the `content_encoder` + and `content_adapter`. + frame_resolution: + Time resolution, in seconds, of each content frame when predicting + duration alignment. Used when calculating duration loss. + duration_offset: + A small positive offset (frame number) added to predicted durations + to ensure numerical stability of log-scaled duration prediction. + noise_scheduler_name: + Identifier of the pretrained noise scheduler to use. + snr_gamma: + Clipping value in min-SNR diffusion loss weighting strategy. + cfg_drop_ratio: + Probability of dropping the content conditioning during training + to support CFG. + """ + super().__init__( + autoencoder=autoencoder, + content_encoder=content_encoder, + content_adapter=content_adapter, + backbone=backbone, + duration_offset=duration_offset, + noise_scheduler_name=noise_scheduler_name, + snr_gamma=snr_gamma, + cfg_drop_ratio=cfg_drop_ratio, + ) + self.frame_resolution = frame_resolution + self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim)) + self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim)) + + def forward( + self, content, duration, task, is_time_aligned, waveform, + waveform_lengths, instruction, instruction_lengths, **kwargs + ): + device = self.dummy_param.device + num_train_timesteps = self.noise_scheduler.config.num_train_timesteps + self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + # content: (B, L, E) + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + length_aligned_content = content_output["length_aligned_content"] + content, content_mask = content_output["content"], content_output[ + "content_mask"] + instruction_mask = create_mask_from_length(instruction_lengths) + + content, content_mask, global_duration_pred, local_duration_pred = \ + self.content_adapter(content, content_mask, instruction, instruction_mask) + + n_frames = torch.round(duration / self.frame_resolution) + local_duration_target = torch.log(n_frames + self.duration_offset) + global_duration_target = torch.log( + latent_mask.sum(1) / self.autoencoder.latent_token_rate + + self.duration_offset + ) + + # truncate unused non time aligned duration prediction + if is_time_aligned.sum() > 0: + trunc_ta_length = content_mask[is_time_aligned].sum(1).max() + else: + trunc_ta_length = content.size(1) + + # local duration loss + local_duration_pred = local_duration_pred[:, :trunc_ta_length] + ta_content_mask = content_mask[:, :trunc_ta_length] + local_duration_target = local_duration_target.to( + dtype=local_duration_pred.dtype + ) + local_duration_loss = loss_with_mask( + (local_duration_target - local_duration_pred)**2, + ta_content_mask, + reduce=False + ) + local_duration_loss *= is_time_aligned + if is_time_aligned.sum().item() == 0: + local_duration_loss *= 0.0 + local_duration_loss = local_duration_loss.mean() + else: + local_duration_loss = local_duration_loss.sum( + ) / is_time_aligned.sum() + + # global duration loss + global_duration_loss = F.mse_loss( + global_duration_target, global_duration_pred + ) + + # -------------------------------------------------------------------- + # prepare latent and diffusion-related noise + # -------------------------------------------------------------------- + + batch_size = latent.shape[0] + timesteps = self.get_timesteps(batch_size, device, self.training) + noise = torch.randn_like(latent) + noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) + target = self.get_target(latent, noise, timesteps) + + # -------------------------------------------------------------------- + # duration adapter + # -------------------------------------------------------------------- + if is_time_aligned.sum() == 0 and \ + duration.size(1) < content_mask.size(1): + # for non time-aligned tasks like TTA, `duration` is dummy one + duration = F.pad( + duration, (0, content_mask.size(1) - duration.size(1)) + ) + n_latents = torch.round(duration * self.autoencoder.latent_token_rate) + # content_mask: [B, L], helper_latent_mask: [B, T] + helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to( + content_mask.device + ) + attn_mask = ta_content_mask.unsqueeze( + -1 + ) * helper_latent_mask.unsqueeze(1) + # attn_mask: [B, L, T] + align_path = create_alignment_path(n_latents, attn_mask) + time_aligned_content = content[:, :trunc_ta_length] + time_aligned_content = torch.matmul( + align_path.transpose(1, 2).to(content.dtype), time_aligned_content + ) # (B, T, L) x (B, L, E) -> (B, T, E) + + # -------------------------------------------------------------------- + # prepare input to the backbone + # -------------------------------------------------------------------- + # TODO compatility for 2D spectrogram VAE + latent_length = noisy_latent.size(self.autoencoder.time_dim) + time_aligned_content = trim_or_pad_length( + time_aligned_content, latent_length, 1 + ) + length_aligned_content = trim_or_pad_length( + length_aligned_content, latent_length, 1 + ) + # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) + # length_aligned_content: from aligned input (f0/energy) + time_aligned_content = time_aligned_content + length_aligned_content + time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( + time_aligned_content.dtype + ) + + context = content + context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) + # only use the first dummy non time aligned embedding + context_mask = content_mask.detach().clone() + context_mask[is_time_aligned, 1:] = False + + # truncate dummy non time aligned context + if is_time_aligned.sum().item() < batch_size: + trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() + else: + trunc_nta_length = content.size(1) + context = context[:, :trunc_nta_length] + context_mask = context_mask[:, :trunc_nta_length] + + # -------------------------------------------------------------------- + # classifier free guidance + # -------------------------------------------------------------------- + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + context[mask_indices] = 0 + time_aligned_content[mask_indices] = 0 + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + timesteps=timesteps, + time_aligned_context=time_aligned_content, + context=context, + x_mask=latent_mask, + context_mask=context_mask + ) + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) + return { + "diff_loss": diff_loss, + "local_duration_loss": local_duration_loss, + "global_duration_loss": global_duration_loss + } + + @torch.no_grad() + def inference( + self, + content: list[Any], + condition: list[Any], + task: list[str], + is_time_aligned: list[bool], + instruction: torch.Tensor, + instruction_lengths: Sequence[int], + scheduler: SchedulerMixin, + num_steps: int = 20, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + disable_progress: bool = True, + use_gt_duration: bool = False, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + length_aligned_content = content_output["length_aligned_content"] + content, content_mask = content_output["content"], content_output[ + "content_mask"] + instruction_mask = create_mask_from_length(instruction_lengths) + content, content_mask, global_duration_pred, local_duration_pred = \ + self.content_adapter(content, content_mask, instruction, instruction_mask) + + scheduler.set_timesteps(num_steps, device=device) + timesteps = scheduler.timesteps + batch_size = content.size(0) + + # truncate dummy time aligned duration prediction + is_time_aligned = torch.as_tensor(is_time_aligned) + if is_time_aligned.sum() > 0: + trunc_ta_length = content_mask[is_time_aligned].sum(1).max() + else: + trunc_ta_length = content.size(1) + + # prepare local duration + local_duration_pred = torch.exp(local_duration_pred) * content_mask + local_duration_pred = torch.ceil( + local_duration_pred + ) - self.duration_offset # frame number in `self.frame_resolution` + local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \ + self.autoencoder.latent_token_rate) + local_duration_pred = local_duration_pred[:, :trunc_ta_length] + # use ground truth duration + if use_gt_duration and "duration" in kwargs: + local_duration_pred = torch.round( + torch.as_tensor(kwargs["duration"]) * + self.autoencoder.latent_token_rate + ).to(device) + + # prepare global duration + global_duration = local_duration_pred.sum(1) + global_duration_pred = torch.exp( + global_duration_pred + ) - self.duration_offset + global_duration_pred *= self.autoencoder.latent_token_rate + global_duration_pred = torch.round(global_duration_pred) + global_duration[~is_time_aligned] = global_duration_pred[ + ~is_time_aligned] + + # -------------------------------------------------------------------- + # duration adapter + # -------------------------------------------------------------------- + time_aligned_content = content[:, :trunc_ta_length] + ta_content_mask = content_mask[:, :trunc_ta_length] + latent_mask = create_mask_from_length(global_duration).to( + content_mask.device + ) + attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) + # attn_mask: [B, L, T] + align_path = create_alignment_path(local_duration_pred, attn_mask) + time_aligned_content = torch.matmul( + align_path.transpose(1, 2).to(content.dtype), time_aligned_content + ) # (B, T, L) x (B, L, E) -> (B, T, E) + time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( + time_aligned_content.dtype + ) + + length_aligned_content = trim_or_pad_length( + length_aligned_content, time_aligned_content.size(1), 1 + ) + time_aligned_content = time_aligned_content + length_aligned_content + + # -------------------------------------------------------------------- + # prepare unconditional input + # -------------------------------------------------------------------- + context = content + context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) + context_mask = content_mask + context_mask[ + is_time_aligned, + 1:] = False # only use the first dummy non time aligned embedding + # truncate dummy non time aligned context + if is_time_aligned.sum().item() < batch_size: + trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() + else: + trunc_nta_length = content.size(1) + context = context[:, :trunc_nta_length] + context_mask = context_mask[:, :trunc_nta_length] + + if classifier_free_guidance: + uncond_time_aligned_content = torch.zeros_like( + time_aligned_content + ) + uncond_context = torch.zeros_like(context) + uncond_context_mask = context_mask.detach().clone() + time_aligned_content = torch.cat([ + uncond_time_aligned_content, time_aligned_content + ]) + context = torch.cat([uncond_context, context]) + context_mask = torch.cat([uncond_context_mask, context_mask]) + latent_mask = torch.cat([ + latent_mask, latent_mask.detach().clone() + ]) + + # -------------------------------------------------------------------- + # prepare input to the backbone + # -------------------------------------------------------------------- + latent_shape = tuple( + int(global_duration.max().item()) if dim is None else dim + for dim in self.autoencoder.latent_shape + ) + shape = (batch_size, *latent_shape) + latent = randn_tensor( + shape, generator=None, device=device, dtype=content.dtype + ) + # scale the initial noise by the standard deviation required by the scheduler + latent = latent * scheduler.init_noise_sigma + + num_warmup_steps = len(timesteps) - num_steps * scheduler.order + progress_bar = tqdm(range(num_steps), disable=disable_progress) + # -------------------------------------------------------------------- + # iteratively denoising + # -------------------------------------------------------------------- + for i, timestep in enumerate(timesteps): + # expand the latent if we are doing classifier free guidance + if classifier_free_guidance: + latent_input = torch.cat([latent, latent]) + else: + latent_input = latent + + latent_input = scheduler.scale_model_input(latent_input, timestep) + noise_pred = self.backbone( + x=latent_input, + x_mask=latent_mask, + timesteps=timestep, + time_aligned_context=time_aligned_content, + context=context, + context_mask=context_mask + ) + + if classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + if guidance_rescale != 0.0: + noise_pred = self.rescale_cfg( + noise_pred_cond, noise_pred, guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latent = scheduler.step(noise_pred, timestep, latent).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and + (i + 1) % scheduler.order == 0): + progress_bar.update(1) + + progress_bar.close() + + # TODO variable length decoding, using `latent_mask` + waveform = self.autoencoder.decode(latent) + return waveform + + +class DoubleContentAudioDiffusion(CrossAttentionAudioDiffusion): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + content_adapter: nn.Module, + backbone: nn.Module, + content_dim: int, + frame_resolution: float, + duration_offset: float = 1.0, + noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", + snr_gamma: float = None, + cfg_drop_ratio: float = 0.2, + ): + super().__init__( + autoencoder=autoencoder, + content_encoder=content_encoder, + content_adapter=content_adapter, + backbone=backbone, + duration_offset=duration_offset, + noise_scheduler_name=noise_scheduler_name, + snr_gamma=snr_gamma, + cfg_drop_ratio=cfg_drop_ratio + ) + self.frame_resolution = frame_resolution + + def forward( + self, content, duration, task, is_time_aligned, waveform, + waveform_lengths, instruction, instruction_lengths, **kwargs + ): + device = self.dummy_param.device + num_train_timesteps = self.noise_scheduler.config.num_train_timesteps + self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + length_aligned_content = content_output["length_aligned_content"] + content, content_mask = content_output["content"], content_output[ + "content_mask"] + context_mask = content_mask.detach() + instruction_mask = create_mask_from_length(instruction_lengths) + + content, content_mask, global_duration_pred, local_duration_pred = \ + self.content_adapter(content, content_mask, instruction, instruction_mask) + + # TODO if all non time aligned, content length > duration length + + n_frames = torch.round(duration / self.frame_resolution) + local_duration_target = torch.log(n_frames + self.duration_offset) + global_duration_target = torch.log( + latent_mask.sum(1) / self.autoencoder.latent_token_rate + + self.duration_offset + ) + # truncate unused non time aligned duration prediction + if is_time_aligned.sum() > 0: + trunc_ta_length = content_mask[is_time_aligned].sum(1).max() + else: + trunc_ta_length = content.size(1) + # local duration loss + local_duration_pred = local_duration_pred[:, :trunc_ta_length] + ta_content_mask = content_mask[:, :trunc_ta_length] + local_duration_target = local_duration_target.to( + dtype=local_duration_pred.dtype + ) + local_duration_loss = loss_with_mask( + (local_duration_target - local_duration_pred)**2, + ta_content_mask, + reduce=False + ) + local_duration_loss *= is_time_aligned + if is_time_aligned.sum().item() == 0: + local_duration_loss *= 0.0 + local_duration_loss = local_duration_loss.mean() + else: + local_duration_loss = local_duration_loss.sum( + ) / is_time_aligned.sum() + + # global duration loss + global_duration_loss = F.mse_loss( + global_duration_target, global_duration_pred + ) + # -------------------------------------------------------------------- + # prepare latent and diffusion-related noise + # -------------------------------------------------------------------- + batch_size = latent.shape[0] + timesteps = self.get_timesteps(batch_size, device, self.training) + noise = torch.randn_like(latent) + noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) + target = self.get_target(latent, noise, timesteps) + + # -------------------------------------------------------------------- + # duration adapter + # -------------------------------------------------------------------- + # content_mask: [B, L], helper_latent_mask: [B, T] + if is_time_aligned.sum() == 0 and \ + duration.size(1) < content_mask.size(1): + # for non time-aligned tasks like TTA, `duration` is dummy one + duration = F.pad( + duration, (0, content_mask.size(1) - duration.size(1)) + ) + n_latents = torch.round(duration * self.autoencoder.latent_token_rate) + helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to( + content_mask.device + ) + attn_mask = ta_content_mask.unsqueeze( + -1 + ) * helper_latent_mask.unsqueeze(1) + align_path = create_alignment_path(n_latents, attn_mask) + time_aligned_content = content[:, :trunc_ta_length] + time_aligned_content = torch.matmul( + align_path.transpose(1, 2).to(content.dtype), time_aligned_content + ) + + latent_length = noisy_latent.size(self.autoencoder.time_dim) + time_aligned_content = trim_or_pad_length( + time_aligned_content, latent_length, 1 + ) + length_aligned_content = trim_or_pad_length( + length_aligned_content, latent_length, 1 + ) + time_aligned_content = time_aligned_content + length_aligned_content + context = content + # -------------------------------------------------------------------- + # classifier free guidance + # -------------------------------------------------------------------- + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + context[mask_indices] = 0 + time_aligned_content[mask_indices] = 0 + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + timesteps=timesteps, + time_aligned_context=time_aligned_content, + context=context, + x_mask=latent_mask, + context_mask=context_mask, + ) + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) + return { + "diff_loss": diff_loss, + "local_duration_loss": local_duration_loss, + "global_duration_loss": global_duration_loss, + } + + @torch.no_grad() + def inference( + self, + content: list[Any], + condition: list[Any], + task: list[str], + is_time_aligned: list[bool], + instruction: torch.Tensor, + instruction_lengths: Sequence[int], + scheduler: SchedulerMixin, + num_steps: int = 20, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + disable_progress: bool = True, + use_gt_duration: bool = False, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + length_aligned_content = content_output["length_aligned_content"] + content, content_mask = content_output["content"], content_output[ + "content_mask"] + instruction_mask = create_mask_from_length(instruction_lengths) + + content, content_mask, global_duration_pred, local_duration_pred = \ + self.content_adapter(content, content_mask, instruction, instruction_mask) + + scheduler.set_timesteps(num_steps, device=device) + timesteps = scheduler.timesteps + batch_size = content.size(0) + + # truncate dummy time aligned duration prediction + is_time_aligned = torch.as_tensor(is_time_aligned) + if is_time_aligned.sum() > 0: + trunc_ta_length = content_mask[is_time_aligned].sum(1).max() + else: + trunc_ta_length = content.size(1) + + # prepare local duration + local_duration_pred = torch.exp(local_duration_pred) * content_mask + local_duration_pred = torch.ceil( + local_duration_pred + ) - self.duration_offset # frame number in `self.frame_resolution` + local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \ + self.autoencoder.latent_token_rate) + local_duration_pred = local_duration_pred[:, :trunc_ta_length] + # use ground truth duration + if use_gt_duration and "duration" in kwargs: + local_duration_pred = torch.round( + torch.as_tensor(kwargs["duration"]) * + self.autoencoder.latent_token_rate + ).to(device) + + # prepare global duration + global_duration = local_duration_pred.sum(1) + global_duration_pred = torch.exp( + global_duration_pred + ) - self.duration_offset + global_duration_pred *= self.autoencoder.latent_token_rate + global_duration_pred = torch.round(global_duration_pred) + global_duration[~is_time_aligned] = global_duration_pred[ + ~is_time_aligned] + + # -------------------------------------------------------------------- + # duration adapter + # -------------------------------------------------------------------- + time_aligned_content = content[:, :trunc_ta_length] + ta_content_mask = content_mask[:, :trunc_ta_length] + latent_mask = create_mask_from_length(global_duration).to( + content_mask.device + ) + attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) + # attn_mask: [B, L, T] + align_path = create_alignment_path(local_duration_pred, attn_mask) + time_aligned_content = torch.matmul( + align_path.transpose(1, 2).to(content.dtype), time_aligned_content + ) # (B, T, L) x (B, L, E) -> (B, T, E) + + # time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( + # time_aligned_content.dtype + # ) + + length_aligned_content = trim_or_pad_length( + length_aligned_content, time_aligned_content.size(1), 1 + ) + time_aligned_content = time_aligned_content + length_aligned_content + + # -------------------------------------------------------------------- + # prepare unconditional input + # -------------------------------------------------------------------- + context = content + # context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) + context_mask = content_mask + # context_mask[ + # is_time_aligned, + # 1:] = False # only use the first dummy non time aligned embedding + # # truncate dummy non time aligned context + # if is_time_aligned.sum().item() < batch_size: + # trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() + # else: + # trunc_nta_length = content.size(1) + # context = context[:, :trunc_nta_length] + # context_mask = context_mask[:, :trunc_nta_length] + + if classifier_free_guidance: + uncond_time_aligned_content = torch.zeros_like( + time_aligned_content + ) + uncond_context = torch.zeros_like(context) + uncond_context_mask = context_mask.detach().clone() + time_aligned_content = torch.cat([ + uncond_time_aligned_content, time_aligned_content + ]) + context = torch.cat([uncond_context, context]) + context_mask = torch.cat([uncond_context_mask, context_mask]) + latent_mask = torch.cat([ + latent_mask, latent_mask.detach().clone() + ]) + + # -------------------------------------------------------------------- + # prepare input to the backbone + # -------------------------------------------------------------------- + latent_shape = tuple( + int(global_duration.max().item()) if dim is None else dim + for dim in self.autoencoder.latent_shape + ) + shape = (batch_size, *latent_shape) + latent = randn_tensor( + shape, generator=None, device=device, dtype=content.dtype + ) + # scale the initial noise by the standard deviation required by the scheduler + latent = latent * scheduler.init_noise_sigma + + num_warmup_steps = len(timesteps) - num_steps * scheduler.order + progress_bar = tqdm(range(num_steps), disable=disable_progress) + # -------------------------------------------------------------------- + # iteratively denoising + # -------------------------------------------------------------------- + for i, timestep in enumerate(timesteps): + # expand the latent if we are doing classifier free guidance + if classifier_free_guidance: + latent_input = torch.cat([latent, latent]) + else: + latent_input = latent + + latent_input = scheduler.scale_model_input(latent_input, timestep) + noise_pred = self.backbone( + x=latent_input, + x_mask=latent_mask, + timesteps=timestep, + time_aligned_context=time_aligned_content, + context=context, + context_mask=context_mask + ) + + if classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + if guidance_rescale != 0.0: + noise_pred = self.rescale_cfg( + noise_pred_cond, noise_pred, guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latent = scheduler.step(noise_pred, timestep, latent).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and + (i + 1) % scheduler.order == 0): + progress_bar.update(1) + + progress_bar.close() + + # TODO variable length decoding, using `latent_mask` + waveform = self.autoencoder.decode(latent) + return waveform diff --git a/models/dit/__pycache__/attention.cpython-310.pyc b/models/dit/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d4715a10dc05eb3bc4aa9d5c489f6294f4905e3 Binary files /dev/null and b/models/dit/__pycache__/attention.cpython-310.pyc differ diff --git a/models/dit/__pycache__/audio_dit.cpython-310.pyc b/models/dit/__pycache__/audio_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..251f919f74fa76c7ed76bafa2300bdc1987913b9 Binary files /dev/null and b/models/dit/__pycache__/audio_dit.cpython-310.pyc differ diff --git a/models/dit/__pycache__/mask_dit.cpython-310.pyc b/models/dit/__pycache__/mask_dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f3709304a101c69d8d549b045c9bb9dfc8883f1 Binary files /dev/null and b/models/dit/__pycache__/mask_dit.cpython-310.pyc differ diff --git a/models/dit/__pycache__/modules.cpython-310.pyc b/models/dit/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3e698ff4e9a086e1c897ed8c02931a47c331880 Binary files /dev/null and b/models/dit/__pycache__/modules.cpython-310.pyc differ diff --git a/models/dit/__pycache__/rotary.cpython-310.pyc b/models/dit/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9abe202a53a40f634fd99ca83c61d7c05e558b4 Binary files /dev/null and b/models/dit/__pycache__/rotary.cpython-310.pyc differ diff --git a/models/dit/__pycache__/span_mask.cpython-310.pyc b/models/dit/__pycache__/span_mask.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5809b47940f981bc83b994b07fdbf93b27a7bc7d Binary files /dev/null and b/models/dit/__pycache__/span_mask.cpython-310.pyc differ diff --git a/models/dit/attention.py b/models/dit/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..862fb2070e536f344f15565fe67b135050098c57 --- /dev/null +++ b/models/dit/attention.py @@ -0,0 +1,349 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding +from .modules import RMSNorm + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + + b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device + q_mask = default( + q_mask, torch.ones((b, i), device=device, dtype=torch.bool) + ) + k_mask = default( + k_mask, torch.ones((b, j), device=device, dtype=torch.bool) + ) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1' + ) * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__( + self, + dim, + context_dim=None, + num_heads=8, + qkv_bias=False, + qk_scale=None, + qk_norm=None, + attn_drop=0., + proj_drop=0., + rope_mode='none' + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + if context_dim is None: + self.cross_attn = False + else: + self.cross_attn = True + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + + if qk_norm is None: + self.norm_q = nn.Identity() + self.norm_k = nn.Identity() + elif qk_norm == 'layernorm': + self.norm_q = nn.LayerNorm(head_dim) + self.norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + self.norm_q = RMSNorm(head_dim) + self.norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if self.cross_attn: + assert rope_mode == 'none' + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary( + q=q[:, :, extras:, :], k=k[:, :, extras:, :] + ) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x( + q=q[:, :, extras:, :], k=k[:, :, extras:, :] + ) + q_c, k_c = self.rotary_c( + q=q[:, :, :extras, :], k=k[:, :, :extras, :] + ) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary + ) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask( + attn, mask_binary + ) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def forward(self, x, context=None, context_mask=None, extras=0): + B, L, C = x.shape + if context is None: + context = x + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask( + x.shape, context.shape, x.device, None, context_mask + ) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads) + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads) + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads) + + q = self.norm_q(q) + k = self.norm_k(k) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JointAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + qk_norm=None, + attn_drop=0., + proj_drop=0., + rope_mode='none' + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers( + dim, qkv_bias + ) + self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers( + dim, qkv_bias + ) + + self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim) + self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim) + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + + self.proj_x = nn.Linear(dim, dim) + self.proj_drop_x = nn.Dropout(proj_drop) + + self.proj_c = nn.Linear(dim, dim) + self.proj_drop_c = nn.Dropout(proj_drop) + + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _make_qkv_layers(self, dim, qkv_bias): + return ( + nn.Linear(dim, dim, + bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias) + ) + + def _make_norm_layers(self, qk_norm, head_dim): + if qk_norm is None: + norm_q = nn.Identity() + norm_k = nn.Identity() + elif qk_norm == 'layernorm': + norm_q = nn.LayerNorm(head_dim) + norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + norm_q = RMSNorm(head_dim) + norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + return norm_q, norm_k + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary( + q=q[:, :, extras:, :], k=k[:, :, extras:, :] + ) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x( + q=q[:, :, extras:, :], k=k[:, :, extras:, :] + ) + q_c, k_c = self.rotary_c( + q=q[:, :, :extras, :], k=k[:, :, :extras, :] + ) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary + ) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask( + attn, mask_binary + ) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def _cat_mask(self, x, context, x_mask=None, context_mask=None): + B = x.shape[0] + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones( + B, context.shape[-2], device=context.device + ).bool() + mask = torch.cat([context_mask, x_mask], dim=1) + return mask + + def forward(self, x, context, x_mask=None, context_mask=None, extras=0): + B, Lx, C = x.shape + _, Lc, _ = context.shape + if x_mask is not None or context_mask is not None: + mask = self._cat_mask( + x, context, x_mask=x_mask, context_mask=context_mask + ) + shape = [B, Lx + Lc, C] + mask_binary = create_mask( + q_shape=shape, + k_shape=shape, + device=x.device, + q_mask=None, + k_mask=mask + ) + else: + mask_binary = None + + qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x) + qc, kc, vc = self.to_qc(context), self.to_kc(context + ), self.to_vc(context) + + qx, kx, vx = map( + lambda t: einops. + rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads), + [qx, kx, vx] + ) + qc, kc, vc = map( + lambda t: einops. + rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads), + [qc, kc, vc] + ) + + qx, kx = self.norm_qx(qx), self.norm_kx(kx) + qc, kc = self.norm_qc(qc), self.norm_kc(kc) + + q, k, v = ( + torch.cat([qc, qx], + dim=2), torch.cat([kc, kx], + dim=2), torch.cat([vc, vx], dim=2) + ) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + context, x = x[:, :Lc, :], x[:, Lc:, :] + + x = self.proj_x(x) + x = self.proj_drop_x(x) + + context = self.proj_c(context) + context = self.proj_drop_c(context) + + return x, context diff --git a/models/dit/audio_diffsingernet_dit.py b/models/dit/audio_diffsingernet_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5facb5c2316a04f0010a477ce0b6d7268d043a --- /dev/null +++ b/models/dit/audio_diffsingernet_dit.py @@ -0,0 +1,520 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .mask_dit import DiTBlock, FinalBlock, UDiT +from .modules import ( + film_modulate, + PatchEmbed, + PE_wrapper, + TimestepEmbedder, + RMSNorm, +) + + +class AudioDiTBlock(DiTBlock): + """ + A modified DiT block with time_aligned_context add to latent. + """ + def __init__( + self, + dim, + time_aligned_context_dim, + dilation, + context_dim=None, + num_heads=8, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer=nn.LayerNorm, + time_fusion='none', + ada_sola_rank=None, + ada_sola_alpha=None, + skip=False, + skip_norm=False, + rope_mode='none', + context_norm=False, + use_checkpoint=False + ): + super().__init__( + dim=dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=skip, + skip_norm=skip_norm, + rope_mode=rope_mode, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) + # time-aligned context projection + self.ta_context_projection = nn.Linear( + time_aligned_context_dim, 2 * dim + ) + self.dilated_conv = nn.Conv1d( + dim, 2 * dim, kernel_size=3, padding=dilation, dilation=dilation + ) + + def forward( + self, + x, + time_aligned_context, + time_token=None, + time_ada=None, + skip=None, + context=None, + x_mask=None, + context_mask=None, + extras=None + ): + if self.use_checkpoint: + return checkpoint( + self._forward, + x, + time_aligned_context, + time_token, + time_ada, + skip, + context, + x_mask, + context_mask, + extras, + use_reentrant=False + ) + else: + return self._forward( + x, + time_aligned_context, + time_token, + time_ada, + skip, + context, + x_mask, + context_mask, + extras, + ) + + def _forward( + self, + x, + time_aligned_context, + time_token=None, + time_ada=None, + skip=None, + context=None, + x_mask=None, + context_mask=None, + extras=None + ): + B, T, C = x.shape + if self.skip_linear is not None: + assert skip is not None + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, + gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + if self.use_adanorm: + x_norm = film_modulate( + self.norm1(x), shift=shift_msa, scale=scale_msa + ) + x = x + (1-gate_msa) * self.attn( + x_norm, context=None, context_mask=x_mask, extras=extras + ) + else: + # TODO diffusion timestep input is not fused here + x = x + self.attn( + self.norm1(x), + context=None, + context_mask=x_mask, + extras=extras + ) + + # time-aligned context + time_aligned_context = self.ta_context_projection(time_aligned_context) + x = self.dilated_conv(x.transpose(1, 2) + ).transpose(1, 2) + time_aligned_context + + gate, filter = torch.chunk(x, 2, dim=-1) + x = torch.sigmoid(gate) * torch.tanh(filter) + + # cross attention + if self.use_context: + assert context is not None + x = x + self.cross_attn( + x=self.norm2(x), + context=self.norm_context(context), + context_mask=context_mask, + extras=extras + ) + + # mlp + if self.use_adanorm: + x_norm = film_modulate( + self.norm3(x), shift=shift_mlp, scale=scale_mlp + ) + x = x + (1-gate_mlp) * self.mlp(x_norm) + else: + x = x + self.mlp(self.norm3(x)) + + return x + + +class AudioUDiT(UDiT): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + input_type='2d', + out_chans=None, + embed_dim=768, + depth=12, + dilation_cycle_length=4, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + time_fusion='token', + ada_sola_rank=None, + ada_sola_alpha=None, + cls_dim=None, + time_aligned_context_dim=768, + context_dim=768, + context_fusion='concat', + context_max_length=128, + context_pe_method='sinu', + pe_method='abs', + rope_mode='none', + use_conv=True, + skip=True, + skip_norm=True + ): + nn.Module.__init__(self) + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // + patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + input_type=input_type + ) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper( + dim=embed_dim, method=pe_method, length=num_patches + ) + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper( + dim=embed_dim, method='abs', length=self.extras + ) + elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + self.time_ada_final = nn.Linear( + embed_dim, 2 * embed_dim, bias=True + ) + if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper( + dim=embed_dim, + method=context_pe_method, + length=context_max_length + ) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper( + dim=embed_dim, + method=context_pe_method, + length=context_max_length + ) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + + self.use_skip = skip + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + self.in_blocks = nn.ModuleList([ + AudioDiTBlock( + dim=embed_dim, + time_aligned_context_dim=time_aligned_context_dim, + dilation=2**(i % dilation_cycle_length), + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=False, + skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) for i in range(depth // 2) + ]) + + self.mid_block = AudioDiTBlock( + dim=embed_dim, + time_aligned_context_dim=time_aligned_context_dim, + dilation=1, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=False, + skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) + + self.out_blocks = nn.ModuleList([ + AudioDiTBlock( + dim=embed_dim, + time_aligned_context_dim=time_aligned_context_dim, + dilation=2**(i % dilation_cycle_length), + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=skip, + skip_norm=skip_norm, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) for i in range(depth // 2) + ]) + + # FinalLayer block + self.use_conv = use_conv + self.final_block = FinalBlock( + embed_dim=embed_dim, + patch_size=patch_size, + img_size=img_size, + in_chans=out_chans, + input_type=input_type, + norm_layer=norm_layer, + use_conv=use_conv, + use_adanorm=self.use_adanorm + ) + self.initialize_weights() + + def forward( + self, + x, + timesteps, + time_aligned_context, + context, + x_mask=None, + context_mask=None, + cls_token=None, + controlnet_skips=None, + ): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0] + ).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context( + x=x, + context=context_token, + x_mask=x_mask, + context_mask=context_mask + ) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + time_ada_final = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + time_ada_final = self.time_ada_final(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat([ + torch.ones(B, time_token.shape[1], + device=x_mask.device).bool(), x_mask + ], + dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk( + x=x, + time_aligned_context=time_aligned_context, + time_token=time_token, + time_ada=time_ada, + skip=None, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + if self.use_skip: + skips.append(x) + + x = self.mid_block( + x=x, + time_aligned_context=time_aligned_context, + time_token=time_token, + time_ada=time_ada, + skip=None, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + for blk in self.out_blocks: + if self.use_skip: + skip = skips.pop() + if controlnet_skips: + # add to skip like u-net controlnet + skip = skip + controlnet_skips.pop() + else: + skip = None + if controlnet_skips: + # directly add to x + x = x + controlnet_skips.pop() + + x = blk( + x=x, + time_aligned_context=time_aligned_context, + time_token=time_token, + time_ada=time_ada, + skip=skip, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + + x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) + + return x diff --git a/models/dit/audio_dit.py b/models/dit/audio_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..5151857f9756bd50ba6aa0bc9ca4f798345c0b98 --- /dev/null +++ b/models/dit/audio_dit.py @@ -0,0 +1,652 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .mask_dit import DiTBlock, FinalBlock, UDiT +from .modules import ( + film_modulate, + PatchEmbed, + PE_wrapper, + TimestepEmbedder, + RMSNorm, +) + + +class LayerFusionDiTBlock(DiTBlock): + """ + A modified DiT block with time aligned context add to latent. + """ + def __init__( + self, + dim, + ta_context_dim, + ta_context_norm=False, + context_dim=None, + num_heads=8, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer=nn.LayerNorm, + ta_context_fusion='add', + time_fusion='none', + ada_sola_rank=None, + ada_sola_alpha=None, + skip=False, + skip_norm=False, + rope_mode='none', + context_norm=False, + use_checkpoint=False + ): + super().__init__( + dim=dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=skip, + skip_norm=skip_norm, + rope_mode=rope_mode, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) + self.ta_context_fusion = ta_context_fusion + self.ta_context_norm = ta_context_norm + if self.ta_context_fusion == "add": + self.ta_context_projection = nn.Linear( + ta_context_dim, dim, bias=False + ) + self.ta_context_norm = norm_layer( + ta_context_dim + ) if self.ta_context_norm else nn.Identity() + elif self.ta_context_fusion == "concat": + self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim) + self.ta_context_norm = norm_layer( + ta_context_dim + dim + ) if self.ta_context_norm else nn.Identity() + + def forward( + self, + x, + time_aligned_context, + time_token=None, + time_ada=None, + skip=None, + context=None, + x_mask=None, + context_mask=None, + extras=None + ): + if self.use_checkpoint: + return checkpoint( + self._forward, + x, + time_aligned_context, + time_token, + time_ada, + skip, + context, + x_mask, + context_mask, + extras, + use_reentrant=False + ) + else: + return self._forward( + x, + time_aligned_context, + time_token, + time_ada, + skip, + context, + x_mask, + context_mask, + extras, + ) + + def _forward( + self, + x, + time_aligned_context, + time_token=None, + time_ada=None, + skip=None, + context=None, + x_mask=None, + context_mask=None, + extras=None + ): + B, T, C = x.shape + + # skip connection + if self.skip_linear is not None: + assert skip is not None + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, + gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + if self.use_adanorm: + x_norm = film_modulate( + self.norm1(x), shift=shift_msa, scale=scale_msa + ) + tanh_gate_msa = torch.tanh(1 - gate_msa) + x = x + tanh_gate_msa * self.attn( + x_norm, context=None, context_mask=x_mask, extras=extras + ) + # x = x + (1 - gate_msa) * self.attn( + # x_norm, context=None, context_mask=x_mask, extras=extras + # ) + else: + # TODO diffusion timestep input is not fused here + x = x + self.attn( + self.norm1(x), + context=None, + context_mask=x_mask, + extras=extras + ) + + # time aligned context fusion + if self.ta_context_fusion == "add": + time_aligned_context = self.ta_context_projection( + self.ta_context_norm(time_aligned_context) + ) + if time_aligned_context.size(1) < x.size(1): + time_aligned_context = nn.functional.pad( + time_aligned_context, (0, 0, 1, 0) + ) + x = x + time_aligned_context + elif self.ta_context_fusion == "concat": + if time_aligned_context.size(1) < x.size(1): + time_aligned_context = nn.functional.pad( + time_aligned_context, (0, 0, 1, 0) + ) + cat = torch.cat([x, time_aligned_context], dim=-1) + cat = self.ta_context_norm(cat) + x = self.ta_context_projection(cat) + + # cross attention + if self.use_context: + assert context is not None + x = x + self.cross_attn( + x=self.norm2(x), + context=self.norm_context(context), + context_mask=context_mask, + extras=extras + ) + + # mlp + if self.use_adanorm: + x_norm = film_modulate( + self.norm3(x), shift=shift_mlp, scale=scale_mlp + ) + x = x + (1 - gate_mlp) * self.mlp(x_norm) + else: + x = x + self.mlp(self.norm3(x)) + + return x + + +class LayerFusionAudioDiT(UDiT): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + input_type='2d', + out_chans=None, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + time_fusion='token', + ada_sola_rank=None, + ada_sola_alpha=None, + cls_dim=None, + ta_context_dim=768, + ta_context_fusion='concat', + ta_context_norm=True, + context_dim=768, + context_fusion='concat', + context_max_length=128, + context_pe_method='sinu', + pe_method='abs', + rope_mode='none', + use_conv=True, + skip=True, + skip_norm=True + ): + nn.Module.__init__(self) + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // + patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + input_type=input_type + ) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper( + dim=embed_dim, method=pe_method, length=num_patches + ) + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper( + dim=embed_dim, method='abs', length=self.extras + ) + elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + self.time_ada_final = nn.Linear( + embed_dim, 2 * embed_dim, bias=True + ) + if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper( + dim=embed_dim, + method=context_pe_method, + length=context_max_length + ) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper( + dim=embed_dim, + method=context_pe_method, + length=context_max_length + ) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + + self.use_skip = skip + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + self.in_blocks = nn.ModuleList([ + LayerFusionDiTBlock( + dim=embed_dim, + ta_context_dim=ta_context_dim, + ta_context_fusion=ta_context_fusion, + ta_context_norm=ta_context_norm, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=False, + skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) for i in range(depth // 2) + ]) + + self.mid_block = LayerFusionDiTBlock( + dim=embed_dim, + ta_context_dim=ta_context_dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + ta_context_fusion=ta_context_fusion, + ta_context_norm=ta_context_norm, + skip=False, + skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) + + self.out_blocks = nn.ModuleList([ + LayerFusionDiTBlock( + dim=embed_dim, + ta_context_dim=ta_context_dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + ta_context_fusion=ta_context_fusion, + ta_context_norm=ta_context_norm, + skip=skip, + skip_norm=skip_norm, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) for i in range(depth // 2) + ]) + + # FinalLayer block + self.use_conv = use_conv + self.final_block = FinalBlock( + embed_dim=embed_dim, + patch_size=patch_size, + img_size=img_size, + in_chans=out_chans, + input_type=input_type, + norm_layer=norm_layer, + use_conv=use_conv, + use_adanorm=self.use_adanorm + ) + self.initialize_weights() + + def forward( + self, + x, + timesteps, + time_aligned_context, + context, + x_mask=None, + context_mask=None, + cls_token=None, + controlnet_skips=None, + ): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0] + ).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context( + x=x, + context=context_token, + x_mask=x_mask, + context_mask=context_mask + ) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + time_ada_final = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + time_ada_final = self.time_ada_final(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat([ + torch.ones(B, time_token.shape[1], + device=x_mask.device).bool(), x_mask + ], + dim=1) + time_token = None + + skips = [] + for blk_idx, blk in enumerate(self.in_blocks): + x = blk( + x=x, + time_aligned_context=time_aligned_context, + time_token=time_token, + time_ada=time_ada, + skip=None, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + # if not self.training: + # print( + # f"in block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}" + # ) + if self.use_skip: + skips.append(x) + + x = self.mid_block( + x=x, + time_aligned_context=time_aligned_context, + time_token=time_token, + time_ada=time_ada, + skip=None, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + for blk_idx, blk in enumerate(self.out_blocks): + if self.use_skip: + skip = skips.pop() + if controlnet_skips: + # add to skip like u-net controlnet + skip = skip + controlnet_skips.pop() + else: + skip = None + if controlnet_skips: + # directly add to x + x = x + controlnet_skips.pop() + + x = blk( + x=x, + time_aligned_context=time_aligned_context, + time_token=time_token, + time_ada=time_ada, + skip=skip, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + # if not self.training: + # print( + # f"out block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}" + # ) + + x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) + + return x + + +class InputFusionAudioDiT(UDiT): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + input_type='2d', + out_chans=None, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + time_fusion='token', + ada_sola_rank=None, + ada_sola_alpha=None, + cls_dim=None, + ta_context_dim=768, + context_dim=768, + context_fusion='concat', + context_max_length=128, + context_pe_method='sinu', + pe_method='abs', + rope_mode='none', + use_conv=True, + skip=True, + skip_norm=True + ): + super().__init__( + img_size, + patch_size, + in_chans, + input_type, + out_chans, + embed_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + qk_norm, + act_layer, + norm_layer, + context_norm, + use_checkpoint, + time_fusion, + ada_sola_rank, + ada_sola_alpha, + cls_dim, + context_dim, + context_fusion, + context_max_length, + context_pe_method, + pe_method, + rope_mode, + use_conv, + skip, + skip_norm, + ) + self.input_proj = nn.Linear(in_chans + ta_context_dim, in_chans) + nn.init.xavier_uniform_(self.input_proj.weight) + nn.init.constant_(self.input_proj.bias, 0) + + def forward( + self, + x, + timesteps, + time_aligned_context, + context, + x_mask=None, + context_mask=None, + cls_token=None, + controlnet_skips=None + ): + x = self.input_proj( + torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1) + ) + x = x.transpose(1, 2) + return super().forward( + x=x, + timesteps=timesteps, + context=context, + x_mask=x_mask, + context_mask=context_mask, + cls_token=cls_token, + controlnet_skips=controlnet_skips + ) diff --git a/models/dit/mask_dit.py b/models/dit/mask_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..5c86260669d9ed0a9c4018069e55125d026894fb --- /dev/null +++ b/models/dit/mask_dit.py @@ -0,0 +1,823 @@ +import logging +import math +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .modules import ( + film_modulate, + unpatchify, + PatchEmbed, + PE_wrapper, + TimestepEmbedder, + FeedForward, + RMSNorm, +) +from .span_mask import compute_mask_indices +from .attention import Attention + +logger = logging.Logger(__file__) + + +class AdaLN(nn.Module): + def __init__(self, dim, ada_mode='ada', r=None, alpha=None): + super().__init__() + self.ada_mode = ada_mode + self.scale_shift_table = None + if ada_mode == 'ada': + # move nn.silu outside + self.time_ada = nn.Linear(dim, 6 * dim, bias=True) + elif ada_mode == 'ada_single': + # adaln used in pixel-art alpha + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + elif ada_mode in ['ada_sola', 'ada_sola_bias']: + self.lora_a = nn.Linear(dim, r * 6, bias=False) + self.lora_b = nn.Linear(r * 6, dim * 6, bias=False) + self.scaling = alpha / r + if ada_mode == 'ada_sola_bias': + # take bias out for consistency + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + else: + raise NotImplementedError + + def forward(self, time_token=None, time_ada=None): + if self.ada_mode == 'ada': + assert time_ada is None + B = time_token.shape[0] + time_ada = self.time_ada(time_token).reshape(B, 6, -1) + elif self.ada_mode == 'ada_single': + B = time_ada.shape[0] + time_ada = time_ada.reshape(B, 6, -1) + time_ada = self.scale_shift_table[None] + time_ada + elif self.ada_mode in ['ada_sola', 'ada_sola_bias']: + B = time_ada.shape[0] + time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling + time_ada = time_ada + time_ada_lora + time_ada = time_ada.reshape(B, 6, -1) + if self.scale_shift_table is not None: + time_ada = self.scale_shift_table[None] + time_ada + else: + raise NotImplementedError + return time_ada + + +class DiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + def __init__( + self, + dim, + context_dim=None, + num_heads=8, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer=nn.LayerNorm, + time_fusion='none', + ada_sola_rank=None, + ada_sola_alpha=None, + skip=False, + skip_norm=False, + rope_mode='none', + context_norm=False, + use_checkpoint=False + ): + + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode + ) + + if context_dim is not None: + self.use_context = True + self.cross_attn = Attention( + dim=dim, + num_heads=num_heads, + context_dim=context_dim, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode='none' + ) + self.norm2 = norm_layer(dim) + if context_norm: + self.norm_context = norm_layer(context_dim) + else: + self.norm_context = nn.Identity() + else: + self.use_context = False + + self.norm3 = norm_layer(dim) + self.mlp = FeedForward( + dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0 + ) + + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN( + dim, + ada_mode=time_fusion, + r=ada_sola_rank, + alpha=ada_sola_alpha + ) + if skip: + self.skip_norm = norm_layer(2 * + dim) if skip_norm else nn.Identity() + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + self.use_checkpoint = use_checkpoint + + def forward( + self, + x, + time_token=None, + time_ada=None, + skip=None, + context=None, + x_mask=None, + context_mask=None, + extras=None + ): + if self.use_checkpoint: + return checkpoint( + self._forward, + x, + time_token, + time_ada, + skip, + context, + x_mask, + context_mask, + extras, + use_reentrant=False + ) + else: + return self._forward( + x, time_token, time_ada, skip, context, x_mask, context_mask, + extras + ) + + def _forward( + self, + x, + time_token=None, + time_ada=None, + skip=None, + context=None, + x_mask=None, + context_mask=None, + extras=None + ): + B, T, C = x.shape + if self.skip_linear is not None: + assert skip is not None + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, + gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + if self.use_adanorm: + x_norm = film_modulate( + self.norm1(x), shift=shift_msa, scale=scale_msa + ) + x = x + (1 - gate_msa) * self.attn( + x_norm, context=None, context_mask=x_mask, extras=extras + ) + else: + x = x + self.attn( + self.norm1(x), + context=None, + context_mask=x_mask, + extras=extras + ) + + # cross attention + if self.use_context: + assert context is not None + x = x + self.cross_attn( + x=self.norm2(x), + context=self.norm_context(context), + context_mask=context_mask, + extras=extras + ) + + # mlp + if self.use_adanorm: + x_norm = film_modulate( + self.norm3(x), shift=shift_mlp, scale=scale_mlp + ) + x = x + (1 - gate_mlp) * self.mlp(x_norm) + else: + x = x + self.mlp(self.norm3(x)) + + return x + + +class FinalBlock(nn.Module): + def __init__( + self, + embed_dim, + patch_size, + in_chans, + img_size, + input_type='2d', + norm_layer=nn.LayerNorm, + use_conv=True, + use_adanorm=True + ): + super().__init__() + self.in_chans = in_chans + self.img_size = img_size + self.input_type = input_type + + self.norm = norm_layer(embed_dim) + if use_adanorm: + self.use_adanorm = True + else: + self.use_adanorm = False + + if input_type == '2d': + self.patch_dim = patch_size**2 * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv2d( + self.in_chans, self.in_chans, 3, padding=1 + ) + else: + self.final_layer = nn.Identity() + + elif input_type == '1d': + self.patch_dim = patch_size * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv1d( + self.in_chans, self.in_chans, 3, padding=1 + ) + else: + self.final_layer = nn.Identity() + + def forward(self, x, time_ada=None, extras=0): + B, T, C = x.shape + x = x[:, extras:, :] + # only handle generation target + if self.use_adanorm: + shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1) + x = film_modulate(self.norm(x), shift, scale) + else: + x = self.norm(x) + x = self.linear(x) + x = unpatchify(x, self.in_chans, self.input_type, self.img_size) + x = self.final_layer(x) + return x + + +class UDiT(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + input_type='2d', + out_chans=None, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + qk_norm=None, + act_layer='gelu', + norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + # time fusion ada or token + time_fusion='token', + ada_sola_rank=None, + ada_sola_alpha=None, + cls_dim=None, + # max length is only used for concat + context_dim=768, + context_fusion='concat', + context_max_length=128, + context_pe_method='sinu', + pe_method='abs', + rope_mode='none', + use_conv=True, + skip=True, + skip_norm=True + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // + patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + input_type=input_type + ) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper( + dim=embed_dim, method=pe_method, length=num_patches + ) + + logger.info(f'x position embedding: {pe_method}') + logger.info(f'rope mode: {self.rope}') + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper( + dim=embed_dim, method='abs', length=self.extras + ) + elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + self.time_ada_final = nn.Linear( + embed_dim, 2 * embed_dim, bias=True + ) + if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + logger.info(f'time fusion mode: {self.time_fusion}') + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper( + dim=embed_dim, + method=context_pe_method, + length=context_max_length + ) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper( + dim=embed_dim, + method=context_pe_method, + length=context_max_length + ) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + logger.info(f'context fusion mode: {context_fusion}') + logger.info(f'context position embedding: {context_pe_method}') + + self.use_skip = skip + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + logger.info(f'use long skip connection: {skip}') + self.in_blocks = nn.ModuleList([ + DiTBlock( + dim=embed_dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=False, + skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) for _ in range(depth // 2) + ]) + + self.mid_block = DiTBlock( + dim=embed_dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=False, + skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) + + self.out_blocks = nn.ModuleList([ + DiTBlock( + dim=embed_dim, + context_dim=context_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + qk_norm=qk_norm, + act_layer=act_layer, + norm_layer=norm_layer, + time_fusion=time_fusion, + ada_sola_rank=ada_sola_rank, + ada_sola_alpha=ada_sola_alpha, + skip=skip, + skip_norm=skip_norm, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint + ) for _ in range(depth // 2) + ]) + + # FinalLayer block + self.use_conv = use_conv + self.final_block = FinalBlock( + embed_dim=embed_dim, + patch_size=patch_size, + img_size=img_size, + in_chans=out_chans, + input_type=input_type, + norm_layer=norm_layer, + use_conv=use_conv, + use_adanorm=self.use_adanorm + ) + self.initialize_weights() + + def _init_ada(self): + if self.time_fusion == 'ada': + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + elif self.time_fusion == 'ada_single': + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + elif self.time_fusion in ['ada_sola', 'ada_sola_bias']: + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.kaiming_uniform_( + block.adaln.lora_a.weight, a=math.sqrt(5) + ) + nn.init.constant_(block.adaln.lora_b.weight, 0) + nn.init.kaiming_uniform_( + self.mid_block.adaln.lora_a.weight, a=math.sqrt(5) + ) + nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) + for block in self.out_blocks: + nn.init.kaiming_uniform_( + block.adaln.lora_a.weight, a=math.sqrt(5) + ) + nn.init.constant_(block.adaln.lora_b.weight, 0) + + def initialize_weights(self): + # Basic init for all layers + def _basic_init(module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # init patch Conv like Linear + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.patch_embed.proj.bias, 0) + + # Zero-out AdaLN + if self.use_adanorm: + self._init_ada() + + # Zero-out Cross Attention + if self.context_cross: + for block in self.in_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out cls embedding + if self.cls_embed: + if self.use_adanorm: + nn.init.constant_(self.cls_embed[-1].weight, 0) + nn.init.constant_(self.cls_embed[-1].bias, 0) + + # Zero-out Output + # might not zero-out this when using v-prediction + # it could be good when using noise-prediction + # nn.init.constant_(self.final_block.linear.weight, 0) + # nn.init.constant_(self.final_block.linear.bias, 0) + # if self.use_conv: + # nn.init.constant_(self.final_block.final_layer.weight.data, 0) + # nn.init.constant_(self.final_block.final_layer.bias, 0) + + # init out Conv + if self.use_conv: + nn.init.xavier_uniform_(self.final_block.final_layer.weight) + nn.init.constant_(self.final_block.final_layer.bias, 0) + + def _concat_x_context(self, x, context, x_mask=None, context_mask=None): + assert context.shape[-2] == self.context_max_length + # Check if either x_mask or context_mask is provided + B = x.shape[0] + # Create default masks if they are not provided + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones( + B, context.shape[-2], device=context.device + ).bool() + # Concatenate the masks along the second dimension (dim=1) + x_mask = torch.cat([context_mask, x_mask], dim=1) + # Concatenate context and x along the second dimension (dim=1) + x = torch.cat((context, x), dim=1) + return x, x_mask + + def forward( + self, + x, + timesteps, + context, + x_mask=None, + context_mask=None, + cls_token=None, + controlnet_skips=None, + ): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0] + ).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context( + x=x, + context=context_token, + x_mask=x_mask, + context_mask=context_mask + ) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + time_ada_final = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + time_ada_final = self.time_ada_final(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat([ + torch.ones(B, time_token.shape[1], + device=x_mask.device).bool(), x_mask + ], + dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk( + x=x, + time_token=time_token, + time_ada=time_ada, + skip=None, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + if self.use_skip: + skips.append(x) + + x = self.mid_block( + x=x, + time_token=time_token, + time_ada=time_ada, + skip=None, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + for blk in self.out_blocks: + if self.use_skip: + skip = skips.pop() + if controlnet_skips: + # add to skip like u-net controlnet + skip = skip + controlnet_skips.pop() + else: + skip = None + if controlnet_skips: + # directly add to x + x = x + controlnet_skips.pop() + + x = blk( + x=x, + time_token=time_token, + time_ada=time_ada, + skip=skip, + context=context_token, + x_mask=x_mask, + context_mask=context_mask, + extras=self.extras + ) + + x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) + + return x + + +class MaskDiT(nn.Module): + def __init__( + self, + model: UDiT, + mae=False, + mae_prob=0.5, + mask_ratio=[0.25, 1.0], + mask_span=10, + ): + super().__init__() + self.model = model + self.mae = mae + if self.mae: + out_channel = model.out_chans + self.mask_embed = nn.Parameter(torch.zeros((out_channel))) + self.mae_prob = mae_prob + self.mask_ratio = mask_ratio + self.mask_span = mask_span + + def random_masking(self, gt, mask_ratios, mae_mask_infer=None): + B, D, L = gt.shape + if mae_mask_infer is None: + # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) + mask_ratios = mask_ratios.cpu().numpy() + mask = compute_mask_indices( + shape=[B, L], + padding_mask=None, + mask_prob=mask_ratios, + mask_length=self.mask_span, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0, + ) + mask = mask.unsqueeze(1).expand_as(gt) + else: + mask = mae_mask_infer + mask = mask.expand_as(gt) + gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] + return gt, mask.type_as(gt) + + def forward( + self, + x, + timesteps, + context, + x_mask=None, + context_mask=None, + cls_token=None, + gt=None, + mae_mask_infer=None, + forward_model=True + ): + # todo: handle controlnet inside + mae_mask = torch.ones_like(x) + if self.mae: + if gt is not None: + B, D, L = gt.shape + mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio + ).to(gt.device) + gt, mae_mask = self.random_masking( + gt, mask_ratios, mae_mask_infer + ) + # apply mae only to the selected batches + if mae_mask_infer is None: + # determine mae batch + mae_batch = torch.rand(B) < self.mae_prob + gt[~mae_batch] = self.mask_embed.view( + 1, D, 1 + ).expand_as(gt)[~mae_batch] + mae_mask[~mae_batch] = 1.0 + else: + B, D, L = x.shape + gt = self.mask_embed.view(1, D, 1).expand_as(x) + x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) + + if forward_model: + x = self.model( + x=x, + timesteps=timesteps, + context=context, + x_mask=x_mask, + context_mask=context_mask, + cls_token=cls_token + ) + # logger.info(mae_mask[:, 0, :].sum(dim=-1)) + return x, mae_mask diff --git a/models/dit/modules.py b/models/dit/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..61ee20880e40a491815fea142a7bf6e1b800467f --- /dev/null +++ b/models/dit/modules.py @@ -0,0 +1,445 @@ +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.cuda.amp import autocast +import math +import einops +from einops import rearrange, repeat +from inspect import isfunction + + +def trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2 + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +# disable in checkpoint mode +# @torch.jit.script +def film_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * + torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, + torch.zeros_like(embedding[:, :1])], + dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__( + self, hidden_size, frequency_embedding_size=256, out_size=None + ): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type( + self.mlp[0].weight.dtype + ) + t_emb = self.mlp(t_freq) + return t_emb + + +def patchify(imgs, patch_size, input_type='2d'): + if input_type == '2d': + x = einops.rearrange( + imgs, + 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', + p1=patch_size, + p2=patch_size + ) + elif input_type == '1d': + x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size) + return x + + +def unpatchify(x, channels=3, input_type='2d', img_size=None): + if input_type == '2d': + patch_size = int((x.shape[2] // channels)**0.5) + # h = w = int(x.shape[1] ** .5) + h, w = img_size[0] // patch_size, img_size[1] // patch_size + assert h * w == x.shape[1] and patch_size**2 * channels == x.shape[2] + x = einops.rearrange( + x, + 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', + h=h, + p1=patch_size, + p2=patch_size + ) + elif input_type == '1d': + patch_size = int((x.shape[2] // channels)) + h = x.shape[1] + assert patch_size * channels == x.shape[2] + x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'): + super().__init__() + self.patch_size = patch_size + self.input_type = input_type + if input_type == '2d': + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=True + ) + elif input_type == '1d': + self.proj = nn.Conv1d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=True + ) + + def forward(self, x): + if self.input_type == '2d': + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + elif self.input_type == '1d': + B, C, H = x.shape + assert H % self.patch_size == 0 + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PositionalConvEmbedding(nn.Module): + """ + Convolutional positional embedding used in F5-TTS. + """ + def __init__(self, dim=768, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d( + dim, dim, kernel_size, groups=groups, padding=kernel_size // 2 + ), + nn.Mish(), + nn.Conv1d( + dim, dim, kernel_size, groups=groups, padding=kernel_size // 2 + ), + nn.Mish(), + ) + + def forward(self, x): + # B T C + x = self.conv1d(x.transpose(1, 2)) + x = x.transpose(1, 2) + return x + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, dim, length): + super(SinusoidalPositionalEncoding, self).__init__() + self.length = length + self.dim = dim + self.register_buffer( + 'pe', self._generate_positional_encoding(length, dim) + ) + + def _generate_positional_encoding(self, length, dim): + pe = torch.zeros(length, dim) + position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + return pe + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + + +class PE_wrapper(nn.Module): + def __init__(self, dim=768, method='abs', length=None, **kwargs): + super().__init__() + self.method = method + if method == 'abs': + # init absolute pe like UViT + self.length = length + self.abs_pe = nn.Parameter(torch.zeros(1, length, dim)) + trunc_normal_(self.abs_pe, mean=0.0, std=.02, a=-.04, b=.04) + elif method == 'conv': + self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs) + elif method == 'sinu': + self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length) + elif method == 'none': + # skip pe + self.id = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x): + if self.method == 'abs': + _, L, _ = x.shape + assert L <= self.length + x = x + self.abs_pe[:, :L, :] + elif self.method == 'conv': + x = x + self.conv_pe(x) + elif self.method == 'sinu': + x = self.sinu_pe(x) + elif self.method == 'none': + x = self.id(x) + else: + raise NotImplementedError + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class GELU(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + approximate: str = "none", + bias: bool = True + ): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu( + gate.to(dtype=torch.float32), approximate=self.approximate + ).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +# disable in checkpoint mode +# @torch.jit.script +def snake_beta(x, alpha, beta): + return x + beta * torch.sin(x * alpha).pow(2) + + +class Snake(nn.Module): + def __init__(self, dim_in, dim_out, bias, alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x = snake_beta(x, self.alpha, self.beta) + return x + + +class GESnake(nn.Module): + def __init__(self, dim_in, dim_out, bias, alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * snake_beta(gate, self.alpha, self.beta) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + dropout=0.0, + activation_fn="geglu", + final_dropout=False, + inner_dim=None, + bias=True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "snake": + act_fn = Snake(dim, inner_dim, bias=bias) + elif activation_fn == "gesnake": + act_fn = GESnake(dim, inner_dim, bias=bias) + else: + raise NotImplementedError + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/models/dit/rotary.py b/models/dit/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..f539185c22e715d5a7ac66772ffa9f15a1e5df35 --- /dev/null +++ b/models/dit/rotary.py @@ -0,0 +1,88 @@ +import torch +"this rope is faster than llama rope with jit script" + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +# disable in checkpoint mode +# @torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, :x.shape[-2], :] + sin = sin[:, :, :x.shape[-2], :] + return (x*cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=-2): + # expect input: B, H, L, D + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + # also make sure dtype wont change + if ( + seq_len != self._seq_len_cached or + self._cos_cached.device != x.device or + self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + q.float(), seq_dimension=-2 + ) + if k is not None: + return ( + apply_rotary_pos_emb( + q.float(), self._cos_cached, self._sin_cached + ).type_as(q), + apply_rotary_pos_emb( + k.float(), self._cos_cached, self._sin_cached + ).type_as(k), + ) + else: + return ( + apply_rotary_pos_emb( + q.float(), self._cos_cached, self._sin_cached + ).type_as(q), None + ) diff --git a/models/dit/span_mask.py b/models/dit/span_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..c0832567c3e4dcc0c49fdd88dadff11c80d8e2a0 --- /dev/null +++ b/models/dit/span_mask.py @@ -0,0 +1,149 @@ +import numpy as np +import torch +from typing import Optional, Tuple + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + # Convert mask_prob to a NumPy array + mask_prob = np.array(mask_prob) + + # Calculate all_num_mask for each element in the batch + all_num_mask = np.floor( + mask_prob * all_sz / float(mask_length) + np.random.rand(bsz) + ).astype(int) + + # Apply the max operation with min_masks for each element + all_num_mask = np.maximum(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask[i] + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint( + mask_other, mask_length*2 + 1, size=num_mask + ) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + ( + e - s if e - s >= length + min_space else 0 + for s, e in parts + ), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray([ + mask_idc[j] + offset for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ]) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + # min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + # if len(mask_idc) > min_len: + # mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return torch.tensor(mask) + + +if __name__ == '__main__': + mask = compute_mask_indices( + shape=[4, 500], + padding_mask=None, + mask_prob=[0.65, 0.5, 0.65, 0.65], + mask_length=10, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0, + ) + print(mask) + print(mask.sum(dim=1)) diff --git a/models/flow_matching.py b/models/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..4a07ee2c126618fe105ff640825908ddd9f562e1 --- /dev/null +++ b/models/flow_matching.py @@ -0,0 +1,1267 @@ +from typing import Any, Optional, Union, List, Sequence + +import inspect +import random + +from tqdm import tqdm +import numpy as np +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.utils.torch_utils import randn_tensor +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.training_utils import compute_density_for_timestep_sampling + +from models.autoencoder.autoencoder_base import AutoEncoderBase +from models.content_encoder.content_encoder import ContentEncoder +from models.content_adapter import ContentAdapterBase +from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase +from utils.torch_utilities import ( + create_alignment_path, create_mask_from_length, loss_with_mask, + trim_or_pad_length +) +from safetensors.torch import load_file + +class FlowMatchingMixin: + def __init__( + self, + cfg_drop_ratio: float = 0.2, + sample_strategy: str = 'normal', + num_train_steps: int = 1000 + ) -> None: + r""" + Args: + cfg_drop_ratio (float): Dropout ratio for the autoencoder. + sample_strategy (str): Sampling strategy for timesteps during training. + num_train_steps (int): Number of training steps for the noise scheduler. + """ + self.sample_strategy = sample_strategy + self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=num_train_steps + ) + self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler) + + self.classifier_free_guidance = cfg_drop_ratio > 0.0 + self.cfg_drop_ratio = cfg_drop_ratio + + def get_input_target_and_timesteps( + self, + latent: torch.Tensor, + training: bool = True + ): + bsz = latent.shape[0] + noise = torch.randn_like(latent) + + if training: + if self.sample_strategy == 'normal': + u = compute_density_for_timestep_sampling( + weighting_scheme="logit_normal", + batch_size=bsz, + logit_mean=0, + logit_std=1, + mode_scale=None, + ) + elif self.sample_strategy == 'uniform': + u = torch.randn(bsz, ) + else: + raise NotImplementedError( + f"{self.sample_strategy} samlping for timesteps is not supported now" + ) + else: + u = torch.ones(bsz, ) / 2 + + indices = (u * self.train_noise_scheduler.config.num_train_timesteps + ).long() + + # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval + timesteps = self.train_noise_scheduler.timesteps[indices].to( + device=latent.device + ) + sigmas = self.get_sigmas( + timesteps, n_dim=latent.ndim, dtype=latent.dtype + ) + + noisy_latent = (1.0 - sigmas) * latent + sigmas * noise + + target = noise - latent + + return noisy_latent, target, timesteps + + def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32): + device = timesteps.device + + # a list from 1 declining to 1/num_train_steps + sigmas = self.train_noise_scheduler.sigmas.to( + device=device, dtype=dtype + ) + + schedule_timesteps = self.train_noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def retrieve_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, + ): + # used in inference, retrieve new timesteps on given inference timesteps + scheduler = self.infer_noise_scheduler + + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps( + timesteps=timesteps, device=device, **kwargs + ) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps( + num_inference_steps, device=device, **kwargs + ) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ContentEncoderAdapterMixin: + def __init__( + self, + content_encoder: ContentEncoder, + content_adapter: ContentAdapterBase | None = None + ): + self.content_encoder = content_encoder + self.content_adapter = content_adapter + + def encode_content( + self, + content: list[Any], + task: list[str], + device: str | torch.device, + instruction: torch.Tensor | None = None, + instruction_lengths: torch.Tensor | None = None + ): + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task, device=device + ) + content, content_mask = content_output["content"], content_output[ + "content_mask"] + + if instruction is not None: + instruction_mask = create_mask_from_length(instruction_lengths) + ( + content, + content_mask, + global_duration_pred, + local_duration_pred, + ) = self.content_adapter( + content, content_mask, instruction, instruction_mask + ) + + return_dict = { + "content": content, + "content_mask": content_mask, + "length_aligned_content": content_output["length_aligned_content"], + } + if instruction is not None: + return_dict["global_duration_pred"] = global_duration_pred + return_dict["local_duration_pred"] = local_duration_pred + + return return_dict + + +class SingleTaskCrossAttentionAudioFlowMatching( + LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, + FlowMatchingMixin, ContentEncoderAdapterMixin +): + def __init__( + self, + autoencoder: nn.Module, + content_encoder: ContentEncoder, + backbone: nn.Module, + cfg_drop_ratio: float = 0.2, + sample_strategy: str = 'normal', + num_train_steps: int = 1000, + pretrained_ckpt: str | None = None, + ): + nn.Module.__init__(self) + FlowMatchingMixin.__init__( + self, cfg_drop_ratio, sample_strategy, num_train_steps + ) + ContentEncoderAdapterMixin.__init__( + self, content_encoder=content_encoder + ) + + self.autoencoder = autoencoder + for param in self.autoencoder.parameters(): + param.requires_grad = False + + if hasattr(self.content_encoder, "audio_encoder"): + if self.content_encoder.audio_encoder is not None: + self.content_encoder.audio_encoder.model = self.autoencoder + + self.backbone = backbone + self.dummy_param = nn.Parameter(torch.empty(0)) + + if pretrained_ckpt is not None: + print(f"Load pretrain FlowMatching model from {pretrained_ckpt}") + pretrained_state_dict = load_file(pretrained_ckpt) + self.load_pretrained(pretrained_state_dict) + # missing, unexpected = self.load_state_dict(pretrained_state_dict, strict=False) + # print("Missing keys:", missing) + # print("Unexpected keys:", unexpected) + + # if content_encoder.embed_dim != 1024: + # self.context_proj = nn.Sequential( + # nn.Linear(content_encoder.embed_dim, 1024), + # nn.SiLU(), + # nn.Linear(1024, 1024), + # ) + # else: + # self.context_proj = nn.Identity() + + def forward( + self, content: list[Any], condition: list[Any], task: list[str], + waveform: torch.Tensor, waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs + + ): + loss_reduce = self.training or (loss_reduce and not self.training) + device = self.dummy_param.device + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + content_dict = self.encode_content(content, task, device) + content, content_mask = content_dict["content"], content_dict[ + "content_mask"] + + # content = self.context_proj(content) + + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + content[mask_indices] = 0 + + noisy_latent, target, timesteps = self.get_input_target_and_timesteps( + latent, + training = self.training + ) + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + timesteps=timesteps, + context=content, + x_mask=latent_mask, + context_mask=content_mask + ) + + diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") + diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1), reduce=loss_reduce) + #diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1)) + output = {"diff_loss": diff_loss} + return output + + def iterative_denoise( + self, latent: torch.Tensor, timesteps: list[int], num_steps: int, + verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict + ): + progress_bar = tqdm(range(num_steps), disable=not verbose) + + for i, timestep in enumerate(timesteps): + # expand the latent if we are doing classifier free guidance + if cfg: + latent_input = torch.cat([latent, latent]) + else: + latent_input = latent + + noise_pred: torch.Tensor = self.backbone( + x=latent_input, timesteps=timestep, **backbone_input + ) + + # perform guidance + if cfg: + noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + cfg_scale * ( + noise_pred_content - noise_pred_uncond + ) + + latent = self.infer_noise_scheduler.step( + noise_pred, timestep, latent + ).prev_sample + + progress_bar.update(1) + + progress_bar.close() + + return latent + + @torch.no_grad() + def inference( + self, + content: list[Any], + condition: list[Any], + task: list[str], + latent_shape: Sequence[int], + num_steps: int = 50, + sway_sampling_coef: float | None = -1.0, + guidance_scale: float = 3.0, + num_samples_per_content: int = 1, + disable_progress: bool = True, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + batch_size = len(content) * num_samples_per_content + + if classifier_free_guidance: + content, content_mask = self.encode_content_classifier_free( + content, task, device, num_samples_per_content + ) + else: + content_output: dict[ + str, torch.Tensor] = self.content_encoder.encode_content( + content, task + ) + content, content_mask = content_output["content"], content_output[ + "content_mask"] + content = content.repeat_interleave(num_samples_per_content, 0) + content_mask = content_mask.repeat_interleave( + num_samples_per_content, 0 + ) + + latent = self.prepare_latent( + batch_size, latent_shape, content.dtype, device + ) + + if not sway_sampling_coef: + sigmas = np.linspace(1.0, 1 / num_steps, num_steps) + else: + t = torch.linspace(0, 1, num_steps + 1) + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + sigmas = 1 - t + timesteps, num_steps = self.retrieve_timesteps( + num_steps, device, timesteps=None, sigmas=sigmas + ) + + latent = self.iterative_denoise( + latent=latent, + timesteps=timesteps, + num_steps=num_steps, + verbose=not disable_progress, + cfg=classifier_free_guidance, + cfg_scale=guidance_scale, + backbone_input={ + "context": content, + "context_mask": content_mask, + }, + ) + + waveform = self.autoencoder.decode(latent) + + return waveform + + def prepare_latent( + self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype, + device: str + ): + shape = (batch_size, *latent_shape) + latent = randn_tensor( + shape, generator=None, device=device, dtype=dtype + ) + return latent + + def encode_content_classifier_free( + self, + content: list[Any], + task: list[str], + device, + num_samples_per_content: int = 1 + ): + content_dict = self.content_encoder.encode_content( + content, task, device + ) + content, content_mask = content_dict["content"], content_dict["content_mask"] + # content, content_mask = self.content_encoder.encode_content( + # content, task, device=device + # ) + + content = content.repeat_interleave(num_samples_per_content, 0) + content_mask = content_mask.repeat_interleave( + num_samples_per_content, 0 + ) + + # get unconditional embeddings for classifier free guidance + uncond_content = torch.zeros_like(content) + uncond_content_mask = content_mask.detach().clone() + + uncond_content = uncond_content.repeat_interleave( + num_samples_per_content, 0 + ) + uncond_content_mask = uncond_content_mask.repeat_interleave( + num_samples_per_content, 0 + ) + + # For classifier free guidance, we need to do two forward passes. + # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes + content = torch.cat([uncond_content, content]) + content_mask = torch.cat([uncond_content_mask, content_mask]) + + return content, content_mask + +class MultiContentAudioFlowMatching(SingleTaskCrossAttentionAudioFlowMatching): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + backbone: nn.Module, + cfg_drop_ratio: float = 0.2, + sample_strategy: str = 'normal', + num_train_steps: int = 1000, + pretrained_ckpt: str | None = None, + embed_dim: int = 1024, + ): + super().__init__( + autoencoder=autoencoder, + content_encoder=content_encoder, + backbone=backbone, + cfg_drop_ratio=cfg_drop_ratio, + sample_strategy=sample_strategy, + num_train_steps=num_train_steps, + pretrained_ckpt=pretrained_ckpt, + ) + + def forward( + self, + content: list[Any], + duration: Sequence[float], + task: list[str], + waveform: torch.Tensor, + waveform_lengths: torch.Tensor, + loss_reduce: bool = True, + **kwargs + ): + device = self.dummy_param.device + loss_reduce = self.training or (loss_reduce and not self.training) + + self.autoencoder.eval() + + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) # latent [B, 128, 500/T=10s], latent_mask [B, 500/T=10s] + + content_dict = self.encode_content(content, task, device) + context, context_mask, length_aligned_content = content_dict["content"], content_dict[ + "content_mask"], content_dict["length_aligned_content"] + + # -------------------------------------------------------------------- + # prepare latent and noise + # -------------------------------------------------------------------- + noisy_latent, target, timesteps = self.get_input_target_and_timesteps( + latent, + training = self.training + ) + + # -------------------------------------------------------------------- + # prepare input to the backbone + # -------------------------------------------------------------------- + # TODO compatility for 2D spectrogram VAE + + latent_length = noisy_latent.size(self.autoencoder.time_dim) + time_aligned_content = trim_or_pad_length( + length_aligned_content, latent_length, 1 + ) + + # -------------------------------------------------------------------- + # classifier free guidance + # -------------------------------------------------------------------- + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + context[mask_indices] = 0 + time_aligned_content[mask_indices] = 0 + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + x_mask=latent_mask, + timesteps=timesteps, + context=context, + context_mask=context_mask, + time_aligned_context=time_aligned_content, + ) + + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") + diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) + + return { + "diff_loss": diff_loss, + } + + def inference( + self, + content: list[Any], + task: list[str], + latent_shape: Sequence[int], + num_steps: int = 50, + sway_sampling_coef: float | None = -1.0, + guidance_scale: float = 3.0, + disable_progress: bool = True, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + batch_size = len(content) + + + content_dict: dict[ + str, torch.Tensor] = self.encode_content( + content, task, device + ) + context, context_mask, length_aligned_content = \ + content_dict["content"], content_dict[ + "content_mask"], content_dict["length_aligned_content"] + + shape = (batch_size, *latent_shape) + latent_length = shape[self.autoencoder.time_dim] + time_aligned_content = trim_or_pad_length( + length_aligned_content, latent_length, 1 + ) + + # -------------------------------------------------------------------- + # prepare unconditional input + # -------------------------------------------------------------------- + if classifier_free_guidance: + uncond_time_aligned_content = torch.zeros_like( + time_aligned_content + ) + uncond_context = torch.zeros_like(context) + uncond_context_mask = context_mask.detach().clone() + time_aligned_content = torch.cat([ + uncond_time_aligned_content, time_aligned_content + ]) + context = torch.cat([uncond_context, context]) + context_mask = torch.cat([uncond_context_mask, context_mask]) + + + latent = randn_tensor( + shape, generator=None, device=device, dtype=context.dtype + ) + + if not sway_sampling_coef: + sigmas = np.linspace(1.0, 1 / num_steps, num_steps) + else: + t = torch.linspace(0, 1, num_steps + 1) + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + sigmas = 1 - t + timesteps, num_steps = self.retrieve_timesteps( + num_steps, device, timesteps=None, sigmas=sigmas + ) + latent = self.iterative_denoise( + latent=latent, + timesteps=timesteps, + num_steps=num_steps, + verbose=not disable_progress, + cfg=classifier_free_guidance, + cfg_scale=guidance_scale, + backbone_input={ + "context": context, + "context_mask": context_mask, + "time_aligned_context": time_aligned_content, + } + ) + + waveform = self.autoencoder.decode(latent) + return waveform + +class DurationAdapterMixin: + def __init__( + self, + latent_token_rate: int, + offset: float = 1.0, + frame_resolution: float | None = None + ): + self.latent_token_rate = latent_token_rate + self.offset = offset + self.frame_resolution = frame_resolution + + def get_global_duration_loss( + self, + pred: torch.Tensor, + latent_mask: torch.Tensor, + reduce: bool = True, + ): + target = torch.log( + latent_mask.sum(1) / self.latent_token_rate + self.offset + ) + loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none") + return loss + + def get_local_duration_loss( + self, ground_truth: torch.Tensor, pred: torch.Tensor, + mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool + ): + n_frames = torch.round(ground_truth / self.frame_resolution) + target = torch.log(n_frames + self.offset) + loss = loss_with_mask( + (target - pred)**2, + mask, + reduce=False, + ) + loss *= is_time_aligned + if reduce: + if is_time_aligned.sum().item() == 0: + loss *= 0.0 + loss = loss.mean() + else: + loss = loss.sum() / is_time_aligned.sum() + + return loss + + def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor): + pred = torch.exp(pred) * mask + pred = torch.ceil(pred) - self.offset + pred *= self.frame_resolution + return pred + + def prepare_global_duration( + self, + global_pred: torch.Tensor, + local_pred: torch.Tensor, + is_time_aligned: Sequence[bool], + use_local: bool = True, + ): + """ + global_pred: predicted duration value, processed by logarithmic and offset + local_pred: predicted latent length + """ + global_pred = torch.exp(global_pred) - self.offset + result = global_pred + # avoid error accumulation for each frame + if use_local: + pred_from_local = torch.round(local_pred * self.latent_token_rate) + pred_from_local = pred_from_local.sum(1) / self.latent_token_rate + result[is_time_aligned] = pred_from_local[is_time_aligned] + + return result + + def expand_by_duration( + self, + x: torch.Tensor, + content_mask: torch.Tensor, + local_duration: torch.Tensor, + global_duration: torch.Tensor | None = None, + ): + n_latents = torch.round(local_duration * self.latent_token_rate) + if global_duration is not None: + latent_length = torch.round( + global_duration * self.latent_token_rate + ) + else: + latent_length = n_latents.sum(1) + latent_mask = create_mask_from_length(latent_length).to( + content_mask.device + ) + attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) + align_path = create_alignment_path(n_latents, attn_mask) + expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x) + return expanded_x, latent_mask + + +class CrossAttentionAudioFlowMatching( + SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin +): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + content_adapter: ContentAdapterBase, + backbone: nn.Module, + content_dim: int, + frame_resolution: float, + duration_offset: float = 1.0, + cfg_drop_ratio: float = 0.2, + sample_strategy: str = 'normal', + num_train_steps: int = 1000 + ): + super().__init__( + autoencoder=autoencoder, + content_encoder=content_encoder, + backbone=backbone, + cfg_drop_ratio=cfg_drop_ratio, + sample_strategy=sample_strategy, + num_train_steps=num_train_steps, + ) + ContentEncoderAdapterMixin.__init__( + self, + content_encoder=content_encoder, + content_adapter=content_adapter + ) + DurationAdapterMixin.__init__( + self, + latent_token_rate=autoencoder.latent_token_rate, + offset=duration_offset + ) + + def encode_content_with_instruction( + self, content: list[Any], task: list[str], device, + instruction: torch.Tensor, instruction_lengths: torch.Tensor + ): + content_dict = self.encode_content( + content, task, device, instruction, instruction_lengths + ) + return ( + content_dict["content"], content_dict["content_mask"], + content_dict["global_duration_pred"], + content_dict["local_duration_pred"], + content_dict["length_aligned_content"] + ) + + def forward( + self, + content: list[Any], + task: list[str], + waveform: torch.Tensor, + waveform_lengths: torch.Tensor, + instruction: torch.Tensor, + instruction_lengths: torch.Tensor, + loss_reduce: bool = True, + **kwargs + ): + device = self.dummy_param.device + loss_reduce = self.training or (loss_reduce and not self.training) + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + content, content_mask, global_duration_pred, _, _ = \ + self.encode_content_with_instruction( + content, task, device, instruction, instruction_lengths + ) + + global_duration_loss = self.get_global_duration_loss( + global_duration_pred, latent_mask, reduce=loss_reduce + ) + + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + content[mask_indices] = 0 + + noisy_latent, target, timesteps = self.get_input_target_and_timesteps( + latent, + training = self.training + ) + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + timesteps=timesteps, + context=content, + x_mask=latent_mask, + context_mask=content_mask, + ) + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none") + diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) + + return { + "diff_loss": diff_loss, + "global_duration_loss": global_duration_loss, + } + + @torch.no_grad() + def inference( + self, + content: list[Any], + condition: list[Any], + task: list[str], + is_time_aligned: Sequence[bool], + instruction: torch.Tensor, + instruction_lengths: torch.Tensor, + num_steps: int = 20, + sway_sampling_coef: float | None = -1.0, + guidance_scale: float = 3.0, + disable_progress=True, + use_gt_duration: bool = False, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + + ( + content, + content_mask, + global_duration_pred, + local_duration_pred, + _, + ) = self.encode_content_with_instruction( + content, task, device, instruction, instruction_lengths + ) + batch_size = content.size(0) + + if use_gt_duration: + raise NotImplementedError( + "Using ground truth global duration only is not implemented yet" + ) + + # prepare global duration + global_duration = self.prepare_global_duration( + global_duration_pred, + local_duration_pred, + is_time_aligned, + use_local=False + ) + latent_length = torch.round(global_duration * self.latent_token_rate) + latent_mask = create_mask_from_length(latent_length).to(device) + max_latent_length = latent_mask.sum(1).max().item() + + # prepare latent and noise + if classifier_free_guidance: + uncond_context = torch.zeros_like(content) + uncond_content_mask = content_mask.detach().clone() + context = torch.cat([uncond_context, content]) + context_mask = torch.cat([uncond_content_mask, content_mask]) + else: + context = content + context_mask = content_mask + + latent_shape = tuple( + max_latent_length if dim is None else dim + for dim in self.autoencoder.latent_shape + ) + shape = (batch_size, *latent_shape) + latent = randn_tensor( + shape, generator=None, device=device, dtype=content.dtype + ) + if not sway_sampling_coef: + sigmas = np.linspace(1.0, 1 / num_steps, num_steps) + else: + t = torch.linspace(0, 1, num_steps + 1) + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + sigmas = 1 - t + timesteps, num_steps = self.retrieve_timesteps( + num_steps, device, timesteps=None, sigmas=sigmas + ) + latent = self.iterative_denoise( + latent=latent, + timesteps=timesteps, + num_steps=num_steps, + verbose=not disable_progress, + cfg=classifier_free_guidance, + cfg_scale=guidance_scale, + backbone_input={ + "x_mask": latent_mask, + "context": context, + "context_mask": context_mask, + } + ) + + waveform = self.autoencoder.decode(latent) + return waveform + + +class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching): + def __init__( + self, + autoencoder: AutoEncoderBase, + content_encoder: ContentEncoder, + content_adapter: ContentAdapterBase, + backbone: nn.Module, + content_dim: int, + frame_resolution: float, + duration_offset: float = 1.0, + cfg_drop_ratio: float = 0.2, + sample_strategy: str = 'normal', + num_train_steps: int = 1000 + ): + + super().__init__( + autoencoder=autoencoder, + content_encoder=content_encoder, + content_adapter=content_adapter, + backbone=backbone, + content_dim=content_dim, + frame_resolution=frame_resolution, + duration_offset=duration_offset, + cfg_drop_ratio=cfg_drop_ratio, + sample_strategy=sample_strategy, + num_train_steps=num_train_steps + ) + DurationAdapterMixin.__init__( + self, + latent_token_rate=autoencoder.latent_token_rate, + offset=duration_offset, + frame_resolution=frame_resolution + ) + self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim)) + self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim)) + + def get_backbone_input( + self, target_length: int, content: torch.Tensor, + content_mask: torch.Tensor, time_aligned_content: torch.Tensor, + length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor + ): + # TODO compatility for 2D spectrogram VAE + time_aligned_content = trim_or_pad_length( + time_aligned_content, target_length, 1 + ) + length_aligned_content = trim_or_pad_length( + length_aligned_content, target_length, 1 + ) + # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) + # length_aligned_content: from aligned input (f0/energy) + time_aligned_content = time_aligned_content + length_aligned_content + time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( + time_aligned_content.dtype + ) + + context = content + context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) + # only use the first dummy non time aligned embedding + context_mask = content_mask.detach().clone() + context_mask[is_time_aligned, 1:] = False + + # truncate dummy non time aligned context + if is_time_aligned.sum().item() < content.size(0): + trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() + else: + trunc_nta_length = content.size(1) + context = context[:, :trunc_nta_length] + context_mask = context_mask[:, :trunc_nta_length] + + return context, context_mask, time_aligned_content + + def forward( + self, + content: list[Any], + duration: Sequence[float], + task: list[str], + is_time_aligned: Sequence[bool], + waveform: torch.Tensor, + waveform_lengths: torch.Tensor, + instruction: torch.Tensor, + instruction_lengths: torch.Tensor, + loss_reduce: bool = True, + **kwargs + ): + device = self.dummy_param.device + loss_reduce = self.training or (loss_reduce and not self.training) + + self.autoencoder.eval() + with torch.no_grad(): + latent, latent_mask = self.autoencoder.encode( + waveform.unsqueeze(1), waveform_lengths + ) + + ( + content, content_mask, global_duration_pred, local_duration_pred, + length_aligned_content + ) = self.encode_content_with_instruction( + content, task, device, instruction, instruction_lengths + ) + + # truncate unused non time aligned duration prediction + if is_time_aligned.sum() > 0: + trunc_ta_length = content_mask[is_time_aligned].sum(1).max() + else: + trunc_ta_length = content.size(1) + + # duration loss + local_duration_pred = local_duration_pred[:, :trunc_ta_length] + ta_content_mask = content_mask[:, :trunc_ta_length] + local_duration_loss = self.get_local_duration_loss( + duration, + local_duration_pred, + ta_content_mask, + is_time_aligned, + reduce=loss_reduce + ) + + global_duration_loss = self.get_global_duration_loss( + global_duration_pred, latent_mask, reduce=loss_reduce + ) + + # -------------------------------------------------------------------- + # prepare latent and noise + # -------------------------------------------------------------------- + noisy_latent, target, timesteps = self.get_input_target_and_timesteps( + latent, + training = self.training + ) + + # -------------------------------------------------------------------- + # duration adapter + # -------------------------------------------------------------------- + if is_time_aligned.sum() == 0 and \ + duration.size(1) < content_mask.size(1): + duration = F.pad( + duration, (0, content_mask.size(1) - duration.size(1)) + ) + time_aligned_content, _ = self.expand_by_duration( + x=content[:, :trunc_ta_length], + content_mask=ta_content_mask, + local_duration=duration, + ) + + # -------------------------------------------------------------------- + # prepare input to the backbone + # -------------------------------------------------------------------- + # TODO compatility for 2D spectrogram VAE + latent_length = noisy_latent.size(self.autoencoder.time_dim) + context, context_mask, time_aligned_content = self.get_backbone_input( + latent_length, content, content_mask, time_aligned_content, + length_aligned_content, is_time_aligned + ) + + # -------------------------------------------------------------------- + # classifier free guidance + # -------------------------------------------------------------------- + if self.training and self.classifier_free_guidance: + mask_indices = [ + k for k in range(len(waveform)) + if random.random() < self.cfg_drop_ratio + ] + if len(mask_indices) > 0: + context[mask_indices] = 0 + time_aligned_content[mask_indices] = 0 + + pred: torch.Tensor = self.backbone( + x=noisy_latent, + x_mask=latent_mask, + timesteps=timesteps, + context=context, + context_mask=context_mask, + time_aligned_context=time_aligned_content, + ) + pred = pred.transpose(1, self.autoencoder.time_dim) + target = target.transpose(1, self.autoencoder.time_dim) + diff_loss = F.mse_loss(pred, target, reduction="none") + diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce) + return { + "diff_loss": diff_loss, + "local_duration_loss": local_duration_loss, + "global_duration_loss": global_duration_loss, + } + + def inference( + self, + content: list[Any], + task: list[str], + is_time_aligned: Sequence[bool], + instruction: torch.Tensor, + instruction_lengths: Sequence[int], + num_steps: int = 20, + sway_sampling_coef: float | None = -1.0, + guidance_scale: float = 3.0, + disable_progress: bool = True, + use_gt_duration: bool = False, + **kwargs + ): + device = self.dummy_param.device + classifier_free_guidance = guidance_scale > 1.0 + + ( + content, content_mask, global_duration_pred, local_duration_pred, + length_aligned_content + ) = self.encode_content_with_instruction( + content, task, device, instruction, instruction_lengths + ) + # print("content std: ", content.std()) + batch_size = content.size(0) + + # truncate dummy time aligned duration prediction + is_time_aligned = torch.as_tensor(is_time_aligned) + if is_time_aligned.sum() > 0: + trunc_ta_length = content_mask[is_time_aligned].sum(1).max() + else: + trunc_ta_length = content.size(1) + + # prepare local duration + local_duration = self.prepare_local_duration( + local_duration_pred, content_mask + ) + local_duration = local_duration[:, :trunc_ta_length] + # use ground truth duration + if use_gt_duration and "duration" in kwargs: + local_duration = torch.as_tensor(kwargs["duration"]).to(device) + + # prepare global duration + global_duration = self.prepare_global_duration( + global_duration_pred, local_duration, is_time_aligned + ) + + # -------------------------------------------------------------------- + # duration adapter + # -------------------------------------------------------------------- + time_aligned_content, latent_mask = self.expand_by_duration( + x=content[:, :trunc_ta_length], + content_mask=content_mask[:, :trunc_ta_length], + local_duration=local_duration, + global_duration=global_duration, + ) + + context, context_mask, time_aligned_content = self.get_backbone_input( + target_length=time_aligned_content.size(1), + content=content, + content_mask=content_mask, + time_aligned_content=time_aligned_content, + length_aligned_content=length_aligned_content, + is_time_aligned=is_time_aligned + ) + + # -------------------------------------------------------------------- + # prepare unconditional input + # -------------------------------------------------------------------- + if classifier_free_guidance: + uncond_time_aligned_content = torch.zeros_like( + time_aligned_content + ) + uncond_context = torch.zeros_like(context) + uncond_context_mask = context_mask.detach().clone() + time_aligned_content = torch.cat([ + uncond_time_aligned_content, time_aligned_content + ]) + context = torch.cat([uncond_context, context]) + context_mask = torch.cat([uncond_context_mask, context_mask]) + latent_mask = torch.cat([ + latent_mask, latent_mask.detach().clone() + ]) + + # -------------------------------------------------------------------- + # prepare input to the backbone + # -------------------------------------------------------------------- + latent_length = latent_mask.sum(1).max().item() + latent_shape = tuple( + latent_length if dim is None else dim + for dim in self.autoencoder.latent_shape + ) + shape = (batch_size, *latent_shape) + latent = randn_tensor( + shape, generator=None, device=device, dtype=content.dtype + ) + + if not sway_sampling_coef: + sigmas = np.linspace(1.0, 1 / num_steps, num_steps) + else: + t = torch.linspace(0, 1, num_steps + 1) + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + sigmas = 1 - t + timesteps, num_steps = self.retrieve_timesteps( + num_steps, device, timesteps=None, sigmas=sigmas + ) + latent = self.iterative_denoise( + latent=latent, + timesteps=timesteps, + num_steps=num_steps, + verbose=not disable_progress, + cfg=classifier_free_guidance, + cfg_scale=guidance_scale, + backbone_input={ + "x_mask": latent_mask, + "context": context, + "context_mask": context_mask, + "time_aligned_context": time_aligned_content, + } + ) + + waveform = self.autoencoder.decode(latent) + return waveform + + +class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching): + def get_backbone_input( + self, target_length: int, content: torch.Tensor, + content_mask: torch.Tensor, time_aligned_content: torch.Tensor, + length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor + ): + # TODO compatility for 2D spectrogram VAE + time_aligned_content = trim_or_pad_length( + time_aligned_content, target_length, 1 + ) + length_aligned_content = trim_or_pad_length( + length_aligned_content, target_length, 1 + ) + # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) + # length_aligned_content: from aligned input (f0/energy) + time_aligned_content = time_aligned_content + length_aligned_content + + context = content + context_mask = content_mask.detach().clone() + + return context, context_mask, time_aligned_content + + +class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching): + def get_backbone_input( + self, target_length: int, content: torch.Tensor, + content_mask: torch.Tensor, time_aligned_content: torch.Tensor, + length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor + ): + # TODO compatility for 2D spectrogram VAE + time_aligned_content = trim_or_pad_length( + time_aligned_content, target_length, 1 + ) + length_aligned_content = trim_or_pad_length( + length_aligned_content, target_length, 1 + ) + # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme) + # length_aligned_content: from aligned input (f0/energy) + time_aligned_content = time_aligned_content + length_aligned_content + time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( + time_aligned_content.dtype + ) + + context = content + context_mask = content_mask.detach().clone() + + return context, context_mask, time_aligned_content diff --git a/requirements.txt b/requirements.txt index 5149ecea6592775ad8dfd58a4abcedfd2c2cf8f4..1fa472e143c6ba0a1f855c811823f7e3b42b73d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,149 @@ -gradio==4.4.1 -transformers==4.31.0 -huggingface-hub==0.16.4 \ No newline at end of file +absl-py==2.3.0 +accelerate==1.2.1 +alias-free-torch==0.0.6 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +astunparse==1.6.3 +attrs==22.2.0 +audioread==3.0.1 +av==11.0.0 +bitarray==3.7.1 +boto3==1.38.36 +botocore==1.38.36 +braceexpand==0.1.7 +brotlipy==0.7.0 +click==8.1.8 +colorama==0.4.6 +conda==23.1.0 +conda-build==3.23.3 +contourpy==1.2.0 +cycler==0.12.1 +dcase-util==0.2.20 +diffusers==0.33.1 +dnspython==2.3.0 +docker-pycreds==0.4.0 +einops==0.7.0 +exceptiongroup==1.1.1 +expecttest==0.1.4 +fire==0.7.0 +fonttools==4.47.2 +fsspec==2023.12.2 +ftfy==6.3.1 +future==1.0.0 +gitdb==4.0.12 +GitPython==3.1.44 +grpcio==1.73.0 +h5py==3.10.0 +huggingface-hub==0.30.2 +hydra-core==1.3.2 +hypothesis==6.70.0 +imageio==2.37.0 +importlib_metadata==8.5.0 +iniconfig==2.0.0 +ipdb==0.13.13 +jmespath==1.0.1 +joblib==1.3.2 +kiwisolver==1.4.5 +laion_clap==1.1.7 +lazy-dataset==0.0.14 +lazy_loader==0.4 +librosa==0.10.2 +llvmlite==0.42.0 +lxml==6.0.1 +Markdown==3.8 +matplotlib==3.8.2 +mkl-fft==1.3.1 +mkl-service==2.4.0 +mpmath==1.3.0 +msgpack==1.0.8 +networkx==3.0 +numba==0.59.1 +numpy==1.26.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-ml-py==12.575.51 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +omegaconf==2.3.0 +packaging==23.2 +pandas==2.2.0 +pathlib==1.0.1 +pillow==11.3.0 +pip-chill==1.0.3 +platformdirs==4.2.1 +pluggy==1.5.0 +pooch==1.8.1 +portalocker==3.2.0 +prettytable==3.16.0 +progressbar==2.5 +protobuf==5.29.2 +psds-eval==0.5.3 +pydantic==2.10.4 +pydantic_core==2.27.2 +pydot-ng==2.0.0 +pyecharts==2.0.8 +pynvml==12.0.0 +pyparsing==3.1.1 +pytest==8.2.0 +python-dateutil==2.8.2 +python-etcd==0.4.5 +python-magic==0.4.27 +regex==2023.12.25 +resampy==0.4.3 +s3transfer==0.13.0 +sacrebleu==2.5.1 +safetensors==0.5.0 +scikit-image==0.25.2 +scikit-learn==1.4.0 +scipy==1.12.0 +sed-eval==0.2.1 +sed-scores-eval==0.0.0 +sentence-transformers==4.1.0 +sentencepiece==0.2.0 +sentry-sdk==2.19.2 +setproctitle==1.3.4 +simplejson==3.20.1 +smmap==5.0.2 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soxr==0.3.7 +swankit==0.2.3 +swanlab==0.6.3 +sympy==1.13.1 +tabulate==0.9.0 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +termcolor==3.1.0 +threadpoolctl==3.2.0 +tifffile==2025.5.10 +timm==0.9.12 +tokenizers==0.21.1 +tomli==2.0.1 +torch==2.6.0 +torchaudio==2.6.0 +torchdata==0.10.1 +torchelastic==0.2.2 +torchlibrosa==0.1.0 +torchtext==0.15.0 +torchvision==0.21.0 +transformers==4.51.3 +triton==3.2.0 +types-dataclasses==0.6.6 +typing_extensions==4.12.2 +tzdata==2023.4 +validators==0.28.1 +wandb==0.19.1 +webdataset==1.0.2 +Werkzeug==3.1.3 +wget==3.2 +wrapt==1.17.2 +zipp==3.21.0 diff --git a/utils/__pycache__/accelerate_utilities.cpython-310.pyc b/utils/__pycache__/accelerate_utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d039d356af90e1c86d9442812b7c9462f8a99a1 Binary files /dev/null and b/utils/__pycache__/accelerate_utilities.cpython-310.pyc differ diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecfe36dc4fb653c0dcd67bfb6e2a63cac8b603c6 Binary files /dev/null and b/utils/__pycache__/config.cpython-310.pyc differ diff --git a/utils/__pycache__/diffsinger_utilities.cpython-310.pyc b/utils/__pycache__/diffsinger_utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60e4cb6d6a33922102a8addeb45e3bf2128abc47 Binary files /dev/null and b/utils/__pycache__/diffsinger_utilities.cpython-310.pyc differ diff --git a/utils/__pycache__/general.cpython-310.pyc b/utils/__pycache__/general.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25e2e42a9dfedb3de3c42da7996a2695061f8716 Binary files /dev/null and b/utils/__pycache__/general.cpython-310.pyc differ diff --git a/utils/__pycache__/logging.cpython-310.pyc b/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cbfc641e1c3864dee6d9a31115464d39cf2a1b5 Binary files /dev/null and b/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc b/utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c893bd72df9d75ca5c493921263e38729bfac930 Binary files /dev/null and b/utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc differ diff --git a/utils/__pycache__/torch_utilities.cpython-310.pyc b/utils/__pycache__/torch_utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20d4f9f6dae86af3fe8842badd64afd81e82c72 Binary files /dev/null and b/utils/__pycache__/torch_utilities.cpython-310.pyc differ diff --git a/utils/accelerate_utilities.py b/utils/accelerate_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..99fc5aa3ad2700361e006799c6aac8119c9bb15a --- /dev/null +++ b/utils/accelerate_utilities.py @@ -0,0 +1,13 @@ +from accelerate import Accelerator + + +class AcceleratorSaveTrainableParams(Accelerator): + def get_state_dict(self, model, unwrap=True): + state_dict = super().get_state_dict(model, unwrap) + if hasattr(model, "param_names_to_save"): + param_names_to_save = model.param_names_to_save + return { + k: v + for k, v in state_dict.items() if k in param_names_to_save + } + return state_dict diff --git a/utils/audio.py b/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..350a9fe08e229a2f979f8090e314216c6b356739 --- /dev/null +++ b/utils/audio.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +import torchaudio + + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if ( + not self.randomize + ) else torch.randint(0, + max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + + +def prepare_audio( + audio, in_sr, target_sr, target_length, target_channels, device +): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = torchaudio.transforms.Resample(in_sr, + target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfee06dbfdf633899e8d4748a06a2b2f71673c1 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,53 @@ +from pathlib import Path +import sys +import os + +import hydra +import omegaconf +from omegaconf import OmegaConf + + +def multiply(*args): + result = 1 + for arg in args: + result *= arg + return result + + +def get_pitch_downsample_ratio( + autoencoder_config: dict, pitch_frame_resolution: float +): + latent_frame_resolution = autoencoder_config[ + "downsampling_ratio"] / autoencoder_config["sample_rate"] + return round(latent_frame_resolution / pitch_frame_resolution) + + +def register_omegaconf_resolvers() -> None: + """ + Register custom resolver for hydra configs, which can be used in YAML + files for dynamically setting values + """ + OmegaConf.clear_resolvers() + OmegaConf.register_new_resolver("len", len, replace=True) + OmegaConf.register_new_resolver("multiply", multiply, replace=True) + OmegaConf.register_new_resolver( + "get_pitch_downsample_ratio", get_pitch_downsample_ratio, replace=True + ) + + +def generate_config_from_command_line_overrides( + config_file: str | Path +) -> omegaconf.DictConfig: + register_omegaconf_resolvers() + + config_file = Path(config_file).resolve() + config_name = config_file.name.__str__() + config_path = config_file.parent.__str__() + config_path = os.path.relpath(config_path, Path(__file__).resolve().parent) + + overrides = sys.argv[1:] + with hydra.initialize(version_base=None, config_path=config_path): + config = hydra.compose(config_name=config_name, overrides=overrides) + omegaconf.OmegaConf.resolve(config) + + return config diff --git a/utils/diffsinger_utilities.py b/utils/diffsinger_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..8d658ec5a8306d46e65ef539e4e679c6a57d484b --- /dev/null +++ b/utils/diffsinger_utilities.py @@ -0,0 +1,551 @@ +import six +from pathlib import Path +import re +import json +from collections import OrderedDict +from typing import Union + +import numpy as np +import librosa +import torch + +PAD = "" +EOS = "" +UNK = "" +SEG = "|" +RESERVED_TOKENS = [PAD, EOS, UNK] +NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) +PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 +UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 + +F0_BIN = 256 +F0_MAX = 1100.0 +F0_MIN = 50.0 +F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700) +F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700) + + +def f0_to_coarse(f0): + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / + 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0 + ] = (f0_mel[f0_mel > 0] - + F0_MEL_MIN) * (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > F0_BIN - 1] = F0_BIN - 1 + f0_coarse = (f0_mel + + 0.5).long() if is_torch else np.rint(f0_mel).astype(int) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( + f0_coarse.max(), f0_coarse.min() + ) + return f0_coarse + + +def norm_f0( + f0: Union[np.ndarray, torch.Tensor], + uv: Union[None, np.ndarray], + f0_mean: float, + f0_std: float, + pitch_norm: str = "log", + use_uv: bool = True +): + is_torch = isinstance(f0, torch.Tensor) + if pitch_norm == 'standard': + f0 = (f0 - f0_mean) / f0_std + if pitch_norm == 'log': + f0 = torch.log2(f0) if is_torch else np.log2(f0) + if uv is not None and use_uv: + f0[uv > 0] = 0 + return f0 + + +def norm_interp_f0( + f0: Union[np.ndarray, torch.Tensor], + f0_mean: float, + f0_std: float, + pitch_norm: str = "log", + use_uv: bool = True +): + is_torch = isinstance(f0, torch.Tensor) + if is_torch: + device = f0.device + f0 = f0.data.cpu().numpy() + uv = f0 == 0 + f0 = norm_f0(f0, uv, f0_mean, f0_std, pitch_norm, use_uv) + if sum(uv) == len(f0): + f0[uv] = 0 + elif sum(uv) > 0: + f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) + uv = torch.as_tensor(uv).float() + f0 = torch.as_tensor(f0).float() + if is_torch: + f0 = f0.to(device) + return f0, uv + + +def denorm_f0( + f0, + uv, + pitch_norm="log", + f0_mean=None, + f0_std=None, + pitch_padding=None, + min=None, + max=None, + use_uv=True +): + if pitch_norm == 'standard': + f0 = f0 * f0_std + f0_mean + if pitch_norm == 'log': + f0 = 2**f0 + if min is not None: + f0 = f0.clamp(min=min) + if max is not None: + f0 = f0.clamp(max=max) + if uv is not None and use_uv: + f0[uv > 0] = 0 + if pitch_padding is not None: + f0[pitch_padding] = 0 + return f0 + + +def librosa_pad_lr(x, fshift, pad_sides=1): + '''compute right padding (final frame) or both sides padding (first and final frames) + ''' + assert pad_sides in (1, 2) + # return int(fsize // 2) + pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] + if pad_sides == 1: + return 0, pad + else: + return pad // 2, pad // 2 + pad % 2 + + +def get_pitch( + wav_file: Union[str, Path], sample_rate: int, frame_shift: float +): + import parselmouth + hop_size = int(frame_shift * sample_rate) + wav, _ = librosa.core.load(wav_file, sr=sample_rate) + # l_pad, r_pad = librosa_pad_lr(wav, hop_size, 1) + # wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) + + latent_length = wav.shape[0] // hop_size + f0_min = 80 + f0_max = 750 + pad_size = 4 + + f0 = parselmouth.Sound(wav, sample_rate).to_pitch_ac( + time_step=frame_shift, + voicing_threshold=0.6, + pitch_floor=f0_min, + pitch_ceiling=f0_max + ).selected_array['frequency'] + delta_l = latent_length - len(f0) + if delta_l > 0: + f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) + pitch_coarse = f0_to_coarse(f0) + return f0, pitch_coarse + + +def remove_empty_lines(text): + """remove empty lines""" + assert (len(text) > 0) + assert (isinstance(text, list)) + text = [t.strip() for t in text] + if "" in text: + text.remove("") + return text + + +def is_sil_phoneme(p): + return not p[0].isalpha() + + +def strip_ids(ids, ids_to_strip): + """Strip ids_to_strip from the end ids.""" + ids = list(ids) + while ids and ids[-1] in ids_to_strip: + ids.pop() + return ids + + +class TextEncoder(object): + """Base class for converting from ints to/from human readable strings.""" + def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): + self._num_reserved_ids = num_reserved_ids + + @property + def num_reserved_ids(self): + return self._num_reserved_ids + + def encode(self, s): + """Transform a human-readable string into a sequence of int ids. + + The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, + num_reserved_ids) are reserved. + + EOS is not appended. + + Args: + s: human-readable string to be converted. + + Returns: + ids: list of integers + """ + return [int(w) + self._num_reserved_ids for w in s.split()] + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int ids into a human-readable string. + + EOS is not expected in ids. + + Args: + ids: list of integers to be converted. + strip_extraneous: bool, whether to strip off extraneous tokens + (EOS and PAD). + + Returns: + s: human-readable string. + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + """Transform a sequence of int ids into a their string versions. + + This method supports transforming individual input/output ids to their + string versions so that sequence to/from text conversions can be visualized + in a human readable format. + + Args: + ids: list of integers to be converted. + + Returns: + strs: list of human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_ - self._num_reserved_ids) + return [str(d) for d in decoded_ids] + + @property + def vocab_size(self): + raise NotImplementedError() + + +class TokenTextEncoder(TextEncoder): + """Encoder based on a user-supplied vocabulary (file or list).""" + def __init__( + self, + vocab_filename, + reverse=False, + vocab_list=None, + replace_oov=None, + num_reserved_ids=NUM_RESERVED_TOKENS + ): + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when + encoding will be replaced by this string (which must be in vocab). + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ + super(TokenTextEncoder, + self).__init__(num_reserved_ids=num_reserved_ids) + self._reverse = reverse + self._replace_oov = replace_oov + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) + self.pad_index = self._token_to_id[PAD] + self.eos_index = self._token_to_id[EOS] + self.unk_index = self._token_to_id[UNK] + self.seg_index = self._token_to_id[ + SEG] if SEG in self._token_to_id else self.eos_index + + def encode(self, s): + """Converts a space-separated string of tokens to a list of ids.""" + sentence = s + tokens = sentence.strip().split() + if self._replace_oov is not None: + tokens = [ + t if t in self._token_to_id else self._replace_oov + for t in tokens + ] + ret = [self._token_to_id[tok] for tok in tokens] + return ret[::-1] if self._reverse else ret + + def decode(self, ids, strip_eos=False, strip_padding=False): + if strip_padding and self.pad() in list(ids): + pad_pos = list(ids).index(self.pad()) + ids = ids[:pad_pos] + if strip_eos and self.eos() in list(ids): + eos_pos = list(ids).index(self.eos()) + ids = ids[:eos_pos] + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + seq = reversed(ids) if self._reverse else ids + return [self._safe_id_to_token(i) for i in seq] + + @property + def vocab_size(self): + return len(self._id_to_token) + + def __len__(self): + return self.vocab_size + + def _safe_id_to_token(self, idx): + return self._id_to_token.get(idx, "ID_%d" % idx) + + def _init_vocab_from_file(self, filename): + """Load vocab from a file. + + Args: + filename: The file to load vocabulary from. + """ + with open(filename) as f: + tokens = [token.strip() for token in f.readlines()] + + def token_gen(): + for token in tokens: + yield token + + self._init_vocab(token_gen(), add_reserved_tokens=False) + + def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. + + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + + Args: + vocab_list: A list of tokens. + """ + def token_gen(): + for token in vocab_list: + if token not in RESERVED_TOKENS: + yield token + + self._init_vocab(token_gen()) + + def _init_vocab(self, token_generator, add_reserved_tokens=True): + """Initialize vocabulary with tokens from token_generator.""" + + self._id_to_token = {} + non_reserved_start_index = 0 + + if add_reserved_tokens: + self._id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) + + self._id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index) + ) + + # _token_to_id is the reverse of _id_to_token + self._token_to_id = dict( + (v, k) for k, v in six.iteritems(self._id_to_token) + ) + + def pad(self): + return self.pad_index + + def eos(self): + return self.eos_index + + def unk(self): + return self.unk_index + + def seg(self): + return self.seg_index + + def store_to_file(self, filename): + """Write vocab file to disk. + + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. + + Args: + filename: Full path of the file to store the vocab to. + """ + with open(filename, "w") as f: + for i in range(len(self._id_to_token)): + f.write(self._id_to_token[i] + "\n") + + def sil_phonemes(self): + return [p for p in self._id_to_token.values() if not p[0].isalpha()] + + +class TextGrid(object): + def __init__(self, text): + text = remove_empty_lines(text) + self.text = text + self.line_count = 0 + self._get_type() + self._get_time_intval() + self._get_size() + self.tier_list = [] + self._get_item_list() + + def _extract_pattern(self, pattern, inc): + """ + Parameters + ---------- + pattern : regex to extract pattern + inc : increment of line count after extraction + Returns + ------- + group : extracted info + """ + try: + group = re.match(pattern, self.text[self.line_count]).group(1) + self.line_count += inc + except AttributeError: + raise ValueError( + "File format error at line %d:%s" % + (self.line_count, self.text[self.line_count]) + ) + return group + + def _get_type(self): + self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2) + + def _get_time_intval(self): + self.xmin = self._extract_pattern(r"xmin = (.*)", 1) + self.xmax = self._extract_pattern(r"xmax = (.*)", 2) + + def _get_size(self): + self.size = int(self._extract_pattern(r"size = (.*)", 2)) + + def _get_item_list(self): + """Only supports IntervalTier currently""" + for itemIdx in range(1, self.size + 1): + tier = OrderedDict() + item_list = [] + tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1) + tier_class = self._extract_pattern(r"class = \"(.*)\"", 1) + if tier_class != "IntervalTier": + raise NotImplementedError( + "Only IntervalTier class is supported currently" + ) + tier_name = self._extract_pattern(r"name = \"(.*)\"", 1) + tier_xmin = self._extract_pattern(r"xmin = (.*)", 1) + tier_xmax = self._extract_pattern(r"xmax = (.*)", 1) + tier_size = self._extract_pattern(r"intervals: size = (.*)", 1) + for i in range(int(tier_size)): + item = OrderedDict() + item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1) + item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1) + item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1) + item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1) + item_list.append(item) + tier["idx"] = tier_idx + tier["class"] = tier_class + tier["name"] = tier_name + tier["xmin"] = tier_xmin + tier["xmax"] = tier_xmax + tier["size"] = tier_size + tier["items"] = item_list + self.tier_list.append(tier) + + def toJson(self): + _json = OrderedDict() + _json["file_type"] = self.file_type + _json["xmin"] = self.xmin + _json["xmax"] = self.xmax + _json["size"] = self.size + _json["tiers"] = self.tier_list + return json.dumps(_json, ensure_ascii=False, indent=2) + + +def read_duration_from_textgrid( + textgrid_path: Union[str, Path], + phoneme: str, + utterance_duration: float, +): + ph_list = phoneme.split(" ") + with open(textgrid_path, "r") as f: + textgrid = f.readlines() + textgrid = remove_empty_lines(textgrid) + textgrid = TextGrid(textgrid) + textgrid = json.loads(textgrid.toJson()) + + split = np.ones(len(ph_list) + 1, np.float32) * -1 + tg_idx = 0 + ph_idx = 0 + tg_align = [x for x in textgrid['tiers'][-1]['items']] + tg_align_ = [] + for x in tg_align: + x['xmin'] = float(x['xmin']) + x['xmax'] = float(x['xmax']) + if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC', '', '']: + x['text'] = '' + if len(tg_align_) > 0 and tg_align_[-1]['text'] == '': + tg_align_[-1]['xmax'] = x['xmax'] + continue + tg_align_.append(x) + tg_align = tg_align_ + tg_len = len([x for x in tg_align if x['text'] != '']) + ph_len = len([x for x in ph_list if not is_sil_phoneme(x)]) + assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, textgrid_path) + while tg_idx < len(tg_align) or ph_idx < len(ph_list): + if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]): + split[ph_idx] = 1e8 + ph_idx += 1 + continue + x = tg_align[tg_idx] + if x['text'] == '' and ph_idx == len(ph_list): + tg_idx += 1 + continue + assert ph_idx < len(ph_list), ( + tg_len, ph_len, tg_align, ph_list, textgrid_path + ) + + ph = ph_list[ph_idx] + if x['text'] == '' and not is_sil_phoneme(ph): + assert False, (ph_list, tg_align) + if x['text'] != '' and is_sil_phoneme(ph): + ph_idx += 1 + else: + assert (x['text'] == '' and is_sil_phoneme(ph)) \ + or x['text'].lower() == ph.lower() \ + or x['text'].lower() == 'sil', (x['text'], ph) + split[ph_idx] = x['xmin'] + if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme( + ph_list[ph_idx - 1] + ): + split[ph_idx - 1] = split[ph_idx] + ph_idx += 1 + tg_idx += 1 + assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align]) + assert ph_idx >= len(ph_list) - 1, ( + ph_idx, ph_list, len(ph_list), [x['text'] + for x in tg_align], textgrid_path + ) + + split[0] = 0 + split[-1] = utterance_duration + duration = np.diff(split) + return duration diff --git a/utils/general.py b/utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..3d86f1e9e9990fb416d67392b5dd65be12ce4e6b --- /dev/null +++ b/utils/general.py @@ -0,0 +1,73 @@ +import json +import re +from typing import Union, Dict +from pathlib import Path +import os + +MAX_FILE_NAME_LENGTH = 100 + + +def read_jsonl_to_mapping( + jsonl_file: Union[str, Path], + key_col: str, + value_col: str, + base_path=None, + overwrite=True, +) -> Dict[str, str]: + """ + Read two columns, indicated by `key_col` and `value_col`, from the + given jsonl file to return the mapping dict + TODO handle duplicate keys + """ + mapping = {} + with open(jsonl_file, 'r') as file: + for line in file.readlines(): + data = json.loads(line.strip()) + key = data[key_col] + value = data[value_col] + if base_path: + value = os.path.join(base_path, value) + if key not in mapping.keys() or overwrite: + mapping[key] = value + return mapping + + +def sanitize_filename(name: str, max_len: int = MAX_FILE_NAME_LENGTH) -> str: + """ + Clean and truncate a string to make it a valid and safe filename. + """ + name = re.sub(r'[\\/*?:"<>|]', '_', name) + name = name.replace('/', '_') + max_len = min(len(name), max_len) + return name[:max_len] + + +def transform_gen_fn_to_id(audio_file: Path, task: str) -> str: + if task == "svs": + audio_id = audio_file.stem.split("_")[0] + elif task == "sr": + audio_id = audio_file.stem + elif task == "tta": + audio_id = audio_file.stem[:12] + '.wav' + elif task == "ttm": + audio_id = audio_file.stem[:11] + # audio_id = audio_file.stem[:12] + '.wav' + elif task == "v2a": + audio_id = audio_file.stem.rsplit("_", 1)[0] + ".mp4" + elif task == "sta_test" or task == "tta_test": + audio_id = audio_file.stem[:12] + '.wav' + elif task == "sta_base": + audio_id = 'Y' + audio_file.stem[:11] + '.wav' + else: + audio_id = audio_file.stem + return audio_id + + +def audio_dir_to_mapping(audio_dir: str | Path, task: str) -> dict: + mapping = {} + audio_dir = Path(audio_dir) + audio_files = sorted(audio_dir.iterdir()) + for audio_file in audio_files: + audio_id = transform_gen_fn_to_id(audio_file, task) + mapping[audio_id] = str(audio_file.resolve()) + return mapping diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb011c02dca7c42c9b387c4f126ba3e9fe7efb4 --- /dev/null +++ b/utils/logging.py @@ -0,0 +1,23 @@ +from pathlib import Path +from dataclasses import dataclass +import logging + + +@dataclass +class LoggingLogger: + + filename: str | Path + level: str = "INFO" + + def create_instance(self, ): + filename = self.filename.__str__() + formatter = logging.Formatter("[%(asctime)s] - %(message)s") + + logger = logging.getLogger(__name__ + "." + filename) + logger.setLevel(getattr(logging, self.level)) + + file_handler = logging.FileHandler(filename) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger diff --git a/utils/lr_scheduler_utilities.py b/utils/lr_scheduler_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd0333a6c51bbc1f307213eced5d34b4b906d22 --- /dev/null +++ b/utils/lr_scheduler_utilities.py @@ -0,0 +1,154 @@ +from typing import Any +import math +import copy +from torch.utils.data import DataLoader + + +def get_warmup_steps( + dataloader_one_pass_outside_steps: int, + warmup_steps: int | None = None, + warmup_epochs: float | None = None, + epoch_length: int | None = None, +) -> int: + """ + Derive warmup steps according to step number or epoch number. + If `warmup_steps` is provided, then just return it. Otherwise, derive + the warmup steps by epoch length and warmup epoch number. + """ + if warmup_steps is not None: + return warmup_steps + else: + if epoch_length is None: + epoch_length = dataloader_one_pass_outside_steps + assert warmup_epochs is not None, "warmup_steps and warmup_epochs cannot be both None" + return int(epoch_length * warmup_epochs) + + +def get_dataloader_one_pass_outside_steps( + train_dataloader: DataLoader, + num_processes: int = 1, +): + """ + dataloader length after DDP, close to `original_length / gpu_number` + """ + return math.ceil(len(train_dataloader) / num_processes) + + +def get_total_training_steps( + train_dataloader: DataLoader, + epochs: int, + num_processes: int = 1, + epoch_length: int | None = None +): + """ + Calculate the total number of "visible" training steps. + + If `epoch_length` is provided, it is used as the fixed length for each epoch. + Otherwise, the function will determine the epoch length from `train_dataloader`. + + Args: + train_dataloader: + Training dataloader object. + epochs: + The total number of epochs to run. + num_processes: + The number of parallel processes used for distributed training. + epoch_length: + A fixed number of training steps for each epoch. Defaults to None. + + Returns: + int: The total number of training steps (i.e., `epochs * epoch_length`). + """ + # `epoch_length` is not None: fixed length for each epoch + if epoch_length is None: + # `epoch_length` is the length of DDP-wrapped `train_dataloader` + epoch_length = get_dataloader_one_pass_outside_steps( + train_dataloader, num_processes + ) + return epochs * epoch_length + + +def get_dataloader_one_pass_steps_inside_accelerator( + dataloader_one_pass_steps: int, gradient_accumulation_steps: int, + num_processes: int +): + """ + Calculate the number of "visible" training steps for a single pass over the dataloader + inside an accelerator, accounting for gradient accumulation and distributed training. + + + Args: + dataloader_one_pass_steps: + The number of steps (batches) in one pass over the dataset. + gradient_accumulation_steps: + The number of steps to accumulate gradients before performing a parameter update. + num_processes: + The number of parallel processes used for distributed training. + + Returns: + int: The total number of "visible" training steps for one pass over the dataset, + multiplied by the number of processes. + """ + return math.ceil( + dataloader_one_pass_steps / gradient_accumulation_steps + ) * num_processes + + +def get_steps_inside_accelerator_from_outside_steps( + outside_steps: int, dataloader_one_pass_outside_steps: int, + dataloader_one_pass_steps_inside_accelerator: int, + gradient_accumulation_steps: int, num_processes: int +): + """ + Convert "outside" steps (as observed in wandb logger or similar context) + to the corresponding number of "inside" steps (for accelerate lr scheduler). + + Specifically, accelerate lr scheduler call `step()` `num_processes` times for + every `gradient_accumulation_steps` outside steps. + + Args: + outside_steps: + The total number of steps counted outside accelerate context. + dataloader_one_pass_outside_steps: + The number of steps (batches) to complete one pass of the dataloader + outside accelerate. + dataloader_one_pass_steps_inside_accelerator: + The number of `lr_scheduler.step()` calls inside accelerate, calculated via + `get_dataloader_one_pass_steps_inside_accelerator`. + gradient_accumulation_steps: + The number of steps to accumulate gradients. + num_processes: + The number of parallel processes (GPUs) used in distributed training. + + Returns: + int: The total number of `lr_scheduler.step()` calls inside accelerate that + correspond to the given `outside_steps`. + """ + num_dataloader_epochs_passed = outside_steps // dataloader_one_pass_outside_steps + remaining_outside_steps = outside_steps % dataloader_one_pass_outside_steps + remaining_inside_accelerator_steps = ( + remaining_outside_steps // gradient_accumulation_steps * num_processes + ) + # accelerate scheduler call `step()` `num_processes` times every + # `gradient_accumulation_steps` steps: + # https://github.com/huggingface/accelerate/blob/main/src/accelerate/scheduler.py#L76 + total_steps = ( + num_dataloader_epochs_passed* + dataloader_one_pass_steps_inside_accelerator + + remaining_inside_accelerator_steps + ) + return total_steps + + +def lr_scheduler_param_adapter( + config_dict: dict[str, Any], num_training_steps: int, num_warmup_steps: int +) -> dict[str, Any]: + target_class = config_dict["_target_"] + return_dict = copy.deepcopy(config_dict) + if target_class == "transformers.get_scheduler": + return_dict.update({ + "num_training_steps": num_training_steps, + "num_warmup_steps": num_warmup_steps + }) + + return return_dict diff --git a/utils/tests/test_logging.py b/utils/tests/test_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce9d080a7dd63c013ea7382e160d8b4ecfa6adb --- /dev/null +++ b/utils/tests/test_logging.py @@ -0,0 +1,19 @@ +import unittest +from pathlib import Path + +from utils.logging import LoggingLogger + + +class TestLoggingLogger(unittest.TestCase): + def setUp(self): + self.tmp_log_path = Path("./tmp_logging.txt") + + def test_logging_info(self): + logger = LoggingLogger(filename=self.tmp_log_path, + level="INFO").create_instance() + logger.info("logging information") + self.assertTrue(self.tmp_log_path.exists()) + + def tearDown(self): + if self.tmp_log_path.exists(): + self.tmp_log_path.unlink() diff --git a/utils/torch_utilities.py b/utils/torch_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c46e2e2013479d608e6c1aded2126165443a9a --- /dev/null +++ b/utils/torch_utilities.py @@ -0,0 +1,288 @@ +import logging +import math +from typing import Callable +from pathlib import Path +import numpy as np +import torch +import torch.nn as nn + +logger = logging.Logger(__file__) + + +def remove_key_prefix_factory(prefix: str = "module."): + def func( + model_dict: dict[str, torch.Tensor], state_dict: dict[str, + torch.Tensor] + ) -> dict[str, torch.Tensor]: + + state_dict = { + key[len(prefix):]: value + for key, value in state_dict.items() if key.startswith(prefix) + } + return state_dict + + return func + + +def merge_matched_keys( + model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor] +) -> dict[str, torch.Tensor]: + """ + Args: + model_dict: + The state dict of the current model, which is going to load pretrained parameters + state_dict: + A dictionary of parameters from a pre-trained model. + + Returns: + dict[str, torch.Tensor]: + The updated state dict, where parameters with matched keys and shape are + updated with values in `state_dict`. + """ + pretrained_dict = {} + mismatch_keys = [] + for key, value in state_dict.items(): + if key in model_dict and model_dict[key].shape == value.shape: + pretrained_dict[key] = value + else: + mismatch_keys.append(key) + logger.info( + f"Loading pre-trained model, with mismatched keys {mismatch_keys}" + ) + model_dict.update(pretrained_dict) + return model_dict + + +def load_pretrained_model( + model: nn.Module, + ckpt_or_state_dict: str | Path | dict[str, torch.Tensor], + state_dict_process_fn: Callable = merge_matched_keys +) -> None: + state_dict = ckpt_or_state_dict + if not isinstance(state_dict, dict): + state_dict = torch.load(ckpt_or_state_dict, "cpu") + + model_dict = model.state_dict() + state_dict = state_dict_process_fn(model_dict, state_dict) + model.load_state_dict(state_dict) + + +def create_mask_from_length( + lengths: torch.Tensor, max_length: int | None = None +): + if max_length is None: + max_length = max(lengths) + idxs = torch.arange(max_length).reshape(1, -1) # (1, max_length) + mask = idxs.to(lengths.device) < lengths.view(-1, 1) + # (1, max_length) < (batch_size, 1) -> (batch_size, max_length) + return mask + + +def loss_with_mask( + loss: torch.Tensor, + mask: torch.Tensor, + reduce: bool = True +) -> torch.Tensor: + """ + Apply a mask to the loss tensor and optionally reduce it. + + Args: + loss: Tensor of shape (b, t, ...) representing the loss values. + mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions. + reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,). + + Returns: + torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,). + """ + expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)] + expanded_mask = expanded_mask.expand_as(loss) + masked_loss = loss * expanded_mask + + sum_dims = tuple(range(1, loss.ndim)) + loss_sum = masked_loss.sum(dim=sum_dims) + mask_sum = expanded_mask.sum(dim=sum_dims) + loss = loss_sum / mask_sum + + if reduce: + return loss.mean() + else: + return loss + + +def convert_pad_shape(pad_shape: list[list[int]]): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = create_mask_from_length(cum_duration_flat, t_y).float() + path = path.view(b, t_x, t_y) + # take the diff on the `t_x` axis + path = path - torch.nn.functional.pad( + path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]) + )[:, :-1] + path = path * mask + return path + + +def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int): + """ + Adjusts the size of the specified dimension of tensor x to match `target_length`. + + Args: + x: + Input tensor. + target_length: + Desired size of the specified dimension. + length_dim: + The dimension to modify. + + Returns: + torch.Tensor: The adjusted tensor. + """ + current_length = x.shape[length_dim] + + if current_length > target_length: + # Truncate the tensor + slices = [slice(None)] * x.ndim + slices[length_dim] = slice(0, target_length) + return x[tuple(slices)] + + elif current_length < target_length: + # Pad the tensor with zeros + pad_shape = list(x.shape) + pad_length = target_length - current_length + + pad_shape[length_dim] = pad_length # Shape for left padding + padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device) + + return torch.cat([x, padding], dim=length_dim) + + return x + + +def concat_non_padding( + seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor, + mask2: torch.BoolTensor +): + """ + Args + seq1 : Tensor (B, L1, E) + First sequence. + mask1 : BoolTensor (B, L1) + True for valid tokens in seq1, False for padding. + seq2 : Tensor (B, L2, E) + Second sequence. + mask2 : BoolTensor (B, L2) + True for valid tokens in seq2, False for padding. + + Returns + concat_seq : Tensor (B, L1+L2, E) + Both sequences concatenated; valid tokens are left-aligned, + padding on the right is 0. + concat_mask: BoolTensor (B, L1+L2) + Mask for the concatenated sequence. + perm : LongTensor (B, L1+L2) + Permutation that maps **original indices → new indices**. + Needed for restoring the original sequences. + """ + mask1, mask2 = mask1.bool(), mask2.bool() + B, L1, E = seq1.shape + L2 = seq2.size(1) + L = L1 + L2 + + seq_cat = torch.cat([seq1, seq2], dim=1) # (B, L, E) + mask_cat = torch.cat([mask1, mask2], dim=1) # (B, L) + + # ----- Key step: stable sort so that all valid tokens move to the left ----- + # Padding positions get +L, guaranteeing the largest “score” → sorted to the end. + positions = torch.arange(L, device=seq_cat.device).unsqueeze(0) # (1, L) + sort_score = positions + (~mask_cat) * L + perm = sort_score.argsort(dim=1, stable=True) # (B, L) + + # Build concatenated sequence & mask + gather_idx = perm.unsqueeze(-1).expand(-1, -1, E) # (B, L, E) + concat_seq = seq_cat.gather(1, gather_idx) + concat_mask = mask_cat.gather(1, perm) + + # Explicitly zero out the right-hand padding region for safety + concat_seq = concat_seq * concat_mask.unsqueeze(-1) + + return concat_seq, concat_mask, perm + + +def restore_from_concat( + concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor, + perm: torch.LongTensor +): + """ + Restore (seq1, seq2) from the concatenated sequence produced by + `concat_non_padding`, using the returned permutation `perm`. + Fully vectorised — no Python loops. + """ + mask1, mask2 = mask1.bool(), mask2.bool() + B, L1 = mask1.shape + L2 = mask2.size(1) + E = concat_seq.size(-1) + + # Inverse permutation: maps **new_idx → old_idx** + inv_perm = torch.empty_like(perm) + inv_perm.scatter_( + 1, perm, + torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1) + ) + + # Bring tokens back to their original order + gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E) + seq_cat_rec = concat_seq.gather(1, gather_idx) # (B, L1+L2, E) + + # Split back into the two sequences and mask out padding positions + seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1) + seq1_restore = seq1_restore * mask1.unsqueeze(-1) + seq2_restore = seq2_restore * mask2.unsqueeze(-1) + + return seq1_restore, seq2_restore + + +def contains_nan(data): + """check if data contains NaN""" + if isinstance(data, torch.Tensor): + return torch.isnan(data).any().item() + elif isinstance(data, np.ndarray): + return np.isnan(data).any() + elif isinstance(data, float): + return math.isnan(data) + elif isinstance(data, (list, tuple)): + return any(contains_nan(x) for x in data) + elif isinstance(data, dict): + return any(contains_nan(v) for v in data.values()) + return False + + +def check_nan_in_batch(batch): + """check if batch contains NaN and return nan audio ids""" + assert type(batch)==dict,"batch type error" + nan_audio_ids=[] + audio_ids=batch["audio_id"] + audio_id2content={} + for idx,audio_id in enumerate(audio_ids): + content=[] + for k,v in batch.items(): + if k=="audio_id": + continue + content.append(v[idx]) + audio_id2content[audio_id]=content + + for audio_id,content in audio_id2content.items(): + if contains_nan(content): + nan_audio_ids.append(audio_id) + print(f"{audio_id} contains NaN") + return nan_audio_ids + diff --git a/utils/video.py b/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..0c02a1ad1e01f00e74c39e1a60f7af533ed8b568 --- /dev/null +++ b/utils/video.py @@ -0,0 +1,44 @@ +from pathlib import Path +import os +from moviepy import VideoFileClip, AudioFileClip +from moviepy.audio.fx import AudioLoop + + +def merge_audio_video( + audio_path: str | Path, + video_path: str | Path, + target_path: str | Path, + backend: str = "moviepy", + logging: bool = False +): + """ + Merge audio and video into a single file. + + Args: + audio_path (str | Path): Path to the audio file. + video_path (str | Path): Path to the video file. + target_path (str | Path): Path to the target file. + backend (str, optional): The backend to use for merging. Defaults to "moviepy". + """ + assert backend in [ + "moviepy", "ffmpeg" + ], "Backend should be moviepy or ffmpeg" + if backend == "moviepy": + video = VideoFileClip(video_path.__str__()) + audio = AudioFileClip(audio_path.__str__()) + + video = video.with_audio(audio) + + target_path = Path(target_path) + video.write_videofile( + target_path, + logger=None if not logging else "bar", + threads=8, + preset="ultrafast", + ffmpeg_params=["-crf", "23"] + ) + else: + logging_arg = "" if logging else "-loglevel quiet" + command = f"ffmpeg {logging_arg} -i '{video_path.__str__()}' -i '{audio_path.__str__()}' -c:v copy " \ + f"-c:a copy -map 0:v:0 -map 1:a:0 '{target_path.__str__()}'" + os.system(command) diff --git a/wav/audio.wav b/wav/audio.wav new file mode 100644 index 0000000000000000000000000000000000000000..a5d1b3c0a3e64ca66bc88605219bb17402867b56 --- /dev/null +++ b/wav/audio.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e947b99300cc46dde56a1367958802d7736c2e3efb640c44c58d50122e41c2a8 +size 480044 diff --git a/wav/speech.wav b/wav/speech.wav new file mode 100644 index 0000000000000000000000000000000000000000..933c2776795b9369347efb1103fd9764e6b7b861 --- /dev/null +++ b/wav/speech.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d23cb7106c66ad61d2b9717daea77385883cf71772836a8c5d18b9496dbb8d5 +size 130604