diff --git a/app.py b/app.py
index 399c014e7f96f89c97f024b3d80c6f61145b4235..510cca474d028a879dd54a17168cc5139684132e 100644
--- a/app.py
+++ b/app.py
@@ -1,26 +1,118 @@
import gradio as gr
+from pathlib import Path
+
+import soundfile as sf
+import torch
+import torchaudio
+import hydra
+from omegaconf import OmegaConf
+import diffusers.schedulers as noise_schedulers
+
+from utils.config import register_omegaconf_resolvers
+from models.common import LoadPretrainedBase
+
+from huggingface_hub import hf_hub_download
+import fairseq
+
+register_omegaconf_resolvers()
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+config = OmegaConf.load("configs/infer.yaml")
+
+ckpt_path = hf_hub_download(
+ repo_id="assasinatee/STAR",
+ filename="model.safetensors",
+ repo_type="model",
+ force_download=False
+)
+
+exp_config = OmegaConf.load("configs/config.yaml")
+if "pretrained_ckpt" in exp_config["model"]:
+ exp_config["model"]["pretrained_ckpt"] = ckpt_path
+model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
+
+model = model.to(device)
+
+ckpt_path = hf_hub_download(
+ repo_id="assasinatee/STAR",
+ filename="hubert_large_ll60k.pt",
+ repo_type="model",
+ force_download=False
+)
+hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
+hubert_model = hubert_models[0].eval().to(device)
+
+scheduler = getattr(
+ noise_schedulers,
+ config["noise_scheduler"]["type"],
+).from_pretrained(
+ config["noise_scheduler"]["name"],
+ subfolder="scheduler",
+)
+
+@torch.no_grad()
+def infer(audio_path: str) -> str:
+ waveform_tts, sample_rate = torchaudio.load(audio_path)
+ if sample_rate != 16000:
+ waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
+ if waveform_tts.shape[0] > 1:
+ waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
+ with torch.no_grad():
+ features, _ = hubert_model.extract_features(waveform_tts.to(device))
+
+ kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
+ kwargs['content'] = [features]
+ kwargs['condition'] = None
+ kwargs['task'] = ["speech_to_audio"]
+
+ model.eval()
+ waveform = model.inference(
+ scheduler=scheduler,
+ **kwargs,
+ )
+
+ output_file = "output_audio.wav"
+ sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
+
+ return output_file
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
gr.Markdown("""
+
## 🗣️ Input
A brief input speech utterance for the overall audio scene.
-> Example:A cat meowing and young female speaking
+> Example:A cat meowing and young female speaking
+
+### 🎙️ Input Speech Example
+""")
+
+ speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
+
+ gr.Markdown("""
+
-#### 🎙️ Input Speech Example
+### 🎧️ Output Audio Example
+""")
+
+ audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
-#### 🎧️ Output Audio Example
+ gr.Markdown("""
---
-""")
+""")
-if __name__ == "__main__":
- demo.launch()
+ with gr.Column():
+ input_audio = gr.Audio(label="Speech Input", type="filepath")
+ btn = gr.Button("🎵Generate Audio!", variant="primary")
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
+ btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
+
+demo.launch()
\ No newline at end of file
diff --git a/ckpts/1m.pt b/ckpts/1m.pt
new file mode 100644
index 0000000000000000000000000000000000000000..895b620b0af0420ebf153d4118e5f1e22adb6bbf
--- /dev/null
+++ b/ckpts/1m.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cb13e2699fa922ce6a2b3b4f53c270ec64156e0cc3f3e3645e10cdf98b740dc
+size 183037614
diff --git a/ckpts/exp0_best.pt b/ckpts/exp0_best.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6e25ad0e1da3a2bc155c51215d8b1da35475859e
--- /dev/null
+++ b/ckpts/exp0_best.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e2dc436e6d47cb02e954a0087a3a1b4aa1d5d3e1ded4fdafb6274966264d5a7
+size 73171895
diff --git a/configs/config.yaml b/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a036c765fa906d496e5b97af0ecc9b24e292ddb
--- /dev/null
+++ b/configs/config.yaml
@@ -0,0 +1,84 @@
+model:
+ autoencoder:
+ _target_: models.autoencoder.waveform.stable_vae.StableVAE
+ encoder:
+ _target_: models.autoencoder.waveform.stable_vae.OobleckEncoder
+ in_channels: 1
+ channels: 128
+ c_mults:
+ - 1
+ - 2
+ - 4
+ - 8
+ strides:
+ - 2
+ - 4
+ - 6
+ - 10
+ latent_dim: 256
+ use_snake: true
+ decoder:
+ _target_: models.autoencoder.waveform.stable_vae.OobleckDecoder
+ out_channels: 1
+ channels: 128
+ c_mults:
+ - 1
+ - 2
+ - 4
+ - 8
+ strides:
+ - 2
+ - 4
+ - 6
+ - 10
+ latent_dim: 128
+ use_snake: true
+ final_tanh: false
+ io_channels: 1
+ latent_dim: 128
+ downsampling_ratio: 480
+ sample_rate: 24000
+ pretrained_ckpt: /hpc_stor03/sjtu_home/xuenan.xu/workspace/text_to_audio_generation/ezaudio/ckpts/vae/1m.pt
+ bottleneck:
+ _target_: models.autoencoder.waveform.stable_vae.VAEBottleneck
+ backbone:
+ _target_: models.dit.mask_dit.UDiT
+ img_size: 500
+ patch_size: 1
+ in_chans: 128
+ out_chans: 128
+ input_type: 1d
+ embed_dim: 1024
+ depth: 24
+ num_heads: 16
+ mlp_ratio: 4.0
+ qkv_bias: false
+ qk_scale: null
+ qk_norm: layernorm
+ norm_layer: layernorm
+ act_layer: geglu
+ context_norm: true
+ use_checkpoint: true
+ time_fusion: ada_sola_bias
+ ada_sola_rank: 32
+ ada_sola_alpha: 32
+ cls_dim: null
+ context_dim: 1024
+ context_fusion: cross
+ context_max_length: null
+ context_pe_method: none
+ pe_method: none
+ rope_mode: shared
+ use_conv: true
+ skip: true
+ skip_norm: true
+ cfg_drop_ratio: 0.2
+ _target_: models.flow_matching.SingleTaskCrossAttentionAudioFlowMatching
+ content_encoder:
+ _target_: models.content_encoder.content_encoder.ContentEncoder
+ embed_dim: 1024
+ text_encoder: None
+ speech_encoder:
+ _target_: models.content_encoder.star_encoder.star_encoder.QformerBridgeNet
+ load_from_pretrained: /hpc_stor03/sjtu_home/zeyu.xie/workspace/speech2audio/hear/output/qformer_caption_tts_hubert/exp0_best.pt
+ pretrained_ckpt: /hpc_stor03/sjtu_home/zeyu.xie/workspace/speech2audio/x2audio/x_to_audio_generation/experiments/audiocaps_fm/checkpoints/epoch_100/model.safetensors
\ No newline at end of file
diff --git a/configs/infer.yaml b/configs/infer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ca815e4d0c2106979a32a7c565a1588f291cef2
--- /dev/null
+++ b/configs/infer.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - basic
+ - _self_
+
+wav_dir: inference_delay
+
+noise_scheduler:
+ type: DDIMScheduler
+ name: stabilityai/stable-diffusion-2-1
+
+infer_args:
+ num_steps: 50
+ guidance_scale: 3.5
+ guidance_rescale: 0.5
+ use_gt_duration: false
+ latent_shape: [128, 500]
\ No newline at end of file
diff --git a/models/__pycache__/common.cpython-310.pyc b/models/__pycache__/common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c26bf242810cd0f5a5781f1002167c5b097e0971
Binary files /dev/null and b/models/__pycache__/common.cpython-310.pyc differ
diff --git a/models/__pycache__/content_adapter.cpython-310.pyc b/models/__pycache__/content_adapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98c0b7bf7aba309822162289ac0ef6f43e591c24
Binary files /dev/null and b/models/__pycache__/content_adapter.cpython-310.pyc differ
diff --git a/models/__pycache__/diffusion.cpython-310.pyc b/models/__pycache__/diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c24a2fbbd3e433632f343efc0d7465097e288f2d
Binary files /dev/null and b/models/__pycache__/diffusion.cpython-310.pyc differ
diff --git a/models/__pycache__/flow_matching.cpython-310.pyc b/models/__pycache__/flow_matching.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d1126e2d787c7e6493b3a377bdee61756bfa5ae
Binary files /dev/null and b/models/__pycache__/flow_matching.cpython-310.pyc differ
diff --git a/models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc b/models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0974bf1cb6f1d8847f573c39982c74a98e7b5bea
Binary files /dev/null and b/models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc differ
diff --git a/models/autoencoder/autoencoder_base.py b/models/autoencoder/autoencoder_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2852ad185b48e9595e116735baed689fa09cc0d3
--- /dev/null
+++ b/models/autoencoder/autoencoder_base.py
@@ -0,0 +1,22 @@
+from abc import abstractmethod, ABC
+from typing import Sequence
+import torch
+import torch.nn as nn
+
+
+class AutoEncoderBase(ABC):
+ def __init__(
+ self, downsampling_ratio: int, sample_rate: int,
+ latent_shape: Sequence[int | None]
+ ):
+ self.downsampling_ratio = downsampling_ratio
+ self.sample_rate = sample_rate
+ self.latent_token_rate = sample_rate // downsampling_ratio
+ self.latent_shape = latent_shape
+ self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
+
+ @abstractmethod
+ def encode(
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ ...
diff --git a/models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc b/models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd4019e321981d753c20a140335e93b16d957476
Binary files /dev/null and b/models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc differ
diff --git a/models/autoencoder/waveform/dac.py b/models/autoencoder/waveform/dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/autoencoder/waveform/stable_vae.py b/models/autoencoder/waveform/stable_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbebfce624995379da5b3957e9d74d1186dce6b5
--- /dev/null
+++ b/models/autoencoder/waveform/stable_vae.py
@@ -0,0 +1,559 @@
+from typing import Any, Literal, Callable
+import math
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+import torchaudio
+from alias_free_torch import Activation1d
+
+from models.common import LoadPretrainedBase
+from models.autoencoder.autoencoder_base import AutoEncoderBase
+from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
+
+
+# jit script make it 1.4x faster and save GPU memory
+@torch.jit.script
+def snake_beta(x, alpha, beta):
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
+
+
+class SnakeBeta(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ alpha=1.0,
+ alpha_trainable=True,
+ alpha_logscale=True
+ ):
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale:
+ # log scale alphas initialized to zeros
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
+ else:
+ # linear scale alphas initialized to ones
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ # self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
+ # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = snake_beta(x, alpha, beta)
+
+ return x
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+def get_activation(
+ activation: Literal["elu", "snake", "none"],
+ antialias=False,
+ channels=None
+) -> nn.Module:
+ if activation == "elu":
+ act = nn.ELU()
+ elif activation == "snake":
+ act = SnakeBeta(channels)
+ elif activation == "none":
+ act = nn.Identity()
+ else:
+ raise ValueError(f"Unknown activation {activation}")
+
+ if antialias:
+ act = Activation1d(act)
+
+ return act
+
+
+class ResidualUnit(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ dilation,
+ use_snake=False,
+ antialias_activation=False
+ ):
+ super().__init__()
+
+ self.dilation = dilation
+
+ padding = (dilation * (7 - 1)) // 2
+
+ self.layers = nn.Sequential(
+ get_activation(
+ "snake" if use_snake else "elu",
+ antialias=antialias_activation,
+ channels=out_channels
+ ),
+ WNConv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=7,
+ dilation=dilation,
+ padding=padding
+ ),
+ get_activation(
+ "snake" if use_snake else "elu",
+ antialias=antialias_activation,
+ channels=out_channels
+ ),
+ WNConv1d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=1
+ )
+ )
+
+ def forward(self, x):
+ res = x
+
+ #x = checkpoint(self.layers, x)
+ x = self.layers(x)
+
+ return x + res
+
+
+class EncoderBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ use_snake=False,
+ antialias_activation=False
+ ):
+ super().__init__()
+
+ self.layers = nn.Sequential(
+ ResidualUnit(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dilation=1,
+ use_snake=use_snake
+ ),
+ ResidualUnit(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dilation=3,
+ use_snake=use_snake
+ ),
+ ResidualUnit(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dilation=9,
+ use_snake=use_snake
+ ),
+ get_activation(
+ "snake" if use_snake else "elu",
+ antialias=antialias_activation,
+ channels=in_channels
+ ),
+ WNConv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2)
+ ),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ use_snake=False,
+ antialias_activation=False,
+ use_nearest_upsample=False
+ ):
+ super().__init__()
+
+ if use_nearest_upsample:
+ upsample_layer = nn.Sequential(
+ nn.Upsample(scale_factor=stride, mode="nearest"),
+ WNConv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2 * stride,
+ stride=1,
+ bias=False,
+ padding='same'
+ )
+ )
+ else:
+ upsample_layer = WNConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2)
+ )
+
+ self.layers = nn.Sequential(
+ get_activation(
+ "snake" if use_snake else "elu",
+ antialias=antialias_activation,
+ channels=in_channels
+ ),
+ upsample_layer,
+ ResidualUnit(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ dilation=1,
+ use_snake=use_snake
+ ),
+ ResidualUnit(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ dilation=3,
+ use_snake=use_snake
+ ),
+ ResidualUnit(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ dilation=9,
+ use_snake=use_snake
+ ),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class OobleckEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=2,
+ channels=128,
+ latent_dim=32,
+ c_mults=[1, 2, 4, 8],
+ strides=[2, 4, 8, 8],
+ use_snake=False,
+ antialias_activation=False
+ ):
+ super().__init__()
+
+ c_mults = [1] + c_mults
+
+ self.depth = len(c_mults)
+
+ layers = [
+ WNConv1d(
+ in_channels=in_channels,
+ out_channels=c_mults[0] * channels,
+ kernel_size=7,
+ padding=3
+ )
+ ]
+
+ for i in range(self.depth - 1):
+ layers += [
+ EncoderBlock(
+ in_channels=c_mults[i] * channels,
+ out_channels=c_mults[i + 1] * channels,
+ stride=strides[i],
+ use_snake=use_snake
+ )
+ ]
+
+ layers += [
+ get_activation(
+ "snake" if use_snake else "elu",
+ antialias=antialias_activation,
+ channels=c_mults[-1] * channels
+ ),
+ WNConv1d(
+ in_channels=c_mults[-1] * channels,
+ out_channels=latent_dim,
+ kernel_size=3,
+ padding=1
+ )
+ ]
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class OobleckDecoder(nn.Module):
+ def __init__(
+ self,
+ out_channels=2,
+ channels=128,
+ latent_dim=32,
+ c_mults=[1, 2, 4, 8],
+ strides=[2, 4, 8, 8],
+ use_snake=False,
+ antialias_activation=False,
+ use_nearest_upsample=False,
+ final_tanh=True
+ ):
+ super().__init__()
+
+ c_mults = [1] + c_mults
+
+ self.depth = len(c_mults)
+
+ layers = [
+ WNConv1d(
+ in_channels=latent_dim,
+ out_channels=c_mults[-1] * channels,
+ kernel_size=7,
+ padding=3
+ ),
+ ]
+
+ for i in range(self.depth - 1, 0, -1):
+ layers += [
+ DecoderBlock(
+ in_channels=c_mults[i] * channels,
+ out_channels=c_mults[i - 1] * channels,
+ stride=strides[i - 1],
+ use_snake=use_snake,
+ antialias_activation=antialias_activation,
+ use_nearest_upsample=use_nearest_upsample
+ )
+ ]
+
+ layers += [
+ get_activation(
+ "snake" if use_snake else "elu",
+ antialias=antialias_activation,
+ channels=c_mults[0] * channels
+ ),
+ WNConv1d(
+ in_channels=c_mults[0] * channels,
+ out_channels=out_channels,
+ kernel_size=7,
+ padding=3,
+ bias=False
+ ),
+ nn.Tanh() if final_tanh else nn.Identity()
+ ]
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class Bottleneck(nn.Module):
+ def __init__(self, is_discrete: bool = False):
+ super().__init__()
+
+ self.is_discrete = is_discrete
+
+ def encode(self, x, return_info=False, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, x):
+ raise NotImplementedError
+
+
+@torch.jit.script
+def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
+ stdev = nn.functional.softplus(scale) + 1e-4
+ var = stdev * stdev
+ logvar = torch.log(var)
+ latents = torch.randn_like(mean) * stdev + mean
+
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
+ return {"latents": latents, "kl": kl}
+
+
+class VAEBottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+
+ def encode(self,
+ x,
+ return_info=False,
+ **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
+ mean, scale = x.chunk(2, dim=1)
+ sampled = vae_sample(mean, scale)
+
+ if return_info:
+ return sampled["latents"], {"kl": sampled["kl"]}
+ else:
+ return sampled["latents"]
+
+ def decode(self, x):
+ return x
+
+
+def compute_mean_kernel(x, y):
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
+ return torch.exp(-kernel_input).mean()
+
+
+class Pretransform(nn.Module):
+ def __init__(self, enable_grad, io_channels, is_discrete):
+ super().__init__()
+
+ self.is_discrete = is_discrete
+ self.io_channels = io_channels
+ self.encoded_channels = None
+ self.downsampling_ratio = None
+
+ self.enable_grad = enable_grad
+
+ def encode(self, x):
+ raise NotImplementedError
+
+ def decode(self, z):
+ raise NotImplementedError
+
+ def tokenize(self, x):
+ raise NotImplementedError
+
+ def decode_tokens(self, tokens):
+ raise NotImplementedError
+
+
+class StableVAE(LoadPretrainedBase, AutoEncoderBase):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ latent_dim,
+ downsampling_ratio,
+ sample_rate,
+ io_channels=2,
+ bottleneck: Bottleneck = None,
+ pretransform: Pretransform = None,
+ in_channels=None,
+ out_channels=None,
+ soft_clip=False,
+ pretrained_ckpt: str | Path = None
+ ):
+ LoadPretrainedBase.__init__(self)
+ AutoEncoderBase.__init__(
+ self,
+ downsampling_ratio=downsampling_ratio,
+ sample_rate=sample_rate,
+ latent_shape=(latent_dim, None)
+ )
+
+ self.latent_dim = latent_dim
+ self.io_channels = io_channels
+ self.in_channels = io_channels
+ self.out_channels = io_channels
+ self.min_length = self.downsampling_ratio
+
+ if in_channels is not None:
+ self.in_channels = in_channels
+
+ if out_channels is not None:
+ self.out_channels = out_channels
+
+ self.bottleneck = bottleneck
+ self.encoder = encoder
+ self.decoder = decoder
+ self.pretransform = pretransform
+ self.soft_clip = soft_clip
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
+
+ self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
+ "autoencoder."
+ )
+ if pretrained_ckpt is not None:
+ self.load_pretrained(pretrained_ckpt)
+
+ def process_state_dict(self, model_dict, state_dict):
+ state_dict = state_dict["state_dict"]
+ state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
+ return state_dict
+
+ def encode(
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ z = self.encoder(waveform)
+ z = self.bottleneck.encode(z)
+ z_length = waveform_lengths // self.downsampling_ratio
+ z_mask = create_mask_from_length(z_length)
+ return z, z_mask
+
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
+ waveform = self.decoder(latents)
+ return waveform
+
+
+class StableVAEProjectorWrapper(nn.Module):
+ def __init__(
+ self,
+ vae_dim: int,
+ embed_dim: int,
+ model: StableVAE | None = None,
+ ):
+ super().__init__()
+ self.model = model
+ self.proj = nn.Linear(vae_dim, embed_dim)
+
+ def forward(
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ self.model.eval()
+ with torch.no_grad():
+ z, z_mask = self.model.encode(waveform, waveform_lengths)
+ z = self.proj(z.transpose(1, 2))
+ return {"output": z, "mask": z_mask}
+
+
+if __name__ == '__main__':
+ import hydra
+ from utils.config import generate_config_from_command_line_overrides
+ model_config = generate_config_from_command_line_overrides(
+ "configs/model/autoencoder/stable_vae.yaml"
+ )
+ autoencoder: StableVAE = hydra.utils.instantiate(model_config)
+ autoencoder.eval()
+
+ waveform, sr = torchaudio.load(
+ "/hpc_stor03/sjtu_home/xuenan.xu/data/m4singer/Tenor-1#童话/0006.wav"
+ )
+ waveform = waveform.mean(0, keepdim=True)
+ waveform = torchaudio.functional.resample(
+ waveform, sr, model_config["sample_rate"]
+ )
+ print("waveform: ", waveform.shape)
+ with torch.no_grad():
+ latent, latent_length = autoencoder.encode(
+ waveform, torch.as_tensor([waveform.shape[-1]])
+ )
+ print("latent: ", latent.shape)
+ reconstructed = autoencoder.decode(latent)
+ print("reconstructed: ", reconstructed.shape)
+ import soundfile as sf
+ sf.write(
+ "./reconstructed.wav",
+ reconstructed[0, 0].numpy(),
+ samplerate=model_config["sample_rate"]
+ )
diff --git a/models/common.py b/models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..79afb0be423dab991b5e7f14a74ba5503ce0c426
--- /dev/null
+++ b/models/common.py
@@ -0,0 +1,67 @@
+from pathlib import Path
+import torch
+import torch.nn as nn
+from utils.torch_utilities import load_pretrained_model, merge_matched_keys
+
+
+class LoadPretrainedBase(nn.Module):
+ def process_state_dict(
+ self, model_dict: dict[str, torch.Tensor],
+ state_dict: dict[str, torch.Tensor]
+ ):
+ """
+ Custom processing functions of each model that transforms `state_dict` loaded from
+ checkpoints to the state that can be used in `load_state_dict`.
+ Use `merge_mathced_keys` to update parameters with matched names and shapes by
+ default.
+
+ Args
+ model_dict:
+ The state dict of the current model, which is going to load pretrained parameters
+ state_dict:
+ A dictionary of parameters from a pre-trained model.
+
+ Returns:
+ dict[str, torch.Tensor]:
+ The updated state dict, where parameters with matched keys and shape are
+ updated with values in `state_dict`.
+ """
+ state_dict = merge_matched_keys(model_dict, state_dict)
+ return state_dict
+
+ def load_pretrained(self, ckpt_path: str | Path):
+ load_pretrained_model(
+ self, ckpt_path, state_dict_process_fn=self.process_state_dict
+ )
+
+
+class CountParamsBase(nn.Module):
+ def count_params(self):
+ num_params = 0
+ trainable_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ if param.requires_grad:
+ trainable_params += param.numel()
+ return num_params, trainable_params
+
+
+class SaveTrainableParamsBase(nn.Module):
+ @property
+ def param_names_to_save(self):
+ names = []
+ for name, param in self.named_parameters():
+ if param.requires_grad:
+ names.append(name)
+ for name, _ in self.named_buffers():
+ names.append(name)
+ return names
+
+ def load_state_dict(self, state_dict, strict=True):
+ for key in self.param_names_to_save:
+ if key not in state_dict:
+ raise Exception(
+ f"{key} not found in either pre-trained models (e.g. BERT)"
+ " or resumed checkpoints (e.g. epoch_40/model.pt)"
+ )
+ return super().load_state_dict(state_dict, strict)
diff --git a/models/content_adapter.py b/models/content_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..083b2c9b639de703b8e63748a611509c379e513c
--- /dev/null
+++ b/models/content_adapter.py
@@ -0,0 +1,381 @@
+import math
+import torch
+import torch.nn as nn
+
+from utils.torch_utilities import concat_non_padding, restore_from_concat
+
+
+######################
+# fastspeech modules
+######################
+class LayerNorm(nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+ def __init__(self, nout, dim=-1):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm,
+ self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class DurationPredictor(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ filter_channels: int,
+ n_layers: int = 2,
+ kernel_size: int = 3,
+ p_dropout: float = 0.1,
+ padding: str = "SAME"
+ ):
+ super(DurationPredictor, self).__init__()
+ self.conv = nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.padding = padding
+ for idx in range(n_layers):
+ in_chans = in_channels if idx == 0 else filter_channels
+ self.conv += [
+ nn.Sequential(
+ nn.ConstantPad1d(((kernel_size - 1) // 2,
+ (kernel_size - 1) //
+ 2) if padding == 'SAME' else
+ (kernel_size - 1, 0), 0),
+ nn.Conv1d(
+ in_chans,
+ filter_channels,
+ kernel_size,
+ stride=1,
+ padding=0
+ ), nn.ReLU(), LayerNorm(filter_channels, dim=1),
+ nn.Dropout(p_dropout)
+ )
+ ]
+ self.linear = nn.Linear(filter_channels, 1)
+
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
+ # x: [B, T, E]
+ x = x.transpose(1, -1)
+ x_mask = x_mask.unsqueeze(1).to(x.device)
+ for f in self.conv:
+ x = f(x)
+ x = x * x_mask.float()
+
+ x = self.linear(x.transpose(1, -1)
+ ) * x_mask.transpose(1, -1).float() # [B, T, 1]
+ return x
+
+
+######################
+# adapter modules
+######################
+
+
+class ContentAdapterBase(nn.Module):
+ def __init__(self, d_out):
+ super().__init__()
+ self.d_out = d_out
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ def __init__(self, d_model, dropout, max_len=1000):
+ super().__init__()
+ self.dropout = nn.Dropout(dropout)
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2).float() *
+ (-math.log(10000.0) / d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:x.size(1), :]
+ return self.dropout(x)
+
+
+class ContentAdapter(ContentAdapterBase):
+ def __init__(
+ self,
+ d_model: int,
+ d_out: int,
+ num_layers: int,
+ num_heads: int,
+ duration_predictor: DurationPredictor,
+ dropout: float = 0.1,
+ norm_first: bool = False,
+ activation: str = "gelu",
+ duration_grad_scale: float = 0.0,
+ ):
+ super().__init__(d_out)
+ self.duration_grad_scale = duration_grad_scale
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
+ if hasattr(torch, "npu") and torch.npu.is_available():
+ enable_nested_tensor = False
+ else:
+ enable_nested_tensor = True
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=num_heads,
+ dim_feedforward=4 * d_model,
+ dropout=dropout,
+ activation=activation,
+ norm_first=norm_first,
+ batch_first=True
+ )
+ self.encoder_layers = nn.TransformerEncoder(
+ encoder_layer=encoder_layer,
+ num_layers=num_layers,
+ enable_nested_tensor=enable_nested_tensor
+ )
+ self.duration_predictor = duration_predictor
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
+
+ def forward(self, x, x_mask):
+ batch_size = x.size(0)
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
+ cls_embed = cls_embed.to(x.device).unsqueeze(1)
+ x = torch.cat([cls_embed, x], dim=1)
+
+ cls_mask = torch.ones(batch_size, 1).to(x_mask.device)
+ x_mask = torch.cat([cls_mask, x_mask], dim=1)
+ x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool())
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
+ ) * (1 - self.duration_grad_scale)
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
+
+
+class PrefixAdapter(ContentAdapterBase):
+ def __init__(
+ self,
+ content_dim: int,
+ d_model: int,
+ d_out: int,
+ prefix_dim: int,
+ num_layers: int,
+ num_heads: int,
+ duration_predictor: DurationPredictor,
+ dropout: float = 0.1,
+ norm_first: bool = False,
+ use_last_norm: bool = True,
+ activation: str = "gelu",
+ duration_grad_scale: float = 0.1,
+ ):
+ super().__init__(d_out)
+ self.duration_grad_scale = duration_grad_scale
+ self.prefix_mlp = nn.Sequential(
+ nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
+ nn.Linear(d_model, d_model)
+ )
+ self.content_mlp = nn.Sequential(
+ nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
+ nn.Linear(d_model, d_model)
+ )
+ layer = nn.TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=num_heads,
+ dim_feedforward=4 * d_model,
+ dropout=dropout,
+ activation=activation,
+ batch_first=True,
+ norm_first=norm_first
+ )
+ if hasattr(torch, "npu") and torch.npu.is_available():
+ enable_nested_tensor = False
+ else:
+ enable_nested_tensor = True
+ self.cls_embed = nn.Parameter(torch.randn(d_model))
+ # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout)
+ self.layers = nn.TransformerEncoder(
+ encoder_layer=layer,
+ num_layers=num_layers,
+ enable_nested_tensor=enable_nested_tensor
+ )
+ self.use_last_norm = use_last_norm
+ if self.use_last_norm:
+ self.last_norm = nn.LayerNorm(d_model)
+ self.duration_predictor = duration_predictor
+ self.content_proj = nn.Conv1d(d_model, d_out, 1)
+ nn.init.normal_(self.cls_embed, 0., 0.02)
+ nn.init.xavier_uniform_(self.content_proj.weight)
+ nn.init.constant_(self.content_proj.bias, 0.)
+
+ def forward(self, content, content_mask, instruction, instruction_mask):
+ batch_size = content.size(0)
+ cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
+ cls_embed = cls_embed.to(content.device).unsqueeze(1)
+ content = self.content_mlp(content)
+ x = torch.cat([cls_embed, content], dim=1)
+ cls_mask = torch.ones(batch_size, 1,
+ dtype=bool).to(content_mask.device)
+ x_mask = torch.cat([cls_mask, content_mask], dim=1)
+
+ prefix = self.prefix_mlp(instruction)
+ seq, seq_mask, perm = concat_non_padding(
+ prefix, instruction_mask, x, x_mask
+ )
+ # seq = self.pos_embed(seq)
+ x = self.layers(seq, src_key_padding_mask=~seq_mask.bool())
+ if self.use_last_norm:
+ x = self.last_norm(x)
+ _, x = restore_from_concat(x, instruction_mask, x_mask, perm)
+
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
+ ) * (1 - self.duration_grad_scale)
+ duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
+ return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]
+
+
+class CrossAttentionAdapter(ContentAdapterBase):
+ def __init__(
+ self,
+ d_out: int,
+ content_dim: int,
+ prefix_dim: int,
+ num_heads: int,
+ duration_predictor: DurationPredictor,
+ dropout: float = 0.1,
+ duration_grad_scale: float = 0.1,
+ ):
+ super().__init__(d_out)
+ self.attn = nn.MultiheadAttention(
+ embed_dim=content_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ kdim=prefix_dim,
+ vdim=prefix_dim,
+ batch_first=True,
+ )
+ self.duration_grad_scale = duration_grad_scale
+ self.duration_predictor = duration_predictor
+ self.global_duration_mlp = nn.Sequential(
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
+ )
+ self.norm = nn.LayerNorm(content_dim)
+ self.content_proj = nn.Conv1d(content_dim, d_out, 1)
+
+ def forward(self, content, content_mask, prefix, prefix_mask):
+ attn_output, attn_output_weights = self.attn(
+ query=content,
+ key=prefix,
+ value=prefix,
+ key_padding_mask=~prefix_mask.bool()
+ )
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
+ x = self.norm(attn_output + content)
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
+ ) * (1 - self.duration_grad_scale)
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
+ ).sum(dim=1) / content_mask.sum(dim=1,
+ keepdim=True).float()
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
+ local_duration = self.duration_predictor(
+ x_grad_rescaled, content_mask
+ ).squeeze(-1)
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
+ return content, content_mask, global_duration, local_duration
+
+
+class ExperimentalCrossAttentionAdapter(ContentAdapterBase):
+ def __init__(
+ self,
+ d_out: int,
+ content_dim: int,
+ prefix_dim: int,
+ num_heads: int,
+ duration_predictor: DurationPredictor,
+ dropout: float = 0.1,
+ duration_grad_scale: float = 0.1,
+ ):
+ super().__init__(d_out)
+ self.content_mlp = nn.Sequential(
+ nn.Linear(content_dim, content_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(content_dim, content_dim),
+ )
+ self.content_norm = nn.LayerNorm(content_dim)
+ self.prefix_mlp = nn.Sequential(
+ nn.Linear(prefix_dim, prefix_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(prefix_dim, prefix_dim),
+ )
+ self.prefix_norm = nn.LayerNorm(content_dim)
+ self.attn = nn.MultiheadAttention(
+ embed_dim=content_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ kdim=prefix_dim,
+ vdim=prefix_dim,
+ batch_first=True,
+ )
+ self.duration_grad_scale = duration_grad_scale
+ self.duration_predictor = duration_predictor
+ self.global_duration_mlp = nn.Sequential(
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
+ )
+ self.content_proj = nn.Sequential(
+ nn.Linear(content_dim, d_out),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(d_out, d_out),
+ )
+ self.norm1 = nn.LayerNorm(content_dim)
+ self.norm2 = nn.LayerNorm(d_out)
+ self.init_weights()
+
+ def init_weights(self):
+ def _init_weights(module):
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0.)
+
+ self.apply(_init_weights)
+
+ def forward(self, content, content_mask, prefix, prefix_mask):
+ content = self.content_mlp(content)
+ content = self.content_norm(content)
+ prefix = self.prefix_mlp(prefix)
+ prefix = self.prefix_norm(prefix)
+ attn_output, attn_weights = self.attn(
+ query=content,
+ key=prefix,
+ value=prefix,
+ key_padding_mask=~prefix_mask.bool(),
+ )
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
+ x = attn_output + content
+ x = self.norm1(x)
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach(
+ ) * (1 - self.duration_grad_scale)
+ x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
+ ).sum(dim=1) / content_mask.sum(dim=1,
+ keepdim=True).float()
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
+ local_duration = self.duration_predictor(
+ x_grad_rescaled, content_mask
+ ).squeeze(-1)
+ content = self.content_proj(x)
+ content = self.norm2(content)
+ return content, content_mask, global_duration, local_duration
diff --git a/models/content_encoder/__pycache__/content_encoder.cpython-310.pyc b/models/content_encoder/__pycache__/content_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ace4e43d9626c100e9e1fcf068cd262f828629d2
Binary files /dev/null and b/models/content_encoder/__pycache__/content_encoder.cpython-310.pyc differ
diff --git a/models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc b/models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ac8dbd6065266f00836ba293f607690a9755600
Binary files /dev/null and b/models/content_encoder/__pycache__/sketch_encoder.cpython-310.pyc differ
diff --git a/models/content_encoder/__pycache__/text_encoder.cpython-310.pyc b/models/content_encoder/__pycache__/text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..991aec058fcc39d780b17a5372490355eb49e67f
Binary files /dev/null and b/models/content_encoder/__pycache__/text_encoder.cpython-310.pyc differ
diff --git a/models/content_encoder/content_encoder.py b/models/content_encoder/content_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd881d8012a9c84e8c7d4f82fa10b37ebec0a588
--- /dev/null
+++ b/models/content_encoder/content_encoder.py
@@ -0,0 +1,280 @@
+from typing import Any
+import torch
+import torch.nn as nn
+
+
+class ContentEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ text_encoder: nn.Module = None,
+ video_encoder: nn.Module = None,
+ midi_encoder: nn.Module = None,
+ phoneme_encoder: nn.Module = None,
+ pitch_encoder: nn.Module = None,
+ audio_encoder: nn.Module = None,
+ speech_encoder: nn.Module = None,
+ sketch_encoder: nn.Module = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.text_encoder = text_encoder
+ self.midi_encoder = midi_encoder
+ self.phoneme_encoder = phoneme_encoder
+ self.pitch_encoder = pitch_encoder
+ self.audio_encoder = audio_encoder
+ self.video_encoder = video_encoder
+ self.speech_encoder = speech_encoder
+ self.sketch_encoder = sketch_encoder
+
+ def encode_content(
+ self, batch_content: list[Any], batch_task: list[str],
+ device: str | torch.device
+ ):
+ batch_content_output = []
+ batch_content_mask = []
+ batch_la_content_output = []
+
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
+
+ for content, task in zip(batch_content, batch_task):
+ if task == "audio_super_resolution" or task == "speech_enhancement":
+ content_dict = {
+ "waveform": torch.as_tensor(content).float(),
+ "waveform_lengths": torch.as_tensor(content.shape[0]),
+ }
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ content_output_dict = self.audio_encoder(**content_dict)
+ la_content_output_dict = {
+ "output": zero_la_content,
+ }
+ elif task == "text_to_audio" or task == "text_to_music":
+ content_output_dict = self.text_encoder([content])
+ la_content_output_dict = {
+ "output": zero_la_content,
+ }
+ elif task == "speech_to_audio":
+ input_dict = {
+ "embed": content,
+ "embed_len": torch.tensor([content.shape[1]], dtype=torch.int).to(device),
+ }
+ content_output_dict = self.speech_encoder(input_dict)
+ la_content_output_dict = {
+ "output": zero_la_content,
+ }
+ elif task == "direct_speech_to_audio":
+ # content shape [1, L/T 133, dim] mask [1, L/T 133] in hubert
+ if len(content.shape) < 3:
+ content = content.unsqueeze(0)
+ mask = torch.ones(content.shape[:2])
+ mask = (mask == 1).to(content.device)
+ content_output_dict = {
+ "output": content,
+ "mask": mask,
+ }
+ la_content_output_dict = {
+ "output": zero_la_content,
+ }
+ elif task == "sketch_to_audio":
+ content_output_dict = self.sketch_encoder([content["caption"]])
+ content_dict = {
+ "f0": torch.as_tensor(content["f0"]),
+ "energy": torch.as_tensor(content["energy"]),
+ }
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ la_content_output_dict = self.sketch_encoder.encode_sketch(
+ **content_dict
+ )
+ elif task == "video_to_audio":
+ content_dict = {
+ "frames": torch.as_tensor(content).float(),
+ "frame_nums": torch.as_tensor(content.shape[0]),
+ }
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ content_output_dict = self.video_encoder(**content_dict)
+ la_content_output_dict = {
+ "output": zero_la_content,
+ }
+ elif task == "singing_voice_synthesis":
+ content_dict = {
+ "phoneme":
+ torch.as_tensor(content["phoneme"]).long(),
+ "midi":
+ torch.as_tensor(content["midi"]).long(),
+ "midi_duration":
+ torch.as_tensor(content["midi_duration"]).float(),
+ "is_slur":
+ torch.as_tensor(content["is_slur"]).long()
+ }
+ if "spk" in content:
+ if self.midi_encoder.spk_config.encoding_format == "id":
+ content_dict["spk"] = torch.as_tensor(content["spk"]
+ ).long()
+ elif self.midi_encoder.spk_config.encoding_format == "embedding":
+ content_dict["spk"] = torch.as_tensor(content["spk"]
+ ).float()
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ content_dict["lengths"] = torch.as_tensor([
+ len(content["phoneme"])
+ ])
+ content_output_dict = self.midi_encoder(**content_dict)
+ la_content_output_dict = {"output": zero_la_content}
+ elif task == "text_to_speech":
+ content_dict = {
+ "phoneme": torch.as_tensor(content["phoneme"]).long(),
+ }
+ if "spk" in content:
+ if self.phoneme_encoder.spk_config.encoding_format == "id":
+ content_dict["spk"] = torch.as_tensor(content["spk"]
+ ).long()
+ elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
+ content_dict["spk"] = torch.as_tensor(content["spk"]
+ ).float()
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ content_dict["lengths"] = torch.as_tensor([
+ len(content["phoneme"])
+ ])
+ content_output_dict = self.phoneme_encoder(**content_dict)
+ la_content_output_dict = {"output": zero_la_content}
+ elif task == "singing_acoustic_modeling":
+ content_dict = {
+ "phoneme": torch.as_tensor(content["phoneme"]).long(),
+ }
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ content_dict["lengths"] = torch.as_tensor([
+ len(content["phoneme"])
+ ])
+ content_output_dict = self.pitch_encoder(**content_dict)
+
+ content_dict = {
+ "f0": torch.as_tensor(content["f0"]),
+ "uv": torch.as_tensor(content["uv"]),
+ }
+ for key in list(content_dict.keys()):
+ content_dict[key] = content_dict[key].unsqueeze(0).to(
+ device
+ )
+ la_content_output_dict = self.pitch_encoder.encode_pitch(
+ **content_dict
+ )
+
+ batch_content_output.append(content_output_dict["output"][0])
+ batch_content_mask.append(content_output_dict["mask"][0])
+ batch_la_content_output.append(la_content_output_dict["output"][0])
+
+ batch_content_output = nn.utils.rnn.pad_sequence(
+ batch_content_output, batch_first=True, padding_value=0
+ )
+ batch_content_mask = nn.utils.rnn.pad_sequence(
+ batch_content_mask, batch_first=True, padding_value=False
+ )
+ batch_la_content_output = nn.utils.rnn.pad_sequence(
+ batch_la_content_output, batch_first=True, padding_value=0
+ )
+ return {
+ "content": batch_content_output,
+ "content_mask": batch_content_mask,
+ "length_aligned_content": batch_la_content_output,
+ }
+
+
+class BatchedContentEncoder(ContentEncoder):
+ def encode_content(
+ self, batch_content: list | dict, batch_task: list[str],
+ device: str | torch.device
+ ):
+ task = batch_task[0]
+ zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
+ if task == "audio_super_resolution" or task == "speech_enhancement":
+ content_dict = {
+ "waveform":
+ batch_content["content"].unsqueeze(1).float().to(device),
+ "waveform_lengths":
+ batch_content["content_lengths"].long().to(device),
+ }
+ content_output = self.audio_encoder(**content_dict)
+ la_content_output = zero_la_content
+ elif task == "text_to_audio":
+ content_output = self.text_encoder(batch_content)
+ la_content_output = zero_la_content
+ elif task == "video_to_audio":
+ content_dict = {
+ "frames":
+ batch_content["content"].float().to(device),
+ "frame_nums":
+ batch_content["content_lengths"].long().to(device),
+ }
+ content_output = self.video_encoder(**content_dict)
+ la_content_output = zero_la_content
+ elif task == "singing_voice_synthesis":
+ content_dict = {
+ "phoneme":
+ batch_content["phoneme"].long().to(device),
+ "midi":
+ batch_content["midi"].long().to(device),
+ "midi_duration":
+ batch_content["midi_duration"].float().to(device),
+ "is_slur":
+ batch_content["is_slur"].long().to(device),
+ "lengths":
+ batch_content["phoneme_lengths"].long().cpu(),
+ }
+ if "spk" in batch_content:
+ if self.midi_encoder.spk_config.encoding_format == "id":
+ content_dict["spk"] = batch_content["spk"].long(
+ ).to(device)
+ elif self.midi_encoder.spk_config.encoding_format == "embedding":
+ content_dict["spk"] = batch_content["spk"].float(
+ ).to(device)
+ content_output = self.midi_encoder(**content_dict)
+ la_content_output = zero_la_content
+ elif task == "text_to_speech":
+ content_dict = {
+ "phoneme": batch_content["phoneme"].long().to(device),
+ "lengths": batch_content["phoneme_lengths"].long().cpu(),
+ }
+ if "spk" in batch_content:
+ if self.phoneme_encoder.spk_config.encoding_format == "id":
+ content_dict["spk"] = batch_content["spk"].long(
+ ).to(device)
+ elif self.phoneme_encoder.spk_config.encoding_format == "embedding":
+ content_dict["spk"] = batch_content["spk"].float(
+ ).to(device)
+ content_output = self.phoneme_encoder(**content_dict)
+ la_content_output = zero_la_content
+ elif task == "singing_acoustic_modeling":
+ content_dict = {
+ "phoneme": batch_content["phoneme"].long().to(device),
+ "lengths": batch_content["phoneme_lengths"].long().to(device),
+ }
+ content_output = self.pitch_encoder(**content_dict)
+
+ content_dict = {
+ "f0": batch_content["f0"].float().to(device),
+ "uv": batch_content["uv"].float().to(device),
+ }
+ la_content_output = self.pitch_encoder.encode_pitch(**content_dict)
+
+ return {
+ "content": content_output["output"],
+ "content_mask": content_output["mask"],
+ "length_aligned_content": la_content_output,
+ }
diff --git a/models/content_encoder/midi_encoder.py b/models/content_encoder/midi_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fda372db829831f44dd41a2ee50197e68f4aa618
--- /dev/null
+++ b/models/content_encoder/midi_encoder.py
@@ -0,0 +1,1046 @@
+from typing import Sequence
+from dataclasses import dataclass
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+from utils.torch_utilities import create_mask_from_length
+from utils.diffsinger_utilities import denorm_f0, f0_to_coarse
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (torch.cumsum(mask, dim=1).type_as(mask) *
+ mask).long() + padding_idx
+
+
+def softmax(x, dim):
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def LayerNorm(
+ normalized_shape, eps=1e-5, elementwise_affine=True, export=False
+):
+ if not export and torch.cuda.is_available():
+ try:
+ from apex.normalization import FusedLayerNorm
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ except ImportError:
+ pass
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+def Linear(in_features, out_features, bias=True):
+ m = nn.Linear(in_features, out_features, bias)
+ nn.init.xavier_uniform_(m.weight)
+ if bias:
+ nn.init.constant_(m.bias, 0.)
+ return m
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+class BatchNorm1dTBC(nn.Module):
+ def __init__(self, c):
+ super(BatchNorm1dTBC, self).__init__()
+ self.bn = nn.BatchNorm1d(c)
+
+ def forward(self, x):
+ """
+
+ :param x: [T, B, C]
+ :return: [T, B, C]
+ """
+ x = x.permute(1, 2, 0) # [B, C, T]
+ x = self.bn(x) # [B, C, T]
+ x = x.permute(2, 0, 1) # [T, B, C]
+ return x
+
+
+class PositionalEncoding(nn.Module):
+ """Positional encoding.
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ reverse (bool): Whether to reverse the input position.
+ """
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+ """Construct an PositionalEncoding object."""
+ super(PositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.reverse = reverse
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1),
+ dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, :x.size(1)]
+ return self.dropout(x)
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+ def __init__(self, d_model, padding_idx, init_size=2048):
+ super().__init__()
+ self.d_model = d_model
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ d_model,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, d_model, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = d_model // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings,
+ dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)],
+ dim=1).view(num_embeddings, -1)
+ if d_model % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(
+ self,
+ x,
+ lengths,
+ incremental_state=None,
+ timestep=None,
+ positions=None,
+ **kwargs
+ ):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = x.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.d_model,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = create_mask_from_length(
+ lengths, max_length=x.shape[1]
+ ) * (torch.arange(x.shape[1]) + 1).unsqueeze(0).expand(x.shape[0], -1)
+ positions = positions.to(self.weights.device)
+ pos_emb = self.weights.index_select(0, positions.view(-1)).view(
+ bsz, seq_len, -1
+ ).detach()
+ return x + pos_emb
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self, x, lengths):
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[:, :x.size(1)]
+ return self.dropout(x) + self.dropout(pos_emb)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim**-0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(
+ torch.Tensor(3 * embed_dim, embed_dim)
+ )
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ if hasattr(F, "multi_head_attention_forward"):
+ self.enable_torch_version = True
+ else:
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias, self.bias_k,
+ self.bias_v, self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias, self.training,
+ key_padding_mask, need_weights, attn_mask
+ )
+ else:
+ return F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training,
+ key_padding_mask,
+ need_weights,
+ attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight
+ )
+
+ if incremental_state is not None:
+ print('Not implemented error.')
+ exit()
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask,
+ attn_mask.new_zeros(attn_mask.size(0), 1)],
+ dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(
+ key_padding_mask.size(0), 1
+ )
+ ],
+ dim=1
+ )
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads,
+ self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads,
+ self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads,
+ self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ print('Not implemented error.')
+ exit()
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size(
+ []
+ ):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat(
+ [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1
+ )
+ v = torch.cat(
+ [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1
+ )
+ if attn_mask is not None:
+ attn_mask = torch.cat(
+ [attn_mask,
+ attn_mask.new_zeros(attn_mask.size(0), 1)],
+ dim=1
+ )
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.zeros(key_padding_mask.size(0),
+ 1).type_as(key_padding_mask)
+ ],
+ dim=1
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(
+ attn_weights, tgt_len, src_len, bsz
+ )
+
+ assert list(attn_weights.size()) == [
+ bsz * self.num_heads, tgt_len, src_len
+ ]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat(
+ [1, self.num_heads, 1, 1]
+ ).reshape(bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(
+ bsz, self.num_heads, tgt_len, src_len
+ )
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e9,
+ )
+ attn_weights = attn_weights.view(
+ bsz * self.num_heads, tgt_len, src_len
+ )
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(
+ bsz, self.num_heads, tgt_len, src_len
+ )
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e9,
+ )
+ attn_weights = attn_weights.view(
+ bsz * self.num_heads, tgt_len, src_len
+ )
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(
+ attn_weights_float.type_as(attn_weights),
+ p=self.dropout,
+ training=self.training
+ )
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [
+ bsz * self.num_heads, tgt_len, self.head_dim
+ ]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(
+ key, start=self.embed_dim, end=2 * self.embed_dim
+ )
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ filter_size,
+ padding="SAME",
+ kernel_size=1,
+ dropout=0.,
+ act='gelu'
+ ):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(
+ hidden_size,
+ filter_size,
+ kernel_size,
+ padding=kernel_size // 2
+ )
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = nn.Linear(filter_size, hidden_size)
+
+ def forward(
+ self,
+ x,
+ ):
+ # x: T x B x C
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size**-0.5
+
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ if self.act == 'swish':
+ x = F.silu(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+
+class EncoderSelfAttentionLayer(nn.Module):
+ def __init__(
+ self,
+ c,
+ num_heads,
+ dropout,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ kernel_size=9,
+ padding='SAME',
+ norm='ln',
+ act='gelu',
+ padding_set_zero=True
+ ):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.padding_set_zero = padding_set_zero
+ if num_heads > 0:
+ if norm == 'ln':
+ self.layer_norm1 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm1 = BatchNorm1dTBC(c)
+ self.self_attn = MultiheadAttention(
+ self.c,
+ num_heads=num_heads,
+ self_attention=True,
+ dropout=attention_dropout,
+ bias=False,
+ )
+ if norm == 'ln':
+ self.layer_norm2 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm2 = BatchNorm1dTBC(c)
+ self.ffn = TransformerFFNLayer(
+ c,
+ 4 * c,
+ kernel_size=kernel_size,
+ dropout=relu_dropout,
+ padding=padding,
+ act=act
+ )
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.padding_set_zero:
+ x = x * (1 - encoder_padding_mask.float()).transpose(0,
+ 1)[...,
+ None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.padding_set_zero:
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[...,
+ None]
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ dropout,
+ kernel_size,
+ num_heads=2,
+ norm='ln',
+ padding_set_zero=True,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncoderSelfAttentionLayer(
+ hidden_size,
+ num_heads,
+ dropout=dropout,
+ attention_dropout=0.0,
+ relu_dropout=dropout,
+ kernel_size=kernel_size,
+ padding="SAME",
+ norm=norm,
+ act="gelu",
+ padding_set_zero=padding_set_zero
+ )
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+class FFTBlocks(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_layers,
+ ffn_kernel_size=9,
+ dropout=0.1,
+ num_heads=2,
+ use_last_norm=True,
+ padding_set_zero=True,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.use_last_norm = use_last_norm
+ self.padding_set_zero = padding_set_zero
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend(
+ [
+ TransformerEncoderLayer(
+ self.hidden_size,
+ self.dropout,
+ kernel_size=ffn_kernel_size,
+ num_heads=num_heads,
+ padding_set_zero=padding_set_zero,
+ ) for _ in range(self.num_layers)
+ ]
+ )
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ if padding_mask is None:
+ padding_mask = torch.zeros(x.size(0), x.size(1)).to(x.device)
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float(
+ )[:, :, None] # [T, B, 1]
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ if self.padding_set_zero:
+ x = x * nonpadding_mask_TB
+ for layer in self.layers:
+ x = layer(
+ x, encoder_padding_mask=padding_mask, attn_mask=attn_mask
+ )
+ if self.padding_set_zero:
+ x = x * nonpadding_mask_TB
+ if self.use_last_norm:
+ x = self.layer_norm(x)
+ if self.padding_set_zero:
+ x = x * nonpadding_mask_TB
+
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+
+class FastSpeech2EncoderBase(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ num_layers: int,
+ num_heads: int,
+ ffn_kernel_size: int,
+ d_out: int,
+ dropout: float = 0.1,
+ rel_pos: bool = True,
+ padding_set_zero: bool = True
+ ):
+ super().__init__()
+ self.rel_pos = rel_pos
+
+ if self.rel_pos:
+ self.pos_encoding = RelPositionalEncoding(
+ d_model, dropout_rate=0.0
+ )
+ else:
+ self.pos_encoding = SinusoidalPositionalEmbedding(
+ d_model, padding_idx=0
+ )
+ self.dropout = dropout
+ self.embed_scale = math.sqrt(d_model)
+
+ self.layers = FFTBlocks(
+ hidden_size=d_model,
+ num_layers=num_layers,
+ ffn_kernel_size=ffn_kernel_size,
+ dropout=dropout,
+ num_heads=num_heads,
+ use_last_norm=True,
+ padding_set_zero=padding_set_zero
+ )
+
+ self.out_proj = nn.Linear(d_model, d_out)
+ self.apply(self.init_weights)
+
+ def init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.Embedding):
+ nn.init.normal_(m.weight, mean=0, std=m.embedding_dim**-0.5)
+
+
+@dataclass
+class SpkConfig:
+ encoding_format: str
+ num_spk: int | None = None
+ spk_embed_dim: int | None = None
+
+ def __post_init__(self):
+ allowed_formats = {"id", "embedding"}
+ assert self.encoding_format in allowed_formats, f"mode must be one of {allowed_formats}, got '{self.encoding_format}'"
+ if self.encoding_format == "id":
+ assert self.num_spk is not None
+ if self.encoding_format == "embedding":
+ assert self.spk_embed_dim is not None
+
+
+class FastSpeech2PhonemeEncoder(FastSpeech2EncoderBase):
+ def __init__(
+ self,
+ phone_vocab_size,
+ d_model,
+ num_layers,
+ num_heads,
+ ffn_kernel_size,
+ d_out,
+ dropout=0.1,
+ rel_pos=False,
+ spk_config: SpkConfig | None = None,
+ padding_set_zero: bool = True
+ ):
+ super().__init__(
+ d_model=d_model,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ ffn_kernel_size=ffn_kernel_size,
+ d_out=d_out,
+ dropout=dropout,
+ rel_pos=rel_pos,
+ padding_set_zero=padding_set_zero
+ )
+ self.phone_embed = Embedding(phone_vocab_size, d_model)
+ self.spk_config = spk_config
+ if spk_config is not None:
+ if spk_config.encoding_format == "id":
+ self.spk_embed_proj = Embedding(
+ spk_config.num_spk + 1, d_model
+ )
+ elif spk_config.encoding_format == "embedding":
+ self.spk_embed_proj = Linear(spk_config.spk_embed_dim, d_model)
+
+ def forward(
+ self, phoneme: torch.Tensor, lengths: Sequence[int], spk: torch.Tensor
+ ):
+ x = self.embed_scale * self.phone_embed(phoneme)
+ x = self.pos_encoding(x, lengths)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
+ x = self.layers(x, padding_mask=padding_mask)
+
+ if self.spk_config is not None:
+ spk_embed = self.spk_embed_proj(spk).unsqueeze(1)
+ x = x + spk_embed
+
+ x = self.out_proj(x)
+
+ return {"output": x, "mask": ~padding_mask}
+
+
+class FastSpeech2MIDIEncoder(FastSpeech2PhonemeEncoder):
+ def __init__(
+ self,
+ phone_vocab_size: int,
+ midi_vocab_size: int,
+ slur_vocab_size: int,
+ spk_config: SpkConfig | None,
+ d_model: int,
+ num_layers: int,
+ num_heads: int,
+ ffn_kernel_size: int,
+ d_out: int,
+ dropout: float = 0.1,
+ rel_pos: bool = True,
+ padding_set_zero: bool = True
+ ):
+ super().__init__(
+ phone_vocab_size=phone_vocab_size,
+ d_model=d_model,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ ffn_kernel_size=ffn_kernel_size,
+ d_out=d_out,
+ dropout=dropout,
+ rel_pos=rel_pos,
+ spk_config=spk_config,
+ padding_set_zero=padding_set_zero
+ )
+ self.midi_embed = Embedding(midi_vocab_size, d_model, padding_idx=0)
+ self.midi_dur_embed = Linear(1, d_model)
+ self.is_slur_embed = Embedding(slur_vocab_size, d_model)
+
+ def forward(
+ self,
+ phoneme: torch.Tensor,
+ midi: torch.Tensor,
+ midi_duration: torch.Tensor,
+ is_slur: torch.Tensor,
+ lengths: Sequence[int],
+ spk: torch.Tensor | None = None,
+ ):
+ x = self.embed_scale * self.phone_embed(phoneme)
+ midi_embedding = self.midi_embed(midi)
+ midi_dur_embedding = self.midi_dur_embed(midi_duration[:, :, None])
+ slur_embedding = self.is_slur_embed(is_slur)
+
+ x = x + midi_embedding + midi_dur_embedding + slur_embedding
+ x = self.pos_encoding(x, lengths)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
+ x = self.layers(x, padding_mask=padding_mask)
+
+ if self.spk_config is not None:
+ spk_embed = self.spk_embed_proj(spk).unsqueeze(1)
+ x = x + spk_embed
+
+ x = self.out_proj(x)
+
+ return {"output": x, "mask": ~padding_mask}
+
+
+class FastSpeech2PitchEncoder(FastSpeech2EncoderBase):
+ def __init__(
+ self,
+ phone_vocab_size,
+ d_model,
+ num_layers,
+ num_heads,
+ ffn_kernel_size,
+ d_out,
+ dropout=0.1,
+ rel_pos=False,
+ padding_set_zero=True
+ ):
+ super().__init__(
+ d_model=d_model,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ ffn_kernel_size=ffn_kernel_size,
+ d_out=d_out,
+ dropout=dropout,
+ rel_pos=rel_pos,
+ padding_set_zero=padding_set_zero
+ )
+ self.phone_embed = Embedding(phone_vocab_size, d_model)
+ self.pitch_embed = Embedding(300, d_model)
+
+ def forward(self, phoneme: torch.Tensor, lengths: Sequence[int]):
+ x = self.embed_scale * self.phone_embed(phoneme)
+ x = self.pos_encoding(x, lengths)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ padding_mask = ~create_mask_from_length(lengths).to(phoneme.device)
+ x = self.layers(x, padding_mask=padding_mask)
+
+ x = self.out_proj(x)
+
+ return {"output": x, "mask": ~padding_mask}
+
+ def encode_pitch(self, f0, uv):
+
+ f0_denorm = denorm_f0(f0, uv)
+ pitch = f0_to_coarse(f0_denorm)
+ pitch_embed = self.pitch_embed(pitch)
+ return {"output": pitch_embed}
diff --git a/models/content_encoder/sketch_encoder.py b/models/content_encoder/sketch_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d76bbf9de9073d97b1b5a7a244f7d90c0c6683
--- /dev/null
+++ b/models/content_encoder/sketch_encoder.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+
+
+try:
+ import torch_npu
+ from torch_npu.contrib import transfer_to_npu
+ DEVICE_TYPE = "npu"
+except ModuleNotFoundError:
+ DEVICE_TYPE = "cuda"
+
+from .text_encoder import T5TextEncoder
+
+class SketchT5TextEncoder(T5TextEncoder):
+ def __init__(
+ self, f0_dim: int , energy_dim: int, latent_dim: int,
+ embed_dim: int, model_name: str = "google/flan-t5-large",
+ ):
+ super().__init__(
+ embed_dim = embed_dim,
+ model_name = model_name,
+ )
+ self.f0_proj = nn.Linear(f0_dim, latent_dim)
+ self.f0_norm = nn.LayerNorm(f0_dim)
+ self.energy_proj = nn.Linear(energy_dim, latent_dim)
+
+ def encode(
+ self,
+ text: list[str],
+ ):
+ with torch.no_grad(), torch.amp.autocast(
+ device_type=DEVICE_TYPE, enabled=False
+ ):
+ return super().encode(text)
+
+ def encode_sketch(
+ self,
+ f0,
+ energy,
+ ):
+ f0_embed = self.f0_proj(self.f0_norm(f0)).unsqueeze(-1)
+ energy_embed = self.energy_proj(energy).unsqueeze(-1)
+ sketch_embed = torch.cat([f0_embed, energy_embed], dim=-1)
+ return {"output": sketch_embed}
+
+
+if __name__ == "__main__":
+ text_encoder = T5TextEncoder(embed_dim=512)
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
+
+ output = text_encoder(text)
diff --git a/models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc b/models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26cc67fbc20d872513cc1172de628e5af9b13732
Binary files /dev/null and b/models/content_encoder/star_encoder/__pycache__/Qformer.cpython-310.pyc differ
diff --git a/models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc b/models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6849f1815b014e48523e415059510f90272db373
Binary files /dev/null and b/models/content_encoder/star_encoder/__pycache__/star_encoder.cpython-310.pyc differ
diff --git a/models/content_encoder/star_encoder/star_encoder.py b/models/content_encoder/star_encoder/star_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..59fff9482841d1366cddeec5e6ce7d7b2220fcb4
--- /dev/null
+++ b/models/content_encoder/star_encoder/star_encoder.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from Qformer import BertConfig, BertLMHeadModel
+
+
+try:
+ import torch_npu
+ from torch_npu.contrib import transfer_to_npu
+ DEVICE_TYPE = "npu"
+except ModuleNotFoundError:
+ DEVICE_TYPE = "cuda"
+
+def generate_length_mask(lens, max_length=None):
+ lens = torch.as_tensor(lens)
+ N = lens.size(0)
+ if max_length is None:
+ max_length = max(lens)
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
+ idxs = idxs.to(lens.device)
+ mask = (idxs < lens.view(-1, 1)).int()
+ return mask
+
+class QformerBridgeNet(torch.nn.Module):
+ def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32,
+ hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True,
+ load_from_pretrained: str = None):
+ super().__init__()
+
+ self.Qformer_model_name = Qformer_model_name
+ self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token, speech_width=speech_width)
+ self.audio_Qformer.cls = None
+ self.audio_Qformer.bert.embeddings.word_embeddings = None
+ self.audio_Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.audio_Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+
+ self.freeze_QFormer = freeze_QFormer
+ if freeze_QFormer:
+ for name, param in self.audio_Qformer.named_parameters():
+ param.requires_grad = False
+ self.audio_Qformer.eval()
+ self.audio_query_tokens.requires_grad = False
+
+ self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size)
+ #torch.nn.init.xavier_uniform_(self.hiddin_projection.weight, gain=torch.nn.init.calculate_gain("relu"))
+
+ if load_from_pretrained:
+ state_dict = torch.load(load_from_pretrained)
+ del_key = ["projection.weight", "projection.bias"]
+ del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key}
+ self.load_state_dict(del_state_dict)
+ print("Load adaptor_model_pt from", load_from_pretrained)
+
+
+ def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2):
+ encoder_config = BertConfig.from_pretrained(self.Qformer_model_name)
+ encoder_config.num_hidden_layers = num_hidden_layers
+ encoder_config.encoder_width = speech_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ return Qformer, query_tokens, encoder_config
+
+ def hidden(self, batch,):
+ audio_feature, lens = batch['embed'], batch['embed_len']
+ frame_atts = generate_length_mask(lens).to(audio_feature.device)
+ audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1)
+ #frame_atts = torch.ones(audio_feature.size()[:-1], dtype=torch.long).to(audio_feature.device)
+
+ #print(audio_query_tokens.shape, audio_feature.shape, frame_atts.shape)
+ audio_query_output=self.audio_Qformer.bert(
+ query_embeds=audio_query_tokens, #[32,768]
+ encoder_hidden_states=audio_feature,
+ encoder_attention_mask=frame_atts,
+ return_dict=True,
+ )
+ audio_hidden = audio_query_output.last_hidden_state
+ return audio_hidden
+
+ def forward(self, batch) -> torch.Tensor:
+ with torch.no_grad(), torch.amp.autocast(
+ device_type=DEVICE_TYPE, enabled=False
+ ):
+ x = self.hidden(batch)
+ x = self.hiddin_projection(x)
+
+ mask = torch.ones(x.shape[:2])
+ mask = (mask == 1).to(x.device)
+ return {"output": x, "mask": mask}
+
+
+if __name__ == '__main__':
+ text_encoder = T5TextEncoder()
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
+ text_encoder.eval()
+ with torch.no_grad():
+ output = text_encoder(text)
diff --git a/models/content_encoder/text_encoder.py b/models/content_encoder/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64bbdd60deca4e7c170cb5b02636cb28f49e922
--- /dev/null
+++ b/models/content_encoder/text_encoder.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
+from transformers.modeling_outputs import BaseModelOutput
+
+try:
+ import torch_npu
+ from torch_npu.contrib import transfer_to_npu
+ DEVICE_TYPE = "npu"
+except ModuleNotFoundError:
+ DEVICE_TYPE = "cuda"
+
+
+class TransformersTextEncoderBase(nn.Module):
+ def __init__(self, model_name: str, embed_dim: int):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.model = AutoModel.from_pretrained(model_name)
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
+
+ def forward(
+ self,
+ text: list[str],
+ ):
+ output, mask = self.encode(text)
+ output = self.projection(output)
+ return {"output": output, "mask": mask}
+
+ def encode(self, text: list[str]):
+ device = self.model.device
+ batch = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ )
+ input_ids = batch.input_ids.to(device)
+ attention_mask = batch.attention_mask.to(device)
+ output: BaseModelOutput = self.model(
+ input_ids=input_ids, attention_mask=attention_mask
+ )
+ output = output.last_hidden_state
+ mask = (attention_mask == 1).to(device)
+ return output, mask
+
+ def projection(self, x):
+ return self.proj(x)
+
+
+class T5TextEncoder(TransformersTextEncoderBase):
+ def __init__(
+ self, embed_dim: int, model_name: str = "google/flan-t5-large"
+ ):
+ nn.Module.__init__(self)
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
+ self.model = T5EncoderModel.from_pretrained(model_name)
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.model.eval()
+ self.proj = nn.Linear(self.model.config.hidden_size, embed_dim)
+
+ def encode(
+ self,
+ text: list[str],
+ ):
+ with torch.no_grad(), torch.amp.autocast(
+ device_type=DEVICE_TYPE, enabled=False
+ ):
+ return super().encode(text)
+
+
+if __name__ == "__main__":
+ text_encoder = T5TextEncoder(embed_dim=512)
+ text = ["a man is speaking", "a woman is singing while a dog is barking"]
+
+ output = text_encoder(text)
diff --git a/models/content_encoder/vision_encoder.py b/models/content_encoder/vision_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fe5666cba2b78fc74681c2ce5cfff6c5e3730c6
--- /dev/null
+++ b/models/content_encoder/vision_encoder.py
@@ -0,0 +1,34 @@
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.torch_utilities import create_mask_from_length
+
+
+class MlpVideoEncoder(nn.Module):
+ def __init__(
+ self,
+ video_feat_dim: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+ self.mlp = nn.Linear(video_feat_dim, embed_dim)
+ self.init_weights()
+
+ def init_weights(self):
+ def _init_weights(module):
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0.)
+
+ self.apply(_init_weights)
+
+ def forward(self, frames: torch.Tensor, frame_nums: Sequence[int]):
+ device = frames.device
+ x = F.normalize(frames, p=2, dim=-1)
+ x = self.mlp(x)
+ mask = create_mask_from_length(frame_nums).to(device)
+ return {"output": x, "mask": mask}
diff --git a/models/diffsinger_net.py b/models/diffsinger_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..f802b0e684b532b7fb141348aabadba3d105152d
--- /dev/null
+++ b/models/diffsinger_net.py
@@ -0,0 +1,119 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Mish(nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super(SinusoidalPosEmb, self).__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim-1)
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+ emb = x.unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, encoder_hidden, residual_channels, dilation):
+ super().__init__()
+ self.dilated_conv = nn.Conv1d(
+ residual_channels,
+ 2 * residual_channels,
+ 3,
+ padding=dilation,
+ dilation=dilation
+ )
+ self.diffusion_projection = nn.Linear(
+ residual_channels, residual_channels
+ )
+ self.conditioner_projection = nn.Conv1d(
+ encoder_hidden, 2 * residual_channels, 1
+ )
+ self.output_projection = nn.Conv1d(
+ residual_channels, 2 * residual_channels, 1
+ )
+
+ def forward(self, x, conditioner, diffusion_step):
+ diffusion_step = self.diffusion_projection(diffusion_step
+ ).unsqueeze(-1)
+ conditioner = self.conditioner_projection(conditioner)
+ y = x + diffusion_step
+
+ y = self.dilated_conv(y) + conditioner
+
+ gate, filter = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter)
+
+ y = self.output_projection(y)
+ residual, skip = torch.chunk(y, 2, dim=1)
+ return (x+residual) / math.sqrt(2.0), skip
+
+
+class DiffSingerNet(nn.Module):
+ def __init__(
+ self,
+ in_dims=128,
+ residual_channels=256,
+ encoder_hidden=256,
+ dilation_cycle_length=4,
+ residual_layers=20,
+ ):
+ super().__init__()
+
+ # self.pe_scale = pe_scale
+
+ self.input_projection = nn.Conv1d(in_dims, residual_channels, 1)
+ self.time_pos_emb = SinusoidalPosEmb(residual_channels)
+ dim = residual_channels
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim)
+ )
+ self.residual_layers = nn.ModuleList([
+ ResidualBlock(
+ encoder_hidden, residual_channels,
+ 2**(i % dilation_cycle_length)
+ ) for i in range(residual_layers)
+ ])
+ self.skip_projection = nn.Conv1d(
+ residual_channels, residual_channels, 1
+ )
+ self.output_projection = nn.Conv1d(residual_channels, in_dims, 1)
+ nn.init.zeros_(self.output_projection.weight)
+
+ def forward(self, x, timesteps, context, x_mask=None, context_mask=None):
+ # make it compatible with int time step during inference
+ if timesteps.dim() == 0:
+ timesteps = timesteps.expand(x.shape[0]
+ ).to(x.device, dtype=torch.long)
+
+ x = self.input_projection(x) # x [B, residual_channel, T]
+
+ x = F.relu(x)
+
+ t = self.time_pos_emb(timesteps)
+ t = self.mlp(t)
+
+ cond = context
+
+ skip = []
+ for layer_id, layer in enumerate(self.residual_layers):
+ x, skip_connection = layer(x, cond, t)
+ skip.append(skip_connection)
+
+ x = torch.sum(torch.stack(skip),
+ dim=0) / math.sqrt(len(self.residual_layers))
+ x = self.skip_projection(x)
+ x = F.relu(x)
+ x = self.output_projection(x) # [B, M, T]
+ return x * x_mask.unsqueeze(1)
diff --git a/models/diffusion.py b/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4f8f05c35c048a0c9f334985870e6ea8cd7fa7
--- /dev/null
+++ b/models/diffusion.py
@@ -0,0 +1,1261 @@
+from typing import Sequence
+import random
+from typing import Any
+from pathlib import Path
+
+from tqdm import tqdm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import diffusers.schedulers as noise_schedulers
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+from diffusers.utils.torch_utils import randn_tensor
+
+from models.autoencoder.autoencoder_base import AutoEncoderBase
+from models.content_encoder.content_encoder import ContentEncoder
+from models.content_adapter import ContentAdapterBase
+from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
+from utils.torch_utilities import (
+ create_alignment_path, create_mask_from_length, loss_with_mask,
+ trim_or_pad_length
+)
+from safetensors.torch import load_file
+
+class DiffusionMixin:
+ def __init__(
+ self,
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
+ snr_gamma: float = None,
+ cfg_drop_ratio: float = 0.2
+ ) -> None:
+ self.noise_scheduler_name = noise_scheduler_name
+ self.snr_gamma = snr_gamma
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
+ self.cfg_drop_ratio = cfg_drop_ratio
+ self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained(
+ self.noise_scheduler_name, subfolder="scheduler"
+ )
+
+ def compute_snr(self, timesteps) -> torch.Tensor:
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
+ )[timesteps].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
+ device=timesteps.device
+ )[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
+ None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma)**2
+ return snr
+
+ def get_timesteps(
+ self,
+ batch_size: int,
+ device: torch.device,
+ training: bool = True
+ ) -> torch.Tensor:
+ if training:
+ timesteps = torch.randint(
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (batch_size, ),
+ device=device
+ )
+ else:
+ # validation on half of the total timesteps
+ timesteps = (self.noise_scheduler.config.num_train_timesteps //
+ 2) * torch.ones((batch_size, ),
+ dtype=torch.int64,
+ device=device)
+
+ timesteps = timesteps.long()
+ return timesteps
+
+ def get_target(
+ self, latent: torch.Tensor, noise: torch.Tensor,
+ timesteps: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Get the target for loss depending on the prediction type
+ """
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(
+ latent, noise, timesteps
+ )
+ else:
+ raise ValueError(
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
+ )
+ return target
+
+ def loss_with_snr(
+ self, pred: torch.Tensor, target: torch.Tensor,
+ timesteps: torch.Tensor, mask: torch.Tensor,
+ loss_reduce: bool = True,
+ ) -> torch.Tensor:
+ if self.snr_gamma is None:
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
+ loss = loss_with_mask(loss, mask, reduce=loss_reduce)
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006
+ snr = self.compute_snr(timesteps)
+ mse_loss_weights = torch.stack(
+ [
+ snr,
+ self.snr_gamma * torch.ones_like(timesteps),
+ ],
+ dim=1,
+ ).min(dim=1)[0]
+ # division by (snr + 1) does not work well, not clear about the reason
+ mse_loss_weights = mse_loss_weights / snr
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
+ loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
+ if loss_reduce:
+ loss = loss.mean()
+ return loss
+
+ def rescale_cfg(
+ self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
+ guidance_rescale: float
+ ):
+ """
+ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_cond = pred_cond.std(
+ dim=list(range(1, pred_cond.ndim)), keepdim=True
+ )
+ std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
+
+ pred_rescaled = pred_cfg * (std_cond / std_cfg)
+ pred_cfg = guidance_rescale * pred_rescaled + (
+ 1 - guidance_rescale
+ ) * pred_cfg
+ return pred_cfg
+
+class CrossAttentionAudioDiffusion(
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
+ DiffusionMixin
+):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ content_adapter: ContentAdapterBase,
+ backbone: nn.Module,
+ duration_offset: float = 1.0,
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
+ snr_gamma: float = None,
+ cfg_drop_ratio: float = 0.2,
+ ):
+ nn.Module.__init__(self)
+ DiffusionMixin.__init__(
+ self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
+ )
+
+ self.autoencoder = autoencoder
+ for param in self.autoencoder.parameters():
+ param.requires_grad = False
+
+ self.content_encoder = content_encoder
+ self.content_encoder.audio_encoder.model = self.autoencoder
+ self.content_adapter = content_adapter
+ self.backbone = backbone
+ self.duration_offset = duration_offset
+ self.dummy_param = nn.Parameter(torch.empty(0))
+
+ def forward(
+ self, content: list[Any], task: list[str], waveform: torch.Tensor,
+ waveform_lengths: torch.Tensor, instruction: torch.Tensor,
+ instruction_lengths: Sequence[int], **kwargs
+ ):
+ device = self.dummy_param.device
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ instruction_mask = create_mask_from_length(instruction_lengths)
+ content, content_mask, global_duration_pred, _ = \
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
+ global_duration_target = torch.log(
+ latent_mask.sum(1) / self.autoencoder.latent_token_rate +
+ self.duration_offset
+ )
+ global_duration_loss = F.mse_loss(
+ global_duration_target, global_duration_pred
+ )
+
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ content[mask_indices] = 0
+
+ batch_size = latent.shape[0]
+ timesteps = self.get_timesteps(batch_size, device, self.training)
+ noise = torch.randn_like(latent)
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
+ target = self.get_target(latent, noise, timesteps)
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ timesteps=timesteps,
+ context=content,
+ x_mask=latent_mask,
+ context_mask=content_mask
+ )
+
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
+
+ return {
+ "diff_loss": diff_loss,
+ "global_duration_loss": global_duration_loss,
+ }
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content: list[Any],
+ condition: list[Any],
+ task: list[str],
+ instruction: torch.Tensor,
+ instruction_lengths: Sequence[int],
+ scheduler: SchedulerMixin,
+ num_steps: int = 20,
+ guidance_scale: float = 3.0,
+ guidance_rescale: float = 0.0,
+ disable_progress: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+
+ instruction_mask = create_mask_from_length(instruction_lengths)
+ content, content_mask, global_duration_pred, _ = \
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
+ batch_size = content.size(0)
+
+ if classifier_free_guidance:
+ uncond_content = torch.zeros_like(content)
+ uncond_content_mask = content_mask.detach().clone()
+ content = torch.cat([uncond_content, content])
+ content_mask = torch.cat([uncond_content_mask, content_mask])
+
+ scheduler.set_timesteps(num_steps, device=device)
+ timesteps = scheduler.timesteps
+
+ global_duration_pred = torch.exp(
+ global_duration_pred
+ ) - self.duration_offset
+ global_duration_pred *= self.autoencoder.latent_token_rate
+ global_duration_pred = torch.round(global_duration_pred)
+
+ latent_shape = tuple(
+ int(global_duration_pred.max().item()) if dim is None else dim
+ for dim in self.autoencoder.latent_shape
+ )
+ latent = self.prepare_latent(
+ batch_size, scheduler, latent_shape, content.dtype, device
+ )
+ latent_mask = create_mask_from_length(global_duration_pred).to(
+ content_mask.device
+ )
+ if classifier_free_guidance:
+ latent_mask = torch.cat([latent_mask, latent_mask])
+
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+
+ for i, timestep in enumerate(timesteps):
+ # expand the latent if we are doing classifier free guidance
+ latent_input = torch.cat([latent, latent]
+ ) if classifier_free_guidance else latent
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
+
+ noise_pred = self.backbone(
+ x=latent_input,
+ x_mask=latent_mask,
+ timesteps=timestep,
+ context=content,
+ context_mask=content_mask,
+ )
+
+ # perform guidance
+ if classifier_free_guidance:
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_content - noise_pred_uncond
+ )
+ if guidance_rescale != 0.0:
+ noise_pred = self.rescale_cfg(
+ noise_pred_content, noise_pred, guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
+ (i + 1) % scheduler.order == 0):
+ progress_bar.update(1)
+
+ waveform = self.autoencoder.decode(latent)
+
+ return waveform
+
+ def prepare_latent(
+ self, batch_size: int, scheduler: SchedulerMixin,
+ latent_shape: Sequence[int], dtype: torch.dtype, device: str
+ ):
+ shape = (batch_size, *latent_shape)
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=dtype
+ )
+ # scale the initial noise by the standard deviation required by the scheduler
+ latent = latent * scheduler.init_noise_sigma
+ return latent
+
+class SingleTaskCrossAttentionAudioDiffusion(CrossAttentionAudioDiffusion
+):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ backbone: nn.Module,
+ pretrained_ckpt: str | Path = None,
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
+ snr_gamma: float = None,
+ cfg_drop_ratio: float = 0.2,
+ ):
+ nn.Module.__init__(self)
+ DiffusionMixin.__init__(
+ self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
+ )
+
+ self.autoencoder = autoencoder
+ for param in self.autoencoder.parameters():
+ param.requires_grad = False
+
+ self.backbone = backbone
+ if pretrained_ckpt is not None:
+ pretrained_state_dict = load_file(pretrained_ckpt)
+ self.load_pretrained(pretrained_state_dict)
+
+ self.content_encoder = content_encoder
+ #self.content_encoder.audio_encoder.model = self.autoencoder
+ self.dummy_param = nn.Parameter(torch.empty(0))
+
+ def forward(
+ self, content: list[Any], condition: list[Any], task: list[str], waveform: torch.Tensor,
+ waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs
+ ):
+ loss_reduce = self.training or (loss_reduce and not self.training)
+ device = self.dummy_param.device
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ content[mask_indices] = 0
+
+ batch_size = latent.shape[0]
+ timesteps = self.get_timesteps(batch_size, device, self.training)
+ noise = torch.randn_like(latent)
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
+ target = self.get_target(latent, noise, timesteps)
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ timesteps=timesteps,
+ context=content,
+ x_mask=latent_mask,
+ context_mask=content_mask
+ )
+
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask, loss_reduce=loss_reduce)
+
+ return {
+ "diff_loss": diff_loss,
+ }
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content: list[Any],
+ condition: list[Any],
+ task: list[str],
+ scheduler: SchedulerMixin,
+ latent_shape: Sequence[int],
+ num_steps: int = 20,
+ guidance_scale: float = 3.0,
+ guidance_rescale: float = 0.0,
+ disable_progress: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ batch_size = content.size(0)
+
+ if classifier_free_guidance:
+ uncond_content = torch.zeros_like(content)
+ uncond_content_mask = content_mask.detach().clone()
+ content = torch.cat([uncond_content, content])
+ content_mask = torch.cat([uncond_content_mask, content_mask])
+
+ scheduler.set_timesteps(num_steps, device=device)
+ timesteps = scheduler.timesteps
+
+ latent = self.prepare_latent(
+ batch_size, scheduler, latent_shape, content.dtype, device
+ )
+
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+
+ for i, timestep in enumerate(timesteps):
+ # expand the latent if we are doing classifier free guidance
+ latent_input = torch.cat([latent, latent]
+ ) if classifier_free_guidance else latent
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
+
+ noise_pred = self.backbone(
+ x=latent_input,
+ timesteps=timestep,
+ context=content,
+ context_mask=content_mask,
+ )
+
+ # perform guidance
+ if classifier_free_guidance:
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_content - noise_pred_uncond
+ )
+ if guidance_rescale != 0.0:
+ noise_pred = self.rescale_cfg(
+ noise_pred_content, noise_pred, guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
+ (i + 1) % scheduler.order == 0):
+ progress_bar.update(1)
+
+ waveform = self.autoencoder.decode(latent)
+
+ return waveform
+
+
+class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ content_adapter: ContentAdapterBase,
+ backbone: nn.Module,
+ content_dim: int,
+ frame_resolution: float,
+ duration_offset: float = 1.0,
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
+ snr_gamma: float = None,
+ cfg_drop_ratio: float = 0.2,
+ ):
+ """
+ Args:
+ autoencoder:
+ Pretrained audio autoencoder that encodes raw waveforms into latent
+ space and decodes latents back to waveforms.
+ content_encoder:
+ Module that produces content embeddings (e.g., from text, MIDI, or
+ other modalities) used to guide the diffusion.
+ content_adapter (ContentAdapterBase):
+ Adapter module that fuses task instruction embeddings and content embeddings,
+ and performs duration prediction for time-aligned tasks.
+ backbone:
+ U‑Net or Transformer backbone that performs the core denoising
+ operations in latent space.
+ content_dim:
+ Dimension of the content embeddings produced by the `content_encoder`
+ and `content_adapter`.
+ frame_resolution:
+ Time resolution, in seconds, of each content frame when predicting
+ duration alignment. Used when calculating duration loss.
+ duration_offset:
+ A small positive offset (frame number) added to predicted durations
+ to ensure numerical stability of log-scaled duration prediction.
+ noise_scheduler_name:
+ Identifier of the pretrained noise scheduler to use.
+ snr_gamma:
+ Clipping value in min-SNR diffusion loss weighting strategy.
+ cfg_drop_ratio:
+ Probability of dropping the content conditioning during training
+ to support CFG.
+ """
+ super().__init__(
+ autoencoder=autoencoder,
+ content_encoder=content_encoder,
+ content_adapter=content_adapter,
+ backbone=backbone,
+ duration_offset=duration_offset,
+ noise_scheduler_name=noise_scheduler_name,
+ snr_gamma=snr_gamma,
+ cfg_drop_ratio=cfg_drop_ratio,
+ )
+ self.frame_resolution = frame_resolution
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
+
+ def forward(
+ self, content, duration, task, is_time_aligned, waveform,
+ waveform_lengths, instruction, instruction_lengths, **kwargs
+ ):
+ device = self.dummy_param.device
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ # content: (B, L, E)
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ length_aligned_content = content_output["length_aligned_content"]
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ instruction_mask = create_mask_from_length(instruction_lengths)
+
+ content, content_mask, global_duration_pred, local_duration_pred = \
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
+
+ n_frames = torch.round(duration / self.frame_resolution)
+ local_duration_target = torch.log(n_frames + self.duration_offset)
+ global_duration_target = torch.log(
+ latent_mask.sum(1) / self.autoencoder.latent_token_rate +
+ self.duration_offset
+ )
+
+ # truncate unused non time aligned duration prediction
+ if is_time_aligned.sum() > 0:
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
+ else:
+ trunc_ta_length = content.size(1)
+
+ # local duration loss
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
+ ta_content_mask = content_mask[:, :trunc_ta_length]
+ local_duration_target = local_duration_target.to(
+ dtype=local_duration_pred.dtype
+ )
+ local_duration_loss = loss_with_mask(
+ (local_duration_target - local_duration_pred)**2,
+ ta_content_mask,
+ reduce=False
+ )
+ local_duration_loss *= is_time_aligned
+ if is_time_aligned.sum().item() == 0:
+ local_duration_loss *= 0.0
+ local_duration_loss = local_duration_loss.mean()
+ else:
+ local_duration_loss = local_duration_loss.sum(
+ ) / is_time_aligned.sum()
+
+ # global duration loss
+ global_duration_loss = F.mse_loss(
+ global_duration_target, global_duration_pred
+ )
+
+ # --------------------------------------------------------------------
+ # prepare latent and diffusion-related noise
+ # --------------------------------------------------------------------
+
+ batch_size = latent.shape[0]
+ timesteps = self.get_timesteps(batch_size, device, self.training)
+ noise = torch.randn_like(latent)
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
+ target = self.get_target(latent, noise, timesteps)
+
+ # --------------------------------------------------------------------
+ # duration adapter
+ # --------------------------------------------------------------------
+ if is_time_aligned.sum() == 0 and \
+ duration.size(1) < content_mask.size(1):
+ # for non time-aligned tasks like TTA, `duration` is dummy one
+ duration = F.pad(
+ duration, (0, content_mask.size(1) - duration.size(1))
+ )
+ n_latents = torch.round(duration * self.autoencoder.latent_token_rate)
+ # content_mask: [B, L], helper_latent_mask: [B, T]
+ helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to(
+ content_mask.device
+ )
+ attn_mask = ta_content_mask.unsqueeze(
+ -1
+ ) * helper_latent_mask.unsqueeze(1)
+ # attn_mask: [B, L, T]
+ align_path = create_alignment_path(n_latents, attn_mask)
+ time_aligned_content = content[:, :trunc_ta_length]
+ time_aligned_content = torch.matmul(
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
+ ) # (B, T, L) x (B, L, E) -> (B, T, E)
+
+ # --------------------------------------------------------------------
+ # prepare input to the backbone
+ # --------------------------------------------------------------------
+ # TODO compatility for 2D spectrogram VAE
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
+ time_aligned_content = trim_or_pad_length(
+ time_aligned_content, latent_length, 1
+ )
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, latent_length, 1
+ )
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
+ # length_aligned_content: from aligned input (f0/energy)
+ time_aligned_content = time_aligned_content + length_aligned_content
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
+ time_aligned_content.dtype
+ )
+
+ context = content
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
+ # only use the first dummy non time aligned embedding
+ context_mask = content_mask.detach().clone()
+ context_mask[is_time_aligned, 1:] = False
+
+ # truncate dummy non time aligned context
+ if is_time_aligned.sum().item() < batch_size:
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
+ else:
+ trunc_nta_length = content.size(1)
+ context = context[:, :trunc_nta_length]
+ context_mask = context_mask[:, :trunc_nta_length]
+
+ # --------------------------------------------------------------------
+ # classifier free guidance
+ # --------------------------------------------------------------------
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ context[mask_indices] = 0
+ time_aligned_content[mask_indices] = 0
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ timesteps=timesteps,
+ time_aligned_context=time_aligned_content,
+ context=context,
+ x_mask=latent_mask,
+ context_mask=context_mask
+ )
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
+ return {
+ "diff_loss": diff_loss,
+ "local_duration_loss": local_duration_loss,
+ "global_duration_loss": global_duration_loss
+ }
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content: list[Any],
+ condition: list[Any],
+ task: list[str],
+ is_time_aligned: list[bool],
+ instruction: torch.Tensor,
+ instruction_lengths: Sequence[int],
+ scheduler: SchedulerMixin,
+ num_steps: int = 20,
+ guidance_scale: float = 3.0,
+ guidance_rescale: float = 0.0,
+ disable_progress: bool = True,
+ use_gt_duration: bool = False,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ length_aligned_content = content_output["length_aligned_content"]
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ instruction_mask = create_mask_from_length(instruction_lengths)
+ content, content_mask, global_duration_pred, local_duration_pred = \
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
+
+ scheduler.set_timesteps(num_steps, device=device)
+ timesteps = scheduler.timesteps
+ batch_size = content.size(0)
+
+ # truncate dummy time aligned duration prediction
+ is_time_aligned = torch.as_tensor(is_time_aligned)
+ if is_time_aligned.sum() > 0:
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
+ else:
+ trunc_ta_length = content.size(1)
+
+ # prepare local duration
+ local_duration_pred = torch.exp(local_duration_pred) * content_mask
+ local_duration_pred = torch.ceil(
+ local_duration_pred
+ ) - self.duration_offset # frame number in `self.frame_resolution`
+ local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \
+ self.autoencoder.latent_token_rate)
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
+ # use ground truth duration
+ if use_gt_duration and "duration" in kwargs:
+ local_duration_pred = torch.round(
+ torch.as_tensor(kwargs["duration"]) *
+ self.autoencoder.latent_token_rate
+ ).to(device)
+
+ # prepare global duration
+ global_duration = local_duration_pred.sum(1)
+ global_duration_pred = torch.exp(
+ global_duration_pred
+ ) - self.duration_offset
+ global_duration_pred *= self.autoencoder.latent_token_rate
+ global_duration_pred = torch.round(global_duration_pred)
+ global_duration[~is_time_aligned] = global_duration_pred[
+ ~is_time_aligned]
+
+ # --------------------------------------------------------------------
+ # duration adapter
+ # --------------------------------------------------------------------
+ time_aligned_content = content[:, :trunc_ta_length]
+ ta_content_mask = content_mask[:, :trunc_ta_length]
+ latent_mask = create_mask_from_length(global_duration).to(
+ content_mask.device
+ )
+ attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
+ # attn_mask: [B, L, T]
+ align_path = create_alignment_path(local_duration_pred, attn_mask)
+ time_aligned_content = torch.matmul(
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
+ ) # (B, T, L) x (B, L, E) -> (B, T, E)
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
+ time_aligned_content.dtype
+ )
+
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, time_aligned_content.size(1), 1
+ )
+ time_aligned_content = time_aligned_content + length_aligned_content
+
+ # --------------------------------------------------------------------
+ # prepare unconditional input
+ # --------------------------------------------------------------------
+ context = content
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
+ context_mask = content_mask
+ context_mask[
+ is_time_aligned,
+ 1:] = False # only use the first dummy non time aligned embedding
+ # truncate dummy non time aligned context
+ if is_time_aligned.sum().item() < batch_size:
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
+ else:
+ trunc_nta_length = content.size(1)
+ context = context[:, :trunc_nta_length]
+ context_mask = context_mask[:, :trunc_nta_length]
+
+ if classifier_free_guidance:
+ uncond_time_aligned_content = torch.zeros_like(
+ time_aligned_content
+ )
+ uncond_context = torch.zeros_like(context)
+ uncond_context_mask = context_mask.detach().clone()
+ time_aligned_content = torch.cat([
+ uncond_time_aligned_content, time_aligned_content
+ ])
+ context = torch.cat([uncond_context, context])
+ context_mask = torch.cat([uncond_context_mask, context_mask])
+ latent_mask = torch.cat([
+ latent_mask, latent_mask.detach().clone()
+ ])
+
+ # --------------------------------------------------------------------
+ # prepare input to the backbone
+ # --------------------------------------------------------------------
+ latent_shape = tuple(
+ int(global_duration.max().item()) if dim is None else dim
+ for dim in self.autoencoder.latent_shape
+ )
+ shape = (batch_size, *latent_shape)
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=content.dtype
+ )
+ # scale the initial noise by the standard deviation required by the scheduler
+ latent = latent * scheduler.init_noise_sigma
+
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+ # --------------------------------------------------------------------
+ # iteratively denoising
+ # --------------------------------------------------------------------
+ for i, timestep in enumerate(timesteps):
+ # expand the latent if we are doing classifier free guidance
+ if classifier_free_guidance:
+ latent_input = torch.cat([latent, latent])
+ else:
+ latent_input = latent
+
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
+ noise_pred = self.backbone(
+ x=latent_input,
+ x_mask=latent_mask,
+ timesteps=timestep,
+ time_aligned_context=time_aligned_content,
+ context=context,
+ context_mask=context_mask
+ )
+
+ if classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_cond - noise_pred_uncond
+ )
+ if guidance_rescale != 0.0:
+ noise_pred = self.rescale_cfg(
+ noise_pred_cond, noise_pred, guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
+ (i + 1) % scheduler.order == 0):
+ progress_bar.update(1)
+
+ progress_bar.close()
+
+ # TODO variable length decoding, using `latent_mask`
+ waveform = self.autoencoder.decode(latent)
+ return waveform
+
+
+class DoubleContentAudioDiffusion(CrossAttentionAudioDiffusion):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ content_adapter: nn.Module,
+ backbone: nn.Module,
+ content_dim: int,
+ frame_resolution: float,
+ duration_offset: float = 1.0,
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
+ snr_gamma: float = None,
+ cfg_drop_ratio: float = 0.2,
+ ):
+ super().__init__(
+ autoencoder=autoencoder,
+ content_encoder=content_encoder,
+ content_adapter=content_adapter,
+ backbone=backbone,
+ duration_offset=duration_offset,
+ noise_scheduler_name=noise_scheduler_name,
+ snr_gamma=snr_gamma,
+ cfg_drop_ratio=cfg_drop_ratio
+ )
+ self.frame_resolution = frame_resolution
+
+ def forward(
+ self, content, duration, task, is_time_aligned, waveform,
+ waveform_lengths, instruction, instruction_lengths, **kwargs
+ ):
+ device = self.dummy_param.device
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ length_aligned_content = content_output["length_aligned_content"]
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ context_mask = content_mask.detach()
+ instruction_mask = create_mask_from_length(instruction_lengths)
+
+ content, content_mask, global_duration_pred, local_duration_pred = \
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
+
+ # TODO if all non time aligned, content length > duration length
+
+ n_frames = torch.round(duration / self.frame_resolution)
+ local_duration_target = torch.log(n_frames + self.duration_offset)
+ global_duration_target = torch.log(
+ latent_mask.sum(1) / self.autoencoder.latent_token_rate +
+ self.duration_offset
+ )
+ # truncate unused non time aligned duration prediction
+ if is_time_aligned.sum() > 0:
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
+ else:
+ trunc_ta_length = content.size(1)
+ # local duration loss
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
+ ta_content_mask = content_mask[:, :trunc_ta_length]
+ local_duration_target = local_duration_target.to(
+ dtype=local_duration_pred.dtype
+ )
+ local_duration_loss = loss_with_mask(
+ (local_duration_target - local_duration_pred)**2,
+ ta_content_mask,
+ reduce=False
+ )
+ local_duration_loss *= is_time_aligned
+ if is_time_aligned.sum().item() == 0:
+ local_duration_loss *= 0.0
+ local_duration_loss = local_duration_loss.mean()
+ else:
+ local_duration_loss = local_duration_loss.sum(
+ ) / is_time_aligned.sum()
+
+ # global duration loss
+ global_duration_loss = F.mse_loss(
+ global_duration_target, global_duration_pred
+ )
+ # --------------------------------------------------------------------
+ # prepare latent and diffusion-related noise
+ # --------------------------------------------------------------------
+ batch_size = latent.shape[0]
+ timesteps = self.get_timesteps(batch_size, device, self.training)
+ noise = torch.randn_like(latent)
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
+ target = self.get_target(latent, noise, timesteps)
+
+ # --------------------------------------------------------------------
+ # duration adapter
+ # --------------------------------------------------------------------
+ # content_mask: [B, L], helper_latent_mask: [B, T]
+ if is_time_aligned.sum() == 0 and \
+ duration.size(1) < content_mask.size(1):
+ # for non time-aligned tasks like TTA, `duration` is dummy one
+ duration = F.pad(
+ duration, (0, content_mask.size(1) - duration.size(1))
+ )
+ n_latents = torch.round(duration * self.autoencoder.latent_token_rate)
+ helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to(
+ content_mask.device
+ )
+ attn_mask = ta_content_mask.unsqueeze(
+ -1
+ ) * helper_latent_mask.unsqueeze(1)
+ align_path = create_alignment_path(n_latents, attn_mask)
+ time_aligned_content = content[:, :trunc_ta_length]
+ time_aligned_content = torch.matmul(
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
+ )
+
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
+ time_aligned_content = trim_or_pad_length(
+ time_aligned_content, latent_length, 1
+ )
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, latent_length, 1
+ )
+ time_aligned_content = time_aligned_content + length_aligned_content
+ context = content
+ # --------------------------------------------------------------------
+ # classifier free guidance
+ # --------------------------------------------------------------------
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ context[mask_indices] = 0
+ time_aligned_content[mask_indices] = 0
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ timesteps=timesteps,
+ time_aligned_context=time_aligned_content,
+ context=context,
+ x_mask=latent_mask,
+ context_mask=context_mask,
+ )
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
+ return {
+ "diff_loss": diff_loss,
+ "local_duration_loss": local_duration_loss,
+ "global_duration_loss": global_duration_loss,
+ }
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content: list[Any],
+ condition: list[Any],
+ task: list[str],
+ is_time_aligned: list[bool],
+ instruction: torch.Tensor,
+ instruction_lengths: Sequence[int],
+ scheduler: SchedulerMixin,
+ num_steps: int = 20,
+ guidance_scale: float = 3.0,
+ guidance_rescale: float = 0.0,
+ disable_progress: bool = True,
+ use_gt_duration: bool = False,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ length_aligned_content = content_output["length_aligned_content"]
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ instruction_mask = create_mask_from_length(instruction_lengths)
+
+ content, content_mask, global_duration_pred, local_duration_pred = \
+ self.content_adapter(content, content_mask, instruction, instruction_mask)
+
+ scheduler.set_timesteps(num_steps, device=device)
+ timesteps = scheduler.timesteps
+ batch_size = content.size(0)
+
+ # truncate dummy time aligned duration prediction
+ is_time_aligned = torch.as_tensor(is_time_aligned)
+ if is_time_aligned.sum() > 0:
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
+ else:
+ trunc_ta_length = content.size(1)
+
+ # prepare local duration
+ local_duration_pred = torch.exp(local_duration_pred) * content_mask
+ local_duration_pred = torch.ceil(
+ local_duration_pred
+ ) - self.duration_offset # frame number in `self.frame_resolution`
+ local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \
+ self.autoencoder.latent_token_rate)
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
+ # use ground truth duration
+ if use_gt_duration and "duration" in kwargs:
+ local_duration_pred = torch.round(
+ torch.as_tensor(kwargs["duration"]) *
+ self.autoencoder.latent_token_rate
+ ).to(device)
+
+ # prepare global duration
+ global_duration = local_duration_pred.sum(1)
+ global_duration_pred = torch.exp(
+ global_duration_pred
+ ) - self.duration_offset
+ global_duration_pred *= self.autoencoder.latent_token_rate
+ global_duration_pred = torch.round(global_duration_pred)
+ global_duration[~is_time_aligned] = global_duration_pred[
+ ~is_time_aligned]
+
+ # --------------------------------------------------------------------
+ # duration adapter
+ # --------------------------------------------------------------------
+ time_aligned_content = content[:, :trunc_ta_length]
+ ta_content_mask = content_mask[:, :trunc_ta_length]
+ latent_mask = create_mask_from_length(global_duration).to(
+ content_mask.device
+ )
+ attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
+ # attn_mask: [B, L, T]
+ align_path = create_alignment_path(local_duration_pred, attn_mask)
+ time_aligned_content = torch.matmul(
+ align_path.transpose(1, 2).to(content.dtype), time_aligned_content
+ ) # (B, T, L) x (B, L, E) -> (B, T, E)
+
+ # time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
+ # time_aligned_content.dtype
+ # )
+
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, time_aligned_content.size(1), 1
+ )
+ time_aligned_content = time_aligned_content + length_aligned_content
+
+ # --------------------------------------------------------------------
+ # prepare unconditional input
+ # --------------------------------------------------------------------
+ context = content
+ # context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
+ context_mask = content_mask
+ # context_mask[
+ # is_time_aligned,
+ # 1:] = False # only use the first dummy non time aligned embedding
+ # # truncate dummy non time aligned context
+ # if is_time_aligned.sum().item() < batch_size:
+ # trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
+ # else:
+ # trunc_nta_length = content.size(1)
+ # context = context[:, :trunc_nta_length]
+ # context_mask = context_mask[:, :trunc_nta_length]
+
+ if classifier_free_guidance:
+ uncond_time_aligned_content = torch.zeros_like(
+ time_aligned_content
+ )
+ uncond_context = torch.zeros_like(context)
+ uncond_context_mask = context_mask.detach().clone()
+ time_aligned_content = torch.cat([
+ uncond_time_aligned_content, time_aligned_content
+ ])
+ context = torch.cat([uncond_context, context])
+ context_mask = torch.cat([uncond_context_mask, context_mask])
+ latent_mask = torch.cat([
+ latent_mask, latent_mask.detach().clone()
+ ])
+
+ # --------------------------------------------------------------------
+ # prepare input to the backbone
+ # --------------------------------------------------------------------
+ latent_shape = tuple(
+ int(global_duration.max().item()) if dim is None else dim
+ for dim in self.autoencoder.latent_shape
+ )
+ shape = (batch_size, *latent_shape)
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=content.dtype
+ )
+ # scale the initial noise by the standard deviation required by the scheduler
+ latent = latent * scheduler.init_noise_sigma
+
+ num_warmup_steps = len(timesteps) - num_steps * scheduler.order
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
+ # --------------------------------------------------------------------
+ # iteratively denoising
+ # --------------------------------------------------------------------
+ for i, timestep in enumerate(timesteps):
+ # expand the latent if we are doing classifier free guidance
+ if classifier_free_guidance:
+ latent_input = torch.cat([latent, latent])
+ else:
+ latent_input = latent
+
+ latent_input = scheduler.scale_model_input(latent_input, timestep)
+ noise_pred = self.backbone(
+ x=latent_input,
+ x_mask=latent_mask,
+ timesteps=timestep,
+ time_aligned_context=time_aligned_content,
+ context=context,
+ context_mask=context_mask
+ )
+
+ if classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_cond - noise_pred_uncond
+ )
+ if guidance_rescale != 0.0:
+ noise_pred = self.rescale_cfg(
+ noise_pred_cond, noise_pred, guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latent = scheduler.step(noise_pred, timestep, latent).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
+ (i + 1) % scheduler.order == 0):
+ progress_bar.update(1)
+
+ progress_bar.close()
+
+ # TODO variable length decoding, using `latent_mask`
+ waveform = self.autoencoder.decode(latent)
+ return waveform
diff --git a/models/dit/__pycache__/attention.cpython-310.pyc b/models/dit/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d4715a10dc05eb3bc4aa9d5c489f6294f4905e3
Binary files /dev/null and b/models/dit/__pycache__/attention.cpython-310.pyc differ
diff --git a/models/dit/__pycache__/audio_dit.cpython-310.pyc b/models/dit/__pycache__/audio_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..251f919f74fa76c7ed76bafa2300bdc1987913b9
Binary files /dev/null and b/models/dit/__pycache__/audio_dit.cpython-310.pyc differ
diff --git a/models/dit/__pycache__/mask_dit.cpython-310.pyc b/models/dit/__pycache__/mask_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f3709304a101c69d8d549b045c9bb9dfc8883f1
Binary files /dev/null and b/models/dit/__pycache__/mask_dit.cpython-310.pyc differ
diff --git a/models/dit/__pycache__/modules.cpython-310.pyc b/models/dit/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3e698ff4e9a086e1c897ed8c02931a47c331880
Binary files /dev/null and b/models/dit/__pycache__/modules.cpython-310.pyc differ
diff --git a/models/dit/__pycache__/rotary.cpython-310.pyc b/models/dit/__pycache__/rotary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9abe202a53a40f634fd99ca83c61d7c05e558b4
Binary files /dev/null and b/models/dit/__pycache__/rotary.cpython-310.pyc differ
diff --git a/models/dit/__pycache__/span_mask.cpython-310.pyc b/models/dit/__pycache__/span_mask.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5809b47940f981bc83b994b07fdbf93b27a7bc7d
Binary files /dev/null and b/models/dit/__pycache__/span_mask.cpython-310.pyc differ
diff --git a/models/dit/attention.py b/models/dit/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..862fb2070e536f344f15565fe67b135050098c57
--- /dev/null
+++ b/models/dit/attention.py
@@ -0,0 +1,349 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+from .modules import RMSNorm
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+ b, ndim = sim.shape[0], mask.ndim
+ if ndim == 3:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ if ndim == 2:
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim = sim.masked_fill(~mask, max_neg_value)
+ return sim
+
+
+def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
+ def default(val, d):
+ return val if val is not None else (d() if isfunction(d) else d)
+
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
+ q_mask = default(
+ q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
+ )
+ k_mask = default(
+ k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
+ )
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
+ ) * rearrange(k_mask, 'b j -> b 1 1 j')
+ return attn_mask
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ context_dim=None,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ rope_mode='none'
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ if context_dim is None:
+ self.cross_attn = False
+ else:
+ self.cross_attn = True
+
+ context_dim = dim if context_dim is None else context_dim
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+
+ if qk_norm is None:
+ self.norm_q = nn.Identity()
+ self.norm_k = nn.Identity()
+ elif qk_norm == 'layernorm':
+ self.norm_q = nn.LayerNorm(head_dim)
+ self.norm_k = nn.LayerNorm(head_dim)
+ elif qk_norm == 'rmsnorm':
+ self.norm_q = RMSNorm(head_dim)
+ self.norm_k = RMSNorm(head_dim)
+ else:
+ raise NotImplementedError
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ if self.cross_attn:
+ assert rope_mode == 'none'
+ self.rope_mode = rope_mode
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+ self.rotary = RotaryEmbedding(dim=head_dim)
+ elif self.rope_mode == 'dual':
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+ def _rotary(self, q, k, extras):
+ if self.rope_mode == 'shared':
+ q, k = self.rotary(q=q, k=k)
+ elif self.rope_mode == 'x_only':
+ q_x, k_x = self.rotary(
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
+ )
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'dual':
+ q_x, k_x = self.rotary_x(
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
+ )
+ q_c, k_c = self.rotary_c(
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
+ )
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'none':
+ pass
+ else:
+ raise NotImplementedError
+ return q, k
+
+ def _attn(self, q, k, v, mask_binary):
+ if ATTENTION_MODE == 'flash':
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
+ )
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(
+ attn, mask_binary
+ ) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ else:
+ raise NotImplementedError
+ return x
+
+ def forward(self, x, context=None, context_mask=None, extras=0):
+ B, L, C = x.shape
+ if context is None:
+ context = x
+
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if context_mask is not None:
+ mask_binary = create_mask(
+ x.shape, context.shape, x.device, None, context_mask
+ )
+ else:
+ mask_binary = None
+
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
+
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+
+ q, k = self._rotary(q, k, extras)
+
+ x = self._attn(q, k, v, mask_binary)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class JointAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ rope_mode='none'
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
+ dim, qkv_bias
+ )
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
+ dim, qkv_bias
+ )
+
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.proj_x = nn.Linear(dim, dim)
+ self.proj_drop_x = nn.Dropout(proj_drop)
+
+ self.proj_c = nn.Linear(dim, dim)
+ self.proj_drop_c = nn.Dropout(proj_drop)
+
+ self.rope_mode = rope_mode
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+ self.rotary = RotaryEmbedding(dim=head_dim)
+ elif self.rope_mode == 'dual':
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+ def _make_qkv_layers(self, dim, qkv_bias):
+ return (
+ nn.Linear(dim, dim,
+ bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
+ nn.Linear(dim, dim, bias=qkv_bias)
+ )
+
+ def _make_norm_layers(self, qk_norm, head_dim):
+ if qk_norm is None:
+ norm_q = nn.Identity()
+ norm_k = nn.Identity()
+ elif qk_norm == 'layernorm':
+ norm_q = nn.LayerNorm(head_dim)
+ norm_k = nn.LayerNorm(head_dim)
+ elif qk_norm == 'rmsnorm':
+ norm_q = RMSNorm(head_dim)
+ norm_k = RMSNorm(head_dim)
+ else:
+ raise NotImplementedError
+ return norm_q, norm_k
+
+ def _rotary(self, q, k, extras):
+ if self.rope_mode == 'shared':
+ q, k = self.rotary(q=q, k=k)
+ elif self.rope_mode == 'x_only':
+ q_x, k_x = self.rotary(
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
+ )
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'dual':
+ q_x, k_x = self.rotary_x(
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
+ )
+ q_c, k_c = self.rotary_c(
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
+ )
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'none':
+ pass
+ else:
+ raise NotImplementedError
+ return q, k
+
+ def _attn(self, q, k, v, mask_binary):
+ if ATTENTION_MODE == 'flash':
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
+ )
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(
+ attn, mask_binary
+ ) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ else:
+ raise NotImplementedError
+ return x
+
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
+ B = x.shape[0]
+ if x_mask is None:
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+ if context_mask is None:
+ context_mask = torch.ones(
+ B, context.shape[-2], device=context.device
+ ).bool()
+ mask = torch.cat([context_mask, x_mask], dim=1)
+ return mask
+
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
+ B, Lx, C = x.shape
+ _, Lc, _ = context.shape
+ if x_mask is not None or context_mask is not None:
+ mask = self._cat_mask(
+ x, context, x_mask=x_mask, context_mask=context_mask
+ )
+ shape = [B, Lx + Lc, C]
+ mask_binary = create_mask(
+ q_shape=shape,
+ k_shape=shape,
+ device=x.device,
+ q_mask=None,
+ k_mask=mask
+ )
+ else:
+ mask_binary = None
+
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
+ qc, kc, vc = self.to_qc(context), self.to_kc(context
+ ), self.to_vc(context)
+
+ qx, kx, vx = map(
+ lambda t: einops.
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
+ [qx, kx, vx]
+ )
+ qc, kc, vc = map(
+ lambda t: einops.
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
+ [qc, kc, vc]
+ )
+
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
+
+ q, k, v = (
+ torch.cat([qc, qx],
+ dim=2), torch.cat([kc, kx],
+ dim=2), torch.cat([vc, vx], dim=2)
+ )
+
+ q, k = self._rotary(q, k, extras)
+
+ x = self._attn(q, k, v, mask_binary)
+
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
+
+ x = self.proj_x(x)
+ x = self.proj_drop_x(x)
+
+ context = self.proj_c(context)
+ context = self.proj_drop_c(context)
+
+ return x, context
diff --git a/models/dit/audio_diffsingernet_dit.py b/models/dit/audio_diffsingernet_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5facb5c2316a04f0010a477ce0b6d7268d043a
--- /dev/null
+++ b/models/dit/audio_diffsingernet_dit.py
@@ -0,0 +1,520 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .mask_dit import DiTBlock, FinalBlock, UDiT
+from .modules import (
+ film_modulate,
+ PatchEmbed,
+ PE_wrapper,
+ TimestepEmbedder,
+ RMSNorm,
+)
+
+
+class AudioDiTBlock(DiTBlock):
+ """
+ A modified DiT block with time_aligned_context add to latent.
+ """
+ def __init__(
+ self,
+ dim,
+ time_aligned_context_dim,
+ dilation,
+ context_dim=None,
+ num_heads=8,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer=nn.LayerNorm,
+ time_fusion='none',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ skip=False,
+ skip_norm=False,
+ rope_mode='none',
+ context_norm=False,
+ use_checkpoint=False
+ ):
+ super().__init__(
+ dim=dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=skip,
+ skip_norm=skip_norm,
+ rope_mode=rope_mode,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ )
+ # time-aligned context projection
+ self.ta_context_projection = nn.Linear(
+ time_aligned_context_dim, 2 * dim
+ )
+ self.dilated_conv = nn.Conv1d(
+ dim, 2 * dim, kernel_size=3, padding=dilation, dilation=dilation
+ )
+
+ def forward(
+ self,
+ x,
+ time_aligned_context,
+ time_token=None,
+ time_ada=None,
+ skip=None,
+ context=None,
+ x_mask=None,
+ context_mask=None,
+ extras=None
+ ):
+ if self.use_checkpoint:
+ return checkpoint(
+ self._forward,
+ x,
+ time_aligned_context,
+ time_token,
+ time_ada,
+ skip,
+ context,
+ x_mask,
+ context_mask,
+ extras,
+ use_reentrant=False
+ )
+ else:
+ return self._forward(
+ x,
+ time_aligned_context,
+ time_token,
+ time_ada,
+ skip,
+ context,
+ x_mask,
+ context_mask,
+ extras,
+ )
+
+ def _forward(
+ self,
+ x,
+ time_aligned_context,
+ time_token=None,
+ time_ada=None,
+ skip=None,
+ context=None,
+ x_mask=None,
+ context_mask=None,
+ extras=None
+ ):
+ B, T, C = x.shape
+ if self.skip_linear is not None:
+ assert skip is not None
+ cat = torch.cat([x, skip], dim=-1)
+ cat = self.skip_norm(cat)
+ x = self.skip_linear(cat)
+
+ if self.use_adanorm:
+ time_ada = self.adaln(time_token, time_ada)
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
+ gate_mlp) = time_ada.chunk(6, dim=1)
+
+ # self attention
+ if self.use_adanorm:
+ x_norm = film_modulate(
+ self.norm1(x), shift=shift_msa, scale=scale_msa
+ )
+ x = x + (1-gate_msa) * self.attn(
+ x_norm, context=None, context_mask=x_mask, extras=extras
+ )
+ else:
+ # TODO diffusion timestep input is not fused here
+ x = x + self.attn(
+ self.norm1(x),
+ context=None,
+ context_mask=x_mask,
+ extras=extras
+ )
+
+ # time-aligned context
+ time_aligned_context = self.ta_context_projection(time_aligned_context)
+ x = self.dilated_conv(x.transpose(1, 2)
+ ).transpose(1, 2) + time_aligned_context
+
+ gate, filter = torch.chunk(x, 2, dim=-1)
+ x = torch.sigmoid(gate) * torch.tanh(filter)
+
+ # cross attention
+ if self.use_context:
+ assert context is not None
+ x = x + self.cross_attn(
+ x=self.norm2(x),
+ context=self.norm_context(context),
+ context_mask=context_mask,
+ extras=extras
+ )
+
+ # mlp
+ if self.use_adanorm:
+ x_norm = film_modulate(
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
+ )
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
+ else:
+ x = x + self.mlp(self.norm3(x))
+
+ return x
+
+
+class AudioUDiT(UDiT):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ input_type='2d',
+ out_chans=None,
+ embed_dim=768,
+ depth=12,
+ dilation_cycle_length=4,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer='layernorm',
+ context_norm=False,
+ use_checkpoint=False,
+ time_fusion='token',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ cls_dim=None,
+ time_aligned_context_dim=768,
+ context_dim=768,
+ context_fusion='concat',
+ context_max_length=128,
+ context_pe_method='sinu',
+ pe_method='abs',
+ rope_mode='none',
+ use_conv=True,
+ skip=True,
+ skip_norm=True
+ ):
+ nn.Module.__init__(self)
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ # input
+ self.in_chans = in_chans
+ self.input_type = input_type
+ if self.input_type == '2d':
+ num_patches = (img_size[0] //
+ patch_size) * (img_size[1] // patch_size)
+ elif self.input_type == '1d':
+ num_patches = img_size // patch_size
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ input_type=input_type
+ )
+ out_chans = in_chans if out_chans is None else out_chans
+ self.out_chans = out_chans
+
+ # position embedding
+ self.rope = rope_mode
+ self.x_pe = PE_wrapper(
+ dim=embed_dim, method=pe_method, length=num_patches
+ )
+
+ # time embed
+ self.time_embed = TimestepEmbedder(embed_dim)
+ self.time_fusion = time_fusion
+ self.use_adanorm = False
+
+ # cls embed
+ if cls_dim is not None:
+ self.cls_embed = nn.Sequential(
+ nn.Linear(cls_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ else:
+ self.cls_embed = None
+
+ # time fusion
+ if time_fusion == 'token':
+ # put token at the beginning of sequence
+ self.extras = 2 if self.cls_embed else 1
+ self.time_pe = PE_wrapper(
+ dim=embed_dim, method='abs', length=self.extras
+ )
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
+ self.use_adanorm = True
+ # aviod repetitive silu for each adaln block
+ self.time_act = nn.SiLU()
+ self.extras = 0
+ self.time_ada_final = nn.Linear(
+ embed_dim, 2 * embed_dim, bias=True
+ )
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
+ # shared adaln
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+ else:
+ self.time_ada = None
+ else:
+ raise NotImplementedError
+
+ # context
+ # use a simple projection
+ self.use_context = False
+ self.context_cross = False
+ self.context_max_length = context_max_length
+ self.context_fusion = 'none'
+ if context_dim is not None:
+ self.use_context = True
+ self.context_embed = nn.Sequential(
+ nn.Linear(context_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ self.context_fusion = context_fusion
+ if context_fusion == 'concat' or context_fusion == 'joint':
+ self.extras += context_max_length
+ self.context_pe = PE_wrapper(
+ dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length
+ )
+ # no cross attention layers
+ context_dim = None
+ elif context_fusion == 'cross':
+ self.context_pe = PE_wrapper(
+ dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length
+ )
+ self.context_cross = True
+ context_dim = embed_dim
+ else:
+ raise NotImplementedError
+
+ self.use_skip = skip
+
+ # norm layers
+ if norm_layer == 'layernorm':
+ norm_layer = nn.LayerNorm
+ elif norm_layer == 'rmsnorm':
+ norm_layer = RMSNorm
+ else:
+ raise NotImplementedError
+
+ self.in_blocks = nn.ModuleList([
+ AudioDiTBlock(
+ dim=embed_dim,
+ time_aligned_context_dim=time_aligned_context_dim,
+ dilation=2**(i % dilation_cycle_length),
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=False,
+ skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ ) for i in range(depth // 2)
+ ])
+
+ self.mid_block = AudioDiTBlock(
+ dim=embed_dim,
+ time_aligned_context_dim=time_aligned_context_dim,
+ dilation=1,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=False,
+ skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ )
+
+ self.out_blocks = nn.ModuleList([
+ AudioDiTBlock(
+ dim=embed_dim,
+ time_aligned_context_dim=time_aligned_context_dim,
+ dilation=2**(i % dilation_cycle_length),
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=skip,
+ skip_norm=skip_norm,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ ) for i in range(depth // 2)
+ ])
+
+ # FinalLayer block
+ self.use_conv = use_conv
+ self.final_block = FinalBlock(
+ embed_dim=embed_dim,
+ patch_size=patch_size,
+ img_size=img_size,
+ in_chans=out_chans,
+ input_type=input_type,
+ norm_layer=norm_layer,
+ use_conv=use_conv,
+ use_adanorm=self.use_adanorm
+ )
+ self.initialize_weights()
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ time_aligned_context,
+ context,
+ x_mask=None,
+ context_mask=None,
+ cls_token=None,
+ controlnet_skips=None,
+ ):
+ # make it compatible with int time step during inference
+ if timesteps.dim() == 0:
+ timesteps = timesteps.expand(x.shape[0]
+ ).to(x.device, dtype=torch.long)
+
+ x = self.patch_embed(x)
+ x = self.x_pe(x)
+
+ B, L, D = x.shape
+
+ if self.use_context:
+ context_token = self.context_embed(context)
+ context_token = self.context_pe(context_token)
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+ x, x_mask = self._concat_x_context(
+ x=x,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask
+ )
+ context_token, context_mask = None, None
+ else:
+ context_token, context_mask = None, None
+
+ time_token = self.time_embed(timesteps)
+ if self.cls_embed:
+ cls_token = self.cls_embed(cls_token)
+ time_ada = None
+ time_ada_final = None
+ if self.use_adanorm:
+ if self.cls_embed:
+ time_token = time_token + cls_token
+ time_token = self.time_act(time_token)
+ time_ada_final = self.time_ada_final(time_token)
+ if self.time_ada is not None:
+ time_ada = self.time_ada(time_token)
+ else:
+ time_token = time_token.unsqueeze(dim=1)
+ if self.cls_embed:
+ cls_token = cls_token.unsqueeze(dim=1)
+ time_token = torch.cat([time_token, cls_token], dim=1)
+ time_token = self.time_pe(time_token)
+ x = torch.cat((time_token, x), dim=1)
+ if x_mask is not None:
+ x_mask = torch.cat([
+ torch.ones(B, time_token.shape[1],
+ device=x_mask.device).bool(), x_mask
+ ],
+ dim=1)
+ time_token = None
+
+ skips = []
+ for blk in self.in_blocks:
+ x = blk(
+ x=x,
+ time_aligned_context=time_aligned_context,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=None,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ if self.use_skip:
+ skips.append(x)
+
+ x = self.mid_block(
+ x=x,
+ time_aligned_context=time_aligned_context,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=None,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ for blk in self.out_blocks:
+ if self.use_skip:
+ skip = skips.pop()
+ if controlnet_skips:
+ # add to skip like u-net controlnet
+ skip = skip + controlnet_skips.pop()
+ else:
+ skip = None
+ if controlnet_skips:
+ # directly add to x
+ x = x + controlnet_skips.pop()
+
+ x = blk(
+ x=x,
+ time_aligned_context=time_aligned_context,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=skip,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
+
+ return x
diff --git a/models/dit/audio_dit.py b/models/dit/audio_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5151857f9756bd50ba6aa0bc9ca4f798345c0b98
--- /dev/null
+++ b/models/dit/audio_dit.py
@@ -0,0 +1,652 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .mask_dit import DiTBlock, FinalBlock, UDiT
+from .modules import (
+ film_modulate,
+ PatchEmbed,
+ PE_wrapper,
+ TimestepEmbedder,
+ RMSNorm,
+)
+
+
+class LayerFusionDiTBlock(DiTBlock):
+ """
+ A modified DiT block with time aligned context add to latent.
+ """
+ def __init__(
+ self,
+ dim,
+ ta_context_dim,
+ ta_context_norm=False,
+ context_dim=None,
+ num_heads=8,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer=nn.LayerNorm,
+ ta_context_fusion='add',
+ time_fusion='none',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ skip=False,
+ skip_norm=False,
+ rope_mode='none',
+ context_norm=False,
+ use_checkpoint=False
+ ):
+ super().__init__(
+ dim=dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=skip,
+ skip_norm=skip_norm,
+ rope_mode=rope_mode,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ )
+ self.ta_context_fusion = ta_context_fusion
+ self.ta_context_norm = ta_context_norm
+ if self.ta_context_fusion == "add":
+ self.ta_context_projection = nn.Linear(
+ ta_context_dim, dim, bias=False
+ )
+ self.ta_context_norm = norm_layer(
+ ta_context_dim
+ ) if self.ta_context_norm else nn.Identity()
+ elif self.ta_context_fusion == "concat":
+ self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
+ self.ta_context_norm = norm_layer(
+ ta_context_dim + dim
+ ) if self.ta_context_norm else nn.Identity()
+
+ def forward(
+ self,
+ x,
+ time_aligned_context,
+ time_token=None,
+ time_ada=None,
+ skip=None,
+ context=None,
+ x_mask=None,
+ context_mask=None,
+ extras=None
+ ):
+ if self.use_checkpoint:
+ return checkpoint(
+ self._forward,
+ x,
+ time_aligned_context,
+ time_token,
+ time_ada,
+ skip,
+ context,
+ x_mask,
+ context_mask,
+ extras,
+ use_reentrant=False
+ )
+ else:
+ return self._forward(
+ x,
+ time_aligned_context,
+ time_token,
+ time_ada,
+ skip,
+ context,
+ x_mask,
+ context_mask,
+ extras,
+ )
+
+ def _forward(
+ self,
+ x,
+ time_aligned_context,
+ time_token=None,
+ time_ada=None,
+ skip=None,
+ context=None,
+ x_mask=None,
+ context_mask=None,
+ extras=None
+ ):
+ B, T, C = x.shape
+
+ # skip connection
+ if self.skip_linear is not None:
+ assert skip is not None
+ cat = torch.cat([x, skip], dim=-1)
+ cat = self.skip_norm(cat)
+ x = self.skip_linear(cat)
+
+ if self.use_adanorm:
+ time_ada = self.adaln(time_token, time_ada)
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
+ gate_mlp) = time_ada.chunk(6, dim=1)
+
+ # self attention
+ if self.use_adanorm:
+ x_norm = film_modulate(
+ self.norm1(x), shift=shift_msa, scale=scale_msa
+ )
+ tanh_gate_msa = torch.tanh(1 - gate_msa)
+ x = x + tanh_gate_msa * self.attn(
+ x_norm, context=None, context_mask=x_mask, extras=extras
+ )
+ # x = x + (1 - gate_msa) * self.attn(
+ # x_norm, context=None, context_mask=x_mask, extras=extras
+ # )
+ else:
+ # TODO diffusion timestep input is not fused here
+ x = x + self.attn(
+ self.norm1(x),
+ context=None,
+ context_mask=x_mask,
+ extras=extras
+ )
+
+ # time aligned context fusion
+ if self.ta_context_fusion == "add":
+ time_aligned_context = self.ta_context_projection(
+ self.ta_context_norm(time_aligned_context)
+ )
+ if time_aligned_context.size(1) < x.size(1):
+ time_aligned_context = nn.functional.pad(
+ time_aligned_context, (0, 0, 1, 0)
+ )
+ x = x + time_aligned_context
+ elif self.ta_context_fusion == "concat":
+ if time_aligned_context.size(1) < x.size(1):
+ time_aligned_context = nn.functional.pad(
+ time_aligned_context, (0, 0, 1, 0)
+ )
+ cat = torch.cat([x, time_aligned_context], dim=-1)
+ cat = self.ta_context_norm(cat)
+ x = self.ta_context_projection(cat)
+
+ # cross attention
+ if self.use_context:
+ assert context is not None
+ x = x + self.cross_attn(
+ x=self.norm2(x),
+ context=self.norm_context(context),
+ context_mask=context_mask,
+ extras=extras
+ )
+
+ # mlp
+ if self.use_adanorm:
+ x_norm = film_modulate(
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
+ )
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
+ else:
+ x = x + self.mlp(self.norm3(x))
+
+ return x
+
+
+class LayerFusionAudioDiT(UDiT):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ input_type='2d',
+ out_chans=None,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer='layernorm',
+ context_norm=False,
+ use_checkpoint=False,
+ time_fusion='token',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ cls_dim=None,
+ ta_context_dim=768,
+ ta_context_fusion='concat',
+ ta_context_norm=True,
+ context_dim=768,
+ context_fusion='concat',
+ context_max_length=128,
+ context_pe_method='sinu',
+ pe_method='abs',
+ rope_mode='none',
+ use_conv=True,
+ skip=True,
+ skip_norm=True
+ ):
+ nn.Module.__init__(self)
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ # input
+ self.in_chans = in_chans
+ self.input_type = input_type
+ if self.input_type == '2d':
+ num_patches = (img_size[0] //
+ patch_size) * (img_size[1] // patch_size)
+ elif self.input_type == '1d':
+ num_patches = img_size // patch_size
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ input_type=input_type
+ )
+ out_chans = in_chans if out_chans is None else out_chans
+ self.out_chans = out_chans
+
+ # position embedding
+ self.rope = rope_mode
+ self.x_pe = PE_wrapper(
+ dim=embed_dim, method=pe_method, length=num_patches
+ )
+
+ # time embed
+ self.time_embed = TimestepEmbedder(embed_dim)
+ self.time_fusion = time_fusion
+ self.use_adanorm = False
+
+ # cls embed
+ if cls_dim is not None:
+ self.cls_embed = nn.Sequential(
+ nn.Linear(cls_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ else:
+ self.cls_embed = None
+
+ # time fusion
+ if time_fusion == 'token':
+ # put token at the beginning of sequence
+ self.extras = 2 if self.cls_embed else 1
+ self.time_pe = PE_wrapper(
+ dim=embed_dim, method='abs', length=self.extras
+ )
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
+ self.use_adanorm = True
+ # aviod repetitive silu for each adaln block
+ self.time_act = nn.SiLU()
+ self.extras = 0
+ self.time_ada_final = nn.Linear(
+ embed_dim, 2 * embed_dim, bias=True
+ )
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
+ # shared adaln
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+ else:
+ self.time_ada = None
+ else:
+ raise NotImplementedError
+
+ # context
+ # use a simple projection
+ self.use_context = False
+ self.context_cross = False
+ self.context_max_length = context_max_length
+ self.context_fusion = 'none'
+ if context_dim is not None:
+ self.use_context = True
+ self.context_embed = nn.Sequential(
+ nn.Linear(context_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ self.context_fusion = context_fusion
+ if context_fusion == 'concat' or context_fusion == 'joint':
+ self.extras += context_max_length
+ self.context_pe = PE_wrapper(
+ dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length
+ )
+ # no cross attention layers
+ context_dim = None
+ elif context_fusion == 'cross':
+ self.context_pe = PE_wrapper(
+ dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length
+ )
+ self.context_cross = True
+ context_dim = embed_dim
+ else:
+ raise NotImplementedError
+
+ self.use_skip = skip
+
+ # norm layers
+ if norm_layer == 'layernorm':
+ norm_layer = nn.LayerNorm
+ elif norm_layer == 'rmsnorm':
+ norm_layer = RMSNorm
+ else:
+ raise NotImplementedError
+
+ self.in_blocks = nn.ModuleList([
+ LayerFusionDiTBlock(
+ dim=embed_dim,
+ ta_context_dim=ta_context_dim,
+ ta_context_fusion=ta_context_fusion,
+ ta_context_norm=ta_context_norm,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=False,
+ skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ ) for i in range(depth // 2)
+ ])
+
+ self.mid_block = LayerFusionDiTBlock(
+ dim=embed_dim,
+ ta_context_dim=ta_context_dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ ta_context_fusion=ta_context_fusion,
+ ta_context_norm=ta_context_norm,
+ skip=False,
+ skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ )
+
+ self.out_blocks = nn.ModuleList([
+ LayerFusionDiTBlock(
+ dim=embed_dim,
+ ta_context_dim=ta_context_dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ ta_context_fusion=ta_context_fusion,
+ ta_context_norm=ta_context_norm,
+ skip=skip,
+ skip_norm=skip_norm,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ ) for i in range(depth // 2)
+ ])
+
+ # FinalLayer block
+ self.use_conv = use_conv
+ self.final_block = FinalBlock(
+ embed_dim=embed_dim,
+ patch_size=patch_size,
+ img_size=img_size,
+ in_chans=out_chans,
+ input_type=input_type,
+ norm_layer=norm_layer,
+ use_conv=use_conv,
+ use_adanorm=self.use_adanorm
+ )
+ self.initialize_weights()
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ time_aligned_context,
+ context,
+ x_mask=None,
+ context_mask=None,
+ cls_token=None,
+ controlnet_skips=None,
+ ):
+ # make it compatible with int time step during inference
+ if timesteps.dim() == 0:
+ timesteps = timesteps.expand(x.shape[0]
+ ).to(x.device, dtype=torch.long)
+
+ x = self.patch_embed(x)
+ x = self.x_pe(x)
+
+ B, L, D = x.shape
+
+ if self.use_context:
+ context_token = self.context_embed(context)
+ context_token = self.context_pe(context_token)
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+ x, x_mask = self._concat_x_context(
+ x=x,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask
+ )
+ context_token, context_mask = None, None
+ else:
+ context_token, context_mask = None, None
+
+ time_token = self.time_embed(timesteps)
+ if self.cls_embed:
+ cls_token = self.cls_embed(cls_token)
+ time_ada = None
+ time_ada_final = None
+ if self.use_adanorm:
+ if self.cls_embed:
+ time_token = time_token + cls_token
+ time_token = self.time_act(time_token)
+ time_ada_final = self.time_ada_final(time_token)
+ if self.time_ada is not None:
+ time_ada = self.time_ada(time_token)
+ else:
+ time_token = time_token.unsqueeze(dim=1)
+ if self.cls_embed:
+ cls_token = cls_token.unsqueeze(dim=1)
+ time_token = torch.cat([time_token, cls_token], dim=1)
+ time_token = self.time_pe(time_token)
+ x = torch.cat((time_token, x), dim=1)
+ if x_mask is not None:
+ x_mask = torch.cat([
+ torch.ones(B, time_token.shape[1],
+ device=x_mask.device).bool(), x_mask
+ ],
+ dim=1)
+ time_token = None
+
+ skips = []
+ for blk_idx, blk in enumerate(self.in_blocks):
+ x = blk(
+ x=x,
+ time_aligned_context=time_aligned_context,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=None,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ # if not self.training:
+ # print(
+ # f"in block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}"
+ # )
+ if self.use_skip:
+ skips.append(x)
+
+ x = self.mid_block(
+ x=x,
+ time_aligned_context=time_aligned_context,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=None,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ for blk_idx, blk in enumerate(self.out_blocks):
+ if self.use_skip:
+ skip = skips.pop()
+ if controlnet_skips:
+ # add to skip like u-net controlnet
+ skip = skip + controlnet_skips.pop()
+ else:
+ skip = None
+ if controlnet_skips:
+ # directly add to x
+ x = x + controlnet_skips.pop()
+
+ x = blk(
+ x=x,
+ time_aligned_context=time_aligned_context,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=skip,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ # if not self.training:
+ # print(
+ # f"out block {blk_idx}, min: {x.min().item()}, max: {x.max().item()}, std: {x.std().item()}"
+ # )
+
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
+
+ return x
+
+
+class InputFusionAudioDiT(UDiT):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ input_type='2d',
+ out_chans=None,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer='layernorm',
+ context_norm=False,
+ use_checkpoint=False,
+ time_fusion='token',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ cls_dim=None,
+ ta_context_dim=768,
+ context_dim=768,
+ context_fusion='concat',
+ context_max_length=128,
+ context_pe_method='sinu',
+ pe_method='abs',
+ rope_mode='none',
+ use_conv=True,
+ skip=True,
+ skip_norm=True
+ ):
+ super().__init__(
+ img_size,
+ patch_size,
+ in_chans,
+ input_type,
+ out_chans,
+ embed_dim,
+ depth,
+ num_heads,
+ mlp_ratio,
+ qkv_bias,
+ qk_scale,
+ qk_norm,
+ act_layer,
+ norm_layer,
+ context_norm,
+ use_checkpoint,
+ time_fusion,
+ ada_sola_rank,
+ ada_sola_alpha,
+ cls_dim,
+ context_dim,
+ context_fusion,
+ context_max_length,
+ context_pe_method,
+ pe_method,
+ rope_mode,
+ use_conv,
+ skip,
+ skip_norm,
+ )
+ self.input_proj = nn.Linear(in_chans + ta_context_dim, in_chans)
+ nn.init.xavier_uniform_(self.input_proj.weight)
+ nn.init.constant_(self.input_proj.bias, 0)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ time_aligned_context,
+ context,
+ x_mask=None,
+ context_mask=None,
+ cls_token=None,
+ controlnet_skips=None
+ ):
+ x = self.input_proj(
+ torch.cat([x.transpose(1, 2), time_aligned_context], dim=-1)
+ )
+ x = x.transpose(1, 2)
+ return super().forward(
+ x=x,
+ timesteps=timesteps,
+ context=context,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ cls_token=cls_token,
+ controlnet_skips=controlnet_skips
+ )
diff --git a/models/dit/mask_dit.py b/models/dit/mask_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c86260669d9ed0a9c4018069e55125d026894fb
--- /dev/null
+++ b/models/dit/mask_dit.py
@@ -0,0 +1,823 @@
+import logging
+import math
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from .modules import (
+ film_modulate,
+ unpatchify,
+ PatchEmbed,
+ PE_wrapper,
+ TimestepEmbedder,
+ FeedForward,
+ RMSNorm,
+)
+from .span_mask import compute_mask_indices
+from .attention import Attention
+
+logger = logging.Logger(__file__)
+
+
+class AdaLN(nn.Module):
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
+ super().__init__()
+ self.ada_mode = ada_mode
+ self.scale_shift_table = None
+ if ada_mode == 'ada':
+ # move nn.silu outside
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
+ elif ada_mode == 'ada_single':
+ # adaln used in pixel-art alpha
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+ elif ada_mode in ['ada_sola', 'ada_sola_bias']:
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
+ self.scaling = alpha / r
+ if ada_mode == 'ada_sola_bias':
+ # take bias out for consistency
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+ else:
+ raise NotImplementedError
+
+ def forward(self, time_token=None, time_ada=None):
+ if self.ada_mode == 'ada':
+ assert time_ada is None
+ B = time_token.shape[0]
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
+ elif self.ada_mode == 'ada_single':
+ B = time_ada.shape[0]
+ time_ada = time_ada.reshape(B, 6, -1)
+ time_ada = self.scale_shift_table[None] + time_ada
+ elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
+ B = time_ada.shape[0]
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
+ time_ada = time_ada + time_ada_lora
+ time_ada = time_ada.reshape(B, 6, -1)
+ if self.scale_shift_table is not None:
+ time_ada = self.scale_shift_table[None] + time_ada
+ else:
+ raise NotImplementedError
+ return time_ada
+
+
+class DiTBlock(nn.Module):
+ """
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+ """
+ def __init__(
+ self,
+ dim,
+ context_dim=None,
+ num_heads=8,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer=nn.LayerNorm,
+ time_fusion='none',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ skip=False,
+ skip_norm=False,
+ rope_mode='none',
+ context_norm=False,
+ use_checkpoint=False
+ ):
+
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim=dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ rope_mode=rope_mode
+ )
+
+ if context_dim is not None:
+ self.use_context = True
+ self.cross_attn = Attention(
+ dim=dim,
+ num_heads=num_heads,
+ context_dim=context_dim,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ rope_mode='none'
+ )
+ self.norm2 = norm_layer(dim)
+ if context_norm:
+ self.norm_context = norm_layer(context_dim)
+ else:
+ self.norm_context = nn.Identity()
+ else:
+ self.use_context = False
+
+ self.norm3 = norm_layer(dim)
+ self.mlp = FeedForward(
+ dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
+ )
+
+ self.use_adanorm = True if time_fusion != 'token' else False
+ if self.use_adanorm:
+ self.adaln = AdaLN(
+ dim,
+ ada_mode=time_fusion,
+ r=ada_sola_rank,
+ alpha=ada_sola_alpha
+ )
+ if skip:
+ self.skip_norm = norm_layer(2 *
+ dim) if skip_norm else nn.Identity()
+ self.skip_linear = nn.Linear(2 * dim, dim)
+ else:
+ self.skip_linear = None
+
+ self.use_checkpoint = use_checkpoint
+
+ def forward(
+ self,
+ x,
+ time_token=None,
+ time_ada=None,
+ skip=None,
+ context=None,
+ x_mask=None,
+ context_mask=None,
+ extras=None
+ ):
+ if self.use_checkpoint:
+ return checkpoint(
+ self._forward,
+ x,
+ time_token,
+ time_ada,
+ skip,
+ context,
+ x_mask,
+ context_mask,
+ extras,
+ use_reentrant=False
+ )
+ else:
+ return self._forward(
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
+ extras
+ )
+
+ def _forward(
+ self,
+ x,
+ time_token=None,
+ time_ada=None,
+ skip=None,
+ context=None,
+ x_mask=None,
+ context_mask=None,
+ extras=None
+ ):
+ B, T, C = x.shape
+ if self.skip_linear is not None:
+ assert skip is not None
+ cat = torch.cat([x, skip], dim=-1)
+ cat = self.skip_norm(cat)
+ x = self.skip_linear(cat)
+
+ if self.use_adanorm:
+ time_ada = self.adaln(time_token, time_ada)
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
+ gate_mlp) = time_ada.chunk(6, dim=1)
+
+ # self attention
+ if self.use_adanorm:
+ x_norm = film_modulate(
+ self.norm1(x), shift=shift_msa, scale=scale_msa
+ )
+ x = x + (1 - gate_msa) * self.attn(
+ x_norm, context=None, context_mask=x_mask, extras=extras
+ )
+ else:
+ x = x + self.attn(
+ self.norm1(x),
+ context=None,
+ context_mask=x_mask,
+ extras=extras
+ )
+
+ # cross attention
+ if self.use_context:
+ assert context is not None
+ x = x + self.cross_attn(
+ x=self.norm2(x),
+ context=self.norm_context(context),
+ context_mask=context_mask,
+ extras=extras
+ )
+
+ # mlp
+ if self.use_adanorm:
+ x_norm = film_modulate(
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
+ )
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
+ else:
+ x = x + self.mlp(self.norm3(x))
+
+ return x
+
+
+class FinalBlock(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ patch_size,
+ in_chans,
+ img_size,
+ input_type='2d',
+ norm_layer=nn.LayerNorm,
+ use_conv=True,
+ use_adanorm=True
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.img_size = img_size
+ self.input_type = input_type
+
+ self.norm = norm_layer(embed_dim)
+ if use_adanorm:
+ self.use_adanorm = True
+ else:
+ self.use_adanorm = False
+
+ if input_type == '2d':
+ self.patch_dim = patch_size**2 * in_chans
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+ if use_conv:
+ self.final_layer = nn.Conv2d(
+ self.in_chans, self.in_chans, 3, padding=1
+ )
+ else:
+ self.final_layer = nn.Identity()
+
+ elif input_type == '1d':
+ self.patch_dim = patch_size * in_chans
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+ if use_conv:
+ self.final_layer = nn.Conv1d(
+ self.in_chans, self.in_chans, 3, padding=1
+ )
+ else:
+ self.final_layer = nn.Identity()
+
+ def forward(self, x, time_ada=None, extras=0):
+ B, T, C = x.shape
+ x = x[:, extras:, :]
+ # only handle generation target
+ if self.use_adanorm:
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
+ x = film_modulate(self.norm(x), shift, scale)
+ else:
+ x = self.norm(x)
+ x = self.linear(x)
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
+ x = self.final_layer(x)
+ return x
+
+
+class UDiT(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ input_type='2d',
+ out_chans=None,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ qk_norm=None,
+ act_layer='gelu',
+ norm_layer='layernorm',
+ context_norm=False,
+ use_checkpoint=False,
+ # time fusion ada or token
+ time_fusion='token',
+ ada_sola_rank=None,
+ ada_sola_alpha=None,
+ cls_dim=None,
+ # max length is only used for concat
+ context_dim=768,
+ context_fusion='concat',
+ context_max_length=128,
+ context_pe_method='sinu',
+ pe_method='abs',
+ rope_mode='none',
+ use_conv=True,
+ skip=True,
+ skip_norm=True
+ ):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ # input
+ self.in_chans = in_chans
+ self.input_type = input_type
+ if self.input_type == '2d':
+ num_patches = (img_size[0] //
+ patch_size) * (img_size[1] // patch_size)
+ elif self.input_type == '1d':
+ num_patches = img_size // patch_size
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ input_type=input_type
+ )
+ out_chans = in_chans if out_chans is None else out_chans
+ self.out_chans = out_chans
+
+ # position embedding
+ self.rope = rope_mode
+ self.x_pe = PE_wrapper(
+ dim=embed_dim, method=pe_method, length=num_patches
+ )
+
+ logger.info(f'x position embedding: {pe_method}')
+ logger.info(f'rope mode: {self.rope}')
+
+ # time embed
+ self.time_embed = TimestepEmbedder(embed_dim)
+ self.time_fusion = time_fusion
+ self.use_adanorm = False
+
+ # cls embed
+ if cls_dim is not None:
+ self.cls_embed = nn.Sequential(
+ nn.Linear(cls_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ else:
+ self.cls_embed = None
+
+ # time fusion
+ if time_fusion == 'token':
+ # put token at the beginning of sequence
+ self.extras = 2 if self.cls_embed else 1
+ self.time_pe = PE_wrapper(
+ dim=embed_dim, method='abs', length=self.extras
+ )
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
+ self.use_adanorm = True
+ # aviod repetitive silu for each adaln block
+ self.time_act = nn.SiLU()
+ self.extras = 0
+ self.time_ada_final = nn.Linear(
+ embed_dim, 2 * embed_dim, bias=True
+ )
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
+ # shared adaln
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+ else:
+ self.time_ada = None
+ else:
+ raise NotImplementedError
+ logger.info(f'time fusion mode: {self.time_fusion}')
+
+ # context
+ # use a simple projection
+ self.use_context = False
+ self.context_cross = False
+ self.context_max_length = context_max_length
+ self.context_fusion = 'none'
+ if context_dim is not None:
+ self.use_context = True
+ self.context_embed = nn.Sequential(
+ nn.Linear(context_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ self.context_fusion = context_fusion
+ if context_fusion == 'concat' or context_fusion == 'joint':
+ self.extras += context_max_length
+ self.context_pe = PE_wrapper(
+ dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length
+ )
+ # no cross attention layers
+ context_dim = None
+ elif context_fusion == 'cross':
+ self.context_pe = PE_wrapper(
+ dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length
+ )
+ self.context_cross = True
+ context_dim = embed_dim
+ else:
+ raise NotImplementedError
+ logger.info(f'context fusion mode: {context_fusion}')
+ logger.info(f'context position embedding: {context_pe_method}')
+
+ self.use_skip = skip
+
+ # norm layers
+ if norm_layer == 'layernorm':
+ norm_layer = nn.LayerNorm
+ elif norm_layer == 'rmsnorm':
+ norm_layer = RMSNorm
+ else:
+ raise NotImplementedError
+
+ logger.info(f'use long skip connection: {skip}')
+ self.in_blocks = nn.ModuleList([
+ DiTBlock(
+ dim=embed_dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=False,
+ skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ ) for _ in range(depth // 2)
+ ])
+
+ self.mid_block = DiTBlock(
+ dim=embed_dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=False,
+ skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ )
+
+ self.out_blocks = nn.ModuleList([
+ DiTBlock(
+ dim=embed_dim,
+ context_dim=context_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_sola_rank=ada_sola_rank,
+ ada_sola_alpha=ada_sola_alpha,
+ skip=skip,
+ skip_norm=skip_norm,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint
+ ) for _ in range(depth // 2)
+ ])
+
+ # FinalLayer block
+ self.use_conv = use_conv
+ self.final_block = FinalBlock(
+ embed_dim=embed_dim,
+ patch_size=patch_size,
+ img_size=img_size,
+ in_chans=out_chans,
+ input_type=input_type,
+ norm_layer=norm_layer,
+ use_conv=use_conv,
+ use_adanorm=self.use_adanorm
+ )
+ self.initialize_weights()
+
+ def _init_ada(self):
+ if self.time_fusion == 'ada':
+ nn.init.constant_(self.time_ada_final.weight, 0)
+ nn.init.constant_(self.time_ada_final.bias, 0)
+ for block in self.in_blocks:
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
+ for block in self.out_blocks:
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
+ elif self.time_fusion == 'ada_single':
+ nn.init.constant_(self.time_ada.weight, 0)
+ nn.init.constant_(self.time_ada.bias, 0)
+ nn.init.constant_(self.time_ada_final.weight, 0)
+ nn.init.constant_(self.time_ada_final.bias, 0)
+ elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
+ nn.init.constant_(self.time_ada.weight, 0)
+ nn.init.constant_(self.time_ada.bias, 0)
+ nn.init.constant_(self.time_ada_final.weight, 0)
+ nn.init.constant_(self.time_ada_final.bias, 0)
+ for block in self.in_blocks:
+ nn.init.kaiming_uniform_(
+ block.adaln.lora_a.weight, a=math.sqrt(5)
+ )
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
+ nn.init.kaiming_uniform_(
+ self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
+ )
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
+ for block in self.out_blocks:
+ nn.init.kaiming_uniform_(
+ block.adaln.lora_a.weight, a=math.sqrt(5)
+ )
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
+
+ def initialize_weights(self):
+ # Basic init for all layers
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # init patch Conv like Linear
+ w = self.patch_embed.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
+
+ # Zero-out AdaLN
+ if self.use_adanorm:
+ self._init_ada()
+
+ # Zero-out Cross Attention
+ if self.context_cross:
+ for block in self.in_blocks:
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
+ for block in self.out_blocks:
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
+
+ # Zero-out cls embedding
+ if self.cls_embed:
+ if self.use_adanorm:
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
+
+ # Zero-out Output
+ # might not zero-out this when using v-prediction
+ # it could be good when using noise-prediction
+ # nn.init.constant_(self.final_block.linear.weight, 0)
+ # nn.init.constant_(self.final_block.linear.bias, 0)
+ # if self.use_conv:
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+ # init out Conv
+ if self.use_conv:
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
+ assert context.shape[-2] == self.context_max_length
+ # Check if either x_mask or context_mask is provided
+ B = x.shape[0]
+ # Create default masks if they are not provided
+ if x_mask is None:
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+ if context_mask is None:
+ context_mask = torch.ones(
+ B, context.shape[-2], device=context.device
+ ).bool()
+ # Concatenate the masks along the second dimension (dim=1)
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
+ # Concatenate context and x along the second dimension (dim=1)
+ x = torch.cat((context, x), dim=1)
+ return x, x_mask
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ x_mask=None,
+ context_mask=None,
+ cls_token=None,
+ controlnet_skips=None,
+ ):
+ # make it compatible with int time step during inference
+ if timesteps.dim() == 0:
+ timesteps = timesteps.expand(x.shape[0]
+ ).to(x.device, dtype=torch.long)
+
+ x = self.patch_embed(x)
+ x = self.x_pe(x)
+
+ B, L, D = x.shape
+
+ if self.use_context:
+ context_token = self.context_embed(context)
+ context_token = self.context_pe(context_token)
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+ x, x_mask = self._concat_x_context(
+ x=x,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask
+ )
+ context_token, context_mask = None, None
+ else:
+ context_token, context_mask = None, None
+
+ time_token = self.time_embed(timesteps)
+ if self.cls_embed:
+ cls_token = self.cls_embed(cls_token)
+ time_ada = None
+ time_ada_final = None
+ if self.use_adanorm:
+ if self.cls_embed:
+ time_token = time_token + cls_token
+ time_token = self.time_act(time_token)
+ time_ada_final = self.time_ada_final(time_token)
+ if self.time_ada is not None:
+ time_ada = self.time_ada(time_token)
+ else:
+ time_token = time_token.unsqueeze(dim=1)
+ if self.cls_embed:
+ cls_token = cls_token.unsqueeze(dim=1)
+ time_token = torch.cat([time_token, cls_token], dim=1)
+ time_token = self.time_pe(time_token)
+ x = torch.cat((time_token, x), dim=1)
+ if x_mask is not None:
+ x_mask = torch.cat([
+ torch.ones(B, time_token.shape[1],
+ device=x_mask.device).bool(), x_mask
+ ],
+ dim=1)
+ time_token = None
+
+ skips = []
+ for blk in self.in_blocks:
+ x = blk(
+ x=x,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=None,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ if self.use_skip:
+ skips.append(x)
+
+ x = self.mid_block(
+ x=x,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=None,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+ for blk in self.out_blocks:
+ if self.use_skip:
+ skip = skips.pop()
+ if controlnet_skips:
+ # add to skip like u-net controlnet
+ skip = skip + controlnet_skips.pop()
+ else:
+ skip = None
+ if controlnet_skips:
+ # directly add to x
+ x = x + controlnet_skips.pop()
+
+ x = blk(
+ x=x,
+ time_token=time_token,
+ time_ada=time_ada,
+ skip=skip,
+ context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ extras=self.extras
+ )
+
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
+
+ return x
+
+
+class MaskDiT(nn.Module):
+ def __init__(
+ self,
+ model: UDiT,
+ mae=False,
+ mae_prob=0.5,
+ mask_ratio=[0.25, 1.0],
+ mask_span=10,
+ ):
+ super().__init__()
+ self.model = model
+ self.mae = mae
+ if self.mae:
+ out_channel = model.out_chans
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
+ self.mae_prob = mae_prob
+ self.mask_ratio = mask_ratio
+ self.mask_span = mask_span
+
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
+ B, D, L = gt.shape
+ if mae_mask_infer is None:
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
+ mask_ratios = mask_ratios.cpu().numpy()
+ mask = compute_mask_indices(
+ shape=[B, L],
+ padding_mask=None,
+ mask_prob=mask_ratios,
+ mask_length=self.mask_span,
+ mask_type="static",
+ mask_other=0.0,
+ min_masks=1,
+ no_overlap=False,
+ min_space=0,
+ )
+ mask = mask.unsqueeze(1).expand_as(gt)
+ else:
+ mask = mae_mask_infer
+ mask = mask.expand_as(gt)
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
+ return gt, mask.type_as(gt)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ x_mask=None,
+ context_mask=None,
+ cls_token=None,
+ gt=None,
+ mae_mask_infer=None,
+ forward_model=True
+ ):
+ # todo: handle controlnet inside
+ mae_mask = torch.ones_like(x)
+ if self.mae:
+ if gt is not None:
+ B, D, L = gt.shape
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio
+ ).to(gt.device)
+ gt, mae_mask = self.random_masking(
+ gt, mask_ratios, mae_mask_infer
+ )
+ # apply mae only to the selected batches
+ if mae_mask_infer is None:
+ # determine mae batch
+ mae_batch = torch.rand(B) < self.mae_prob
+ gt[~mae_batch] = self.mask_embed.view(
+ 1, D, 1
+ ).expand_as(gt)[~mae_batch]
+ mae_mask[~mae_batch] = 1.0
+ else:
+ B, D, L = x.shape
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
+
+ if forward_model:
+ x = self.model(
+ x=x,
+ timesteps=timesteps,
+ context=context,
+ x_mask=x_mask,
+ context_mask=context_mask,
+ cls_token=cls_token
+ )
+ # logger.info(mae_mask[:, 0, :].sum(dim=-1))
+ return x, mae_mask
diff --git a/models/dit/modules.py b/models/dit/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..61ee20880e40a491815fea142a7bf6e1b800467f
--- /dev/null
+++ b/models/dit/modules.py
@@ -0,0 +1,445 @@
+import warnings
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.cuda.amp import autocast
+import math
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+
+
+def trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def film_modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding,
+ torch.zeros_like(embedding[:, :1])],
+ dim=-1)
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(
+ self, hidden_size, frequency_embedding_size=256, out_size=None
+ ):
+ super().__init__()
+ if out_size is None:
+ out_size = hidden_size
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, out_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
+ self.mlp[0].weight.dtype
+ )
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+def patchify(imgs, patch_size, input_type='2d'):
+ if input_type == '2d':
+ x = einops.rearrange(
+ imgs,
+ 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)',
+ p1=patch_size,
+ p2=patch_size
+ )
+ elif input_type == '1d':
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
+ return x
+
+
+def unpatchify(x, channels=3, input_type='2d', img_size=None):
+ if input_type == '2d':
+ patch_size = int((x.shape[2] // channels)**0.5)
+ # h = w = int(x.shape[1] ** .5)
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
+ assert h * w == x.shape[1] and patch_size**2 * channels == x.shape[2]
+ x = einops.rearrange(
+ x,
+ 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)',
+ h=h,
+ p1=patch_size,
+ p2=patch_size
+ )
+ elif input_type == '1d':
+ patch_size = int((x.shape[2] // channels))
+ h = x.shape[1]
+ assert patch_size * channels == x.shape[2]
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding
+ """
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
+ super().__init__()
+ self.patch_size = patch_size
+ self.input_type = input_type
+ if input_type == '2d':
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=True
+ )
+ elif input_type == '1d':
+ self.proj = nn.Conv1d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=True
+ )
+
+ def forward(self, x):
+ if self.input_type == '2d':
+ B, C, H, W = x.shape
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
+ elif self.input_type == '1d':
+ B, C, H = x.shape
+ assert H % self.patch_size == 0
+
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ """
+ Convolutional positional embedding used in F5-TTS.
+ """
+ def __init__(self, dim=768, kernel_size=31, groups=16):
+ super().__init__()
+ assert kernel_size % 2 != 0
+ self.conv1d = nn.Sequential(
+ nn.Conv1d(
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
+ ),
+ nn.Mish(),
+ nn.Conv1d(
+ dim, dim, kernel_size, groups=groups, padding=kernel_size // 2
+ ),
+ nn.Mish(),
+ )
+
+ def forward(self, x):
+ # B T C
+ x = self.conv1d(x.transpose(1, 2))
+ x = x.transpose(1, 2)
+ return x
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+ def __init__(self, dim, length):
+ super(SinusoidalPositionalEncoding, self).__init__()
+ self.length = length
+ self.dim = dim
+ self.register_buffer(
+ 'pe', self._generate_positional_encoding(length, dim)
+ )
+
+ def _generate_positional_encoding(self, length, dim):
+ pe = torch.zeros(length, dim)
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ pe = pe.unsqueeze(0)
+ return pe
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return x
+
+
+class PE_wrapper(nn.Module):
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
+ super().__init__()
+ self.method = method
+ if method == 'abs':
+ # init absolute pe like UViT
+ self.length = length
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
+ trunc_normal_(self.abs_pe, mean=0.0, std=.02, a=-.04, b=.04)
+ elif method == 'conv':
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
+ elif method == 'sinu':
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
+ elif method == 'none':
+ # skip pe
+ self.id = nn.Identity()
+ else:
+ raise NotImplementedError
+
+ def forward(self, x):
+ if self.method == 'abs':
+ _, L, _ = x.shape
+ assert L <= self.length
+ x = x + self.abs_pe[:, :L, :]
+ elif self.method == 'conv':
+ x = x + self.conv_pe(x)
+ elif self.method == 'sinu':
+ x = self.sinu_pe(x)
+ elif self.method == 'none':
+ x = self.id(x)
+ else:
+ raise NotImplementedError
+ return x
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+class GELU(nn.Module):
+ def __init__(
+ self,
+ dim_in: int,
+ dim_out: int,
+ approximate: str = "none",
+ bias: bool = True
+ ):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(
+ gate.to(dtype=torch.float32), approximate=self.approximate
+ ).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def snake_beta(x, alpha, beta):
+ return x + beta * torch.sin(x * alpha).pow(2)
+
+
+class Snake(nn.Module):
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = snake_beta(x, self.alpha, self.beta)
+ return x
+
+
+class GESnake(nn.Module):
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ def forward(self, x):
+ x = self.proj(x)
+ x, gate = x.chunk(2, dim=-1)
+ return x * snake_beta(gate, self.alpha, self.beta)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out=None,
+ mult=4,
+ dropout=0.0,
+ activation_fn="geglu",
+ final_dropout=False,
+ inner_dim=None,
+ bias=True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "snake":
+ act_fn = Snake(dim, inner_dim, bias=bias)
+ elif activation_fn == "gesnake":
+ act_fn = GESnake(dim, inner_dim, bias=bias)
+ else:
+ raise NotImplementedError
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
diff --git a/models/dit/rotary.py b/models/dit/rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..f539185c22e715d5a7ac66772ffa9f15a1e5df35
--- /dev/null
+++ b/models/dit/rotary.py
@@ -0,0 +1,88 @@
+import torch
+"this rope is faster than llama rope with jit script"
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def apply_rotary_pos_emb(x, cos, sin):
+ # NOTE: This could probably be moved to Triton
+ # Handle a possible sequence length mismatch in between q and k
+ cos = cos[:, :, :x.shape[-2], :]
+ sin = sin[:, :, :x.shape[-2], :]
+ return (x*cos) + (rotate_half(x) * sin)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """
+ The rotary position embeddings from RoFormer_ (Su et. al).
+ A crucial insight from the method is that the query and keys are
+ transformed by rotation matrices which depend on the relative positions.
+
+ Other implementations are available in the Rotary Transformer repo_ and in
+ GPT-NeoX_, GPT-NeoX was an inspiration
+
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
+ """
+ def __init__(self, dim: int):
+ super().__init__()
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self._seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
+ # expect input: B, H, L, D
+ seq_len = x.shape[seq_dimension]
+
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ # also make sure dtype wont change
+ if (
+ seq_len != self._seq_len_cached or
+ self._cos_cached.device != x.device or
+ self._cos_cached.dtype != x.dtype
+ ):
+ self._seq_len_cached = seq_len
+ t = torch.arange(
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
+
+ return self._cos_cached, self._sin_cached
+
+ def forward(self, q, k):
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
+ q.float(), seq_dimension=-2
+ )
+ if k is not None:
+ return (
+ apply_rotary_pos_emb(
+ q.float(), self._cos_cached, self._sin_cached
+ ).type_as(q),
+ apply_rotary_pos_emb(
+ k.float(), self._cos_cached, self._sin_cached
+ ).type_as(k),
+ )
+ else:
+ return (
+ apply_rotary_pos_emb(
+ q.float(), self._cos_cached, self._sin_cached
+ ).type_as(q), None
+ )
diff --git a/models/dit/span_mask.py b/models/dit/span_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0832567c3e4dcc0c49fdd88dadff11c80d8e2a0
--- /dev/null
+++ b/models/dit/span_mask.py
@@ -0,0 +1,149 @@
+import numpy as np
+import torch
+from typing import Optional, Tuple
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ # Convert mask_prob to a NumPy array
+ mask_prob = np.array(mask_prob)
+
+ # Calculate all_num_mask for each element in the batch
+ all_num_mask = np.floor(
+ mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)
+ ).astype(int)
+
+ # Apply the max operation with min_masks for each element
+ all_num_mask = np.maximum(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length) + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask[i]
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = np.random.randint(
+ mask_other, mask_length*2 + 1, size=num_mask
+ )
+ elif mask_type == "normal":
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = np.random.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = np.random.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - keep_length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (
+ e - s if e - s >= length + min_space else 0
+ for s, e in parts
+ ),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = np.random.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+ mask_idc = np.asarray([
+ mask_idc[j] + offset for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ])
+
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+ # min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ # if len(mask_idc) > min_len:
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+
+ return torch.tensor(mask)
+
+
+if __name__ == '__main__':
+ mask = compute_mask_indices(
+ shape=[4, 500],
+ padding_mask=None,
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
+ mask_length=10,
+ mask_type="static",
+ mask_other=0.0,
+ min_masks=1,
+ no_overlap=False,
+ min_space=0,
+ )
+ print(mask)
+ print(mask.sum(dim=1))
diff --git a/models/flow_matching.py b/models/flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a07ee2c126618fe105ff640825908ddd9f562e1
--- /dev/null
+++ b/models/flow_matching.py
@@ -0,0 +1,1267 @@
+from typing import Any, Optional, Union, List, Sequence
+
+import inspect
+import random
+
+from tqdm import tqdm
+import numpy as np
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers import FlowMatchEulerDiscreteScheduler
+from diffusers.training_utils import compute_density_for_timestep_sampling
+
+from models.autoencoder.autoencoder_base import AutoEncoderBase
+from models.content_encoder.content_encoder import ContentEncoder
+from models.content_adapter import ContentAdapterBase
+from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
+from utils.torch_utilities import (
+ create_alignment_path, create_mask_from_length, loss_with_mask,
+ trim_or_pad_length
+)
+from safetensors.torch import load_file
+
+class FlowMatchingMixin:
+ def __init__(
+ self,
+ cfg_drop_ratio: float = 0.2,
+ sample_strategy: str = 'normal',
+ num_train_steps: int = 1000
+ ) -> None:
+ r"""
+ Args:
+ cfg_drop_ratio (float): Dropout ratio for the autoencoder.
+ sample_strategy (str): Sampling strategy for timesteps during training.
+ num_train_steps (int): Number of training steps for the noise scheduler.
+ """
+ self.sample_strategy = sample_strategy
+ self.infer_noise_scheduler = FlowMatchEulerDiscreteScheduler(
+ num_train_timesteps=num_train_steps
+ )
+ self.train_noise_scheduler = copy.deepcopy(self.infer_noise_scheduler)
+
+ self.classifier_free_guidance = cfg_drop_ratio > 0.0
+ self.cfg_drop_ratio = cfg_drop_ratio
+
+ def get_input_target_and_timesteps(
+ self,
+ latent: torch.Tensor,
+ training: bool = True
+ ):
+ bsz = latent.shape[0]
+ noise = torch.randn_like(latent)
+
+ if training:
+ if self.sample_strategy == 'normal':
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme="logit_normal",
+ batch_size=bsz,
+ logit_mean=0,
+ logit_std=1,
+ mode_scale=None,
+ )
+ elif self.sample_strategy == 'uniform':
+ u = torch.randn(bsz, )
+ else:
+ raise NotImplementedError(
+ f"{self.sample_strategy} samlping for timesteps is not supported now"
+ )
+ else:
+ u = torch.ones(bsz, ) / 2
+
+ indices = (u * self.train_noise_scheduler.config.num_train_timesteps
+ ).long()
+
+ # train_noise_scheduler.timesteps: a list from 1 ~ num_trainsteps with 1 as interval
+ timesteps = self.train_noise_scheduler.timesteps[indices].to(
+ device=latent.device
+ )
+ sigmas = self.get_sigmas(
+ timesteps, n_dim=latent.ndim, dtype=latent.dtype
+ )
+
+ noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
+
+ target = noise - latent
+
+ return noisy_latent, target, timesteps
+
+ def get_sigmas(self, timesteps, n_dim=3, dtype=torch.float32):
+ device = timesteps.device
+
+ # a list from 1 declining to 1/num_train_steps
+ sigmas = self.train_noise_scheduler.sigmas.to(
+ device=device, dtype=dtype
+ )
+
+ schedule_timesteps = self.train_noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+ step_indices = [(schedule_timesteps == t).nonzero().item()
+ for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ def retrieve_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+ ):
+ # used in inference, retrieve new timesteps on given inference timesteps
+ scheduler = self.infer_noise_scheduler
+
+ if timesteps is not None and sigmas is not None:
+ raise ValueError(
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
+ )
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(
+ timesteps=timesteps, device=device, **kwargs
+ )
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
+ )
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(
+ num_inference_steps, device=device, **kwargs
+ )
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ContentEncoderAdapterMixin:
+ def __init__(
+ self,
+ content_encoder: ContentEncoder,
+ content_adapter: ContentAdapterBase | None = None
+ ):
+ self.content_encoder = content_encoder
+ self.content_adapter = content_adapter
+
+ def encode_content(
+ self,
+ content: list[Any],
+ task: list[str],
+ device: str | torch.device,
+ instruction: torch.Tensor | None = None,
+ instruction_lengths: torch.Tensor | None = None
+ ):
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task, device=device
+ )
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+
+ if instruction is not None:
+ instruction_mask = create_mask_from_length(instruction_lengths)
+ (
+ content,
+ content_mask,
+ global_duration_pred,
+ local_duration_pred,
+ ) = self.content_adapter(
+ content, content_mask, instruction, instruction_mask
+ )
+
+ return_dict = {
+ "content": content,
+ "content_mask": content_mask,
+ "length_aligned_content": content_output["length_aligned_content"],
+ }
+ if instruction is not None:
+ return_dict["global_duration_pred"] = global_duration_pred
+ return_dict["local_duration_pred"] = local_duration_pred
+
+ return return_dict
+
+
+class SingleTaskCrossAttentionAudioFlowMatching(
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
+ FlowMatchingMixin, ContentEncoderAdapterMixin
+):
+ def __init__(
+ self,
+ autoencoder: nn.Module,
+ content_encoder: ContentEncoder,
+ backbone: nn.Module,
+ cfg_drop_ratio: float = 0.2,
+ sample_strategy: str = 'normal',
+ num_train_steps: int = 1000,
+ pretrained_ckpt: str | None = None,
+ ):
+ nn.Module.__init__(self)
+ FlowMatchingMixin.__init__(
+ self, cfg_drop_ratio, sample_strategy, num_train_steps
+ )
+ ContentEncoderAdapterMixin.__init__(
+ self, content_encoder=content_encoder
+ )
+
+ self.autoencoder = autoencoder
+ for param in self.autoencoder.parameters():
+ param.requires_grad = False
+
+ if hasattr(self.content_encoder, "audio_encoder"):
+ if self.content_encoder.audio_encoder is not None:
+ self.content_encoder.audio_encoder.model = self.autoencoder
+
+ self.backbone = backbone
+ self.dummy_param = nn.Parameter(torch.empty(0))
+
+ if pretrained_ckpt is not None:
+ print(f"Load pretrain FlowMatching model from {pretrained_ckpt}")
+ pretrained_state_dict = load_file(pretrained_ckpt)
+ self.load_pretrained(pretrained_state_dict)
+ # missing, unexpected = self.load_state_dict(pretrained_state_dict, strict=False)
+ # print("Missing keys:", missing)
+ # print("Unexpected keys:", unexpected)
+
+ # if content_encoder.embed_dim != 1024:
+ # self.context_proj = nn.Sequential(
+ # nn.Linear(content_encoder.embed_dim, 1024),
+ # nn.SiLU(),
+ # nn.Linear(1024, 1024),
+ # )
+ # else:
+ # self.context_proj = nn.Identity()
+
+ def forward(
+ self, content: list[Any], condition: list[Any], task: list[str],
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs
+
+ ):
+ loss_reduce = self.training or (loss_reduce and not self.training)
+ device = self.dummy_param.device
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ content_dict = self.encode_content(content, task, device)
+ content, content_mask = content_dict["content"], content_dict[
+ "content_mask"]
+
+ # content = self.context_proj(content)
+
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ content[mask_indices] = 0
+
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
+ latent,
+ training = self.training
+ )
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ timesteps=timesteps,
+ context=content,
+ x_mask=latent_mask,
+ context_mask=content_mask
+ )
+
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
+ diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1), reduce=loss_reduce)
+ #diff_loss = loss_with_mask(diff_loss, latent_mask.unsqueeze(1))
+ output = {"diff_loss": diff_loss}
+ return output
+
+ def iterative_denoise(
+ self, latent: torch.Tensor, timesteps: list[int], num_steps: int,
+ verbose: bool, cfg: bool, cfg_scale: float, backbone_input: dict
+ ):
+ progress_bar = tqdm(range(num_steps), disable=not verbose)
+
+ for i, timestep in enumerate(timesteps):
+ # expand the latent if we are doing classifier free guidance
+ if cfg:
+ latent_input = torch.cat([latent, latent])
+ else:
+ latent_input = latent
+
+ noise_pred: torch.Tensor = self.backbone(
+ x=latent_input, timesteps=timestep, **backbone_input
+ )
+
+ # perform guidance
+ if cfg:
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + cfg_scale * (
+ noise_pred_content - noise_pred_uncond
+ )
+
+ latent = self.infer_noise_scheduler.step(
+ noise_pred, timestep, latent
+ ).prev_sample
+
+ progress_bar.update(1)
+
+ progress_bar.close()
+
+ return latent
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content: list[Any],
+ condition: list[Any],
+ task: list[str],
+ latent_shape: Sequence[int],
+ num_steps: int = 50,
+ sway_sampling_coef: float | None = -1.0,
+ guidance_scale: float = 3.0,
+ num_samples_per_content: int = 1,
+ disable_progress: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(content) * num_samples_per_content
+
+ if classifier_free_guidance:
+ content, content_mask = self.encode_content_classifier_free(
+ content, task, device, num_samples_per_content
+ )
+ else:
+ content_output: dict[
+ str, torch.Tensor] = self.content_encoder.encode_content(
+ content, task
+ )
+ content, content_mask = content_output["content"], content_output[
+ "content_mask"]
+ content = content.repeat_interleave(num_samples_per_content, 0)
+ content_mask = content_mask.repeat_interleave(
+ num_samples_per_content, 0
+ )
+
+ latent = self.prepare_latent(
+ batch_size, latent_shape, content.dtype, device
+ )
+
+ if not sway_sampling_coef:
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
+ else:
+ t = torch.linspace(0, 1, num_steps + 1)
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
+ sigmas = 1 - t
+ timesteps, num_steps = self.retrieve_timesteps(
+ num_steps, device, timesteps=None, sigmas=sigmas
+ )
+
+ latent = self.iterative_denoise(
+ latent=latent,
+ timesteps=timesteps,
+ num_steps=num_steps,
+ verbose=not disable_progress,
+ cfg=classifier_free_guidance,
+ cfg_scale=guidance_scale,
+ backbone_input={
+ "context": content,
+ "context_mask": content_mask,
+ },
+ )
+
+ waveform = self.autoencoder.decode(latent)
+
+ return waveform
+
+ def prepare_latent(
+ self, batch_size: int, latent_shape: Sequence[int], dtype: torch.dtype,
+ device: str
+ ):
+ shape = (batch_size, *latent_shape)
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=dtype
+ )
+ return latent
+
+ def encode_content_classifier_free(
+ self,
+ content: list[Any],
+ task: list[str],
+ device,
+ num_samples_per_content: int = 1
+ ):
+ content_dict = self.content_encoder.encode_content(
+ content, task, device
+ )
+ content, content_mask = content_dict["content"], content_dict["content_mask"]
+ # content, content_mask = self.content_encoder.encode_content(
+ # content, task, device=device
+ # )
+
+ content = content.repeat_interleave(num_samples_per_content, 0)
+ content_mask = content_mask.repeat_interleave(
+ num_samples_per_content, 0
+ )
+
+ # get unconditional embeddings for classifier free guidance
+ uncond_content = torch.zeros_like(content)
+ uncond_content_mask = content_mask.detach().clone()
+
+ uncond_content = uncond_content.repeat_interleave(
+ num_samples_per_content, 0
+ )
+ uncond_content_mask = uncond_content_mask.repeat_interleave(
+ num_samples_per_content, 0
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
+ content = torch.cat([uncond_content, content])
+ content_mask = torch.cat([uncond_content_mask, content_mask])
+
+ return content, content_mask
+
+class MultiContentAudioFlowMatching(SingleTaskCrossAttentionAudioFlowMatching):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ backbone: nn.Module,
+ cfg_drop_ratio: float = 0.2,
+ sample_strategy: str = 'normal',
+ num_train_steps: int = 1000,
+ pretrained_ckpt: str | None = None,
+ embed_dim: int = 1024,
+ ):
+ super().__init__(
+ autoencoder=autoencoder,
+ content_encoder=content_encoder,
+ backbone=backbone,
+ cfg_drop_ratio=cfg_drop_ratio,
+ sample_strategy=sample_strategy,
+ num_train_steps=num_train_steps,
+ pretrained_ckpt=pretrained_ckpt,
+ )
+
+ def forward(
+ self,
+ content: list[Any],
+ duration: Sequence[float],
+ task: list[str],
+ waveform: torch.Tensor,
+ waveform_lengths: torch.Tensor,
+ loss_reduce: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ loss_reduce = self.training or (loss_reduce and not self.training)
+
+ self.autoencoder.eval()
+
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ ) # latent [B, 128, 500/T=10s], latent_mask [B, 500/T=10s]
+
+ content_dict = self.encode_content(content, task, device)
+ context, context_mask, length_aligned_content = content_dict["content"], content_dict[
+ "content_mask"], content_dict["length_aligned_content"]
+
+ # --------------------------------------------------------------------
+ # prepare latent and noise
+ # --------------------------------------------------------------------
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
+ latent,
+ training = self.training
+ )
+
+ # --------------------------------------------------------------------
+ # prepare input to the backbone
+ # --------------------------------------------------------------------
+ # TODO compatility for 2D spectrogram VAE
+
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
+ time_aligned_content = trim_or_pad_length(
+ length_aligned_content, latent_length, 1
+ )
+
+ # --------------------------------------------------------------------
+ # classifier free guidance
+ # --------------------------------------------------------------------
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ context[mask_indices] = 0
+ time_aligned_content[mask_indices] = 0
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ x_mask=latent_mask,
+ timesteps=timesteps,
+ context=context,
+ context_mask=context_mask,
+ time_aligned_context=time_aligned_content,
+ )
+
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
+
+ return {
+ "diff_loss": diff_loss,
+ }
+
+ def inference(
+ self,
+ content: list[Any],
+ task: list[str],
+ latent_shape: Sequence[int],
+ num_steps: int = 50,
+ sway_sampling_coef: float | None = -1.0,
+ guidance_scale: float = 3.0,
+ disable_progress: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(content)
+
+
+ content_dict: dict[
+ str, torch.Tensor] = self.encode_content(
+ content, task, device
+ )
+ context, context_mask, length_aligned_content = \
+ content_dict["content"], content_dict[
+ "content_mask"], content_dict["length_aligned_content"]
+
+ shape = (batch_size, *latent_shape)
+ latent_length = shape[self.autoencoder.time_dim]
+ time_aligned_content = trim_or_pad_length(
+ length_aligned_content, latent_length, 1
+ )
+
+ # --------------------------------------------------------------------
+ # prepare unconditional input
+ # --------------------------------------------------------------------
+ if classifier_free_guidance:
+ uncond_time_aligned_content = torch.zeros_like(
+ time_aligned_content
+ )
+ uncond_context = torch.zeros_like(context)
+ uncond_context_mask = context_mask.detach().clone()
+ time_aligned_content = torch.cat([
+ uncond_time_aligned_content, time_aligned_content
+ ])
+ context = torch.cat([uncond_context, context])
+ context_mask = torch.cat([uncond_context_mask, context_mask])
+
+
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=context.dtype
+ )
+
+ if not sway_sampling_coef:
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
+ else:
+ t = torch.linspace(0, 1, num_steps + 1)
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
+ sigmas = 1 - t
+ timesteps, num_steps = self.retrieve_timesteps(
+ num_steps, device, timesteps=None, sigmas=sigmas
+ )
+ latent = self.iterative_denoise(
+ latent=latent,
+ timesteps=timesteps,
+ num_steps=num_steps,
+ verbose=not disable_progress,
+ cfg=classifier_free_guidance,
+ cfg_scale=guidance_scale,
+ backbone_input={
+ "context": context,
+ "context_mask": context_mask,
+ "time_aligned_context": time_aligned_content,
+ }
+ )
+
+ waveform = self.autoencoder.decode(latent)
+ return waveform
+
+class DurationAdapterMixin:
+ def __init__(
+ self,
+ latent_token_rate: int,
+ offset: float = 1.0,
+ frame_resolution: float | None = None
+ ):
+ self.latent_token_rate = latent_token_rate
+ self.offset = offset
+ self.frame_resolution = frame_resolution
+
+ def get_global_duration_loss(
+ self,
+ pred: torch.Tensor,
+ latent_mask: torch.Tensor,
+ reduce: bool = True,
+ ):
+ target = torch.log(
+ latent_mask.sum(1) / self.latent_token_rate + self.offset
+ )
+ loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
+ return loss
+
+ def get_local_duration_loss(
+ self, ground_truth: torch.Tensor, pred: torch.Tensor,
+ mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
+ ):
+ n_frames = torch.round(ground_truth / self.frame_resolution)
+ target = torch.log(n_frames + self.offset)
+ loss = loss_with_mask(
+ (target - pred)**2,
+ mask,
+ reduce=False,
+ )
+ loss *= is_time_aligned
+ if reduce:
+ if is_time_aligned.sum().item() == 0:
+ loss *= 0.0
+ loss = loss.mean()
+ else:
+ loss = loss.sum() / is_time_aligned.sum()
+
+ return loss
+
+ def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
+ pred = torch.exp(pred) * mask
+ pred = torch.ceil(pred) - self.offset
+ pred *= self.frame_resolution
+ return pred
+
+ def prepare_global_duration(
+ self,
+ global_pred: torch.Tensor,
+ local_pred: torch.Tensor,
+ is_time_aligned: Sequence[bool],
+ use_local: bool = True,
+ ):
+ """
+ global_pred: predicted duration value, processed by logarithmic and offset
+ local_pred: predicted latent length
+ """
+ global_pred = torch.exp(global_pred) - self.offset
+ result = global_pred
+ # avoid error accumulation for each frame
+ if use_local:
+ pred_from_local = torch.round(local_pred * self.latent_token_rate)
+ pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
+ result[is_time_aligned] = pred_from_local[is_time_aligned]
+
+ return result
+
+ def expand_by_duration(
+ self,
+ x: torch.Tensor,
+ content_mask: torch.Tensor,
+ local_duration: torch.Tensor,
+ global_duration: torch.Tensor | None = None,
+ ):
+ n_latents = torch.round(local_duration * self.latent_token_rate)
+ if global_duration is not None:
+ latent_length = torch.round(
+ global_duration * self.latent_token_rate
+ )
+ else:
+ latent_length = n_latents.sum(1)
+ latent_mask = create_mask_from_length(latent_length).to(
+ content_mask.device
+ )
+ attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
+ align_path = create_alignment_path(n_latents, attn_mask)
+ expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
+ return expanded_x, latent_mask
+
+
+class CrossAttentionAudioFlowMatching(
+ SingleTaskCrossAttentionAudioFlowMatching, DurationAdapterMixin
+):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ content_adapter: ContentAdapterBase,
+ backbone: nn.Module,
+ content_dim: int,
+ frame_resolution: float,
+ duration_offset: float = 1.0,
+ cfg_drop_ratio: float = 0.2,
+ sample_strategy: str = 'normal',
+ num_train_steps: int = 1000
+ ):
+ super().__init__(
+ autoencoder=autoencoder,
+ content_encoder=content_encoder,
+ backbone=backbone,
+ cfg_drop_ratio=cfg_drop_ratio,
+ sample_strategy=sample_strategy,
+ num_train_steps=num_train_steps,
+ )
+ ContentEncoderAdapterMixin.__init__(
+ self,
+ content_encoder=content_encoder,
+ content_adapter=content_adapter
+ )
+ DurationAdapterMixin.__init__(
+ self,
+ latent_token_rate=autoencoder.latent_token_rate,
+ offset=duration_offset
+ )
+
+ def encode_content_with_instruction(
+ self, content: list[Any], task: list[str], device,
+ instruction: torch.Tensor, instruction_lengths: torch.Tensor
+ ):
+ content_dict = self.encode_content(
+ content, task, device, instruction, instruction_lengths
+ )
+ return (
+ content_dict["content"], content_dict["content_mask"],
+ content_dict["global_duration_pred"],
+ content_dict["local_duration_pred"],
+ content_dict["length_aligned_content"]
+ )
+
+ def forward(
+ self,
+ content: list[Any],
+ task: list[str],
+ waveform: torch.Tensor,
+ waveform_lengths: torch.Tensor,
+ instruction: torch.Tensor,
+ instruction_lengths: torch.Tensor,
+ loss_reduce: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ loss_reduce = self.training or (loss_reduce and not self.training)
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ content, content_mask, global_duration_pred, _, _ = \
+ self.encode_content_with_instruction(
+ content, task, device, instruction, instruction_lengths
+ )
+
+ global_duration_loss = self.get_global_duration_loss(
+ global_duration_pred, latent_mask, reduce=loss_reduce
+ )
+
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ content[mask_indices] = 0
+
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
+ latent,
+ training = self.training
+ )
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ timesteps=timesteps,
+ context=content,
+ x_mask=latent_mask,
+ context_mask=content_mask,
+ )
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = F.mse_loss(pred.float(), target.float(), reduction="none")
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
+
+ return {
+ "diff_loss": diff_loss,
+ "global_duration_loss": global_duration_loss,
+ }
+
+ @torch.no_grad()
+ def inference(
+ self,
+ content: list[Any],
+ condition: list[Any],
+ task: list[str],
+ is_time_aligned: Sequence[bool],
+ instruction: torch.Tensor,
+ instruction_lengths: torch.Tensor,
+ num_steps: int = 20,
+ sway_sampling_coef: float | None = -1.0,
+ guidance_scale: float = 3.0,
+ disable_progress=True,
+ use_gt_duration: bool = False,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+
+ (
+ content,
+ content_mask,
+ global_duration_pred,
+ local_duration_pred,
+ _,
+ ) = self.encode_content_with_instruction(
+ content, task, device, instruction, instruction_lengths
+ )
+ batch_size = content.size(0)
+
+ if use_gt_duration:
+ raise NotImplementedError(
+ "Using ground truth global duration only is not implemented yet"
+ )
+
+ # prepare global duration
+ global_duration = self.prepare_global_duration(
+ global_duration_pred,
+ local_duration_pred,
+ is_time_aligned,
+ use_local=False
+ )
+ latent_length = torch.round(global_duration * self.latent_token_rate)
+ latent_mask = create_mask_from_length(latent_length).to(device)
+ max_latent_length = latent_mask.sum(1).max().item()
+
+ # prepare latent and noise
+ if classifier_free_guidance:
+ uncond_context = torch.zeros_like(content)
+ uncond_content_mask = content_mask.detach().clone()
+ context = torch.cat([uncond_context, content])
+ context_mask = torch.cat([uncond_content_mask, content_mask])
+ else:
+ context = content
+ context_mask = content_mask
+
+ latent_shape = tuple(
+ max_latent_length if dim is None else dim
+ for dim in self.autoencoder.latent_shape
+ )
+ shape = (batch_size, *latent_shape)
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=content.dtype
+ )
+ if not sway_sampling_coef:
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
+ else:
+ t = torch.linspace(0, 1, num_steps + 1)
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
+ sigmas = 1 - t
+ timesteps, num_steps = self.retrieve_timesteps(
+ num_steps, device, timesteps=None, sigmas=sigmas
+ )
+ latent = self.iterative_denoise(
+ latent=latent,
+ timesteps=timesteps,
+ num_steps=num_steps,
+ verbose=not disable_progress,
+ cfg=classifier_free_guidance,
+ cfg_scale=guidance_scale,
+ backbone_input={
+ "x_mask": latent_mask,
+ "context": context,
+ "context_mask": context_mask,
+ }
+ )
+
+ waveform = self.autoencoder.decode(latent)
+ return waveform
+
+
+class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
+ def __init__(
+ self,
+ autoencoder: AutoEncoderBase,
+ content_encoder: ContentEncoder,
+ content_adapter: ContentAdapterBase,
+ backbone: nn.Module,
+ content_dim: int,
+ frame_resolution: float,
+ duration_offset: float = 1.0,
+ cfg_drop_ratio: float = 0.2,
+ sample_strategy: str = 'normal',
+ num_train_steps: int = 1000
+ ):
+
+ super().__init__(
+ autoencoder=autoencoder,
+ content_encoder=content_encoder,
+ content_adapter=content_adapter,
+ backbone=backbone,
+ content_dim=content_dim,
+ frame_resolution=frame_resolution,
+ duration_offset=duration_offset,
+ cfg_drop_ratio=cfg_drop_ratio,
+ sample_strategy=sample_strategy,
+ num_train_steps=num_train_steps
+ )
+ DurationAdapterMixin.__init__(
+ self,
+ latent_token_rate=autoencoder.latent_token_rate,
+ offset=duration_offset,
+ frame_resolution=frame_resolution
+ )
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
+
+ def get_backbone_input(
+ self, target_length: int, content: torch.Tensor,
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
+ ):
+ # TODO compatility for 2D spectrogram VAE
+ time_aligned_content = trim_or_pad_length(
+ time_aligned_content, target_length, 1
+ )
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, target_length, 1
+ )
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
+ # length_aligned_content: from aligned input (f0/energy)
+ time_aligned_content = time_aligned_content + length_aligned_content
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
+ time_aligned_content.dtype
+ )
+
+ context = content
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
+ # only use the first dummy non time aligned embedding
+ context_mask = content_mask.detach().clone()
+ context_mask[is_time_aligned, 1:] = False
+
+ # truncate dummy non time aligned context
+ if is_time_aligned.sum().item() < content.size(0):
+ trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
+ else:
+ trunc_nta_length = content.size(1)
+ context = context[:, :trunc_nta_length]
+ context_mask = context_mask[:, :trunc_nta_length]
+
+ return context, context_mask, time_aligned_content
+
+ def forward(
+ self,
+ content: list[Any],
+ duration: Sequence[float],
+ task: list[str],
+ is_time_aligned: Sequence[bool],
+ waveform: torch.Tensor,
+ waveform_lengths: torch.Tensor,
+ instruction: torch.Tensor,
+ instruction_lengths: torch.Tensor,
+ loss_reduce: bool = True,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ loss_reduce = self.training or (loss_reduce and not self.training)
+
+ self.autoencoder.eval()
+ with torch.no_grad():
+ latent, latent_mask = self.autoencoder.encode(
+ waveform.unsqueeze(1), waveform_lengths
+ )
+
+ (
+ content, content_mask, global_duration_pred, local_duration_pred,
+ length_aligned_content
+ ) = self.encode_content_with_instruction(
+ content, task, device, instruction, instruction_lengths
+ )
+
+ # truncate unused non time aligned duration prediction
+ if is_time_aligned.sum() > 0:
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
+ else:
+ trunc_ta_length = content.size(1)
+
+ # duration loss
+ local_duration_pred = local_duration_pred[:, :trunc_ta_length]
+ ta_content_mask = content_mask[:, :trunc_ta_length]
+ local_duration_loss = self.get_local_duration_loss(
+ duration,
+ local_duration_pred,
+ ta_content_mask,
+ is_time_aligned,
+ reduce=loss_reduce
+ )
+
+ global_duration_loss = self.get_global_duration_loss(
+ global_duration_pred, latent_mask, reduce=loss_reduce
+ )
+
+ # --------------------------------------------------------------------
+ # prepare latent and noise
+ # --------------------------------------------------------------------
+ noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
+ latent,
+ training = self.training
+ )
+
+ # --------------------------------------------------------------------
+ # duration adapter
+ # --------------------------------------------------------------------
+ if is_time_aligned.sum() == 0 and \
+ duration.size(1) < content_mask.size(1):
+ duration = F.pad(
+ duration, (0, content_mask.size(1) - duration.size(1))
+ )
+ time_aligned_content, _ = self.expand_by_duration(
+ x=content[:, :trunc_ta_length],
+ content_mask=ta_content_mask,
+ local_duration=duration,
+ )
+
+ # --------------------------------------------------------------------
+ # prepare input to the backbone
+ # --------------------------------------------------------------------
+ # TODO compatility for 2D spectrogram VAE
+ latent_length = noisy_latent.size(self.autoencoder.time_dim)
+ context, context_mask, time_aligned_content = self.get_backbone_input(
+ latent_length, content, content_mask, time_aligned_content,
+ length_aligned_content, is_time_aligned
+ )
+
+ # --------------------------------------------------------------------
+ # classifier free guidance
+ # --------------------------------------------------------------------
+ if self.training and self.classifier_free_guidance:
+ mask_indices = [
+ k for k in range(len(waveform))
+ if random.random() < self.cfg_drop_ratio
+ ]
+ if len(mask_indices) > 0:
+ context[mask_indices] = 0
+ time_aligned_content[mask_indices] = 0
+
+ pred: torch.Tensor = self.backbone(
+ x=noisy_latent,
+ x_mask=latent_mask,
+ timesteps=timesteps,
+ context=context,
+ context_mask=context_mask,
+ time_aligned_context=time_aligned_content,
+ )
+ pred = pred.transpose(1, self.autoencoder.time_dim)
+ target = target.transpose(1, self.autoencoder.time_dim)
+ diff_loss = F.mse_loss(pred, target, reduction="none")
+ diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
+ return {
+ "diff_loss": diff_loss,
+ "local_duration_loss": local_duration_loss,
+ "global_duration_loss": global_duration_loss,
+ }
+
+ def inference(
+ self,
+ content: list[Any],
+ task: list[str],
+ is_time_aligned: Sequence[bool],
+ instruction: torch.Tensor,
+ instruction_lengths: Sequence[int],
+ num_steps: int = 20,
+ sway_sampling_coef: float | None = -1.0,
+ guidance_scale: float = 3.0,
+ disable_progress: bool = True,
+ use_gt_duration: bool = False,
+ **kwargs
+ ):
+ device = self.dummy_param.device
+ classifier_free_guidance = guidance_scale > 1.0
+
+ (
+ content, content_mask, global_duration_pred, local_duration_pred,
+ length_aligned_content
+ ) = self.encode_content_with_instruction(
+ content, task, device, instruction, instruction_lengths
+ )
+ # print("content std: ", content.std())
+ batch_size = content.size(0)
+
+ # truncate dummy time aligned duration prediction
+ is_time_aligned = torch.as_tensor(is_time_aligned)
+ if is_time_aligned.sum() > 0:
+ trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
+ else:
+ trunc_ta_length = content.size(1)
+
+ # prepare local duration
+ local_duration = self.prepare_local_duration(
+ local_duration_pred, content_mask
+ )
+ local_duration = local_duration[:, :trunc_ta_length]
+ # use ground truth duration
+ if use_gt_duration and "duration" in kwargs:
+ local_duration = torch.as_tensor(kwargs["duration"]).to(device)
+
+ # prepare global duration
+ global_duration = self.prepare_global_duration(
+ global_duration_pred, local_duration, is_time_aligned
+ )
+
+ # --------------------------------------------------------------------
+ # duration adapter
+ # --------------------------------------------------------------------
+ time_aligned_content, latent_mask = self.expand_by_duration(
+ x=content[:, :trunc_ta_length],
+ content_mask=content_mask[:, :trunc_ta_length],
+ local_duration=local_duration,
+ global_duration=global_duration,
+ )
+
+ context, context_mask, time_aligned_content = self.get_backbone_input(
+ target_length=time_aligned_content.size(1),
+ content=content,
+ content_mask=content_mask,
+ time_aligned_content=time_aligned_content,
+ length_aligned_content=length_aligned_content,
+ is_time_aligned=is_time_aligned
+ )
+
+ # --------------------------------------------------------------------
+ # prepare unconditional input
+ # --------------------------------------------------------------------
+ if classifier_free_guidance:
+ uncond_time_aligned_content = torch.zeros_like(
+ time_aligned_content
+ )
+ uncond_context = torch.zeros_like(context)
+ uncond_context_mask = context_mask.detach().clone()
+ time_aligned_content = torch.cat([
+ uncond_time_aligned_content, time_aligned_content
+ ])
+ context = torch.cat([uncond_context, context])
+ context_mask = torch.cat([uncond_context_mask, context_mask])
+ latent_mask = torch.cat([
+ latent_mask, latent_mask.detach().clone()
+ ])
+
+ # --------------------------------------------------------------------
+ # prepare input to the backbone
+ # --------------------------------------------------------------------
+ latent_length = latent_mask.sum(1).max().item()
+ latent_shape = tuple(
+ latent_length if dim is None else dim
+ for dim in self.autoencoder.latent_shape
+ )
+ shape = (batch_size, *latent_shape)
+ latent = randn_tensor(
+ shape, generator=None, device=device, dtype=content.dtype
+ )
+
+ if not sway_sampling_coef:
+ sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
+ else:
+ t = torch.linspace(0, 1, num_steps + 1)
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
+ sigmas = 1 - t
+ timesteps, num_steps = self.retrieve_timesteps(
+ num_steps, device, timesteps=None, sigmas=sigmas
+ )
+ latent = self.iterative_denoise(
+ latent=latent,
+ timesteps=timesteps,
+ num_steps=num_steps,
+ verbose=not disable_progress,
+ cfg=classifier_free_guidance,
+ cfg_scale=guidance_scale,
+ backbone_input={
+ "x_mask": latent_mask,
+ "context": context,
+ "context_mask": context_mask,
+ "time_aligned_context": time_aligned_content,
+ }
+ )
+
+ waveform = self.autoencoder.decode(latent)
+ return waveform
+
+
+class DoubleContentAudioFlowMatching(DummyContentAudioFlowMatching):
+ def get_backbone_input(
+ self, target_length: int, content: torch.Tensor,
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
+ ):
+ # TODO compatility for 2D spectrogram VAE
+ time_aligned_content = trim_or_pad_length(
+ time_aligned_content, target_length, 1
+ )
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, target_length, 1
+ )
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
+ # length_aligned_content: from aligned input (f0/energy)
+ time_aligned_content = time_aligned_content + length_aligned_content
+
+ context = content
+ context_mask = content_mask.detach().clone()
+
+ return context, context_mask, time_aligned_content
+
+
+class HybridContentAudioFlowMatching(DummyContentAudioFlowMatching):
+ def get_backbone_input(
+ self, target_length: int, content: torch.Tensor,
+ content_mask: torch.Tensor, time_aligned_content: torch.Tensor,
+ length_aligned_content: torch.Tensor, is_time_aligned: torch.Tensor
+ ):
+ # TODO compatility for 2D spectrogram VAE
+ time_aligned_content = trim_or_pad_length(
+ time_aligned_content, target_length, 1
+ )
+ length_aligned_content = trim_or_pad_length(
+ length_aligned_content, target_length, 1
+ )
+ # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
+ # length_aligned_content: from aligned input (f0/energy)
+ time_aligned_content = time_aligned_content + length_aligned_content
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
+ time_aligned_content.dtype
+ )
+
+ context = content
+ context_mask = content_mask.detach().clone()
+
+ return context, context_mask, time_aligned_content
diff --git a/requirements.txt b/requirements.txt
index 5149ecea6592775ad8dfd58a4abcedfd2c2cf8f4..1fa472e143c6ba0a1f855c811823f7e3b42b73d9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,149 @@
-gradio==4.4.1
-transformers==4.31.0
-huggingface-hub==0.16.4
\ No newline at end of file
+absl-py==2.3.0
+accelerate==1.2.1
+alias-free-torch==0.0.6
+annotated-types==0.7.0
+antlr4-python3-runtime==4.9.3
+astunparse==1.6.3
+attrs==22.2.0
+audioread==3.0.1
+av==11.0.0
+bitarray==3.7.1
+boto3==1.38.36
+botocore==1.38.36
+braceexpand==0.1.7
+brotlipy==0.7.0
+click==8.1.8
+colorama==0.4.6
+conda==23.1.0
+conda-build==3.23.3
+contourpy==1.2.0
+cycler==0.12.1
+dcase-util==0.2.20
+diffusers==0.33.1
+dnspython==2.3.0
+docker-pycreds==0.4.0
+einops==0.7.0
+exceptiongroup==1.1.1
+expecttest==0.1.4
+fire==0.7.0
+fonttools==4.47.2
+fsspec==2023.12.2
+ftfy==6.3.1
+future==1.0.0
+gitdb==4.0.12
+GitPython==3.1.44
+grpcio==1.73.0
+h5py==3.10.0
+huggingface-hub==0.30.2
+hydra-core==1.3.2
+hypothesis==6.70.0
+imageio==2.37.0
+importlib_metadata==8.5.0
+iniconfig==2.0.0
+ipdb==0.13.13
+jmespath==1.0.1
+joblib==1.3.2
+kiwisolver==1.4.5
+laion_clap==1.1.7
+lazy-dataset==0.0.14
+lazy_loader==0.4
+librosa==0.10.2
+llvmlite==0.42.0
+lxml==6.0.1
+Markdown==3.8
+matplotlib==3.8.2
+mkl-fft==1.3.1
+mkl-service==2.4.0
+mpmath==1.3.0
+msgpack==1.0.8
+networkx==3.0
+numba==0.59.1
+numpy==1.26.4
+nvidia-cublas-cu12==12.4.5.8
+nvidia-cuda-cupti-cu12==12.4.127
+nvidia-cuda-nvrtc-cu12==12.4.127
+nvidia-cuda-runtime-cu12==12.4.127
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.2.1.3
+nvidia-curand-cu12==10.3.5.147
+nvidia-cusolver-cu12==11.6.1.9
+nvidia-cusparse-cu12==12.3.1.170
+nvidia-cusparselt-cu12==0.6.2
+nvidia-ml-py==12.575.51
+nvidia-nccl-cu12==2.21.5
+nvidia-nvjitlink-cu12==12.4.127
+nvidia-nvtx-cu12==12.4.127
+omegaconf==2.3.0
+packaging==23.2
+pandas==2.2.0
+pathlib==1.0.1
+pillow==11.3.0
+pip-chill==1.0.3
+platformdirs==4.2.1
+pluggy==1.5.0
+pooch==1.8.1
+portalocker==3.2.0
+prettytable==3.16.0
+progressbar==2.5
+protobuf==5.29.2
+psds-eval==0.5.3
+pydantic==2.10.4
+pydantic_core==2.27.2
+pydot-ng==2.0.0
+pyecharts==2.0.8
+pynvml==12.0.0
+pyparsing==3.1.1
+pytest==8.2.0
+python-dateutil==2.8.2
+python-etcd==0.4.5
+python-magic==0.4.27
+regex==2023.12.25
+resampy==0.4.3
+s3transfer==0.13.0
+sacrebleu==2.5.1
+safetensors==0.5.0
+scikit-image==0.25.2
+scikit-learn==1.4.0
+scipy==1.12.0
+sed-eval==0.2.1
+sed-scores-eval==0.0.0
+sentence-transformers==4.1.0
+sentencepiece==0.2.0
+sentry-sdk==2.19.2
+setproctitle==1.3.4
+simplejson==3.20.1
+smmap==5.0.2
+sortedcontainers==2.4.0
+soundfile==0.12.1
+soxr==0.3.7
+swankit==0.2.3
+swanlab==0.6.3
+sympy==1.13.1
+tabulate==0.9.0
+tensorboard==2.19.0
+tensorboard-data-server==0.7.2
+termcolor==3.1.0
+threadpoolctl==3.2.0
+tifffile==2025.5.10
+timm==0.9.12
+tokenizers==0.21.1
+tomli==2.0.1
+torch==2.6.0
+torchaudio==2.6.0
+torchdata==0.10.1
+torchelastic==0.2.2
+torchlibrosa==0.1.0
+torchtext==0.15.0
+torchvision==0.21.0
+transformers==4.51.3
+triton==3.2.0
+types-dataclasses==0.6.6
+typing_extensions==4.12.2
+tzdata==2023.4
+validators==0.28.1
+wandb==0.19.1
+webdataset==1.0.2
+Werkzeug==3.1.3
+wget==3.2
+wrapt==1.17.2
+zipp==3.21.0
diff --git a/utils/__pycache__/accelerate_utilities.cpython-310.pyc b/utils/__pycache__/accelerate_utilities.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d039d356af90e1c86d9442812b7c9462f8a99a1
Binary files /dev/null and b/utils/__pycache__/accelerate_utilities.cpython-310.pyc differ
diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecfe36dc4fb653c0dcd67bfb6e2a63cac8b603c6
Binary files /dev/null and b/utils/__pycache__/config.cpython-310.pyc differ
diff --git a/utils/__pycache__/diffsinger_utilities.cpython-310.pyc b/utils/__pycache__/diffsinger_utilities.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60e4cb6d6a33922102a8addeb45e3bf2128abc47
Binary files /dev/null and b/utils/__pycache__/diffsinger_utilities.cpython-310.pyc differ
diff --git a/utils/__pycache__/general.cpython-310.pyc b/utils/__pycache__/general.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25e2e42a9dfedb3de3c42da7996a2695061f8716
Binary files /dev/null and b/utils/__pycache__/general.cpython-310.pyc differ
diff --git a/utils/__pycache__/logging.cpython-310.pyc b/utils/__pycache__/logging.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cbfc641e1c3864dee6d9a31115464d39cf2a1b5
Binary files /dev/null and b/utils/__pycache__/logging.cpython-310.pyc differ
diff --git a/utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc b/utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c893bd72df9d75ca5c493921263e38729bfac930
Binary files /dev/null and b/utils/__pycache__/lr_scheduler_utilities.cpython-310.pyc differ
diff --git a/utils/__pycache__/torch_utilities.cpython-310.pyc b/utils/__pycache__/torch_utilities.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b20d4f9f6dae86af3fe8842badd64afd81e82c72
Binary files /dev/null and b/utils/__pycache__/torch_utilities.cpython-310.pyc differ
diff --git a/utils/accelerate_utilities.py b/utils/accelerate_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..99fc5aa3ad2700361e006799c6aac8119c9bb15a
--- /dev/null
+++ b/utils/accelerate_utilities.py
@@ -0,0 +1,13 @@
+from accelerate import Accelerator
+
+
+class AcceleratorSaveTrainableParams(Accelerator):
+ def get_state_dict(self, model, unwrap=True):
+ state_dict = super().get_state_dict(model, unwrap)
+ if hasattr(model, "param_names_to_save"):
+ param_names_to_save = model.param_names_to_save
+ return {
+ k: v
+ for k, v in state_dict.items() if k in param_names_to_save
+ }
+ return state_dict
diff --git a/utils/audio.py b/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..350a9fe08e229a2f979f8090e314216c6b356739
--- /dev/null
+++ b/utils/audio.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+import torchaudio
+
+
+class PadCrop(nn.Module):
+ def __init__(self, n_samples, randomize=True):
+ super().__init__()
+ self.n_samples = n_samples
+ self.randomize = randomize
+
+ def __call__(self, signal):
+ n, s = signal.shape
+ start = 0 if (
+ not self.randomize
+ ) else torch.randint(0,
+ max(0, s - self.n_samples) + 1, []).item()
+ end = start + self.n_samples
+ output = signal.new_zeros([n, self.n_samples])
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
+ return output
+
+
+def set_audio_channels(audio, target_channels):
+ if target_channels == 1:
+ # Convert to mono
+ audio = audio.mean(1, keepdim=True)
+ elif target_channels == 2:
+ # Convert to stereo
+ if audio.shape[1] == 1:
+ audio = audio.repeat(1, 2, 1)
+ elif audio.shape[1] > 2:
+ audio = audio[:, :2, :]
+ return audio
+
+
+def prepare_audio(
+ audio, in_sr, target_sr, target_length, target_channels, device
+):
+
+ audio = audio.to(device)
+
+ if in_sr != target_sr:
+ resample_tf = torchaudio.transforms.Resample(in_sr,
+ target_sr).to(device)
+ audio = resample_tf(audio)
+
+ audio = PadCrop(target_length, randomize=False)(audio)
+
+ # Add batch dimension
+ if audio.dim() == 1:
+ audio = audio.unsqueeze(0).unsqueeze(0)
+ elif audio.dim() == 2:
+ audio = audio.unsqueeze(0)
+
+ audio = set_audio_channels(audio, target_channels)
+
+ return audio
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dfee06dbfdf633899e8d4748a06a2b2f71673c1
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1,53 @@
+from pathlib import Path
+import sys
+import os
+
+import hydra
+import omegaconf
+from omegaconf import OmegaConf
+
+
+def multiply(*args):
+ result = 1
+ for arg in args:
+ result *= arg
+ return result
+
+
+def get_pitch_downsample_ratio(
+ autoencoder_config: dict, pitch_frame_resolution: float
+):
+ latent_frame_resolution = autoencoder_config[
+ "downsampling_ratio"] / autoencoder_config["sample_rate"]
+ return round(latent_frame_resolution / pitch_frame_resolution)
+
+
+def register_omegaconf_resolvers() -> None:
+ """
+ Register custom resolver for hydra configs, which can be used in YAML
+ files for dynamically setting values
+ """
+ OmegaConf.clear_resolvers()
+ OmegaConf.register_new_resolver("len", len, replace=True)
+ OmegaConf.register_new_resolver("multiply", multiply, replace=True)
+ OmegaConf.register_new_resolver(
+ "get_pitch_downsample_ratio", get_pitch_downsample_ratio, replace=True
+ )
+
+
+def generate_config_from_command_line_overrides(
+ config_file: str | Path
+) -> omegaconf.DictConfig:
+ register_omegaconf_resolvers()
+
+ config_file = Path(config_file).resolve()
+ config_name = config_file.name.__str__()
+ config_path = config_file.parent.__str__()
+ config_path = os.path.relpath(config_path, Path(__file__).resolve().parent)
+
+ overrides = sys.argv[1:]
+ with hydra.initialize(version_base=None, config_path=config_path):
+ config = hydra.compose(config_name=config_name, overrides=overrides)
+ omegaconf.OmegaConf.resolve(config)
+
+ return config
diff --git a/utils/diffsinger_utilities.py b/utils/diffsinger_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d658ec5a8306d46e65ef539e4e679c6a57d484b
--- /dev/null
+++ b/utils/diffsinger_utilities.py
@@ -0,0 +1,551 @@
+import six
+from pathlib import Path
+import re
+import json
+from collections import OrderedDict
+from typing import Union
+
+import numpy as np
+import librosa
+import torch
+
+PAD = ""
+EOS = ""
+UNK = ""
+SEG = "|"
+RESERVED_TOKENS = [PAD, EOS, UNK]
+NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
+PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
+EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
+UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
+
+F0_BIN = 256
+F0_MAX = 1100.0
+F0_MIN = 50.0
+F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
+F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
+
+
+def f0_to_coarse(f0):
+ is_torch = isinstance(f0, torch.Tensor)
+ f0_mel = 1127 * (1 + f0 /
+ 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+ f0_mel[f0_mel > 0
+ ] = (f0_mel[f0_mel > 0] -
+ F0_MEL_MIN) * (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > F0_BIN - 1] = F0_BIN - 1
+ f0_coarse = (f0_mel +
+ 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
+ f0_coarse.max(), f0_coarse.min()
+ )
+ return f0_coarse
+
+
+def norm_f0(
+ f0: Union[np.ndarray, torch.Tensor],
+ uv: Union[None, np.ndarray],
+ f0_mean: float,
+ f0_std: float,
+ pitch_norm: str = "log",
+ use_uv: bool = True
+):
+ is_torch = isinstance(f0, torch.Tensor)
+ if pitch_norm == 'standard':
+ f0 = (f0 - f0_mean) / f0_std
+ if pitch_norm == 'log':
+ f0 = torch.log2(f0) if is_torch else np.log2(f0)
+ if uv is not None and use_uv:
+ f0[uv > 0] = 0
+ return f0
+
+
+def norm_interp_f0(
+ f0: Union[np.ndarray, torch.Tensor],
+ f0_mean: float,
+ f0_std: float,
+ pitch_norm: str = "log",
+ use_uv: bool = True
+):
+ is_torch = isinstance(f0, torch.Tensor)
+ if is_torch:
+ device = f0.device
+ f0 = f0.data.cpu().numpy()
+ uv = f0 == 0
+ f0 = norm_f0(f0, uv, f0_mean, f0_std, pitch_norm, use_uv)
+ if sum(uv) == len(f0):
+ f0[uv] = 0
+ elif sum(uv) > 0:
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+ uv = torch.as_tensor(uv).float()
+ f0 = torch.as_tensor(f0).float()
+ if is_torch:
+ f0 = f0.to(device)
+ return f0, uv
+
+
+def denorm_f0(
+ f0,
+ uv,
+ pitch_norm="log",
+ f0_mean=None,
+ f0_std=None,
+ pitch_padding=None,
+ min=None,
+ max=None,
+ use_uv=True
+):
+ if pitch_norm == 'standard':
+ f0 = f0 * f0_std + f0_mean
+ if pitch_norm == 'log':
+ f0 = 2**f0
+ if min is not None:
+ f0 = f0.clamp(min=min)
+ if max is not None:
+ f0 = f0.clamp(max=max)
+ if uv is not None and use_uv:
+ f0[uv > 0] = 0
+ if pitch_padding is not None:
+ f0[pitch_padding] = 0
+ return f0
+
+
+def librosa_pad_lr(x, fshift, pad_sides=1):
+ '''compute right padding (final frame) or both sides padding (first and final frames)
+ '''
+ assert pad_sides in (1, 2)
+ # return int(fsize // 2)
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+ if pad_sides == 1:
+ return 0, pad
+ else:
+ return pad // 2, pad // 2 + pad % 2
+
+
+def get_pitch(
+ wav_file: Union[str, Path], sample_rate: int, frame_shift: float
+):
+ import parselmouth
+ hop_size = int(frame_shift * sample_rate)
+ wav, _ = librosa.core.load(wav_file, sr=sample_rate)
+ # l_pad, r_pad = librosa_pad_lr(wav, hop_size, 1)
+ # wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
+
+ latent_length = wav.shape[0] // hop_size
+ f0_min = 80
+ f0_max = 750
+ pad_size = 4
+
+ f0 = parselmouth.Sound(wav, sample_rate).to_pitch_ac(
+ time_step=frame_shift,
+ voicing_threshold=0.6,
+ pitch_floor=f0_min,
+ pitch_ceiling=f0_max
+ ).selected_array['frequency']
+ delta_l = latent_length - len(f0)
+ if delta_l > 0:
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
+ pitch_coarse = f0_to_coarse(f0)
+ return f0, pitch_coarse
+
+
+def remove_empty_lines(text):
+ """remove empty lines"""
+ assert (len(text) > 0)
+ assert (isinstance(text, list))
+ text = [t.strip() for t in text]
+ if "" in text:
+ text.remove("")
+ return text
+
+
+def is_sil_phoneme(p):
+ return not p[0].isalpha()
+
+
+def strip_ids(ids, ids_to_strip):
+ """Strip ids_to_strip from the end ids."""
+ ids = list(ids)
+ while ids and ids[-1] in ids_to_strip:
+ ids.pop()
+ return ids
+
+
+class TextEncoder(object):
+ """Base class for converting from ints to/from human readable strings."""
+ def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
+ self._num_reserved_ids = num_reserved_ids
+
+ @property
+ def num_reserved_ids(self):
+ return self._num_reserved_ids
+
+ def encode(self, s):
+ """Transform a human-readable string into a sequence of int ids.
+
+ The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
+ num_reserved_ids) are reserved.
+
+ EOS is not appended.
+
+ Args:
+ s: human-readable string to be converted.
+
+ Returns:
+ ids: list of integers
+ """
+ return [int(w) + self._num_reserved_ids for w in s.split()]
+
+ def decode(self, ids, strip_extraneous=False):
+ """Transform a sequence of int ids into a human-readable string.
+
+ EOS is not expected in ids.
+
+ Args:
+ ids: list of integers to be converted.
+ strip_extraneous: bool, whether to strip off extraneous tokens
+ (EOS and PAD).
+
+ Returns:
+ s: human-readable string.
+ """
+ if strip_extraneous:
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
+ return " ".join(self.decode_list(ids))
+
+ def decode_list(self, ids):
+ """Transform a sequence of int ids into a their string versions.
+
+ This method supports transforming individual input/output ids to their
+ string versions so that sequence to/from text conversions can be visualized
+ in a human readable format.
+
+ Args:
+ ids: list of integers to be converted.
+
+ Returns:
+ strs: list of human-readable string.
+ """
+ decoded_ids = []
+ for id_ in ids:
+ if 0 <= id_ < self._num_reserved_ids:
+ decoded_ids.append(RESERVED_TOKENS[int(id_)])
+ else:
+ decoded_ids.append(id_ - self._num_reserved_ids)
+ return [str(d) for d in decoded_ids]
+
+ @property
+ def vocab_size(self):
+ raise NotImplementedError()
+
+
+class TokenTextEncoder(TextEncoder):
+ """Encoder based on a user-supplied vocabulary (file or list)."""
+ def __init__(
+ self,
+ vocab_filename,
+ reverse=False,
+ vocab_list=None,
+ replace_oov=None,
+ num_reserved_ids=NUM_RESERVED_TOKENS
+ ):
+ """Initialize from a file or list, one token per line.
+
+ Handling of reserved tokens works as follows:
+ - When initializing from a list, we add reserved tokens to the vocab.
+ - When initializing from a file, we do not add reserved tokens to the vocab.
+ - When saving vocab files, we save reserved tokens to the file.
+
+ Args:
+ vocab_filename: If not None, the full filename to read vocab from. If this
+ is not None, then vocab_list should be None.
+ reverse: Boolean indicating if tokens should be reversed during encoding
+ and decoding.
+ vocab_list: If not None, a list of elements of the vocabulary. If this is
+ not None, then vocab_filename should be None.
+ replace_oov: If not None, every out-of-vocabulary token seen when
+ encoding will be replaced by this string (which must be in vocab).
+ num_reserved_ids: Number of IDs to save for reserved tokens like .
+ """
+ super(TokenTextEncoder,
+ self).__init__(num_reserved_ids=num_reserved_ids)
+ self._reverse = reverse
+ self._replace_oov = replace_oov
+ if vocab_filename:
+ self._init_vocab_from_file(vocab_filename)
+ else:
+ assert vocab_list is not None
+ self._init_vocab_from_list(vocab_list)
+ self.pad_index = self._token_to_id[PAD]
+ self.eos_index = self._token_to_id[EOS]
+ self.unk_index = self._token_to_id[UNK]
+ self.seg_index = self._token_to_id[
+ SEG] if SEG in self._token_to_id else self.eos_index
+
+ def encode(self, s):
+ """Converts a space-separated string of tokens to a list of ids."""
+ sentence = s
+ tokens = sentence.strip().split()
+ if self._replace_oov is not None:
+ tokens = [
+ t if t in self._token_to_id else self._replace_oov
+ for t in tokens
+ ]
+ ret = [self._token_to_id[tok] for tok in tokens]
+ return ret[::-1] if self._reverse else ret
+
+ def decode(self, ids, strip_eos=False, strip_padding=False):
+ if strip_padding and self.pad() in list(ids):
+ pad_pos = list(ids).index(self.pad())
+ ids = ids[:pad_pos]
+ if strip_eos and self.eos() in list(ids):
+ eos_pos = list(ids).index(self.eos())
+ ids = ids[:eos_pos]
+ return " ".join(self.decode_list(ids))
+
+ def decode_list(self, ids):
+ seq = reversed(ids) if self._reverse else ids
+ return [self._safe_id_to_token(i) for i in seq]
+
+ @property
+ def vocab_size(self):
+ return len(self._id_to_token)
+
+ def __len__(self):
+ return self.vocab_size
+
+ def _safe_id_to_token(self, idx):
+ return self._id_to_token.get(idx, "ID_%d" % idx)
+
+ def _init_vocab_from_file(self, filename):
+ """Load vocab from a file.
+
+ Args:
+ filename: The file to load vocabulary from.
+ """
+ with open(filename) as f:
+ tokens = [token.strip() for token in f.readlines()]
+
+ def token_gen():
+ for token in tokens:
+ yield token
+
+ self._init_vocab(token_gen(), add_reserved_tokens=False)
+
+ def _init_vocab_from_list(self, vocab_list):
+ """Initialize tokens from a list of tokens.
+
+ It is ok if reserved tokens appear in the vocab list. They will be
+ removed. The set of tokens in vocab_list should be unique.
+
+ Args:
+ vocab_list: A list of tokens.
+ """
+ def token_gen():
+ for token in vocab_list:
+ if token not in RESERVED_TOKENS:
+ yield token
+
+ self._init_vocab(token_gen())
+
+ def _init_vocab(self, token_generator, add_reserved_tokens=True):
+ """Initialize vocabulary with tokens from token_generator."""
+
+ self._id_to_token = {}
+ non_reserved_start_index = 0
+
+ if add_reserved_tokens:
+ self._id_to_token.update(enumerate(RESERVED_TOKENS))
+ non_reserved_start_index = len(RESERVED_TOKENS)
+
+ self._id_to_token.update(
+ enumerate(token_generator, start=non_reserved_start_index)
+ )
+
+ # _token_to_id is the reverse of _id_to_token
+ self._token_to_id = dict(
+ (v, k) for k, v in six.iteritems(self._id_to_token)
+ )
+
+ def pad(self):
+ return self.pad_index
+
+ def eos(self):
+ return self.eos_index
+
+ def unk(self):
+ return self.unk_index
+
+ def seg(self):
+ return self.seg_index
+
+ def store_to_file(self, filename):
+ """Write vocab file to disk.
+
+ Vocab files have one token per line. The file ends in a newline. Reserved
+ tokens are written to the vocab file as well.
+
+ Args:
+ filename: Full path of the file to store the vocab to.
+ """
+ with open(filename, "w") as f:
+ for i in range(len(self._id_to_token)):
+ f.write(self._id_to_token[i] + "\n")
+
+ def sil_phonemes(self):
+ return [p for p in self._id_to_token.values() if not p[0].isalpha()]
+
+
+class TextGrid(object):
+ def __init__(self, text):
+ text = remove_empty_lines(text)
+ self.text = text
+ self.line_count = 0
+ self._get_type()
+ self._get_time_intval()
+ self._get_size()
+ self.tier_list = []
+ self._get_item_list()
+
+ def _extract_pattern(self, pattern, inc):
+ """
+ Parameters
+ ----------
+ pattern : regex to extract pattern
+ inc : increment of line count after extraction
+ Returns
+ -------
+ group : extracted info
+ """
+ try:
+ group = re.match(pattern, self.text[self.line_count]).group(1)
+ self.line_count += inc
+ except AttributeError:
+ raise ValueError(
+ "File format error at line %d:%s" %
+ (self.line_count, self.text[self.line_count])
+ )
+ return group
+
+ def _get_type(self):
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
+
+ def _get_time_intval(self):
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
+
+ def _get_size(self):
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
+
+ def _get_item_list(self):
+ """Only supports IntervalTier currently"""
+ for itemIdx in range(1, self.size + 1):
+ tier = OrderedDict()
+ item_list = []
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
+ if tier_class != "IntervalTier":
+ raise NotImplementedError(
+ "Only IntervalTier class is supported currently"
+ )
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
+ for i in range(int(tier_size)):
+ item = OrderedDict()
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
+ item_list.append(item)
+ tier["idx"] = tier_idx
+ tier["class"] = tier_class
+ tier["name"] = tier_name
+ tier["xmin"] = tier_xmin
+ tier["xmax"] = tier_xmax
+ tier["size"] = tier_size
+ tier["items"] = item_list
+ self.tier_list.append(tier)
+
+ def toJson(self):
+ _json = OrderedDict()
+ _json["file_type"] = self.file_type
+ _json["xmin"] = self.xmin
+ _json["xmax"] = self.xmax
+ _json["size"] = self.size
+ _json["tiers"] = self.tier_list
+ return json.dumps(_json, ensure_ascii=False, indent=2)
+
+
+def read_duration_from_textgrid(
+ textgrid_path: Union[str, Path],
+ phoneme: str,
+ utterance_duration: float,
+):
+ ph_list = phoneme.split(" ")
+ with open(textgrid_path, "r") as f:
+ textgrid = f.readlines()
+ textgrid = remove_empty_lines(textgrid)
+ textgrid = TextGrid(textgrid)
+ textgrid = json.loads(textgrid.toJson())
+
+ split = np.ones(len(ph_list) + 1, np.float32) * -1
+ tg_idx = 0
+ ph_idx = 0
+ tg_align = [x for x in textgrid['tiers'][-1]['items']]
+ tg_align_ = []
+ for x in tg_align:
+ x['xmin'] = float(x['xmin'])
+ x['xmax'] = float(x['xmax'])
+ if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC', '', '']:
+ x['text'] = ''
+ if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
+ tg_align_[-1]['xmax'] = x['xmax']
+ continue
+ tg_align_.append(x)
+ tg_align = tg_align_
+ tg_len = len([x for x in tg_align if x['text'] != ''])
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
+ assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, textgrid_path)
+ while tg_idx < len(tg_align) or ph_idx < len(ph_list):
+ if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
+ split[ph_idx] = 1e8
+ ph_idx += 1
+ continue
+ x = tg_align[tg_idx]
+ if x['text'] == '' and ph_idx == len(ph_list):
+ tg_idx += 1
+ continue
+ assert ph_idx < len(ph_list), (
+ tg_len, ph_len, tg_align, ph_list, textgrid_path
+ )
+
+ ph = ph_list[ph_idx]
+ if x['text'] == '' and not is_sil_phoneme(ph):
+ assert False, (ph_list, tg_align)
+ if x['text'] != '' and is_sil_phoneme(ph):
+ ph_idx += 1
+ else:
+ assert (x['text'] == '' and is_sil_phoneme(ph)) \
+ or x['text'].lower() == ph.lower() \
+ or x['text'].lower() == 'sil', (x['text'], ph)
+ split[ph_idx] = x['xmin']
+ if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(
+ ph_list[ph_idx - 1]
+ ):
+ split[ph_idx - 1] = split[ph_idx]
+ ph_idx += 1
+ tg_idx += 1
+ assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
+ assert ph_idx >= len(ph_list) - 1, (
+ ph_idx, ph_list, len(ph_list), [x['text']
+ for x in tg_align], textgrid_path
+ )
+
+ split[0] = 0
+ split[-1] = utterance_duration
+ duration = np.diff(split)
+ return duration
diff --git a/utils/general.py b/utils/general.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d86f1e9e9990fb416d67392b5dd65be12ce4e6b
--- /dev/null
+++ b/utils/general.py
@@ -0,0 +1,73 @@
+import json
+import re
+from typing import Union, Dict
+from pathlib import Path
+import os
+
+MAX_FILE_NAME_LENGTH = 100
+
+
+def read_jsonl_to_mapping(
+ jsonl_file: Union[str, Path],
+ key_col: str,
+ value_col: str,
+ base_path=None,
+ overwrite=True,
+) -> Dict[str, str]:
+ """
+ Read two columns, indicated by `key_col` and `value_col`, from the
+ given jsonl file to return the mapping dict
+ TODO handle duplicate keys
+ """
+ mapping = {}
+ with open(jsonl_file, 'r') as file:
+ for line in file.readlines():
+ data = json.loads(line.strip())
+ key = data[key_col]
+ value = data[value_col]
+ if base_path:
+ value = os.path.join(base_path, value)
+ if key not in mapping.keys() or overwrite:
+ mapping[key] = value
+ return mapping
+
+
+def sanitize_filename(name: str, max_len: int = MAX_FILE_NAME_LENGTH) -> str:
+ """
+ Clean and truncate a string to make it a valid and safe filename.
+ """
+ name = re.sub(r'[\\/*?:"<>|]', '_', name)
+ name = name.replace('/', '_')
+ max_len = min(len(name), max_len)
+ return name[:max_len]
+
+
+def transform_gen_fn_to_id(audio_file: Path, task: str) -> str:
+ if task == "svs":
+ audio_id = audio_file.stem.split("_")[0]
+ elif task == "sr":
+ audio_id = audio_file.stem
+ elif task == "tta":
+ audio_id = audio_file.stem[:12] + '.wav'
+ elif task == "ttm":
+ audio_id = audio_file.stem[:11]
+ # audio_id = audio_file.stem[:12] + '.wav'
+ elif task == "v2a":
+ audio_id = audio_file.stem.rsplit("_", 1)[0] + ".mp4"
+ elif task == "sta_test" or task == "tta_test":
+ audio_id = audio_file.stem[:12] + '.wav'
+ elif task == "sta_base":
+ audio_id = 'Y' + audio_file.stem[:11] + '.wav'
+ else:
+ audio_id = audio_file.stem
+ return audio_id
+
+
+def audio_dir_to_mapping(audio_dir: str | Path, task: str) -> dict:
+ mapping = {}
+ audio_dir = Path(audio_dir)
+ audio_files = sorted(audio_dir.iterdir())
+ for audio_file in audio_files:
+ audio_id = transform_gen_fn_to_id(audio_file, task)
+ mapping[audio_id] = str(audio_file.resolve())
+ return mapping
diff --git a/utils/logging.py b/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcb011c02dca7c42c9b387c4f126ba3e9fe7efb4
--- /dev/null
+++ b/utils/logging.py
@@ -0,0 +1,23 @@
+from pathlib import Path
+from dataclasses import dataclass
+import logging
+
+
+@dataclass
+class LoggingLogger:
+
+ filename: str | Path
+ level: str = "INFO"
+
+ def create_instance(self, ):
+ filename = self.filename.__str__()
+ formatter = logging.Formatter("[%(asctime)s] - %(message)s")
+
+ logger = logging.getLogger(__name__ + "." + filename)
+ logger.setLevel(getattr(logging, self.level))
+
+ file_handler = logging.FileHandler(filename)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
diff --git a/utils/lr_scheduler_utilities.py b/utils/lr_scheduler_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bd0333a6c51bbc1f307213eced5d34b4b906d22
--- /dev/null
+++ b/utils/lr_scheduler_utilities.py
@@ -0,0 +1,154 @@
+from typing import Any
+import math
+import copy
+from torch.utils.data import DataLoader
+
+
+def get_warmup_steps(
+ dataloader_one_pass_outside_steps: int,
+ warmup_steps: int | None = None,
+ warmup_epochs: float | None = None,
+ epoch_length: int | None = None,
+) -> int:
+ """
+ Derive warmup steps according to step number or epoch number.
+ If `warmup_steps` is provided, then just return it. Otherwise, derive
+ the warmup steps by epoch length and warmup epoch number.
+ """
+ if warmup_steps is not None:
+ return warmup_steps
+ else:
+ if epoch_length is None:
+ epoch_length = dataloader_one_pass_outside_steps
+ assert warmup_epochs is not None, "warmup_steps and warmup_epochs cannot be both None"
+ return int(epoch_length * warmup_epochs)
+
+
+def get_dataloader_one_pass_outside_steps(
+ train_dataloader: DataLoader,
+ num_processes: int = 1,
+):
+ """
+ dataloader length after DDP, close to `original_length / gpu_number`
+ """
+ return math.ceil(len(train_dataloader) / num_processes)
+
+
+def get_total_training_steps(
+ train_dataloader: DataLoader,
+ epochs: int,
+ num_processes: int = 1,
+ epoch_length: int | None = None
+):
+ """
+ Calculate the total number of "visible" training steps.
+
+ If `epoch_length` is provided, it is used as the fixed length for each epoch.
+ Otherwise, the function will determine the epoch length from `train_dataloader`.
+
+ Args:
+ train_dataloader:
+ Training dataloader object.
+ epochs:
+ The total number of epochs to run.
+ num_processes:
+ The number of parallel processes used for distributed training.
+ epoch_length:
+ A fixed number of training steps for each epoch. Defaults to None.
+
+ Returns:
+ int: The total number of training steps (i.e., `epochs * epoch_length`).
+ """
+ # `epoch_length` is not None: fixed length for each epoch
+ if epoch_length is None:
+ # `epoch_length` is the length of DDP-wrapped `train_dataloader`
+ epoch_length = get_dataloader_one_pass_outside_steps(
+ train_dataloader, num_processes
+ )
+ return epochs * epoch_length
+
+
+def get_dataloader_one_pass_steps_inside_accelerator(
+ dataloader_one_pass_steps: int, gradient_accumulation_steps: int,
+ num_processes: int
+):
+ """
+ Calculate the number of "visible" training steps for a single pass over the dataloader
+ inside an accelerator, accounting for gradient accumulation and distributed training.
+
+
+ Args:
+ dataloader_one_pass_steps:
+ The number of steps (batches) in one pass over the dataset.
+ gradient_accumulation_steps:
+ The number of steps to accumulate gradients before performing a parameter update.
+ num_processes:
+ The number of parallel processes used for distributed training.
+
+ Returns:
+ int: The total number of "visible" training steps for one pass over the dataset,
+ multiplied by the number of processes.
+ """
+ return math.ceil(
+ dataloader_one_pass_steps / gradient_accumulation_steps
+ ) * num_processes
+
+
+def get_steps_inside_accelerator_from_outside_steps(
+ outside_steps: int, dataloader_one_pass_outside_steps: int,
+ dataloader_one_pass_steps_inside_accelerator: int,
+ gradient_accumulation_steps: int, num_processes: int
+):
+ """
+ Convert "outside" steps (as observed in wandb logger or similar context)
+ to the corresponding number of "inside" steps (for accelerate lr scheduler).
+
+ Specifically, accelerate lr scheduler call `step()` `num_processes` times for
+ every `gradient_accumulation_steps` outside steps.
+
+ Args:
+ outside_steps:
+ The total number of steps counted outside accelerate context.
+ dataloader_one_pass_outside_steps:
+ The number of steps (batches) to complete one pass of the dataloader
+ outside accelerate.
+ dataloader_one_pass_steps_inside_accelerator:
+ The number of `lr_scheduler.step()` calls inside accelerate, calculated via
+ `get_dataloader_one_pass_steps_inside_accelerator`.
+ gradient_accumulation_steps:
+ The number of steps to accumulate gradients.
+ num_processes:
+ The number of parallel processes (GPUs) used in distributed training.
+
+ Returns:
+ int: The total number of `lr_scheduler.step()` calls inside accelerate that
+ correspond to the given `outside_steps`.
+ """
+ num_dataloader_epochs_passed = outside_steps // dataloader_one_pass_outside_steps
+ remaining_outside_steps = outside_steps % dataloader_one_pass_outside_steps
+ remaining_inside_accelerator_steps = (
+ remaining_outside_steps // gradient_accumulation_steps * num_processes
+ )
+ # accelerate scheduler call `step()` `num_processes` times every
+ # `gradient_accumulation_steps` steps:
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/scheduler.py#L76
+ total_steps = (
+ num_dataloader_epochs_passed*
+ dataloader_one_pass_steps_inside_accelerator +
+ remaining_inside_accelerator_steps
+ )
+ return total_steps
+
+
+def lr_scheduler_param_adapter(
+ config_dict: dict[str, Any], num_training_steps: int, num_warmup_steps: int
+) -> dict[str, Any]:
+ target_class = config_dict["_target_"]
+ return_dict = copy.deepcopy(config_dict)
+ if target_class == "transformers.get_scheduler":
+ return_dict.update({
+ "num_training_steps": num_training_steps,
+ "num_warmup_steps": num_warmup_steps
+ })
+
+ return return_dict
diff --git a/utils/tests/test_logging.py b/utils/tests/test_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ce9d080a7dd63c013ea7382e160d8b4ecfa6adb
--- /dev/null
+++ b/utils/tests/test_logging.py
@@ -0,0 +1,19 @@
+import unittest
+from pathlib import Path
+
+from utils.logging import LoggingLogger
+
+
+class TestLoggingLogger(unittest.TestCase):
+ def setUp(self):
+ self.tmp_log_path = Path("./tmp_logging.txt")
+
+ def test_logging_info(self):
+ logger = LoggingLogger(filename=self.tmp_log_path,
+ level="INFO").create_instance()
+ logger.info("logging information")
+ self.assertTrue(self.tmp_log_path.exists())
+
+ def tearDown(self):
+ if self.tmp_log_path.exists():
+ self.tmp_log_path.unlink()
diff --git a/utils/torch_utilities.py b/utils/torch_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4c46e2e2013479d608e6c1aded2126165443a9a
--- /dev/null
+++ b/utils/torch_utilities.py
@@ -0,0 +1,288 @@
+import logging
+import math
+from typing import Callable
+from pathlib import Path
+import numpy as np
+import torch
+import torch.nn as nn
+
+logger = logging.Logger(__file__)
+
+
+def remove_key_prefix_factory(prefix: str = "module."):
+ def func(
+ model_dict: dict[str, torch.Tensor], state_dict: dict[str,
+ torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+
+ state_dict = {
+ key[len(prefix):]: value
+ for key, value in state_dict.items() if key.startswith(prefix)
+ }
+ return state_dict
+
+ return func
+
+
+def merge_matched_keys(
+ model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor]
+) -> dict[str, torch.Tensor]:
+ """
+ Args:
+ model_dict:
+ The state dict of the current model, which is going to load pretrained parameters
+ state_dict:
+ A dictionary of parameters from a pre-trained model.
+
+ Returns:
+ dict[str, torch.Tensor]:
+ The updated state dict, where parameters with matched keys and shape are
+ updated with values in `state_dict`.
+ """
+ pretrained_dict = {}
+ mismatch_keys = []
+ for key, value in state_dict.items():
+ if key in model_dict and model_dict[key].shape == value.shape:
+ pretrained_dict[key] = value
+ else:
+ mismatch_keys.append(key)
+ logger.info(
+ f"Loading pre-trained model, with mismatched keys {mismatch_keys}"
+ )
+ model_dict.update(pretrained_dict)
+ return model_dict
+
+
+def load_pretrained_model(
+ model: nn.Module,
+ ckpt_or_state_dict: str | Path | dict[str, torch.Tensor],
+ state_dict_process_fn: Callable = merge_matched_keys
+) -> None:
+ state_dict = ckpt_or_state_dict
+ if not isinstance(state_dict, dict):
+ state_dict = torch.load(ckpt_or_state_dict, "cpu")
+
+ model_dict = model.state_dict()
+ state_dict = state_dict_process_fn(model_dict, state_dict)
+ model.load_state_dict(state_dict)
+
+
+def create_mask_from_length(
+ lengths: torch.Tensor, max_length: int | None = None
+):
+ if max_length is None:
+ max_length = max(lengths)
+ idxs = torch.arange(max_length).reshape(1, -1) # (1, max_length)
+ mask = idxs.to(lengths.device) < lengths.view(-1, 1)
+ # (1, max_length) < (batch_size, 1) -> (batch_size, max_length)
+ return mask
+
+
+def loss_with_mask(
+ loss: torch.Tensor,
+ mask: torch.Tensor,
+ reduce: bool = True
+) -> torch.Tensor:
+ """
+ Apply a mask to the loss tensor and optionally reduce it.
+
+ Args:
+ loss: Tensor of shape (b, t, ...) representing the loss values.
+ mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.
+ reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,).
+
+ Returns:
+ torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,).
+ """
+ expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)]
+ expanded_mask = expanded_mask.expand_as(loss)
+ masked_loss = loss * expanded_mask
+
+ sum_dims = tuple(range(1, loss.ndim))
+ loss_sum = masked_loss.sum(dim=sum_dims)
+ mask_sum = expanded_mask.sum(dim=sum_dims)
+ loss = loss_sum / mask_sum
+
+ if reduce:
+ return loss.mean()
+ else:
+ return loss
+
+
+def convert_pad_shape(pad_shape: list[list[int]]):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor):
+ device = duration.device
+
+ b, t_x, t_y = mask.shape
+ cum_duration = torch.cumsum(duration, 1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = create_mask_from_length(cum_duration_flat, t_y).float()
+ path = path.view(b, t_x, t_y)
+ # take the diff on the `t_x` axis
+ path = path - torch.nn.functional.pad(
+ path, convert_pad_shape([[0, 0], [1, 0], [0, 0]])
+ )[:, :-1]
+ path = path * mask
+ return path
+
+
+def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int):
+ """
+ Adjusts the size of the specified dimension of tensor x to match `target_length`.
+
+ Args:
+ x:
+ Input tensor.
+ target_length:
+ Desired size of the specified dimension.
+ length_dim:
+ The dimension to modify.
+
+ Returns:
+ torch.Tensor: The adjusted tensor.
+ """
+ current_length = x.shape[length_dim]
+
+ if current_length > target_length:
+ # Truncate the tensor
+ slices = [slice(None)] * x.ndim
+ slices[length_dim] = slice(0, target_length)
+ return x[tuple(slices)]
+
+ elif current_length < target_length:
+ # Pad the tensor with zeros
+ pad_shape = list(x.shape)
+ pad_length = target_length - current_length
+
+ pad_shape[length_dim] = pad_length # Shape for left padding
+ padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
+
+ return torch.cat([x, padding], dim=length_dim)
+
+ return x
+
+
+def concat_non_padding(
+ seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor,
+ mask2: torch.BoolTensor
+):
+ """
+ Args
+ seq1 : Tensor (B, L1, E)
+ First sequence.
+ mask1 : BoolTensor (B, L1)
+ True for valid tokens in seq1, False for padding.
+ seq2 : Tensor (B, L2, E)
+ Second sequence.
+ mask2 : BoolTensor (B, L2)
+ True for valid tokens in seq2, False for padding.
+
+ Returns
+ concat_seq : Tensor (B, L1+L2, E)
+ Both sequences concatenated; valid tokens are left-aligned,
+ padding on the right is 0.
+ concat_mask: BoolTensor (B, L1+L2)
+ Mask for the concatenated sequence.
+ perm : LongTensor (B, L1+L2)
+ Permutation that maps **original indices → new indices**.
+ Needed for restoring the original sequences.
+ """
+ mask1, mask2 = mask1.bool(), mask2.bool()
+ B, L1, E = seq1.shape
+ L2 = seq2.size(1)
+ L = L1 + L2
+
+ seq_cat = torch.cat([seq1, seq2], dim=1) # (B, L, E)
+ mask_cat = torch.cat([mask1, mask2], dim=1) # (B, L)
+
+ # ----- Key step: stable sort so that all valid tokens move to the left -----
+ # Padding positions get +L, guaranteeing the largest “score” → sorted to the end.
+ positions = torch.arange(L, device=seq_cat.device).unsqueeze(0) # (1, L)
+ sort_score = positions + (~mask_cat) * L
+ perm = sort_score.argsort(dim=1, stable=True) # (B, L)
+
+ # Build concatenated sequence & mask
+ gather_idx = perm.unsqueeze(-1).expand(-1, -1, E) # (B, L, E)
+ concat_seq = seq_cat.gather(1, gather_idx)
+ concat_mask = mask_cat.gather(1, perm)
+
+ # Explicitly zero out the right-hand padding region for safety
+ concat_seq = concat_seq * concat_mask.unsqueeze(-1)
+
+ return concat_seq, concat_mask, perm
+
+
+def restore_from_concat(
+ concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor,
+ perm: torch.LongTensor
+):
+ """
+ Restore (seq1, seq2) from the concatenated sequence produced by
+ `concat_non_padding`, using the returned permutation `perm`.
+ Fully vectorised — no Python loops.
+ """
+ mask1, mask2 = mask1.bool(), mask2.bool()
+ B, L1 = mask1.shape
+ L2 = mask2.size(1)
+ E = concat_seq.size(-1)
+
+ # Inverse permutation: maps **new_idx → old_idx**
+ inv_perm = torch.empty_like(perm)
+ inv_perm.scatter_(
+ 1, perm,
+ torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1)
+ )
+
+ # Bring tokens back to their original order
+ gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E)
+ seq_cat_rec = concat_seq.gather(1, gather_idx) # (B, L1+L2, E)
+
+ # Split back into the two sequences and mask out padding positions
+ seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1)
+ seq1_restore = seq1_restore * mask1.unsqueeze(-1)
+ seq2_restore = seq2_restore * mask2.unsqueeze(-1)
+
+ return seq1_restore, seq2_restore
+
+
+def contains_nan(data):
+ """check if data contains NaN"""
+ if isinstance(data, torch.Tensor):
+ return torch.isnan(data).any().item()
+ elif isinstance(data, np.ndarray):
+ return np.isnan(data).any()
+ elif isinstance(data, float):
+ return math.isnan(data)
+ elif isinstance(data, (list, tuple)):
+ return any(contains_nan(x) for x in data)
+ elif isinstance(data, dict):
+ return any(contains_nan(v) for v in data.values())
+ return False
+
+
+def check_nan_in_batch(batch):
+ """check if batch contains NaN and return nan audio ids"""
+ assert type(batch)==dict,"batch type error"
+ nan_audio_ids=[]
+ audio_ids=batch["audio_id"]
+ audio_id2content={}
+ for idx,audio_id in enumerate(audio_ids):
+ content=[]
+ for k,v in batch.items():
+ if k=="audio_id":
+ continue
+ content.append(v[idx])
+ audio_id2content[audio_id]=content
+
+ for audio_id,content in audio_id2content.items():
+ if contains_nan(content):
+ nan_audio_ids.append(audio_id)
+ print(f"{audio_id} contains NaN")
+ return nan_audio_ids
+
diff --git a/utils/video.py b/utils/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c02a1ad1e01f00e74c39e1a60f7af533ed8b568
--- /dev/null
+++ b/utils/video.py
@@ -0,0 +1,44 @@
+from pathlib import Path
+import os
+from moviepy import VideoFileClip, AudioFileClip
+from moviepy.audio.fx import AudioLoop
+
+
+def merge_audio_video(
+ audio_path: str | Path,
+ video_path: str | Path,
+ target_path: str | Path,
+ backend: str = "moviepy",
+ logging: bool = False
+):
+ """
+ Merge audio and video into a single file.
+
+ Args:
+ audio_path (str | Path): Path to the audio file.
+ video_path (str | Path): Path to the video file.
+ target_path (str | Path): Path to the target file.
+ backend (str, optional): The backend to use for merging. Defaults to "moviepy".
+ """
+ assert backend in [
+ "moviepy", "ffmpeg"
+ ], "Backend should be moviepy or ffmpeg"
+ if backend == "moviepy":
+ video = VideoFileClip(video_path.__str__())
+ audio = AudioFileClip(audio_path.__str__())
+
+ video = video.with_audio(audio)
+
+ target_path = Path(target_path)
+ video.write_videofile(
+ target_path,
+ logger=None if not logging else "bar",
+ threads=8,
+ preset="ultrafast",
+ ffmpeg_params=["-crf", "23"]
+ )
+ else:
+ logging_arg = "" if logging else "-loglevel quiet"
+ command = f"ffmpeg {logging_arg} -i '{video_path.__str__()}' -i '{audio_path.__str__()}' -c:v copy " \
+ f"-c:a copy -map 0:v:0 -map 1:a:0 '{target_path.__str__()}'"
+ os.system(command)
diff --git a/wav/audio.wav b/wav/audio.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a5d1b3c0a3e64ca66bc88605219bb17402867b56
--- /dev/null
+++ b/wav/audio.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e947b99300cc46dde56a1367958802d7736c2e3efb640c44c58d50122e41c2a8
+size 480044
diff --git a/wav/speech.wav b/wav/speech.wav
new file mode 100644
index 0000000000000000000000000000000000000000..933c2776795b9369347efb1103fd9764e6b7b861
--- /dev/null
+++ b/wav/speech.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d23cb7106c66ad61d2b9717daea77385883cf71772836a8c5d18b9496dbb8d5
+size 130604