haoxiangsnr's picture
Add files using upload-large-folder tool
5e598cd verified
from dataclasses import dataclass
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from Amphion.models.codec.ns3_codec import FACodecDecoder, FACodecEncoder
from einops import rearrange, repeat
from huggingface_hub import hf_hub_download
from simple_parsing import Serializable, list_field
from audiozen.acoustics.audio_feature import istft, stft
BAND_WIDTH_16K = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 16, 16, 16, 3] # fmt: skip
BAND_WIDTH_32K = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 16, 16, 16, 7] # fmt: skip
@dataclass
class ModelArgs(Serializable):
band_width_list: list[float] = list_field(*BAND_WIDTH_16K)
freq_feat_dim: int = 128
freq_axis_hidden_size: int = 128
temp_axis_hidden_size: int = 128
triple_path_rnn_num_repeat: int = 3
num_layers: int = 1
num_channels: int = 6
sr: int = 16000
num_layer: int = 1
dropout: float = 0.0
n_fft: int = 512
hop_length: int = 128
win_length: int = 512
# For ERB
erb_subband_1: int = 24
erb_subband_2: int = 16
sfe_kernel_size: int = 3
sfe_stride: int = 1
class SubbandFeatureExtractor(nn.Module):
def __init__(self, kernel_size=3, stride=1):
super().__init__()
self.kernel_size = kernel_size
self.unfold = nn.Unfold(kernel_size=(1, kernel_size), stride=(1, stride), padding=(0, (kernel_size - 1) // 2))
def forward(self, x):
"""
x: [b, c, t, f]
"""
b, c, t, f = x.shape
x = self.unfold(x)
x = x.reshape(b, c * self.kernel_size, t, f)
return x
class ERB(nn.Module):
def __init__(self, erb_subband_1, erb_subband_2, n_fft=512, high_lim=8000, sr=16000):
super().__init__()
erb_filter_banks = self.build_erb_filter_banks(erb_subband_1, erb_subband_2, n_fft, high_lim, sr)
num_freqs = n_fft // 2 + 1
self.erb_fc = nn.Linear(num_freqs - erb_subband_1, erb_subband_2, bias=False)
self.inverse_erb_fc = nn.Linear(erb_subband_2, num_freqs - erb_subband_1, bias=False)
self.erb_fc.weight = nn.Parameter(erb_filter_banks, requires_grad=False) # [64, 192]
self.inverse_erb_fc.weight = nn.Parameter(erb_filter_banks.T, requires_grad=False)
self.erb_subband_1 = erb_subband_1
def hz2erb(self, freq_hz):
erb_f = 24.7 * np.log10(0.00437 * freq_hz + 1)
return erb_f
def erb2hz(self, erb_f):
freq_hz = (10 ** (erb_f / 24.7) - 1) / 0.00437
return freq_hz
def build_erb_filter_banks(self, erb_subband_1, erb_subband_2, n_fft=512, high_lim=8000, sr=16000):
low_lim = erb_subband_1 / n_fft * sr
erb_low = self.hz2erb(low_lim)
erb_high = self.hz2erb(high_lim)
erb_points = np.linspace(erb_low, erb_high, erb_subband_2)
bins = np.round(self.erb2hz(erb_points) / sr * n_fft).astype(np.int32)
erb_filters = np.zeros([erb_subband_2, n_fft // 2 + 1], dtype=np.float32)
erb_filters[0, bins[0] : bins[1]] = (bins[1] - np.arange(bins[0], bins[1]) + 1e-12) / (
bins[1] - bins[0] + 1e-12
)
for i in range(erb_subband_2 - 2):
erb_filters[i + 1, bins[i] : bins[i + 1]] = (np.arange(bins[i], bins[i + 1]) - bins[i] + 1e-12) / (
bins[i + 1] - bins[i] + 1e-12
)
erb_filters[i + 1, bins[i + 1] : bins[i + 2]] = (
bins[i + 2] - np.arange(bins[i + 1], bins[i + 2]) + 1e-12
) / (bins[i + 2] - bins[i + 1] + 1e-12)
erb_filters[-1, bins[-2] : bins[-1] + 1] = 1 - erb_filters[-2, bins[-2] : bins[-1] + 1]
erb_filters = erb_filters[:, erb_subband_1:]
# [64, 192]
return torch.from_numpy(np.abs(erb_filters))
def band_merge(self, x):
# x: [..., F]
x_low = x[..., : self.erb_subband_1] # [..., 65]
x_hight = self.erb_fc(x[..., self.erb_subband_1 :]) # [..., 257-65 = 192] => [..., 64]
return torch.cat([x_low, x_hight], dim=-1) # [..., 64 + 64 = 128]
def band_split(self, x_erb):
# x_erb: [..., F_erb]
x_erb_low = x_erb[..., : self.erb_subband_1] # [..., 65]
x_erb_hight = self.inverse_erb_fc(x_erb[..., self.erb_subband_1 :]) # [..., 65] => [..., 192]
return torch.cat([x_erb_low, x_erb_hight], dim=-1) # [..., 65 + 192 = 257]
class ResRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True):
super().__init__()
self.norm = nn.GroupNorm(num_groups=1, num_channels=input_size)
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
self.proj = nn.Linear(hidden_size * 2, input_size) if bidirectional else nn.Linear(hidden_size, input_size)
def forward(self, input):
# input: [b, f, t]
o = self.norm(input)
o = rearrange(o, "b f t -> b t f")
o, _ = self.rnn(o)
o = self.proj(o)
o = rearrange(o, "b t f -> b f t")
return input + o
class TriplePathRNN(nn.Module):
def __init__(self, freq_feat_dim, freq_rnn_hidden_size, temp_rnn_hidden_size, num_layers=1):
super().__init__()
# Frequency path
self.freq_rnn = ResRNN(
input_size=freq_feat_dim,
hidden_size=freq_rnn_hidden_size * 2,
num_layers=num_layers,
bidirectional=True,
)
# Temporal path
self.temp_rnn = ResRNN(
input_size=freq_feat_dim,
hidden_size=temp_rnn_hidden_size * 2,
num_layers=num_layers,
bidirectional=True,
)
self.mix_enroll_fusion_layer = nn.Linear(freq_feat_dim * 2, freq_feat_dim)
# Smooth initialization
self.mix_enroll_fusion_layer.weight = nn.Parameter(
torch.cat([torch.zeros(freq_feat_dim, freq_feat_dim), torch.eye(freq_feat_dim)], -1)
)
self.mix_enroll_fusion_layer.bias = nn.Parameter(torch.zeros_like(self.mix_enroll_fusion_layer.bias))
def forward(self, input, enroll_feat, current_layer=0):
# input: [b, n, ff, t]
# enroll_feat: [b, n, ff, t]
batch_size, num_sub_bands, freq_feat_dim, num_frames = input.shape
if current_layer in [2, 3]:
# Prepare the input by concatenating the enrollment feature to the input
input = torch.cat([input, enroll_feat], dim=-2) # [b, n, ff * 2, t]
input = rearrange(input, "b n ffx2 t -> (b n) t ffx2")
input = self.mix_enroll_fusion_layer(input) # [b * n, t, ff]
input = rearrange(input, "(b n) t ff -> b n ff t", b=batch_size, n=num_sub_bands)
# Temporal path
input = rearrange(input, "b n ff t -> (b n) ff t", b=batch_size, n=num_sub_bands)
input = self.temp_rnn(input) # [b * n, ff, t]
# Frequency path
input = rearrange(input, "(b n) ff t -> (b t) ff n", b=batch_size, n=num_sub_bands)
input = self.freq_rnn(input) # [b * t, ff, n]
return rearrange(input, "(b t) ff n -> b n ff t", b=batch_size)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super(Model, self).__init__()
self.args = args
self.stft = partial(stft, n_fft=args.n_fft, hop_length=args.hop_length, win_length=args.win_length)
self.istft = partial(istft, n_fft=args.n_fft, hop_length=args.hop_length, win_length=args.win_length)
ri_dim = 2 # real and imaginary
# Clue encoder
self.ts_vad_proj = nn.Linear(256, args.freq_feat_dim)
self.fa_encoder, self.fa_decoder = self._load_codec()
# Mixture encoder
self.erb = ERB(args.erb_subband_1, args.erb_subband_2, args.n_fft, sr=args.sr)
self.sfe = SubbandFeatureExtractor(args.sfe_kernel_size, args.sfe_stride)
self.mix_encoder = nn.Sequential(
nn.GroupNorm(1, args.num_channels * (ri_dim + 1) * args.sfe_kernel_size),
nn.Conv1d(args.num_channels * (ri_dim + 1) * args.sfe_kernel_size, args.freq_feat_dim, 1),
)
# Target extractor
self.extractors = nn.ModuleList([])
for i in range(args.triple_path_rnn_num_repeat):
self.extractors.append(
TriplePathRNN(
freq_feat_dim=args.freq_feat_dim,
freq_rnn_hidden_size=args.freq_axis_hidden_size,
temp_rnn_hidden_size=args.temp_axis_hidden_size,
num_layers=args.num_layers,
)
)
# Masking
self.enh_decoder = nn.Sequential(
nn.GroupNorm(1, args.freq_feat_dim),
nn.Conv1d(args.freq_feat_dim, args.freq_feat_dim * 2, 1),
nn.Tanh(),
nn.Conv1d(args.freq_feat_dim * 2, args.freq_feat_dim * 2, 1),
nn.Tanh(),
nn.Conv1d(args.freq_feat_dim * 2, ri_dim * 2, 1),
)
def _load_codec(self):
fa_encoder = FACodecEncoder(ngf=32, up_ratios=[2, 4, 5, 5], out_channels=256)
fa_encoder.load_state_dict(
torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin"))
)
fa_encoder.eval()
for param in fa_encoder.parameters():
param.requires_grad = False
fa_decoder = FACodecDecoder(
in_channels=256,
upsample_initial_channel=1024,
ngf=32,
up_ratios=[5, 5, 4, 2],
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=256,
codebook_dim=8,
codebook_size_prosody=10,
codebook_size_content=10,
codebook_size_residual=10,
use_gr_x_timbre=True,
use_gr_residual_f0=True,
use_gr_residual_phone=True,
)
fa_decoder.load_state_dict(
torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin"))
)
fa_decoder.eval()
for param in fa_decoder.parameters():
param.requires_grad = False
return fa_encoder, fa_decoder
def forward(self, mix_y: torch.Tensor, enroll_y: torch.Tensor):
# y_mix_mc: [b, c, t]
# S_mix: [b, f, t] where f is the number of mel bins
# S_enroll: [b, f, t]
# mix_cc_spec: mixture complex spectrogram # [b, c, f, t]
# mix_ri_spec: mixture real and imaginary spectrogram # [b, c, f, t, 2] or [b, 2c, f, t]
# fs: frequency bins of the subband
batch_size, num_channels, num_samples = mix_y.shape
ref_ch = 0
# Normalize the input
std = torch.std(mix_y, dim=(1, 2), keepdim=True) # [b, 1, 1]
mix_y = mix_y / std
# STFT
mix_mag_spec, mix_phase_spec, mix_real_spec, mix_imag_spec = self.stft(mix_y) # [b, c, f, t]
*_, num_freqs, num_frames = mix_mag_spec.shape
# Clue encoder
enc_out = self.fa_encoder(rearrange(enroll_y, "b t -> b 1 t"))
*_, spk_embs = self.fa_decoder(enc_out, eval_vq=False, vq=True)
ts_vad_emb = self.ts_vad_proj(spk_embs) # [b, E]
ts_vad_emb = repeat(
ts_vad_emb, "b e -> b n e t", t=num_frames, n=self.args.erb_subband_1 + self.args.erb_subband_2
)
# ts_vad_emb = self.ts_vad_model.forward_emb(mix_spec_mfcc, enroll_spec_mfcc) # [B, E, T], e.g., [2, 512, 20]
# ts_vad_emb = self.ts_vad_proj(ts_vad_emb) # [B, E, T], e.g., [2, 128, 20]
# ts_vad_emb = F.interpolate(ts_vad_emb, num_frames) # [B, E, T], e.g., [2, 128, 1251]
# ts_vad_emb = repeat(ts_vad_emb, "b e t -> b n e t", n=self.args.erb_subband_1 + self.args.erb_subband_2)
# Mixture encoder
mix_cc_spec = torch.complex(mix_real_spec, mix_imag_spec) # [b, c, f, t], complex-valued
mix_cc_mono_spec = mix_cc_spec[:, ref_ch] # [b, f, t]
mix_input_spec = torch.cat([mix_mag_spec, mix_real_spec, mix_mag_spec], dim=1) # [b, 3c, f, t]
mix_input_spec = rearrange(mix_input_spec, "b c f t -> b c t f") # [b, 3c, t, f]
mix_erb_spec = self.erb.band_merge(mix_input_spec)
mix_erb_spec = self.sfe(mix_erb_spec) # [b, 3c * 3, t, n]
mix_erb_spec = rearrange(mix_erb_spec, "b ff t n -> (b n) ff t")
mix_erb_spec = self.mix_encoder(mix_erb_spec) # [b * ff, t, n]
sub_band_feat = rearrange(mix_erb_spec, "(b n) ff t -> b n ff t", b=batch_size) # [b, n, ff, t]
# Target extractor
for layer_idx, extractor in enumerate(self.extractors):
sub_band_feat = extractor(sub_band_feat, ts_vad_emb, layer_idx) # [b, n, ff, t]
# Decode the mask
sub_band_feat = rearrange(sub_band_feat, "b n ff t -> (b n) ff t")
enh_mask = self.enh_decoder(sub_band_feat) # [b * n, 4, t]
enh_mask = rearrange(enh_mask, "(b n) (c ri) t -> b c ri t n", b=batch_size, ri=2, c=2)
mask = enh_mask[:, 0] * torch.sigmoid(enh_mask[:, 1]) # [b, ri, t n]
mask = self.erb.band_split(mask) # [b, ri, t, 257]
mask = rearrange(mask, "b ri t f -> b ri f t", b=batch_size)
# Apply mask
enh_spec_r = mix_cc_mono_spec.real * mask[:, 0] - mix_cc_mono_spec.imag * mask[:, 1]
enh_spec_i = mix_cc_mono_spec.real * mask[:, 1] + mix_cc_mono_spec.imag * mask[:, 0]
enh_cc_spec = torch.complex(enh_spec_r, enh_spec_i)
enh_y = self.istft(enh_cc_spec, input_type="complex", length=num_samples) # [b, t]
return enh_y * std.squeeze(1)
if __name__ == "__main__":
from torchinfo import summary
model = Model(ModelArgs())
mixture = torch.rand(2, 6, 16000)
mix_feat = torch.rand(2, 160, 80)
enroll_y = torch.rand(2, 16000)
enroll_feat = torch.rand(2, 1600, 80)
summary(model)
output = model(mixture, mix_feat, enroll_y, enroll_feat)