haoxiangsnr's picture
Upload folder using huggingface_hub
50de2e0 verified
import math
from dataclasses import dataclass
import dac
import torch
import torch.nn as nn
from Amphion.models.codec.ns3_codec import FACodecDecoder, FACodecEncoder
from einops import rearrange, repeat
from huggingface_hub import hf_hub_download
from simple_parsing import Serializable
from torch.nn import CrossEntropyLoss
@dataclass
class ModelArgs(Serializable):
rnn_num_repeat: int = 3
num_layers: int = 1
num_channels: int = 6
feat_dim: int = 512
sr: int = 16000
dropout: float = 0.0
cb_size: int = 1024
num_codebooks: int = 12
class ResRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size=None, num_layers=1, bidirectional=True, use_residual=True):
super().__init__()
self.norm = nn.GroupNorm(num_groups=1, num_channels=input_size)
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
if output_size is None:
output_size = input_size
self.proj = nn.Linear(hidden_size * 2, output_size) if bidirectional else nn.Linear(hidden_size, output_size)
self.use_residual = use_residual
def forward(self, input):
# input: [b, f, t]
o = self.norm(input)
o = rearrange(o, "b f t -> b t f")
o, _ = self.rnn(o)
o = self.proj(o)
o = rearrange(o, "b t f -> b f t")
if self.use_residual:
return input + o
else:
return o
class FusionRNN(nn.Module):
def __init__(self, feat_dim, num_channels, num_layers=1, num_codebooks=12, is_last_layer=False):
super().__init__()
self.mix_enroll_fusion_projector = nn.Sequential(
nn.GroupNorm(num_groups=1, num_channels=feat_dim * 2),
nn.Conv1d(feat_dim * 2, feat_dim, kernel_size=1),
nn.ReLU(),
nn.GroupNorm(num_groups=1, num_channels=feat_dim),
nn.Conv1d(feat_dim, feat_dim, kernel_size=1),
nn.ReLU(),
)
# Temporal path
if is_last_layer:
self.sequence_model = ResRNN(
input_size=feat_dim * num_channels,
hidden_size=feat_dim * num_codebooks,
output_size=feat_dim * num_codebooks,
num_layers=num_layers,
bidirectional=True,
use_residual=False,
)
self.decoder = nn.Sequential(
nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_codebooks),
nn.Conv1d(feat_dim * num_codebooks, feat_dim * num_codebooks, kernel_size=1),
nn.ReLU(),
nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_codebooks),
nn.Conv1d(feat_dim * num_codebooks, feat_dim * num_codebooks, kernel_size=1),
nn.ReLU(),
)
self.output_hidden_dim = num_codebooks
else:
self.sequence_model = ResRNN(
input_size=feat_dim * num_channels,
hidden_size=feat_dim * num_channels,
output_size=feat_dim * num_channels,
num_layers=num_layers,
bidirectional=True,
)
self.decoder = nn.Sequential(
nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_channels),
nn.Conv1d(feat_dim * num_channels, feat_dim * num_channels, kernel_size=1),
nn.ReLU(),
nn.GroupNorm(num_groups=1, num_channels=feat_dim * num_channels),
nn.Conv1d(feat_dim * num_channels, feat_dim * num_channels, kernel_size=1),
nn.ReLU(),
)
self.output_hidden_dim = num_channels
def forward(self, input, enroll_feat, current_layer=0):
# input: [b, c, h, t]
# enroll_feat: [b, c, h, t]
batch_size, num_channels, _, num_frames = input.shape
# Prepare the input by concatenating the enrollment feature to the input
input = torch.cat([input, enroll_feat], dim=-2) # [b, c, h * 2, t]
input = rearrange(input, "b c hx2 t -> (b c) hx2 t")
input = self.mix_enroll_fusion_projector(input) # [b * c, h t]
input = rearrange(input, "(b c) h t -> b c h t", b=batch_size, c=num_channels)
# Temporal path
input = rearrange(input, "b c h t -> b (c h) t")
input = self.sequence_model(input) # [b, c * h, t]
# Decoder
input = self.decoder(input) # [b, (c * h), t]
input = rearrange(input, "b (c h) t -> b c h t", c=self.output_hidden_dim) # [b, c, h, t]
return input
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super(Model, self).__init__()
self.args = args
# Mixture encoder (DAC codec)
dac_codec_ckpt_path = dac.utils.download(model_type="16khz")
self.dac_codec = dac.DAC.load(dac_codec_ckpt_path)
self.dac_codec.eval()
for param in self.dac_codec.parameters():
param.requires_grad = False
self.mix_encoder_projector = nn.Sequential(
nn.GroupNorm(num_groups=1, num_channels=args.cb_size),
nn.Conv1d(args.cb_size, self.args.feat_dim, kernel_size=1),
nn.ReLU(),
nn.GroupNorm(num_groups=1, num_channels=self.args.feat_dim),
nn.Conv1d(self.args.feat_dim, self.args.feat_dim, kernel_size=1),
nn.ReLU(),
)
# Clue encoder
self.clue_encoder_projector = nn.Sequential(
nn.GroupNorm(num_groups=1, num_channels=1024),
nn.Conv1d(1024, args.feat_dim, kernel_size=1),
nn.ReLU(),
nn.GroupNorm(num_groups=1, num_channels=args.feat_dim),
nn.Conv1d(args.feat_dim, args.feat_dim, kernel_size=1),
nn.ReLU(),
)
# Target extractor
self.extractors = nn.ModuleList([])
for i in range(args.rnn_num_repeat):
self.extractors.append(
FusionRNN(
feat_dim=args.feat_dim,
num_channels=args.num_channels,
num_layers=args.num_layers,
num_codebooks=args.num_codebooks,
is_last_layer=(i == args.rnn_num_repeat - 1),
)
)
# Predictor
self.lm_head = nn.Linear(args.feat_dim, args.cb_size, bias=False)
def _load_enroll_encoder(self):
fa_encoder = FACodecEncoder(ngf=32, up_ratios=[2, 4, 5, 5], out_channels=256)
fa_encoder.load_state_dict(
torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin"))
)
fa_encoder.eval()
for param in fa_encoder.parameters():
param.requires_grad = False
fa_decoder = FACodecDecoder(
in_channels=256,
upsample_initial_channel=1024,
ngf=32,
up_ratios=[5, 5, 4, 2],
vq_num_q_c=2,
vq_num_q_p=1,
vq_num_q_r=3,
vq_dim=256,
codebook_dim=8,
codebook_size_prosody=10,
codebook_size_content=10,
codebook_size_residual=10,
use_gr_x_timbre=True,
use_gr_residual_f0=True,
use_gr_residual_phone=True,
)
fa_decoder.load_state_dict(
torch.load(hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin"))
)
fa_decoder.eval()
for param in fa_decoder.parameters():
param.requires_grad = False
return fa_encoder, fa_decoder
def pad(self, audio_data):
"""Add padding to the input audio data. Adopted from DAC's `preprocess` method."""
length = audio_data.shape[-1]
right_pad = (math.ceil(length / self.dac_codec.hop_length) + 1) * self.dac_codec.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def forward(self, mix_y: torch.Tensor, enroll_y: torch.Tensor, clean_y: torch.Tensor = None):
"""
Args:
mix_y (`torch.Tensor` of shape `(batch_size, num_channels, num_samples)`):
The multi-channel mixture waveform.
enroll_y (`torch.Tensor` of shape `(batch_size, num_samples)`):
The mono-channel enrollment waveform.
clean_y: (`torch.Tensor` of shape `(batch_size, num_samples)`):
The reference-channel clean waveform.
Note:
`n`: number of codebooks
`c`: number of microphone channels
`h`: hidden size
"""
batch_size, num_channels, num_samples = mix_y.shape
mix_y = self.pad(mix_y)
# Mixture encoder
mix_y = rearrange(mix_y, "b c t -> (b c) 1 t") # [b * c, 1, t]
# codes = [b * c, n, t], where n=12; z = [b * c, 1024, t]
z, codes, latents, _, _ = self.dac_codec.encode(mix_y)
*_, num_frames = z.shape
mix_feat = self.mix_encoder_projector(z) # [b * c, h, t]
mix_feat = rearrange(mix_feat, "(b c) h t -> b c h t", b=batch_size) # [b, c, h, t]
# Clue encoder
clue_z, *_ = self.dac_codec.encode(rearrange(enroll_y, "b t -> b 1 t")) # [b, 1024, t]
spk_emb = self.clue_encoder_projector(clue_z) # [b, h, t]
spk_emb = torch.mean(spk_emb, dim=-1) # [b, h]
spk_emb = repeat(spk_emb, "b h -> b c h t", c=num_channels, t=num_frames) # [b, c, h, t]
# Target extractor
for layer_idx, extractor in enumerate(self.extractors):
sub_band_feat = extractor(mix_feat, spk_emb, layer_idx) # [b, n, 128 t]
# LLM head
sub_band_feat = rearrange(sub_band_feat, "b n h t -> b n t h")
logits = self.lm_head(sub_band_feat)
logits = rearrange(logits, "b n t cb -> b cb n t") # [b, 1024, 12, t]
# Decode the predicted code
loss = None
if clean_y is not None:
clean_y = rearrange(clean_y, "b t -> b 1 t")
clean_y = self.pad(clean_y)
clean_z, clean_codes, _, _, _ = self.dac_codec.encode(clean_y) # [b, 12, t]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, clean_codes.long())
# logits: [b, 1024, 12, t]
# loss: [b]
return logits if clean_y is None else logits, loss
def decode(self, logits: torch.Tensor, num_samples: int):
# logits: [b, 1024, 12, t]
logits = rearrange(logits, "b fc n t -> b n t fc")
logits = logits.argmax(dim=-1) # [b, n, t]
# Decode the predicted code
z = self.dac_codec.quantizer.from_codes(logits)[0] # [b, 1024, 50]
enh_y = self.dac_codec.decode(z)
enh_y = rearrange(enh_y, "b () t -> b t")
enh_y = enh_y[:, :num_samples]
return enh_y
if __name__ == "__main__":
model = Model(ModelArgs())
mixture = torch.rand(2, 6, 16000 * 4)
clean = torch.rand(2, 16000 * 4)
enroll_y = torch.rand(2, 16000)
output = model(mixture, enroll_y, clean)
print(output[0].shape)
print(output[1])
out = model.decode(output[0], 16000)
print(out.shape)