NeoPy's picture
Update infer/lib/predictors/FCPE/FCPE.py
19c48ba verified
import os
import sys
import torch
import numpy as np
import torch.nn as nn
import onnxruntime as ort
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils.parametrizations import weight_norm
sys.path.append(os.getcwd())
os.environ["LRU_CACHE_CAPACITY"] = "3"
from infer.lib.predictors.FCPE.wav2mel import Wav2Mel
from infer.lib.predictors.FCPE.encoder import EncoderLayer, ConformerNaiveEncoder
from infer.lib.predictors.FCPE.utils import batch_interp_with_replacement_detach, decrypt_model, DotDict
@torch.no_grad()
def cent_to_f0(cent):
return 10 * 2 ** (cent / 1200)
@torch.no_grad()
def f0_to_cent(f0):
return 1200 * (f0 / 10).log2()
@torch.no_grad()
def latent2cents_decoder(cent_table, y, threshold = 0.05, mask = True):
if str(y.device).startswith("privateuseone"):
cent_table = cent_table.cpu()
y = y.cpu()
B, N, _ = y.size()
ci = cent_table[None, None, :].expand(B, N, -1)
rtn = (ci * y).sum(dim=-1, keepdim=True) / y.sum(dim=-1, keepdim=True)
if mask:
confident = y.max(dim=-1, keepdim=True)[0]
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return rtn
@torch.no_grad()
def latent2cents_local_decoder(cent_table, out_dims, y, threshold = 0.05, mask = True):
if str(y.device).startswith("privateuseone"):
cent_table = cent_table.cpu()
y = y.cpu()
B, N, _ = y.size()
ci = cent_table[None, None, :].expand(B, N, -1)
confident, max_index = y.max(dim=-1, keepdim=True)
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
local_argmax_index[local_argmax_index < 0] = 0
local_argmax_index[local_argmax_index >= out_dims] = out_dims - 1
y_l = y.gather(-1, local_argmax_index)
rtn = (ci.gather(-1, local_argmax_index) * y_l).sum(dim=-1, keepdim=True) / y_l.sum(dim=-1, keepdim=True)
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return rtn
def cents_decoder(cent_table, y, confidence, threshold = 0.05, mask=True):
if str(y.device).startswith("privateuseone"):
cent_table = cent_table.cpu()
y = y.cpu()
B, N, _ = y.size()
rtn = (cent_table[None, None, :].expand(B, N, -1) * y).sum(dim=-1, keepdim=True) / y.sum(dim=-1, keepdim=True)
if mask:
confident = y.max(dim=-1, keepdim=True)[0]
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return (rtn, confident) if confidence else rtn
def cents_local_decoder(cent_table, y, n_out, confidence, threshold = 0.05, mask=True):
if str(y.device).startswith("privateuseone"):
cent_table = cent_table.cpu()
y = y.cpu()
B, N, _ = y.size()
confident, max_index = y.max(dim=-1, keepdim=True)
local_argmax_index = (torch.arange(0, 9).to(max_index.device) + (max_index - 4)).clamp(0, n_out - 1)
y_l = y.gather(-1, local_argmax_index)
rtn = (cent_table[None, None, :].expand(B, N, -1).gather(-1, local_argmax_index) * y_l).sum(dim=-1, keepdim=True) / y_l.sum(dim=-1, keepdim=True)
if mask:
confident_mask = torch.ones_like(confident)
confident_mask[confident <= threshold] = float("-INF")
rtn = rtn * confident_mask
return (rtn, confident) if confidence else rtn
class PCmer(nn.Module):
def __init__(
self,
num_layers,
num_heads,
dim_model,
dim_keys,
dim_values,
residual_dropout,
attention_dropout
):
super().__init__()
self.num_layers = num_layers
self.num_heads = num_heads
self.dim_model = dim_model
self.dim_values = dim_values
self.dim_keys = dim_keys
self.residual_dropout = residual_dropout
self.attention_dropout = attention_dropout
self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)])
def forward(self, phone, mask=None):
for layer in self._layers:
phone = layer(phone, mask)
return phone
class CFNaiveMelPE(nn.Module):
def __init__(
self,
input_channels,
out_dims,
hidden_dims = 512,
n_layers = 6,
n_heads = 8,
f0_max = 1975.5,
f0_min = 32.70,
use_fa_norm = False,
conv_only = False,
conv_dropout = 0,
atten_dropout = 0,
use_harmonic_emb = False
):
super().__init__()
self.input_channels = input_channels
self.out_dims = out_dims
self.hidden_dims = hidden_dims
self.n_layers = n_layers
self.n_heads = n_heads
self.f0_max = f0_max
self.f0_min = f0_min
self.use_fa_norm = use_fa_norm
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
self.input_stack = nn.Sequential(
nn.Conv1d(
input_channels,
hidden_dims,
3,
1,
1
),
nn.GroupNorm(
4,
hidden_dims
),
nn.LeakyReLU(),
nn.Conv1d(
hidden_dims,
hidden_dims,
3,
1,
1
)
)
self.net = ConformerNaiveEncoder(
num_layers=n_layers,
num_heads=n_heads,
dim_model=hidden_dims,
use_norm=use_fa_norm,
conv_only=conv_only,
conv_dropout=conv_dropout,
atten_dropout=atten_dropout
)
self.norm = nn.LayerNorm(hidden_dims)
self.output_proj = weight_norm(
nn.Linear(
hidden_dims,
out_dims
)
)
self.cent_table_b = torch.linspace(
f0_to_cent(torch.Tensor([f0_min]))[0],
f0_to_cent(torch.Tensor([f0_max]))[0],
out_dims
).detach()
self.gaussian_blurred_cent_mask_b = (
1200 * torch.Tensor([self.f0_max / 10.]).log2()
)[0].detach()
self.register_buffer("cent_table", self.cent_table_b)
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
def forward(self, x, _h_emb=None):
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
if self.harmonic_emb is not None:
if _h_emb is None:
x += self.harmonic_emb(torch.LongTensor([0]).to(x.device))
else:
x += self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
return self.output_proj(self.norm(self.net(x))).sigmoid()
@torch.no_grad()
def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
latent = self.forward(mel)
return cent_to_f0(
(
latent2cents_decoder(
self.cent_table,
latent,
threshold=threshold
)
) if decoder == "argmax" else (
latent2cents_local_decoder(
self.cent_table,
self.out_dims,
latent,
threshold=threshold
)
)
)
class FCPE_LEGACY(nn.Module):
def __init__(
self,
input_channel=128,
out_dims=360,
n_layers=12,
n_chans=512,
f0_max=1975.5,
f0_min=32.70,
confidence=False,
threshold=0.05,
use_input_conv=True
):
super().__init__()
self.n_out = out_dims
self.f0_max = f0_max
self.f0_min = f0_min
self.confidence = confidence
self.threshold = threshold
self.use_input_conv = use_input_conv
self.cent_table_b = torch.Tensor(
np.linspace(
f0_to_cent(torch.Tensor([f0_min]))[0],
f0_to_cent(torch.Tensor([f0_max]))[0],
out_dims
)
)
self.register_buffer("cent_table", self.cent_table_b)
self.stack = nn.Sequential(
nn.Conv1d(
input_channel,
n_chans,
3,
1,
1
),
nn.GroupNorm(
4,
n_chans
),
nn.LeakyReLU(),
nn.Conv1d(
n_chans,
n_chans,
3,
1,
1
)
)
self.decoder = PCmer(
num_layers=n_layers,
num_heads=8,
dim_model=n_chans,
dim_keys=n_chans,
dim_values=n_chans,
residual_dropout=0.1,
attention_dropout=0.1
)
self.norm = nn.LayerNorm(n_chans)
self.dense_out = weight_norm(
nn.Linear(
n_chans,
self.n_out
)
)
def forward(self, mel, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
x = self.decoder(self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)
x = self.dense_out(self.norm(x)).sigmoid()
x = cent_to_f0(
(
cents_decoder(
self.cent_table,
x,
self.confidence,
threshold=self.threshold,
mask=True
)
) if cdecoder == "argmax" else (
cents_local_decoder(
self.cent_table,
x,
self.n_out,
self.confidence,
threshold=self.threshold,
mask=True
)
)
)
x = (1 + x / 700).log() if not return_hz_f0 else x
if output_interp_target_length is not None:
x = F.interpolate(
torch.where(x == 0, float("nan"), x).transpose(1, 2),
size=int(output_interp_target_length),
mode="linear"
).transpose(1, 2)
x = torch.where(x.isnan(), float(0.0), x)
return x
def gaussian_blurred_cent(self, cents):
B, N, _ = cents.size()
return (
-(self.cent_table[None, None, :].expand(B, N, -1) - cents).square() / 1250
).exp() * (cents > 0.1) & (
cents < (1200.0 * np.log2(self.f0_max / 10.0))
).float()
class InferCFNaiveMelPE(torch.nn.Module):
def __init__(
self,
args,
state_dict
):
super().__init__()
self.model = CFNaiveMelPE(
input_channels=args.mel.num_mels,
out_dims=args.model.out_dims,
hidden_dims=args.model.hidden_dims,
n_layers=args.model.n_layers,
n_heads=args.model.n_heads,
f0_max=args.model.f0_max,
f0_min=args.model.f0_min,
use_fa_norm=args.model.use_fa_norm,
conv_only=args.model.conv_only,
conv_dropout=args.model.conv_dropout,
atten_dropout=args.model.atten_dropout,
use_harmonic_emb=False
)
self.model.load_state_dict(state_dict)
self.model.eval()
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
def forward(self, mel, decoder_mode = "local_argmax", threshold = 0.006):
with torch.no_grad():
mels = rearrange(torch.stack([mel], -1), "B T C K -> (B K) T C")
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=1)
return f0s
def infer(
self,
mel,
decoder_mode = "local_argmax",
threshold = 0.006,
f0_min = None,
f0_max = None,
interp_uv = False,
output_interp_target_length = None,
return_uv = False
):
f0 = self.__call__(mel, decoder_mode, threshold)
f0_for_uv = f0
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
f0 = f0 * (1 - uv)
if interp_uv:
f0 = batch_interp_with_replacement_detach(
uv.squeeze(-1).bool(),
f0.squeeze(-1)
).unsqueeze(-1)
if f0_max is not None: f0[f0 > f0_max] = f0_max
if output_interp_target_length is not None:
f0 = F.interpolate(
torch.where(f0 == 0, float("nan"), f0).transpose(1, 2),
size=int(output_interp_target_length),
mode="linear"
).transpose(1, 2)
f0 = torch.where(f0.isnan(), float(0.0), f0)
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
else: return f0
class FCPEInfer_LEGACY:
def __init__(
self,
configs,
model_path,
device=None,
dtype=torch.float32,
providers=None,
onnx=False,
f0_min=50,
f0_max=1100
):
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
if self.onnx:
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
else:
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
self.args = DotDict(ckpt["config"])
model = FCPE_LEGACY(
input_channel=self.args.model.input_channel,
out_dims=self.args.model.out_dims,
n_layers=self.args.model.n_layers,
n_chans=self.args.model.n_chans,
f0_max=self.f0_max,
f0_min=self.f0_min,
confidence=self.args.model.confidence
)
model.to(self.device).to(self.dtype)
model.load_state_dict(ckpt["model"])
model.eval()
self.model = model
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05, p_len=None):
if not self.onnx: self.model.threshold = threshold
if not hasattr(self, "numpy_threshold") and self.onnx: self.numpy_threshold = np.array(threshold, dtype=np.float32)
mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype)
if self.onnx:
return torch.as_tensor(
self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: mel.detach().cpu().numpy(),
self.model.get_inputs()[1].name: self.numpy_threshold
}
)[0],
dtype=self.dtype,
device=self.device
)
else:
return self.model(
mel=mel,
return_hz_f0=True,
output_interp_target_length=p_len
)
class FCPEInfer:
def __init__(
self,
configs,
model_path,
device=None,
dtype=torch.float32,
providers=None,
onnx=False,
f0_min=50,
f0_max=1100
):
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.dtype = dtype
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
if self.onnx:
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
else:
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
self.args = DotDict(ckpt["config_dict"])
model = InferCFNaiveMelPE(self.args, ckpt["model"])
self.model = model.to(device).to(self.dtype).eval()
@torch.no_grad()
def __call__(self, audio, sr, threshold=0.05, p_len=None):
if not hasattr(self, "numpy_threshold") and self.onnx: self.numpy_threshold = np.array(threshold, dtype=np.float32)
mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype)
if self.onnx:
return torch.as_tensor(
self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: mel.detach().cpu().numpy(),
self.model.get_inputs()[1].name: self.numpy_threshold
}
)[0],
dtype=self.dtype,
device=self.device
)
else:
return self.model.infer(
mel,
threshold=threshold,
f0_min=self.f0_min,
f0_max=self.f0_max,
output_interp_target_length=p_len
)
class FCPE:
def __init__(
self,
configs,
model_path,
hop_length=512,
f0_min=50,
f0_max=1100,
dtype=torch.float32,
device=None,
sample_rate=16000,
threshold=0.05,
providers=None,
onnx=False,
legacy=False
):
self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
self.fcpe = self.model(configs, model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max)
self.hop_length = hop_length
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.threshold = threshold
self.sample_rate = sample_rate
self.dtype = dtype
self.legacy = legacy
def compute_f0(self, wav, p_len=None):
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len)
return f0.cpu().numpy()