import math from dataclasses import dataclass import dac import torch import torch.nn as nn from Amphion.models.codec.ns3_codec import FACodecDecoder, FACodecEncoder from einops import rearrange, repeat from huggingface_hub import hf_hub_download from simple_parsing import Serializable from torch.nn import CrossEntropyLoss @dataclass class ModelArgs(Serializable): rnn_num_repeat: int = 3 num_layers: int = 1 num_channels: int = 6 feat_dim: int = 512 sr: int = 16000 dropout: float = 0.0 cb_size: int = 1024 num_codebooks: int = 12 class ResRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size=None, num_layers=1, bidirectional=True, use_residual=True): super().__init__() self.norm = nn.GroupNorm(num_groups=1, num_channels=input_size) self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional) if output_size is None: output_size = input_size self.proj = nn.Linear(hidden_size * 2, output_size) if bidirectional else nn.Linear(hidden_size, output_size) self.use_residual = use_residual def forward(self, input): # input: [b, f, t] o = self.norm(input) o = rearrange(o, "b f t -> b t f") o, _ = self.rnn(o) o = self.proj(o) o = rearrange(o, "b t f -> b f t") if self.use_residual: return input + o else: return o class FusionRNN(nn.Module): def __init__(self, feat_dim, num_channels, num_layers=1, num_codebooks=12, is_last_layer=False): super().__init__() self.mix_enroll_fusion_projector = nn.Sequential( nn.GroupNorm(num_groups=1, num_channels=feat_dim * 2), nn.Conv1d(feat_dim * 2, feat_dim, kernel_size=1), nn.ReLU(), nn.GroupNorm(num_groups=1, num_channels=feat_dim), nn.Conv1d(feat_dim, feat_dim, kernel_size=1), nn.ReLU(), ) # Temporal path if is_last_layer: self.sequence_model = ResRNN( input_size=feat_dim * num_channels, hidden_size=feat_dim * num_codebooks, output_size=feat_dim * num_codebooks, num_layers=num_layers, bidirectional=True, use_residual=False, ) self.decoder = nn.Sequential( nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_codebooks), nn.Conv1d(feat_dim * num_codebooks, feat_dim * num_codebooks, kernel_size=1), nn.ReLU(), nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_codebooks), nn.Conv1d(feat_dim * num_codebooks, feat_dim * num_codebooks, kernel_size=1), nn.ReLU(), ) self.output_hidden_dim = num_codebooks else: self.sequence_model = ResRNN( input_size=feat_dim * num_channels, hidden_size=feat_dim * num_channels, output_size=feat_dim * num_channels, num_layers=num_layers, bidirectional=True, ) self.decoder = nn.Sequential( nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_channels), nn.Conv1d(feat_dim * num_channels, feat_dim * num_channels, kernel_size=1), nn.ReLU(), nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_channels), nn.Conv1d(feat_dim * num_channels, feat_dim * num_channels, kernel_size=1), nn.ReLU(), ) self.output_hidden_dim = num_channels def forward(self, input, enroll_feat, current_layer=0): # input: [b, c, h, t] # enroll_feat: [b, c, h, t] batch_size, num_channels, _, num_frames = input.shape # Prepare the input by concatenating the enrollment feature to the input input = torch.cat([input, enroll_feat], dim=-2) # [b, c, h * 2, t] input = rearrange(input, "b c hx2 t -> (b c) hx2 t") input = self.mix_enroll_fusion_projector(input) # [b * c, h t] input = rearrange(input, "(b c) h t -> b c h t", b=batch_size, c=num_channels) # Temporal path input = rearrange(input, "b c h t -> b (c h) t") input = self.sequence_model(input) # [b, c * h, t] # Decoder input = self.decoder(input) # [b, (c * h), t] input = rearrange(input, "b (c h) t -> b c h t", c=self.output_hidden_dim) # [b, c, h, t] return input class Model(nn.Module): def __init__(self, args: ModelArgs): super(Model, self).__init__() self.args = args # Mixture encoder (DAC codec) dac_codec_ckpt_path = dac.utils.download(model_type="16khz") self.dac_codec = dac.DAC.load(dac_codec_ckpt_path) self.dac_codec.eval() for param in self.dac_codec.parameters(): param.requires_grad = False self.mix_encoder_projector = nn.Sequential( nn.GroupNorm(num_groups=1, num_channels=args.cb_size), nn.Conv1d(args.cb_size, self.args.feat_dim, kernel_size=1), nn.ReLU(), nn.GroupNorm(num_groups=1, num_channels=self.args.feat_dim), nn.Conv1d(self.args.feat_dim, self.args.feat_dim, kernel_size=1), nn.ReLU(), ) # Clue encoder self.clue_encoder_projector = nn.Sequential( nn.GroupNorm(num_groups=1, num_channels=1024), nn.Conv1d(1024, args.feat_dim, kernel_size=1), nn.ReLU(), nn.GroupNorm(num_groups=1, num_channels=args.feat_dim), nn.Conv1d(args.feat_dim, args.feat_dim, kernel_size=1), nn.ReLU(), ) # Target extractor self.extractors = nn.ModuleList([]) for i in range(args.rnn_num_repeat): self.extractors.append( FusionRNN( feat_dim=args.feat_dim, num_channels=args.num_channels, num_layers=args.num_layers, num_codebooks=args.num_codebooks, is_last_layer=(i == args.rnn_num_repeat - 1), ) ) # Predictor self.lm_head = nn.Linear(args.feat_dim, args.cb_size, bias=False) def _load_enroll_encoder(self): fa_encoder = FACodecEncoder(ngf=32, up_ratios=[2, 4, 5, 5], out_channels=256) fa_encoder.load_state_dict( torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")) ) fa_encoder.eval() for param in fa_encoder.parameters(): param.requires_grad = False fa_decoder = FACodecDecoder( in_channels=256, upsample_initial_channel=1024, ngf=32, up_ratios=[5, 5, 4, 2], vq_num_q_c=2, vq_num_q_p=1, vq_num_q_r=3, vq_dim=256, codebook_dim=8, codebook_size_prosody=10, codebook_size_content=10, codebook_size_residual=10, use_gr_x_timbre=True, use_gr_residual_f0=True, use_gr_residual_phone=True, ) fa_decoder.load_state_dict( torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")) ) fa_decoder.eval() for param in fa_decoder.parameters(): param.requires_grad = False return fa_encoder, fa_decoder def pad(self, audio_data): """Add padding to the input audio data. Adopted from DAC's `preprocess` method.""" length = audio_data.shape[-1] right_pad = (math.ceil(length / self.dac_codec.hop_length) + 1) * self.dac_codec.hop_length - length audio_data = nn.functional.pad(audio_data, (0, right_pad)) return audio_data def forward(self, mix_y: torch.Tensor, enroll_y: torch.Tensor, clean_y: torch.Tensor = None): """ Args: mix_y (`torch.Tensor` of shape `(batch_size, num_channels, num_samples)`): The multi-channel mixture waveform. enroll_y (`torch.Tensor` of shape `(batch_size, num_samples)`): The mono-channel enrollment waveform. clean_y: (`torch.Tensor` of shape `(batch_size, num_samples)`): The reference-channel clean waveform. Note: `n`: number of codebooks `c`: number of microphone channels `h`: hidden size """ batch_size, num_channels, num_samples = mix_y.shape mix_y = self.pad(mix_y) # Mixture encoder mix_y = rearrange(mix_y, "b c t -> (b c) 1 t") # [b * c, 1, t] # codes = [b * c, n, t], where n=12; z = [b * c, 1024, t] z, codes, latents, _, _ = self.dac_codec.encode(mix_y) *_, num_frames = z.shape mix_feat = self.mix_encoder_projector(z) # [b * c, h, t] mix_feat = rearrange(mix_feat, "(b c) h t -> b c h t", b=batch_size) # [b, c, h, t] # Clue encoder clue_z, *_ = self.dac_codec.encode(rearrange(enroll_y, "b t -> b 1 t")) # [b, 1024, t] spk_emb = self.clue_encoder_projector(clue_z) # [b, h, t] spk_emb = torch.mean(spk_emb, dim=-1) # [b, h] spk_emb = repeat(spk_emb, "b h -> b c h t", c=num_channels, t=num_frames) # [b, c, h, t] # Target extractor for layer_idx, extractor in enumerate(self.extractors): sub_band_feat = extractor(mix_feat, spk_emb, layer_idx) # [b, n, 128 t] # LLM head sub_band_feat = rearrange(sub_band_feat, "b n h t -> b n t h") logits = self.lm_head(sub_band_feat) logits = rearrange(logits, "b n t cb -> b cb n t") # [b, 1024, 12, t] # Decode the predicted code loss = None if clean_y is not None: clean_y = rearrange(clean_y, "b t -> b 1 t") clean_y = self.pad(clean_y) clean_z, clean_codes, _, _, _ = self.dac_codec.encode(clean_y) # [b, 12, t] loss_fct = CrossEntropyLoss() loss = loss_fct(logits, clean_codes.long()) # logits: [b, 1024, 12, t] # loss: [b] return logits if clean_y is None else logits, loss def decode(self, logits: torch.Tensor, num_samples: int): # logits: [b, 1024, 12, t] logits = rearrange(logits, "b fc n t -> b n t fc") logits = logits.argmax(dim=-1) # [b, n, t] # Decode the predicted code z = self.dac_codec.quantizer.from_codes(logits)[0] # [b, 1024, 50] enh_y = self.dac_codec.decode(z) enh_y = rearrange(enh_y, "b () t -> b t") enh_y = enh_y[:, :num_samples] return enh_y if __name__ == "__main__": model = Model(ModelArgs()) mixture = torch.rand(2, 6, 16000 * 4) clean = torch.rand(2, 16000 * 4) enroll_y = torch.rand(2, 16000) output = model(mixture, enroll_y, clean) print(output[0].shape) print(output[1]) out = model.decode(output[0], 16000) print(out.shape)