| from dataclasses import dataclass |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from Amphion.models.codec.ns3_codec import FACodecDecoder, FACodecEncoder |
| from einops import rearrange, repeat |
| from huggingface_hub import hf_hub_download |
| from simple_parsing import Serializable, list_field |
|
|
| from audiozen.acoustics.audio_feature import istft, stft |
|
|
|
|
| BAND_WIDTH_16K = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 16, 16, 16, 3] |
| BAND_WIDTH_32K = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 16, 16, 16, 7] |
|
|
|
|
| @dataclass |
| class ModelArgs(Serializable): |
| band_width_list: list[float] = list_field(*BAND_WIDTH_16K) |
| freq_feat_dim: int = 128 |
| freq_axis_hidden_size: int = 128 |
| temp_axis_hidden_size: int = 128 |
| triple_path_rnn_num_repeat: int = 3 |
| num_layers: int = 1 |
| num_channels: int = 6 |
| sr: int = 16000 |
| num_layer: int = 1 |
| dropout: float = 0.0 |
| n_fft: int = 512 |
| hop_length: int = 128 |
| win_length: int = 512 |
| |
| erb_subband_1: int = 24 |
| erb_subband_2: int = 16 |
| sfe_kernel_size: int = 3 |
| sfe_stride: int = 1 |
|
|
|
|
| class SubbandFeatureExtractor(nn.Module): |
| def __init__(self, kernel_size=3, stride=1): |
| super().__init__() |
|
|
| self.kernel_size = kernel_size |
| self.unfold = nn.Unfold(kernel_size=(1, kernel_size), stride=(1, stride), padding=(0, (kernel_size - 1) // 2)) |
|
|
| def forward(self, x): |
| """ |
| x: [b, c, t, f] |
| """ |
| b, c, t, f = x.shape |
| x = self.unfold(x) |
| x = x.reshape(b, c * self.kernel_size, t, f) |
| return x |
|
|
|
|
| class ERB(nn.Module): |
| def __init__(self, erb_subband_1, erb_subband_2, n_fft=512, high_lim=8000, sr=16000): |
| super().__init__() |
| erb_filter_banks = self.build_erb_filter_banks(erb_subband_1, erb_subband_2, n_fft, high_lim, sr) |
| num_freqs = n_fft // 2 + 1 |
|
|
| self.erb_fc = nn.Linear(num_freqs - erb_subband_1, erb_subband_2, bias=False) |
| self.inverse_erb_fc = nn.Linear(erb_subband_2, num_freqs - erb_subband_1, bias=False) |
|
|
| self.erb_fc.weight = nn.Parameter(erb_filter_banks, requires_grad=False) |
| self.inverse_erb_fc.weight = nn.Parameter(erb_filter_banks.T, requires_grad=False) |
|
|
| self.erb_subband_1 = erb_subband_1 |
|
|
| def hz2erb(self, freq_hz): |
| erb_f = 24.7 * np.log10(0.00437 * freq_hz + 1) |
| return erb_f |
|
|
| def erb2hz(self, erb_f): |
| freq_hz = (10 ** (erb_f / 24.7) - 1) / 0.00437 |
| return freq_hz |
|
|
| def build_erb_filter_banks(self, erb_subband_1, erb_subband_2, n_fft=512, high_lim=8000, sr=16000): |
| low_lim = erb_subband_1 / n_fft * sr |
| erb_low = self.hz2erb(low_lim) |
| erb_high = self.hz2erb(high_lim) |
| erb_points = np.linspace(erb_low, erb_high, erb_subband_2) |
|
|
| bins = np.round(self.erb2hz(erb_points) / sr * n_fft).astype(np.int32) |
| erb_filters = np.zeros([erb_subband_2, n_fft // 2 + 1], dtype=np.float32) |
|
|
| erb_filters[0, bins[0] : bins[1]] = (bins[1] - np.arange(bins[0], bins[1]) + 1e-12) / ( |
| bins[1] - bins[0] + 1e-12 |
| ) |
| for i in range(erb_subband_2 - 2): |
| erb_filters[i + 1, bins[i] : bins[i + 1]] = (np.arange(bins[i], bins[i + 1]) - bins[i] + 1e-12) / ( |
| bins[i + 1] - bins[i] + 1e-12 |
| ) |
| erb_filters[i + 1, bins[i + 1] : bins[i + 2]] = ( |
| bins[i + 2] - np.arange(bins[i + 1], bins[i + 2]) + 1e-12 |
| ) / (bins[i + 2] - bins[i + 1] + 1e-12) |
|
|
| erb_filters[-1, bins[-2] : bins[-1] + 1] = 1 - erb_filters[-2, bins[-2] : bins[-1] + 1] |
|
|
| erb_filters = erb_filters[:, erb_subband_1:] |
|
|
| |
| return torch.from_numpy(np.abs(erb_filters)) |
|
|
| def band_merge(self, x): |
| |
| x_low = x[..., : self.erb_subband_1] |
| x_hight = self.erb_fc(x[..., self.erb_subband_1 :]) |
| return torch.cat([x_low, x_hight], dim=-1) |
|
|
| def band_split(self, x_erb): |
| |
| x_erb_low = x_erb[..., : self.erb_subband_1] |
| x_erb_hight = self.inverse_erb_fc(x_erb[..., self.erb_subband_1 :]) |
| return torch.cat([x_erb_low, x_erb_hight], dim=-1) |
|
|
|
|
| class ResRNN(nn.Module): |
| def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True): |
| super().__init__() |
| self.norm = nn.GroupNorm(num_groups=1, num_channels=input_size) |
| self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional) |
| self.proj = nn.Linear(hidden_size * 2, input_size) if bidirectional else nn.Linear(hidden_size, input_size) |
|
|
| def forward(self, input): |
| |
| o = self.norm(input) |
| o = rearrange(o, "b f t -> b t f") |
| o, _ = self.rnn(o) |
| o = self.proj(o) |
| o = rearrange(o, "b t f -> b f t") |
| return input + o |
|
|
|
|
| class TriplePathRNN(nn.Module): |
| def __init__(self, freq_feat_dim, freq_rnn_hidden_size, temp_rnn_hidden_size, num_layers=1): |
| super().__init__() |
|
|
| |
| self.freq_rnn = ResRNN( |
| input_size=freq_feat_dim, |
| hidden_size=freq_rnn_hidden_size * 2, |
| num_layers=num_layers, |
| bidirectional=True, |
| ) |
|
|
| |
| self.temp_rnn = ResRNN( |
| input_size=freq_feat_dim, |
| hidden_size=temp_rnn_hidden_size * 2, |
| num_layers=num_layers, |
| bidirectional=True, |
| ) |
|
|
| self.mix_enroll_fusion_layer = nn.Linear(freq_feat_dim * 2, freq_feat_dim) |
| |
| self.mix_enroll_fusion_layer.weight = nn.Parameter( |
| torch.cat([torch.zeros(freq_feat_dim, freq_feat_dim), torch.eye(freq_feat_dim)], -1) |
| ) |
| self.mix_enroll_fusion_layer.bias = nn.Parameter(torch.zeros_like(self.mix_enroll_fusion_layer.bias)) |
|
|
| def forward(self, input, enroll_feat, current_layer=0): |
| |
| |
| batch_size, num_sub_bands, freq_feat_dim, num_frames = input.shape |
|
|
| if current_layer in [2, 3]: |
| |
| input = torch.cat([input, enroll_feat], dim=-2) |
| input = rearrange(input, "b n ffx2 t -> (b n) t ffx2") |
| input = self.mix_enroll_fusion_layer(input) |
| input = rearrange(input, "(b n) t ff -> b n ff t", b=batch_size, n=num_sub_bands) |
|
|
| |
| input = rearrange(input, "b n ff t -> (b n) ff t", b=batch_size, n=num_sub_bands) |
| input = self.temp_rnn(input) |
|
|
| |
| input = rearrange(input, "(b n) ff t -> (b t) ff n", b=batch_size, n=num_sub_bands) |
| input = self.freq_rnn(input) |
|
|
| return rearrange(input, "(b t) ff n -> b n ff t", b=batch_size) |
|
|
|
|
| class Model(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super(Model, self).__init__() |
| self.args = args |
| self.stft = partial(stft, n_fft=args.n_fft, hop_length=args.hop_length, win_length=args.win_length) |
| self.istft = partial(istft, n_fft=args.n_fft, hop_length=args.hop_length, win_length=args.win_length) |
|
|
| ri_dim = 2 |
|
|
| |
| self.ts_vad_proj = nn.Linear(256, args.freq_feat_dim) |
| self.fa_encoder, self.fa_decoder = self._load_codec() |
|
|
| |
| self.erb = ERB(args.erb_subband_1, args.erb_subband_2, args.n_fft, sr=args.sr) |
| self.sfe = SubbandFeatureExtractor(args.sfe_kernel_size, args.sfe_stride) |
| self.mix_encoder = nn.Sequential( |
| nn.GroupNorm(1, args.num_channels * (ri_dim + 1) * args.sfe_kernel_size), |
| nn.Conv1d(args.num_channels * (ri_dim + 1) * args.sfe_kernel_size, args.freq_feat_dim, 1), |
| ) |
|
|
| |
| self.extractors = nn.ModuleList([]) |
| for i in range(args.triple_path_rnn_num_repeat): |
| self.extractors.append( |
| TriplePathRNN( |
| freq_feat_dim=args.freq_feat_dim, |
| freq_rnn_hidden_size=args.freq_axis_hidden_size, |
| temp_rnn_hidden_size=args.temp_axis_hidden_size, |
| num_layers=args.num_layers, |
| ) |
| ) |
|
|
| |
| self.enh_decoder = nn.Sequential( |
| nn.GroupNorm(1, args.freq_feat_dim), |
| nn.Conv1d(args.freq_feat_dim, args.freq_feat_dim * 2, 1), |
| nn.Tanh(), |
| nn.Conv1d(args.freq_feat_dim * 2, args.freq_feat_dim * 2, 1), |
| nn.Tanh(), |
| nn.Conv1d(args.freq_feat_dim * 2, ri_dim * 2, 1), |
| ) |
|
|
| def _load_codec(self): |
| fa_encoder = FACodecEncoder(ngf=32, up_ratios=[2, 4, 5, 5], out_channels=256) |
| fa_encoder.load_state_dict( |
| torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")) |
| ) |
| fa_encoder.eval() |
| for param in fa_encoder.parameters(): |
| param.requires_grad = False |
|
|
| fa_decoder = FACodecDecoder( |
| in_channels=256, |
| upsample_initial_channel=1024, |
| ngf=32, |
| up_ratios=[5, 5, 4, 2], |
| vq_num_q_c=2, |
| vq_num_q_p=1, |
| vq_num_q_r=3, |
| vq_dim=256, |
| codebook_dim=8, |
| codebook_size_prosody=10, |
| codebook_size_content=10, |
| codebook_size_residual=10, |
| use_gr_x_timbre=True, |
| use_gr_residual_f0=True, |
| use_gr_residual_phone=True, |
| ) |
| fa_decoder.load_state_dict( |
| torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")) |
| ) |
| fa_decoder.eval() |
| for param in fa_decoder.parameters(): |
| param.requires_grad = False |
|
|
| return fa_encoder, fa_decoder |
|
|
| def forward(self, mix_y: torch.Tensor, enroll_y: torch.Tensor): |
| |
| |
| |
| |
| |
| |
| batch_size, num_channels, num_samples = mix_y.shape |
| ref_ch = 0 |
|
|
| |
| std = torch.std(mix_y, dim=(1, 2), keepdim=True) |
| mix_y = mix_y / std |
|
|
| |
| mix_mag_spec, mix_phase_spec, mix_real_spec, mix_imag_spec = self.stft(mix_y) |
| *_, num_freqs, num_frames = mix_mag_spec.shape |
|
|
| |
| enc_out = self.fa_encoder(rearrange(enroll_y, "b t -> b 1 t")) |
| *_, spk_embs = self.fa_decoder(enc_out, eval_vq=False, vq=True) |
| ts_vad_emb = self.ts_vad_proj(spk_embs) |
| ts_vad_emb = repeat( |
| ts_vad_emb, "b e -> b n e t", t=num_frames, n=self.args.erb_subband_1 + self.args.erb_subband_2 |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
| mix_cc_spec = torch.complex(mix_real_spec, mix_imag_spec) |
| mix_cc_mono_spec = mix_cc_spec[:, ref_ch] |
| mix_input_spec = torch.cat([mix_mag_spec, mix_real_spec, mix_mag_spec], dim=1) |
| mix_input_spec = rearrange(mix_input_spec, "b c f t -> b c t f") |
|
|
| mix_erb_spec = self.erb.band_merge(mix_input_spec) |
| mix_erb_spec = self.sfe(mix_erb_spec) |
| mix_erb_spec = rearrange(mix_erb_spec, "b ff t n -> (b n) ff t") |
| mix_erb_spec = self.mix_encoder(mix_erb_spec) |
| sub_band_feat = rearrange(mix_erb_spec, "(b n) ff t -> b n ff t", b=batch_size) |
|
|
| |
| for layer_idx, extractor in enumerate(self.extractors): |
| sub_band_feat = extractor(sub_band_feat, ts_vad_emb, layer_idx) |
|
|
| |
| sub_band_feat = rearrange(sub_band_feat, "b n ff t -> (b n) ff t") |
| enh_mask = self.enh_decoder(sub_band_feat) |
| enh_mask = rearrange(enh_mask, "(b n) (c ri) t -> b c ri t n", b=batch_size, ri=2, c=2) |
| mask = enh_mask[:, 0] * torch.sigmoid(enh_mask[:, 1]) |
|
|
| mask = self.erb.band_split(mask) |
| mask = rearrange(mask, "b ri t f -> b ri f t", b=batch_size) |
|
|
| |
| enh_spec_r = mix_cc_mono_spec.real * mask[:, 0] - mix_cc_mono_spec.imag * mask[:, 1] |
| enh_spec_i = mix_cc_mono_spec.real * mask[:, 1] + mix_cc_mono_spec.imag * mask[:, 0] |
| enh_cc_spec = torch.complex(enh_spec_r, enh_spec_i) |
|
|
| enh_y = self.istft(enh_cc_spec, input_type="complex", length=num_samples) |
|
|
| return enh_y * std.squeeze(1) |
|
|
|
|
| if __name__ == "__main__": |
| from torchinfo import summary |
|
|
| model = Model(ModelArgs()) |
|
|
| mixture = torch.rand(2, 6, 16000) |
| mix_feat = torch.rand(2, 160, 80) |
| enroll_y = torch.rand(2, 16000) |
| enroll_feat = torch.rand(2, 1600, 80) |
|
|
| summary(model) |
|
|
| output = model(mixture, mix_feat, enroll_y, enroll_feat) |
|
|