NeoPy's picture
Update infer/lib/predictors/DJCM/DJCM.py
be43029 verified
import os
import sys
import torch
import numpy as np
from scipy.signal import medfilt
sys.path.append(os.getcwd())
from infer.lib.predictors.DJCM.spec import Spectrogram
SAMPLE_RATE, WINDOW_LENGTH, N_CLASS = 16000, 1024, 360
class DJCM:
def __init__(
self,
model_path,
device = "cpu",
is_half = False,
onnx = False,
svs = False,
providers = ["CPUExecutionProvider"],
batch_size = 1,
segment_len = 5.12,
kernel_size = 3
):
super(DJCM, self).__init__()
if svs: WINDOW_LENGTH = 2048
self.onnx = onnx
if self.onnx:
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
else:
from main.library.predictors.DJCM.model import DJCMM
model = DJCMM(1, 1, 1, svs=svs, window_length=WINDOW_LENGTH, n_class=N_CLASS)
model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
model.eval()
if is_half: model = model.half()
self.model = model.to(device)
self.batch_size = batch_size
self.seg_len = int(segment_len * SAMPLE_RATE)
self.seg_frames = int(self.seg_len // int(SAMPLE_RATE // 100))
self.device = device
self.is_half = is_half
self.kernel_size = kernel_size
self.spec_extractor = Spectrogram(int(SAMPLE_RATE // 100), WINDOW_LENGTH).to(device)
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4))
def spec2hidden(self, spec):
if self.onnx:
spec = spec.cpu().numpy().astype(np.float32)
hidden = torch.as_tensor(
self.model.run(
[self.model.get_outputs()[0].name],
{self.model.get_inputs()[0].name: spec}
)[0],
device=self.device
)
else:
if self.is_half: spec = spec.half()
hidden = self.model(spec)
return hidden
def infer_from_audio(self, audio, thred=0.03):
if torch.is_tensor(audio): audio = audio.cpu().numpy()
if audio.ndim > 1: audio = audio.squeeze()
with torch.no_grad():
padded_audio = self.pad_audio(audio)
hidden = self.inference(padded_audio)[:(audio.shape[-1] // int(SAMPLE_RATE // 100) + 1)]
f0 = self.decode(hidden.squeeze(0).cpu().numpy(), thred)
if self.kernel_size is not None: f0 = medfilt(f0, kernel_size=self.kernel_size)
return f0
def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
f0 = self.infer_from_audio(audio, thred)
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
return f0
def to_local_average_cents(self, salience, thred=0.05):
center = np.argmax(salience, axis=1)
salience = np.pad(salience, ((0, 0), (4, 4)))
center += 4
todo_salience, todo_cents_mapping = [], []
starts = center - 4
ends = center + 5
for idx in range(salience.shape[0]):
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
todo_salience = np.array(todo_salience)
devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
devided[np.max(salience, axis=1) <= thred] = 0
return devided
def decode(self, hidden, thred=0.03):
f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
f0[f0 == 10] = 0
return f0
def pad_audio(self, audio):
audio_len = audio.shape[-1]
seg_nums = int(np.ceil(audio_len / self.seg_len)) + 1
pad_len = int(seg_nums * self.seg_len - audio_len + self.seg_len // 2)
left_pad = np.zeros(int(self.seg_len // 4), dtype=np.float32)
right_pad = np.zeros(int(pad_len - self.seg_len // 4), dtype=np.float32)
padded_audio = np.concatenate([left_pad, audio, right_pad], axis=-1)
segments = [
padded_audio[start: start + int(self.seg_len)]
for start in range(
0,
len(padded_audio) - int(self.seg_len) + 1,
int(self.seg_len // 2)
)
]
segments = np.stack(segments, axis=0)
segments = torch.from_numpy(segments).unsqueeze(1).to(self.device)
return segments
def inference(self, segments):
hidden_segments = torch.cat([
self.spec2hidden(self.spec_extractor(segments[i:i + self.batch_size].float()))
for i in range(0, len(segments), self.batch_size)
], dim=0)
hidden = torch.cat([
seg[self.seg_frames // 4: int(self.seg_frames * 0.75)]
for seg in hidden_segments
], dim=0)
return hidden