| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torchaudio.models import Conformer |
| | from models.svc.transformer.transformer import PositionalEncoding |
| |
|
| | from utils.f0 import f0_to_coarse |
| |
|
| |
|
| | class ContentEncoder(nn.Module): |
| | def __init__(self, cfg, input_dim, output_dim): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | assert input_dim != 0 |
| | self.nn = nn.Linear(input_dim, output_dim) |
| |
|
| | |
| | if ( |
| | "use_conformer_for_content_features" in cfg |
| | and cfg.use_conformer_for_content_features |
| | ): |
| | self.pos_encoder = PositionalEncoding(input_dim) |
| | self.conformer = Conformer( |
| | input_dim=input_dim, |
| | num_heads=2, |
| | ffn_dim=256, |
| | num_layers=6, |
| | depthwise_conv_kernel_size=3, |
| | ) |
| | else: |
| | self.conformer = None |
| |
|
| | def forward(self, x, length=None): |
| | |
| | if self.conformer: |
| | x = self.pos_encoder(x) |
| | x, _ = self.conformer(x, length) |
| | return self.nn(x) |
| |
|
| |
|
| | class MelodyEncoder(nn.Module): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | self.input_dim = self.cfg.input_melody_dim |
| | self.output_dim = self.cfg.output_melody_dim |
| | self.n_bins = self.cfg.n_bins_melody |
| | self.pitch_min = self.cfg.pitch_min |
| | self.pitch_max = self.cfg.pitch_max |
| |
|
| | if self.input_dim != 0: |
| | if self.n_bins == 0: |
| | |
| | self.nn = nn.Linear(self.input_dim, self.output_dim) |
| | else: |
| | self.f0_min = cfg.f0_min |
| | self.f0_max = cfg.f0_max |
| |
|
| | self.nn = nn.Embedding( |
| | num_embeddings=self.n_bins, |
| | embedding_dim=self.output_dim, |
| | padding_idx=None, |
| | ) |
| | self.uv_embedding = nn.Embedding(2, self.output_dim) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, x, uv=None, length=None): |
| | |
| | |
| | if self.n_bins == 0: |
| | x = x.unsqueeze(-1) |
| | else: |
| | x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max) |
| | x = self.nn(x) |
| | if uv is not None: |
| | uv = self.uv_embedding(uv) |
| | x = x + uv |
| | |
| | return x |
| |
|
| |
|
| | class LoudnessEncoder(nn.Module): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | self.input_dim = self.cfg.input_loudness_dim |
| | self.output_dim = self.cfg.output_loudness_dim |
| | self.n_bins = self.cfg.n_bins_loudness |
| |
|
| | if self.input_dim != 0: |
| | if self.n_bins == 0: |
| | |
| | self.nn = nn.Linear(self.input_dim, self.output_dim) |
| | else: |
| | |
| | self.loudness_min = 1e-30 |
| | self.loudness_max = 1.5 |
| |
|
| | if cfg.use_log_loudness: |
| | self.energy_bins = nn.Parameter( |
| | torch.exp( |
| | torch.linspace( |
| | np.log(self.loudness_min), |
| | np.log(self.loudness_max), |
| | self.n_bins - 1, |
| | ) |
| | ), |
| | requires_grad=False, |
| | ) |
| |
|
| | self.nn = nn.Embedding( |
| | num_embeddings=self.n_bins, |
| | embedding_dim=self.output_dim, |
| | padding_idx=None, |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | if self.n_bins == 0: |
| | x = x.unsqueeze(-1) |
| | else: |
| | x = torch.bucketize(x, self.energy_bins) |
| | return self.nn(x) |
| |
|
| |
|
| | class SingerEncoder(nn.Module): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | self.input_dim = 1 |
| | self.output_dim = self.cfg.output_singer_dim |
| |
|
| | self.nn = nn.Embedding( |
| | num_embeddings=cfg.singer_table_size, |
| | embedding_dim=self.output_dim, |
| | padding_idx=None, |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | return self.nn(x) |
| |
|
| |
|
| | class ConditionEncoder(nn.Module): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | self.merge_mode = cfg.merge_mode |
| |
|
| | if cfg.use_whisper: |
| | self.whisper_encoder = ContentEncoder( |
| | self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim |
| | ) |
| |
|
| | if cfg.use_contentvec: |
| | self.contentvec_encoder = ContentEncoder( |
| | self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim |
| | ) |
| |
|
| | if cfg.use_mert: |
| | self.mert_encoder = ContentEncoder( |
| | self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim |
| | ) |
| |
|
| | if cfg.use_wenet: |
| | self.wenet_encoder = ContentEncoder( |
| | self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim |
| | ) |
| |
|
| | self.melody_encoder = MelodyEncoder(self.cfg) |
| | self.loudness_encoder = LoudnessEncoder(self.cfg) |
| | if cfg.use_spkid: |
| | self.singer_encoder = SingerEncoder(self.cfg) |
| |
|
| | def forward(self, x): |
| | outputs = [] |
| |
|
| | if "frame_pitch" in x.keys(): |
| | if "frame_uv" not in x.keys(): |
| | x["frame_uv"] = None |
| | pitch_enc_out = self.melody_encoder( |
| | x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"] |
| | ) |
| | outputs.append(pitch_enc_out) |
| |
|
| | if "frame_energy" in x.keys(): |
| | loudness_enc_out = self.loudness_encoder(x["frame_energy"]) |
| | outputs.append(loudness_enc_out) |
| |
|
| | if "whisper_feat" in x.keys(): |
| | |
| | whiser_enc_out = self.whisper_encoder( |
| | x["whisper_feat"], length=x["target_len"] |
| | ) |
| | outputs.append(whiser_enc_out) |
| | seq_len = whiser_enc_out.shape[1] |
| |
|
| | if "contentvec_feat" in x.keys(): |
| | contentvec_enc_out = self.contentvec_encoder( |
| | x["contentvec_feat"], length=x["target_len"] |
| | ) |
| | outputs.append(contentvec_enc_out) |
| | seq_len = contentvec_enc_out.shape[1] |
| |
|
| | if "mert_feat" in x.keys(): |
| | mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"]) |
| | outputs.append(mert_enc_out) |
| | seq_len = mert_enc_out.shape[1] |
| |
|
| | if "wenet_feat" in x.keys(): |
| | wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"]) |
| | outputs.append(wenet_enc_out) |
| | seq_len = wenet_enc_out.shape[1] |
| |
|
| | if "spk_id" in x.keys(): |
| | speaker_enc_out = self.singer_encoder(x["spk_id"]) |
| | assert ( |
| | "whisper_feat" in x.keys() |
| | or "contentvec_feat" in x.keys() |
| | or "mert_feat" in x.keys() |
| | or "wenet_feat" in x.keys() |
| | ) |
| | singer_info = speaker_enc_out.expand(-1, seq_len, -1) |
| | outputs.append(singer_info) |
| |
|
| | encoder_output = None |
| | if self.merge_mode == "concat": |
| | encoder_output = torch.cat(outputs, dim=-1) |
| | if self.merge_mode == "add": |
| | |
| | outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0) |
| | |
| | encoder_output = torch.sum(outputs, dim=0) |
| |
|
| | return encoder_output |
| |
|