NeoPy's picture
EXP
2dcbf9e verified
import os
import sys
import torch
import librosa
import scipy.stats
import numpy as np
sys.path.append(os.getcwd())
CENTS_PER_BIN, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 360, 16000, 1024
class CREPE:
def __init__(
self,
model_path,
model_size="full",
hop_length=512,
batch_size=None,
f0_min=50,
f0_max=1100,
device=None,
sample_rate=16000,
providers=None,
onnx=False,
return_periodicity=False
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.hop_length = hop_length
self.batch_size = batch_size
self.sample_rate = sample_rate
self.onnx = onnx
self.f0_min = f0_min
self.f0_max = f0_max
self.return_periodicity = return_periodicity
if self.onnx:
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
else:
from main.library.predictors.CREPE.model import CREPEE
model = CREPEE(model_size)
model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
model.eval()
self.model = model.to(device)
def bins_to_frequency(self, bins):
if str(bins.device).startswith(("ocl", "privateuseone")): bins = bins.to(torch.float32)
cents = CENTS_PER_BIN * bins + 1997.3794084376191
cents = (
cents + cents.new_tensor(
scipy.stats.triang.rvs(
c=0.5,
loc=-CENTS_PER_BIN,
scale=2 * CENTS_PER_BIN,
size=cents.size()
)
)
) / 1200
return 10 * 2 ** cents
def frequency_to_bins(self, frequency, quantize_fn=torch.floor):
return quantize_fn(((1200 * (frequency / 10).log2()) - 1997.3794084376191) / CENTS_PER_BIN).int()
def viterbi(self, logits):
if not hasattr(self, 'transition'):
xx, yy = np.meshgrid(range(360), range(360))
transition = np.maximum(12 - abs(xx - yy), 0)
self.transition = transition / transition.sum(axis=1, keepdims=True)
with torch.no_grad():
probs = torch.nn.functional.softmax(logits, dim=1)
bins = torch.tensor(
np.array([
librosa.sequence.viterbi(sequence, self.transition).astype(np.int64)
for sequence in probs.cpu().numpy()
]),
device=probs.device
)
return bins, self.bins_to_frequency(bins)
def preprocess(self, audio, pad=True):
hop_length = (self.sample_rate // 100) if self.hop_length is None else self.hop_length
if self.sample_rate != SAMPLE_RATE:
audio = torch.tensor(
librosa.resample(
audio.detach().cpu().numpy().squeeze(0),
orig_sr=self.sample_rate,
target_sr=SAMPLE_RATE,
res_type="soxr_vhq"
),
device=audio.device
).unsqueeze(0)
hop_length = int(hop_length * SAMPLE_RATE / self.sample_rate)
if pad:
total_frames = 1 + int(audio.size(1) // hop_length)
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
batch_size = total_frames if self.batch_size is None else self.batch_size
for i in range(0, total_frames, batch_size):
frames = torch.nn.functional.unfold(
audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)],
kernel_size=(1, WINDOW_SIZE),
stride=(1, hop_length)
)
if self.device.startswith(("ocl", "privateuseone")):
frames = frames.transpose(1, 2).contiguous().reshape(-1, WINDOW_SIZE).to(self.device)
else:
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(self.device)
frames -= frames.mean(dim=1, keepdim=True)
frames /= torch.tensor(1e-10, device=frames.device).max(frames.std(dim=1, keepdim=True))
yield frames
def periodicity(self, probabilities, bins):
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
def postprocess(self, probabilities):
probabilities = probabilities.detach()
probabilities[:, :self.frequency_to_bins(torch.tensor(self.f0_min))] = -float('inf')
probabilities[:, self.frequency_to_bins(torch.tensor(self.f0_max), torch.ceil):] = -float('inf')
bins, pitch = self.viterbi(probabilities)
if not self.return_periodicity: return pitch
return pitch, self.periodicity(probabilities, bins)
def compute_f0(self, audio, pad=True):
results = []
for frames in self.preprocess(audio, pad):
if self.onnx:
model = torch.tensor(
self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: frames.cpu().numpy()
}
)[0].transpose(1, 0)[None],
device=self.device
)
else:
with torch.no_grad():
model = self.model(
frames,
embed=False
).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2)
result = self.postprocess(model)
results.append(
(result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device)
)
if self.return_periodicity:
pitch, periodicity = zip(*results)
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
return torch.cat(results, 1)