|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import onnxruntime as ort |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from einops import rearrange |
|
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
|
|
|
sys.path.append(os.getcwd()) |
|
|
os.environ["LRU_CACHE_CAPACITY"] = "3" |
|
|
|
|
|
from infer.lib.predictors.FCPE.wav2mel import Wav2Mel |
|
|
from infer.lib.predictors.FCPE.encoder import EncoderLayer, ConformerNaiveEncoder |
|
|
from infer.lib.predictors.FCPE.utils import batch_interp_with_replacement_detach, decrypt_model, DotDict |
|
|
|
|
|
@torch.no_grad() |
|
|
def cent_to_f0(cent): |
|
|
return 10 * 2 ** (cent / 1200) |
|
|
|
|
|
@torch.no_grad() |
|
|
def f0_to_cent(f0): |
|
|
return 1200 * (f0 / 10).log2() |
|
|
|
|
|
@torch.no_grad() |
|
|
def latent2cents_decoder(cent_table, y, threshold = 0.05, mask = True): |
|
|
if str(y.device).startswith("privateuseone"): |
|
|
cent_table = cent_table.cpu() |
|
|
y = y.cpu() |
|
|
|
|
|
B, N, _ = y.size() |
|
|
ci = cent_table[None, None, :].expand(B, N, -1) |
|
|
rtn = (ci * y).sum(dim=-1, keepdim=True) / y.sum(dim=-1, keepdim=True) |
|
|
|
|
|
if mask: |
|
|
confident = y.max(dim=-1, keepdim=True)[0] |
|
|
confident_mask = torch.ones_like(confident) |
|
|
confident_mask[confident <= threshold] = float("-INF") |
|
|
rtn = rtn * confident_mask |
|
|
|
|
|
return rtn |
|
|
|
|
|
@torch.no_grad() |
|
|
def latent2cents_local_decoder(cent_table, out_dims, y, threshold = 0.05, mask = True): |
|
|
if str(y.device).startswith("privateuseone"): |
|
|
cent_table = cent_table.cpu() |
|
|
y = y.cpu() |
|
|
|
|
|
B, N, _ = y.size() |
|
|
ci = cent_table[None, None, :].expand(B, N, -1) |
|
|
confident, max_index = y.max(dim=-1, keepdim=True) |
|
|
|
|
|
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4) |
|
|
local_argmax_index[local_argmax_index < 0] = 0 |
|
|
local_argmax_index[local_argmax_index >= out_dims] = out_dims - 1 |
|
|
|
|
|
y_l = y.gather(-1, local_argmax_index) |
|
|
rtn = (ci.gather(-1, local_argmax_index) * y_l).sum(dim=-1, keepdim=True) / y_l.sum(dim=-1, keepdim=True) |
|
|
|
|
|
if mask: |
|
|
confident_mask = torch.ones_like(confident) |
|
|
confident_mask[confident <= threshold] = float("-INF") |
|
|
rtn = rtn * confident_mask |
|
|
|
|
|
return rtn |
|
|
|
|
|
def cents_decoder(cent_table, y, confidence, threshold = 0.05, mask=True): |
|
|
if str(y.device).startswith("privateuseone"): |
|
|
cent_table = cent_table.cpu() |
|
|
y = y.cpu() |
|
|
|
|
|
B, N, _ = y.size() |
|
|
rtn = (cent_table[None, None, :].expand(B, N, -1) * y).sum(dim=-1, keepdim=True) / y.sum(dim=-1, keepdim=True) |
|
|
|
|
|
if mask: |
|
|
confident = y.max(dim=-1, keepdim=True)[0] |
|
|
confident_mask = torch.ones_like(confident) |
|
|
confident_mask[confident <= threshold] = float("-INF") |
|
|
rtn = rtn * confident_mask |
|
|
|
|
|
return (rtn, confident) if confidence else rtn |
|
|
|
|
|
def cents_local_decoder(cent_table, y, n_out, confidence, threshold = 0.05, mask=True): |
|
|
if str(y.device).startswith("privateuseone"): |
|
|
cent_table = cent_table.cpu() |
|
|
y = y.cpu() |
|
|
|
|
|
B, N, _ = y.size() |
|
|
confident, max_index = y.max(dim=-1, keepdim=True) |
|
|
local_argmax_index = (torch.arange(0, 9).to(max_index.device) + (max_index - 4)).clamp(0, n_out - 1) |
|
|
y_l = y.gather(-1, local_argmax_index) |
|
|
rtn = (cent_table[None, None, :].expand(B, N, -1).gather(-1, local_argmax_index) * y_l).sum(dim=-1, keepdim=True) / y_l.sum(dim=-1, keepdim=True) |
|
|
|
|
|
if mask: |
|
|
confident_mask = torch.ones_like(confident) |
|
|
confident_mask[confident <= threshold] = float("-INF") |
|
|
rtn = rtn * confident_mask |
|
|
|
|
|
return (rtn, confident) if confidence else rtn |
|
|
|
|
|
class PCmer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_layers, |
|
|
num_heads, |
|
|
dim_model, |
|
|
dim_keys, |
|
|
dim_values, |
|
|
residual_dropout, |
|
|
attention_dropout |
|
|
): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
self.num_heads = num_heads |
|
|
self.dim_model = dim_model |
|
|
self.dim_values = dim_values |
|
|
self.dim_keys = dim_keys |
|
|
self.residual_dropout = residual_dropout |
|
|
self.attention_dropout = attention_dropout |
|
|
self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)]) |
|
|
|
|
|
def forward(self, phone, mask=None): |
|
|
for layer in self._layers: |
|
|
phone = layer(phone, mask) |
|
|
|
|
|
return phone |
|
|
|
|
|
class CFNaiveMelPE(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_channels, |
|
|
out_dims, |
|
|
hidden_dims = 512, |
|
|
n_layers = 6, |
|
|
n_heads = 8, |
|
|
f0_max = 1975.5, |
|
|
f0_min = 32.70, |
|
|
use_fa_norm = False, |
|
|
conv_only = False, |
|
|
conv_dropout = 0, |
|
|
atten_dropout = 0, |
|
|
use_harmonic_emb = False |
|
|
): |
|
|
super().__init__() |
|
|
self.input_channels = input_channels |
|
|
self.out_dims = out_dims |
|
|
self.hidden_dims = hidden_dims |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.f0_max = f0_max |
|
|
self.f0_min = f0_min |
|
|
self.use_fa_norm = use_fa_norm |
|
|
|
|
|
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None |
|
|
self.input_stack = nn.Sequential( |
|
|
nn.Conv1d( |
|
|
input_channels, |
|
|
hidden_dims, |
|
|
3, |
|
|
1, |
|
|
1 |
|
|
), |
|
|
nn.GroupNorm( |
|
|
4, |
|
|
hidden_dims |
|
|
), |
|
|
nn.LeakyReLU(), |
|
|
nn.Conv1d( |
|
|
hidden_dims, |
|
|
hidden_dims, |
|
|
3, |
|
|
1, |
|
|
1 |
|
|
) |
|
|
) |
|
|
self.net = ConformerNaiveEncoder( |
|
|
num_layers=n_layers, |
|
|
num_heads=n_heads, |
|
|
dim_model=hidden_dims, |
|
|
use_norm=use_fa_norm, |
|
|
conv_only=conv_only, |
|
|
conv_dropout=conv_dropout, |
|
|
atten_dropout=atten_dropout |
|
|
) |
|
|
self.norm = nn.LayerNorm(hidden_dims) |
|
|
self.output_proj = weight_norm( |
|
|
nn.Linear( |
|
|
hidden_dims, |
|
|
out_dims |
|
|
) |
|
|
) |
|
|
|
|
|
self.cent_table_b = torch.linspace( |
|
|
f0_to_cent(torch.Tensor([f0_min]))[0], |
|
|
f0_to_cent(torch.Tensor([f0_max]))[0], |
|
|
out_dims |
|
|
).detach() |
|
|
self.gaussian_blurred_cent_mask_b = ( |
|
|
1200 * torch.Tensor([self.f0_max / 10.]).log2() |
|
|
)[0].detach() |
|
|
|
|
|
self.register_buffer("cent_table", self.cent_table_b) |
|
|
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b) |
|
|
|
|
|
def forward(self, x, _h_emb=None): |
|
|
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2) |
|
|
|
|
|
if self.harmonic_emb is not None: |
|
|
if _h_emb is None: |
|
|
x += self.harmonic_emb(torch.LongTensor([0]).to(x.device)) |
|
|
else: |
|
|
x += self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device)) |
|
|
|
|
|
return self.output_proj(self.norm(self.net(x))).sigmoid() |
|
|
|
|
|
@torch.no_grad() |
|
|
def infer(self, mel, decoder = "local_argmax", threshold = 0.05): |
|
|
latent = self.forward(mel) |
|
|
|
|
|
return cent_to_f0( |
|
|
( |
|
|
latent2cents_decoder( |
|
|
self.cent_table, |
|
|
latent, |
|
|
threshold=threshold |
|
|
) |
|
|
) if decoder == "argmax" else ( |
|
|
latent2cents_local_decoder( |
|
|
self.cent_table, |
|
|
self.out_dims, |
|
|
latent, |
|
|
threshold=threshold |
|
|
) |
|
|
) |
|
|
) |
|
|
|
|
|
class FCPE_LEGACY(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_channel=128, |
|
|
out_dims=360, |
|
|
n_layers=12, |
|
|
n_chans=512, |
|
|
f0_max=1975.5, |
|
|
f0_min=32.70, |
|
|
confidence=False, |
|
|
threshold=0.05, |
|
|
use_input_conv=True |
|
|
): |
|
|
super().__init__() |
|
|
self.n_out = out_dims |
|
|
self.f0_max = f0_max |
|
|
self.f0_min = f0_min |
|
|
self.confidence = confidence |
|
|
self.threshold = threshold |
|
|
self.use_input_conv = use_input_conv |
|
|
|
|
|
self.cent_table_b = torch.Tensor( |
|
|
np.linspace( |
|
|
f0_to_cent(torch.Tensor([f0_min]))[0], |
|
|
f0_to_cent(torch.Tensor([f0_max]))[0], |
|
|
out_dims |
|
|
) |
|
|
) |
|
|
self.register_buffer("cent_table", self.cent_table_b) |
|
|
|
|
|
self.stack = nn.Sequential( |
|
|
nn.Conv1d( |
|
|
input_channel, |
|
|
n_chans, |
|
|
3, |
|
|
1, |
|
|
1 |
|
|
), |
|
|
nn.GroupNorm( |
|
|
4, |
|
|
n_chans |
|
|
), |
|
|
nn.LeakyReLU(), |
|
|
nn.Conv1d( |
|
|
n_chans, |
|
|
n_chans, |
|
|
3, |
|
|
1, |
|
|
1 |
|
|
) |
|
|
) |
|
|
self.decoder = PCmer( |
|
|
num_layers=n_layers, |
|
|
num_heads=8, |
|
|
dim_model=n_chans, |
|
|
dim_keys=n_chans, |
|
|
dim_values=n_chans, |
|
|
residual_dropout=0.1, |
|
|
attention_dropout=0.1 |
|
|
) |
|
|
self.norm = nn.LayerNorm(n_chans) |
|
|
self.dense_out = weight_norm( |
|
|
nn.Linear( |
|
|
n_chans, |
|
|
self.n_out |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, mel, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None): |
|
|
x = self.decoder(self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel) |
|
|
x = self.dense_out(self.norm(x)).sigmoid() |
|
|
|
|
|
x = cent_to_f0( |
|
|
( |
|
|
cents_decoder( |
|
|
self.cent_table, |
|
|
x, |
|
|
self.confidence, |
|
|
threshold=self.threshold, |
|
|
mask=True |
|
|
) |
|
|
) if cdecoder == "argmax" else ( |
|
|
cents_local_decoder( |
|
|
self.cent_table, |
|
|
x, |
|
|
self.n_out, |
|
|
self.confidence, |
|
|
threshold=self.threshold, |
|
|
mask=True |
|
|
) |
|
|
) |
|
|
) |
|
|
|
|
|
x = (1 + x / 700).log() if not return_hz_f0 else x |
|
|
|
|
|
if output_interp_target_length is not None: |
|
|
x = F.interpolate( |
|
|
torch.where(x == 0, float("nan"), x).transpose(1, 2), |
|
|
size=int(output_interp_target_length), |
|
|
mode="linear" |
|
|
).transpose(1, 2) |
|
|
|
|
|
x = torch.where(x.isnan(), float(0.0), x) |
|
|
|
|
|
return x |
|
|
|
|
|
def gaussian_blurred_cent(self, cents): |
|
|
B, N, _ = cents.size() |
|
|
|
|
|
return ( |
|
|
-(self.cent_table[None, None, :].expand(B, N, -1) - cents).square() / 1250 |
|
|
).exp() * (cents > 0.1) & ( |
|
|
cents < (1200.0 * np.log2(self.f0_max / 10.0)) |
|
|
).float() |
|
|
|
|
|
class InferCFNaiveMelPE(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
args, |
|
|
state_dict |
|
|
): |
|
|
super().__init__() |
|
|
self.model = CFNaiveMelPE( |
|
|
input_channels=args.mel.num_mels, |
|
|
out_dims=args.model.out_dims, |
|
|
hidden_dims=args.model.hidden_dims, |
|
|
n_layers=args.model.n_layers, |
|
|
n_heads=args.model.n_heads, |
|
|
f0_max=args.model.f0_max, |
|
|
f0_min=args.model.f0_min, |
|
|
use_fa_norm=args.model.use_fa_norm, |
|
|
conv_only=args.model.conv_only, |
|
|
conv_dropout=args.model.conv_dropout, |
|
|
atten_dropout=args.model.atten_dropout, |
|
|
use_harmonic_emb=False |
|
|
) |
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.eval() |
|
|
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False) |
|
|
|
|
|
def forward(self, mel, decoder_mode = "local_argmax", threshold = 0.006): |
|
|
with torch.no_grad(): |
|
|
mels = rearrange(torch.stack([mel], -1), "B T C K -> (B K) T C") |
|
|
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=1) |
|
|
|
|
|
return f0s |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
mel, |
|
|
decoder_mode = "local_argmax", |
|
|
threshold = 0.006, |
|
|
f0_min = None, |
|
|
f0_max = None, |
|
|
interp_uv = False, |
|
|
output_interp_target_length = None, |
|
|
return_uv = False |
|
|
): |
|
|
f0 = self.__call__(mel, decoder_mode, threshold) |
|
|
f0_for_uv = f0 |
|
|
|
|
|
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype) |
|
|
f0 = f0 * (1 - uv) |
|
|
|
|
|
if interp_uv: |
|
|
f0 = batch_interp_with_replacement_detach( |
|
|
uv.squeeze(-1).bool(), |
|
|
f0.squeeze(-1) |
|
|
).unsqueeze(-1) |
|
|
|
|
|
if f0_max is not None: f0[f0 > f0_max] = f0_max |
|
|
|
|
|
if output_interp_target_length is not None: |
|
|
f0 = F.interpolate( |
|
|
torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), |
|
|
size=int(output_interp_target_length), |
|
|
mode="linear" |
|
|
).transpose(1, 2) |
|
|
|
|
|
f0 = torch.where(f0.isnan(), float(0.0), f0) |
|
|
|
|
|
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2) |
|
|
else: return f0 |
|
|
|
|
|
class FCPEInfer_LEGACY: |
|
|
def __init__( |
|
|
self, |
|
|
configs, |
|
|
model_path, |
|
|
device=None, |
|
|
dtype=torch.float32, |
|
|
providers=None, |
|
|
onnx=False, |
|
|
f0_min=50, |
|
|
f0_max=1100 |
|
|
): |
|
|
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
self.onnx = onnx |
|
|
self.f0_min = f0_min |
|
|
self.f0_max = f0_max |
|
|
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype) |
|
|
|
|
|
if self.onnx: |
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.log_severity_level = 3 |
|
|
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers) |
|
|
else: |
|
|
ckpt = torch.load(model_path, map_location="cpu", weights_only=True) |
|
|
self.args = DotDict(ckpt["config"]) |
|
|
model = FCPE_LEGACY( |
|
|
input_channel=self.args.model.input_channel, |
|
|
out_dims=self.args.model.out_dims, |
|
|
n_layers=self.args.model.n_layers, |
|
|
n_chans=self.args.model.n_chans, |
|
|
f0_max=self.f0_max, |
|
|
f0_min=self.f0_min, |
|
|
confidence=self.args.model.confidence |
|
|
) |
|
|
model.to(self.device).to(self.dtype) |
|
|
model.load_state_dict(ckpt["model"]) |
|
|
model.eval() |
|
|
self.model = model |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, audio, sr, threshold=0.05, p_len=None): |
|
|
if not self.onnx: self.model.threshold = threshold |
|
|
if not hasattr(self, "numpy_threshold") and self.onnx: self.numpy_threshold = np.array(threshold, dtype=np.float32) |
|
|
|
|
|
mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype) |
|
|
|
|
|
if self.onnx: |
|
|
return torch.as_tensor( |
|
|
self.model.run( |
|
|
[self.model.get_outputs()[0].name], |
|
|
{ |
|
|
self.model.get_inputs()[0].name: mel.detach().cpu().numpy(), |
|
|
self.model.get_inputs()[1].name: self.numpy_threshold |
|
|
} |
|
|
)[0], |
|
|
dtype=self.dtype, |
|
|
device=self.device |
|
|
) |
|
|
else: |
|
|
return self.model( |
|
|
mel=mel, |
|
|
return_hz_f0=True, |
|
|
output_interp_target_length=p_len |
|
|
) |
|
|
|
|
|
class FCPEInfer: |
|
|
def __init__( |
|
|
self, |
|
|
configs, |
|
|
model_path, |
|
|
device=None, |
|
|
dtype=torch.float32, |
|
|
providers=None, |
|
|
onnx=False, |
|
|
f0_min=50, |
|
|
f0_max=1100 |
|
|
): |
|
|
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
self.onnx = onnx |
|
|
self.f0_min = f0_min |
|
|
self.f0_max = f0_max |
|
|
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype) |
|
|
|
|
|
if self.onnx: |
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.log_severity_level = 3 |
|
|
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers) |
|
|
else: |
|
|
ckpt = torch.load(model_path, map_location="cpu", weights_only=True) |
|
|
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0 |
|
|
self.args = DotDict(ckpt["config_dict"]) |
|
|
model = InferCFNaiveMelPE(self.args, ckpt["model"]) |
|
|
self.model = model.to(device).to(self.dtype).eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, audio, sr, threshold=0.05, p_len=None): |
|
|
if not hasattr(self, "numpy_threshold") and self.onnx: self.numpy_threshold = np.array(threshold, dtype=np.float32) |
|
|
mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype) |
|
|
|
|
|
if self.onnx: |
|
|
return torch.as_tensor( |
|
|
self.model.run( |
|
|
[self.model.get_outputs()[0].name], |
|
|
{ |
|
|
self.model.get_inputs()[0].name: mel.detach().cpu().numpy(), |
|
|
self.model.get_inputs()[1].name: self.numpy_threshold |
|
|
} |
|
|
)[0], |
|
|
dtype=self.dtype, |
|
|
device=self.device |
|
|
) |
|
|
else: |
|
|
return self.model.infer( |
|
|
mel, |
|
|
threshold=threshold, |
|
|
f0_min=self.f0_min, |
|
|
f0_max=self.f0_max, |
|
|
output_interp_target_length=p_len |
|
|
) |
|
|
|
|
|
class FCPE: |
|
|
def __init__( |
|
|
self, |
|
|
configs, |
|
|
model_path, |
|
|
hop_length=512, |
|
|
f0_min=50, |
|
|
f0_max=1100, |
|
|
dtype=torch.float32, |
|
|
device=None, |
|
|
sample_rate=16000, |
|
|
threshold=0.05, |
|
|
providers=None, |
|
|
onnx=False, |
|
|
legacy=False |
|
|
): |
|
|
self.model = FCPEInfer_LEGACY if legacy else FCPEInfer |
|
|
self.fcpe = self.model(configs, model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max) |
|
|
self.hop_length = hop_length |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.threshold = threshold |
|
|
self.sample_rate = sample_rate |
|
|
self.dtype = dtype |
|
|
self.legacy = legacy |
|
|
|
|
|
def compute_f0(self, wav, p_len=None): |
|
|
x = torch.FloatTensor(wav).to(self.dtype).to(self.device) |
|
|
p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len |
|
|
|
|
|
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len) |
|
|
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0] |
|
|
|
|
|
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len) |
|
|
return f0.cpu().numpy() |