| | import math |
| | from dataclasses import dataclass |
| |
|
| | import dac |
| | import torch |
| | import torch.nn as nn |
| | from Amphion.models.codec.ns3_codec import FACodecDecoder, FACodecEncoder |
| | from einops import rearrange, repeat |
| | from huggingface_hub import hf_hub_download |
| | from simple_parsing import Serializable |
| | from torch.nn import CrossEntropyLoss |
| |
|
| |
|
| | @dataclass |
| | class ModelArgs(Serializable): |
| | rnn_num_repeat: int = 3 |
| | num_layers: int = 1 |
| | num_channels: int = 6 |
| | feat_dim: int = 512 |
| | sr: int = 16000 |
| | dropout: float = 0.0 |
| | cb_size: int = 1024 |
| | num_codebooks: int = 12 |
| |
|
| |
|
| | class ResRNN(nn.Module): |
| | def __init__(self, input_size, hidden_size, output_size=None, num_layers=1, bidirectional=True, use_residual=True): |
| | super().__init__() |
| | self.norm = nn.GroupNorm(num_groups=1, num_channels=input_size) |
| | self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional) |
| |
|
| | if output_size is None: |
| | output_size = input_size |
| |
|
| | self.proj = nn.Linear(hidden_size * 2, output_size) if bidirectional else nn.Linear(hidden_size, output_size) |
| |
|
| | self.use_residual = use_residual |
| |
|
| | def forward(self, input): |
| | |
| | o = self.norm(input) |
| | o = rearrange(o, "b f t -> b t f") |
| | o, _ = self.rnn(o) |
| | o = self.proj(o) |
| | o = rearrange(o, "b t f -> b f t") |
| |
|
| | if self.use_residual: |
| | return input + o |
| | else: |
| | return o |
| |
|
| |
|
| | class FusionRNN(nn.Module): |
| | def __init__(self, feat_dim, num_channels, num_layers=1, num_codebooks=12, is_last_layer=False): |
| | super().__init__() |
| | self.mix_enroll_fusion_projector = nn.Sequential( |
| | nn.GroupNorm(num_groups=1, num_channels=feat_dim * 2), |
| | nn.Conv1d(feat_dim * 2, feat_dim, kernel_size=1), |
| | nn.ReLU(), |
| | nn.GroupNorm(num_groups=1, num_channels=feat_dim), |
| | nn.Conv1d(feat_dim, feat_dim, kernel_size=1), |
| | nn.ReLU(), |
| | ) |
| |
|
| | |
| | if is_last_layer: |
| | self.sequence_model = ResRNN( |
| | input_size=feat_dim * num_channels, |
| | hidden_size=feat_dim * num_codebooks, |
| | output_size=feat_dim * num_codebooks, |
| | num_layers=num_layers, |
| | bidirectional=True, |
| | use_residual=False, |
| | ) |
| |
|
| | self.decoder = nn.Sequential( |
| | nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_codebooks), |
| | nn.Conv1d(feat_dim * num_codebooks, feat_dim * num_codebooks, kernel_size=1), |
| | nn.ReLU(), |
| | nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_codebooks), |
| | nn.Conv1d(feat_dim * num_codebooks, feat_dim * num_codebooks, kernel_size=1), |
| | nn.ReLU(), |
| | ) |
| |
|
| | self.output_hidden_dim = num_codebooks |
| | else: |
| | self.sequence_model = ResRNN( |
| | input_size=feat_dim * num_channels, |
| | hidden_size=feat_dim * num_channels, |
| | output_size=feat_dim * num_channels, |
| | num_layers=num_layers, |
| | bidirectional=True, |
| | ) |
| | self.decoder = nn.Sequential( |
| | nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_channels), |
| | nn.Conv1d(feat_dim * num_channels, feat_dim * num_channels, kernel_size=1), |
| | nn.ReLU(), |
| | nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_channels), |
| | nn.Conv1d(feat_dim * num_channels, feat_dim * num_channels, kernel_size=1), |
| | nn.ReLU(), |
| | ) |
| | self.output_hidden_dim = num_channels |
| |
|
| | def forward(self, input, enroll_feat, current_layer=0): |
| | |
| | |
| | batch_size, num_channels, _, num_frames = input.shape |
| |
|
| | |
| | input = torch.cat([input, enroll_feat], dim=-2) |
| | input = rearrange(input, "b c hx2 t -> (b c) hx2 t") |
| | input = self.mix_enroll_fusion_projector(input) |
| | input = rearrange(input, "(b c) h t -> b c h t", b=batch_size, c=num_channels) |
| |
|
| | |
| | input = rearrange(input, "b c h t -> b (c h) t") |
| | input = self.sequence_model(input) |
| |
|
| | |
| | input = self.decoder(input) |
| | input = rearrange(input, "b (c h) t -> b c h t", c=self.output_hidden_dim) |
| | return input |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__(self, args: ModelArgs): |
| | super(Model, self).__init__() |
| | self.args = args |
| |
|
| | |
| | dac_codec_ckpt_path = dac.utils.download(model_type="16khz") |
| | self.dac_codec = dac.DAC.load(dac_codec_ckpt_path) |
| | self.dac_codec.eval() |
| | for param in self.dac_codec.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.mix_encoder_projector = nn.Sequential( |
| | nn.GroupNorm(num_groups=1, num_channels=args.cb_size), |
| | nn.Conv1d(args.cb_size, self.args.feat_dim, kernel_size=1), |
| | nn.ReLU(), |
| | nn.GroupNorm(num_groups=1, num_channels=self.args.feat_dim), |
| | nn.Conv1d(self.args.feat_dim, self.args.feat_dim, kernel_size=1), |
| | nn.ReLU(), |
| | ) |
| |
|
| | |
| | self.clue_encoder_projector = nn.Sequential( |
| | nn.GroupNorm(num_groups=1, num_channels=1024), |
| | nn.Conv1d(1024, args.feat_dim, kernel_size=1), |
| | nn.ReLU(), |
| | nn.GroupNorm(num_groups=1, num_channels=args.feat_dim), |
| | nn.Conv1d(args.feat_dim, args.feat_dim, kernel_size=1), |
| | nn.ReLU(), |
| | ) |
| |
|
| | |
| | self.extractors = nn.ModuleList([]) |
| | for i in range(args.rnn_num_repeat): |
| | self.extractors.append( |
| | FusionRNN( |
| | feat_dim=args.feat_dim, |
| | num_channels=args.num_channels, |
| | num_layers=args.num_layers, |
| | num_codebooks=args.num_codebooks, |
| | is_last_layer=(i == args.rnn_num_repeat - 1), |
| | ) |
| | ) |
| |
|
| | |
| | self.lm_head = nn.Linear(args.feat_dim, args.cb_size, bias=False) |
| |
|
| | def _load_enroll_encoder(self): |
| | fa_encoder = FACodecEncoder(ngf=32, up_ratios=[2, 4, 5, 5], out_channels=256) |
| | fa_encoder.load_state_dict( |
| | torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")) |
| | ) |
| | fa_encoder.eval() |
| | for param in fa_encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | fa_decoder = FACodecDecoder( |
| | in_channels=256, |
| | upsample_initial_channel=1024, |
| | ngf=32, |
| | up_ratios=[5, 5, 4, 2], |
| | vq_num_q_c=2, |
| | vq_num_q_p=1, |
| | vq_num_q_r=3, |
| | vq_dim=256, |
| | codebook_dim=8, |
| | codebook_size_prosody=10, |
| | codebook_size_content=10, |
| | codebook_size_residual=10, |
| | use_gr_x_timbre=True, |
| | use_gr_residual_f0=True, |
| | use_gr_residual_phone=True, |
| | ) |
| | fa_decoder.load_state_dict( |
| | torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")) |
| | ) |
| | fa_decoder.eval() |
| | for param in fa_decoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | return fa_encoder, fa_decoder |
| |
|
| | def pad(self, audio_data): |
| | """Add padding to the input audio data. Adopted from DAC's `preprocess` method.""" |
| | length = audio_data.shape[-1] |
| | right_pad = (math.ceil(length / self.dac_codec.hop_length) + 1) * self.dac_codec.hop_length - length |
| | audio_data = nn.functional.pad(audio_data, (0, right_pad)) |
| | return audio_data |
| |
|
| | def forward(self, mix_y: torch.Tensor, enroll_y: torch.Tensor, clean_y: torch.Tensor = None): |
| | """ |
| | Args: |
| | mix_y (`torch.Tensor` of shape `(batch_size, num_channels, num_samples)`): |
| | The multi-channel mixture waveform. |
| | enroll_y (`torch.Tensor` of shape `(batch_size, num_samples)`): |
| | The mono-channel enrollment waveform. |
| | clean_y: (`torch.Tensor` of shape `(batch_size, num_samples)`): |
| | The reference-channel clean waveform. |
| | |
| | Note: |
| | `n`: number of codebooks |
| | `c`: number of microphone channels |
| | `h`: hidden size |
| | """ |
| | batch_size, num_channels, num_samples = mix_y.shape |
| | mix_y = self.pad(mix_y) |
| |
|
| | |
| | mix_y = rearrange(mix_y, "b c t -> (b c) 1 t") |
| | |
| | z, codes, latents, _, _ = self.dac_codec.encode(mix_y) |
| | *_, num_frames = z.shape |
| | mix_feat = self.mix_encoder_projector(z) |
| | mix_feat = rearrange(mix_feat, "(b c) h t -> b c h t", b=batch_size) |
| |
|
| | |
| | clue_z, *_ = self.dac_codec.encode(rearrange(enroll_y, "b t -> b 1 t")) |
| | spk_emb = self.clue_encoder_projector(clue_z) |
| | spk_emb = torch.mean(spk_emb, dim=-1) |
| | spk_emb = repeat(spk_emb, "b h -> b c h t", c=num_channels, t=num_frames) |
| |
|
| | |
| | for layer_idx, extractor in enumerate(self.extractors): |
| | sub_band_feat = extractor(mix_feat, spk_emb, layer_idx) |
| |
|
| | |
| | sub_band_feat = rearrange(sub_band_feat, "b n h t -> b n t h") |
| | logits = self.lm_head(sub_band_feat) |
| | logits = rearrange(logits, "b n t cb -> b cb n t") |
| |
|
| | |
| | loss = None |
| | if clean_y is not None: |
| | clean_y = rearrange(clean_y, "b t -> b 1 t") |
| | clean_y = self.pad(clean_y) |
| | clean_z, clean_codes, _, _, _ = self.dac_codec.encode(clean_y) |
| |
|
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(logits, clean_codes.long()) |
| |
|
| | |
| | |
| | return logits if clean_y is None else logits, loss |
| |
|
| | def decode(self, logits: torch.Tensor, num_samples: int): |
| | |
| | logits = rearrange(logits, "b fc n t -> b n t fc") |
| | logits = logits.argmax(dim=-1) |
| |
|
| | |
| | z = self.dac_codec.quantizer.from_codes(logits)[0] |
| | enh_y = self.dac_codec.decode(z) |
| |
|
| | enh_y = rearrange(enh_y, "b () t -> b t") |
| | enh_y = enh_y[:, :num_samples] |
| |
|
| | return enh_y |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = Model(ModelArgs()) |
| |
|
| | mixture = torch.rand(2, 6, 16000 * 4) |
| | clean = torch.rand(2, 16000 * 4) |
| | enroll_y = torch.rand(2, 16000) |
| |
|
| | output = model(mixture, enroll_y, clean) |
| | print(output[0].shape) |
| | print(output[1]) |
| |
|
| | out = model.decode(output[0], 16000) |
| | print(out.shape) |
| |
|