HQ-SVC / utils /rmvpe /utils.py
shawnpi's picture
Upload 753 files
1cd928a verified
import sys
import numpy as np
import librosa
import torch
from functools import reduce
from .constants import *
from torch.nn.modules.module import _addindent
def cycle(iterable):
while True:
for item in iterable:
yield item
def summary(model, file=sys.stdout):
def repr(model):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
total_params = 0
for key, module in model._modules.items():
mod_str, num_params = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for name, p in model._parameters.items():
if hasattr(p, 'shape'):
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
if file is sys.stdout:
main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
else:
main_str += ', {:,} params'.format(total_params)
return main_str, total_params
string, count = repr(model)
if file is not None:
if isinstance(file, str):
file = open(file, 'w')
print(string, file=file)
file.flush()
return count
def to_local_average_cents(salience, center=None, thred=0.03):
"""
find the weighted average cents near the argmax bin
"""
if not hasattr(to_local_average_cents, 'cents_mapping'):
# the bin number-to-cents mapping
to_local_average_cents.cents_mapping = (
20 * np.arange(N_CLASS) + CONST)
if salience.ndim == 1:
if center is None:
center = int(np.argmax(salience))
start = max(0, center - 4)
end = min(len(salience), center + 5)
salience = salience[start:end]
product_sum = np.sum(
salience * to_local_average_cents.cents_mapping[start:end])
weight_sum = np.sum(salience)
return product_sum / weight_sum if np.max(salience) > thred else 0
if salience.ndim == 2:
return np.array([to_local_average_cents(salience[i, :], None, thred) for i in
range(salience.shape[0])])
raise Exception("label should be either 1d or 2d ndarray")
def to_viterbi_cents(salience, thred=0.03):
# Create viterbi transition matrix
if not hasattr(to_viterbi_cents, 'transition'):
xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
transition = np.maximum(30 - abs(xx - yy), 0)
transition = transition / transition.sum(axis=1, keepdims=True)
to_viterbi_cents.transition = transition
# Convert to probability
prob = salience.T
prob = prob / prob.sum(axis=0)
# Perform viterbi decoding
path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64)
return np.array([to_local_average_cents(salience[i, :], path[i], thred) for i in
range(len(path))])
def to_local_average_f0(hidden, center=None, thred=0.03):
idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N]
idx_cents = idx * 20 + CONST # [B=1, N]
if center is None:
center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
start = torch.clip(center - 4, min=0) # [B, T, 1]
end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1]
idx_mask = (idx >= start) & (idx < end) # [B, T, N]
weights = hidden * idx_mask # [B, T, N]
product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T]
weight_sum = torch.sum(weights, dim=2) # [B, T]
cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
f0 = 10 * 2 ** (cents / 1200)
uv = hidden.max(dim=2)[0] < thred # [B, T]
f0 = f0 * ~uv
return f0.squeeze(0).cpu().numpy()
def to_viterbi_f0(hidden, thred=0.03):
# Create viterbi transition matrix
if not hasattr(to_viterbi_cents, 'transition'):
xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
transition = np.maximum(30 - abs(xx - yy), 0)
transition = transition / transition.sum(axis=1, keepdims=True)
to_viterbi_cents.transition = transition
# Convert to probability
prob = hidden.squeeze(0).cpu().numpy()
prob = prob.T
prob = prob / prob.sum(axis=0)
# Perform viterbi decoding
path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64)
center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device)
return to_local_average_f0(hidden, center=center, thred=thred)