|
|
import os |
|
|
import numpy as np |
|
|
import yaml |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import HubertModel, Wav2Vec2FeatureExtractor |
|
|
from fairseq import checkpoint_utils |
|
|
|
|
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
|
|
from torchaudio.transforms import Resample |
|
|
from .unit2control import Unit2ControlFac |
|
|
from .core import frequency_filter, upsample, remove_above_fmax |
|
|
|
|
|
|
|
|
|
|
|
CREPE_RESAMPLE_KERNEL = {} |
|
|
F0_KERNEL = {} |
|
|
|
|
|
class F0_Extractor: |
|
|
def __init__(self, f0_extractor, sample_rate = 44100, hop_size = 512, f0_min = 65, f0_max = 800): |
|
|
self.f0_extractor = f0_extractor |
|
|
self.sample_rate = sample_rate |
|
|
self.hop_size = hop_size |
|
|
self.f0_min = f0_min |
|
|
self.f0_max = f0_max |
|
|
if f0_extractor == 'crepe': |
|
|
key_str = str(sample_rate) |
|
|
if key_str not in CREPE_RESAMPLE_KERNEL: |
|
|
CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width = 128) |
|
|
self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str] |
|
|
if f0_extractor == 'rmvpe': |
|
|
if 'rmvpe' not in F0_KERNEL : |
|
|
from rmvpe import RMVPE |
|
|
F0_KERNEL['rmvpe'] = RMVPE('utils/pretrain/rmvpe/model.pt', hop_length=160) |
|
|
self.rmvpe = F0_KERNEL['rmvpe'] |
|
|
|
|
|
def extract(self, audio, uv_interp = False, device = None, silence_front = 0): |
|
|
|
|
|
n_frames = int(len(audio) // self.hop_size) + 1 |
|
|
|
|
|
start_frame = int(silence_front * self.sample_rate / self.hop_size) |
|
|
real_silence_front = start_frame * self.hop_size / self.sample_rate |
|
|
audio = audio[int(np.round(real_silence_front * self.sample_rate)) : ] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.f0_extractor == "rmvpe": |
|
|
f0 = self.rmvpe.infer_from_audio(audio, self.sample_rate, device=device, thred=0.03, use_viterbi=False) |
|
|
uv = f0 == 0 |
|
|
if len(f0[~uv]) > 0: |
|
|
f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) |
|
|
origin_time = 0.01 * np.arange(len(f0)) |
|
|
target_time = self.hop_size / self.sample_rate * np.arange(n_frames - start_frame) |
|
|
f0 = np.interp(target_time, origin_time, f0) |
|
|
uv = np.interp(target_time, origin_time, uv.astype(float)) > 0.5 |
|
|
f0[uv] = 0 |
|
|
f0 = np.pad(f0, (start_frame, 0)) |
|
|
|
|
|
else: |
|
|
raise ValueError(f" [x] Unknown f0 extractor: {self.f0_extractor}") |
|
|
|
|
|
|
|
|
if uv_interp: |
|
|
uv = f0 == 0 |
|
|
if len(f0[~uv]) > 0: |
|
|
f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) |
|
|
f0[f0 < self.f0_min] = self.f0_min |
|
|
return f0 |
|
|
|
|
|
def batch_extract(self, audios, uv_interp=False, device=None, silence_front=0): |
|
|
processed_f0s = [] |
|
|
for audio in audios: |
|
|
|
|
|
if self.f0_extractor == "rmvpe": |
|
|
f0 = self.rmvpe.infer_from_audio(audio, self.sample_rate, device=device, thred=0.03, use_viterbi=False) |
|
|
f0 = torch.tensor(f0, dtype=torch.float32, device=device) |
|
|
n_frames = int(len(audio) // self.hop_size) + 1 |
|
|
start_frame = int(silence_front * self.sample_rate / self.hop_size) |
|
|
real_silence_front = start_frame * self.hop_size / self.sample_rate |
|
|
audio = audio[int(np.round(real_silence_front * self.sample_rate)):] |
|
|
|
|
|
target_time = self.hop_size / self.sample_rate * torch.arange(n_frames - start_frame, device=device) |
|
|
f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), size=n_frames - start_frame, mode='linear').squeeze() |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown f0 extractor: {self.f0_extractor}") |
|
|
|
|
|
processed_f0s.append(f0) |
|
|
|
|
|
processed_f0s = torch.stack(processed_f0s, 0) |
|
|
return processed_f0s |
|
|
|
|
|
class Volume_Extractor: |
|
|
def __init__(self, hop_size = 512): |
|
|
self.hop_size = hop_size |
|
|
|
|
|
def extract(self, audio): |
|
|
n_frames = int(len(audio) // self.hop_size) + 1 |
|
|
audio2 = audio ** 2 |
|
|
audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect') |
|
|
volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)]) |
|
|
volume = np.sqrt(volume) |
|
|
return volume |
|
|
|
|
|
|
|
|
class DotDict(dict): |
|
|
def __getattr__(*args): |
|
|
val = dict.get(*args) |
|
|
return DotDict(val) if type(val) is dict else val |
|
|
|
|
|
__setattr__ = dict.__setitem__ |
|
|
__delattr__ = dict.__delitem__ |
|
|
|
|
|
def load_model( |
|
|
model_path, |
|
|
device='cpu'): |
|
|
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') |
|
|
with open(config_file, "r") as config: |
|
|
args = yaml.safe_load(config) |
|
|
args = DotDict(args) |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
if args.model.type == 'Sins': |
|
|
model = Sins( |
|
|
sampling_rate=args.data.sampling_rate, |
|
|
block_size=args.data.block_size, |
|
|
n_harmonics=args.model.n_harmonics, |
|
|
n_mag_allpass=args.model.n_mag_allpass, |
|
|
n_mag_noise=args.model.n_mag_noise, |
|
|
n_unit=args.data.encoder_out_channels, |
|
|
n_spk=args.model.n_spk) |
|
|
|
|
|
elif args.model.type == 'CombSub': |
|
|
model = CombSub( |
|
|
sampling_rate=args.data.sampling_rate, |
|
|
block_size=args.data.block_size, |
|
|
n_mag_allpass=args.model.n_mag_allpass, |
|
|
n_mag_harmonic=args.model.n_mag_harmonic, |
|
|
n_mag_noise=args.model.n_mag_noise, |
|
|
n_unit=args.data.encoder_out_channels, |
|
|
n_spk=args.model.n_spk) |
|
|
|
|
|
elif args.model.type == 'CombSubFast': |
|
|
model = CombSubFast( |
|
|
sampling_rate=args.data.sampling_rate, |
|
|
block_size=args.data.block_size, |
|
|
n_unit=args.data.encoder_out_channels, |
|
|
n_spk=args.model.n_spk) |
|
|
|
|
|
else: |
|
|
raise ValueError(f" [x] Unknown Model: {args.model.type}") |
|
|
|
|
|
print(' [Loading] ' + model_path) |
|
|
ckpt = torch.load(model_path, map_location=torch.device(device)) |
|
|
model.to(device) |
|
|
model.load_state_dict(ckpt['model']) |
|
|
model.eval() |
|
|
return model, args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CombSubFastFac(torch.nn.Module): |
|
|
def __init__(self, |
|
|
sampling_rate, |
|
|
block_size, |
|
|
n_unit=256, |
|
|
n_spk=1, |
|
|
use_pitch_aug=False, |
|
|
pcmer_norm=False): |
|
|
super().__init__() |
|
|
|
|
|
print(' [DDSP Model] Combtooth Subtractive Synthesiser') |
|
|
|
|
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) |
|
|
self.register_buffer("block_size", torch.tensor(block_size)) |
|
|
self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size))) |
|
|
|
|
|
split_map = { |
|
|
'harmonic_magnitude': block_size + 1, |
|
|
'harmonic_phase': block_size + 1, |
|
|
'noise_magnitude': block_size + 1 |
|
|
} |
|
|
self.unit2ctrl = Unit2ControlFac(n_unit, split_map, use_pitch_aug=use_pitch_aug, pcmer_norm=pcmer_norm) |
|
|
|
|
|
def forward(self, units_frames, f0_frames, volume_frames, spk, aug_shift=None, initial_phase=None, infer=True, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f0_frames = f0_frames.unsqueeze(2) |
|
|
volume_frames = volume_frames.unsqueeze(2) |
|
|
|
|
|
f0 = upsample(f0_frames, self.block_size) |
|
|
if infer: |
|
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) |
|
|
else: |
|
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1) |
|
|
if initial_phase is not None: |
|
|
x += initial_phase.to(x) / 2 / np.pi |
|
|
x = x - torch.round(x) |
|
|
x = x.to(f0) |
|
|
|
|
|
phase_frames = 2 * np.pi * x[:, ::self.block_size, :] |
|
|
|
|
|
|
|
|
ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk, aug_shift=aug_shift) |
|
|
|
|
|
src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase']) |
|
|
src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1) |
|
|
noise_filter= torch.exp(ctrls['noise_magnitude']) / 128 |
|
|
noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1) |
|
|
|
|
|
|
|
|
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) |
|
|
combtooth = combtooth.squeeze(-1) |
|
|
combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) |
|
|
combtooth_frames = combtooth_frames * self.window |
|
|
combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size) |
|
|
|
|
|
|
|
|
noise = torch.rand_like(combtooth) * 2 - 1 |
|
|
noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) |
|
|
noise_frames = noise_frames * self.window |
|
|
noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size) |
|
|
|
|
|
|
|
|
signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter |
|
|
|
|
|
|
|
|
signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window |
|
|
|
|
|
|
|
|
fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size)) |
|
|
signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size] |
|
|
|
|
|
return signal, hidden, (signal, signal) |
|
|
|
|
|
class CombSubFast_SingingEnhance(torch.nn.Module): |
|
|
def __init__(self, |
|
|
sampling_rate, |
|
|
block_size, |
|
|
n_unit=256, |
|
|
n_spk=1, |
|
|
use_pitch_aug=False, |
|
|
pcmer_norm=False): |
|
|
super().__init__() |
|
|
|
|
|
print(' [DDSP Model] Combtooth Subtractive Synthesiser') |
|
|
|
|
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) |
|
|
self.register_buffer("block_size", torch.tensor(block_size)) |
|
|
self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size))) |
|
|
|
|
|
split_map = { |
|
|
'harmonic_magnitude': block_size + 1, |
|
|
'harmonic_phase': block_size + 1, |
|
|
'noise_magnitude': block_size + 1 |
|
|
} |
|
|
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map, use_pitch_aug=use_pitch_aug, pcmer_norm=pcmer_norm) |
|
|
|
|
|
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, aug_shift=None, initial_phase=None, infer=True, **kwargs): |
|
|
''' |
|
|
units_frames: B x n_frames x n_unit |
|
|
f0_frames: B x n_frames x 1 |
|
|
volume_frames: B x n_frames x 1 |
|
|
spk_id: B x 1 |
|
|
''' |
|
|
|
|
|
f0 = upsample(f0_frames, self.block_size) |
|
|
if infer: |
|
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) |
|
|
else: |
|
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1) |
|
|
if initial_phase is not None: |
|
|
x += initial_phase.to(x) / 2 / np.pi |
|
|
x = x - torch.round(x) |
|
|
x = x.to(f0) |
|
|
|
|
|
phase_frames = 2 * np.pi * x[:, ::self.block_size, :] |
|
|
|
|
|
|
|
|
ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift) |
|
|
|
|
|
src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase']) |
|
|
src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1) |
|
|
noise_filter= torch.exp(ctrls['noise_magnitude']) / 128 |
|
|
noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1) |
|
|
|
|
|
|
|
|
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) |
|
|
combtooth = combtooth.squeeze(-1) |
|
|
combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) |
|
|
combtooth_frames = combtooth_frames * self.window |
|
|
combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size) |
|
|
|
|
|
|
|
|
noise = torch.rand_like(combtooth) * 2 - 1 |
|
|
noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) |
|
|
noise_frames = noise_frames * self.window |
|
|
noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size) |
|
|
|
|
|
|
|
|
signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter |
|
|
|
|
|
|
|
|
signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window |
|
|
|
|
|
|
|
|
fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size)) |
|
|
signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size] |
|
|
|
|
|
return signal, hidden, (signal, signal) |
|
|
|
|
|
class CombSub(torch.nn.Module): |
|
|
def __init__(self, |
|
|
sampling_rate, |
|
|
block_size, |
|
|
n_mag_allpass, |
|
|
n_mag_harmonic, |
|
|
n_mag_noise, |
|
|
n_unit=256, |
|
|
n_spk=1): |
|
|
super().__init__() |
|
|
|
|
|
print(' [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)') |
|
|
|
|
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) |
|
|
self.register_buffer("block_size", torch.tensor(block_size)) |
|
|
|
|
|
split_map = { |
|
|
'group_delay': n_mag_allpass, |
|
|
'harmonic_magnitude': n_mag_harmonic, |
|
|
'noise_magnitude': n_mag_noise |
|
|
} |
|
|
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map) |
|
|
|
|
|
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs): |
|
|
''' |
|
|
units_frames: B x n_frames x n_unit |
|
|
f0_frames: B x n_frames x 1 |
|
|
volume_frames: B x n_frames x 1 |
|
|
spk_id: B x 1 |
|
|
''' |
|
|
|
|
|
f0 = upsample(f0_frames, self.block_size) |
|
|
if infer: |
|
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) |
|
|
else: |
|
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1) |
|
|
if initial_phase is not None: |
|
|
x += initial_phase.to(x) / 2 / np.pi |
|
|
x = x - torch.round(x) |
|
|
x = x.to(f0) |
|
|
|
|
|
phase_frames = 2 * np.pi * x[:, ::self.block_size, :] |
|
|
|
|
|
|
|
|
ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict) |
|
|
|
|
|
group_delay = np.pi * torch.tanh(ctrls['group_delay']) |
|
|
src_param = torch.exp(ctrls['harmonic_magnitude']) |
|
|
noise_param = torch.exp(ctrls['noise_magnitude']) / 128 |
|
|
|
|
|
|
|
|
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) |
|
|
combtooth = combtooth.squeeze(-1) |
|
|
|
|
|
|
|
|
harmonic = frequency_filter( |
|
|
combtooth, |
|
|
torch.exp(1.j * torch.cumsum(group_delay, axis = -1)), |
|
|
hann_window = False) |
|
|
harmonic = frequency_filter( |
|
|
harmonic, |
|
|
torch.complex(src_param, torch.zeros_like(src_param)), |
|
|
hann_window = True, |
|
|
half_width_frames = 1.5 * self.sampling_rate / (f0_frames + 1e-3)) |
|
|
|
|
|
|
|
|
noise = torch.rand_like(harmonic) * 2 - 1 |
|
|
noise = frequency_filter( |
|
|
noise, |
|
|
torch.complex(noise_param, torch.zeros_like(noise_param)), |
|
|
hann_window = True) |
|
|
|
|
|
signal = harmonic + noise |
|
|
|
|
|
return signal, hidden, (harmonic, noise) |
|
|
|
|
|
|
|
|
class Units_Encoder: |
|
|
def __init__(self, encoder, encoder_ckpt, encoder_sample_rate = 16000, encoder_hop_size = 320, device = None, |
|
|
cnhubertsoft_gate=10): |
|
|
if device is None: |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.device = device |
|
|
|
|
|
is_loaded_encoder = False |
|
|
if encoder == 'hubertsoft': |
|
|
self.model = Audio2HubertSoft(encoder_ckpt).to(device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'hubertbase': |
|
|
self.model = Audio2HubertBase(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'hubertbase768': |
|
|
self.model = Audio2HubertBase768(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'hubertbase768l12': |
|
|
self.model = Audio2HubertBase768L12(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'hubertlarge1024l24': |
|
|
self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'contentvec': |
|
|
self.model = Audio2ContentVec(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'contentvec768': |
|
|
self.model = Audio2ContentVec768(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'spin': |
|
|
self.model = Audio2Spin(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'contentvec768l12': |
|
|
self.model = Audio2ContentVec768L12(encoder_ckpt, device=device) |
|
|
is_loaded_encoder = True |
|
|
if encoder == 'cnhubertsoftfish': |
|
|
self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate) |
|
|
is_loaded_encoder = True |
|
|
if not is_loaded_encoder: |
|
|
raise ValueError(f" [x] Unknown units encoder: {encoder}") |
|
|
|
|
|
self.resample_kernel = {} |
|
|
self.encoder_sample_rate = encoder_sample_rate |
|
|
self.encoder_hop_size = encoder_hop_size |
|
|
|
|
|
def encode(self, |
|
|
audio, |
|
|
sample_rate, |
|
|
hop_size): |
|
|
|
|
|
|
|
|
if sample_rate == self.encoder_sample_rate: |
|
|
audio_res = audio |
|
|
else: |
|
|
key_str = str(sample_rate) |
|
|
if key_str not in self.resample_kernel: |
|
|
self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device) |
|
|
audio_res = self.resample_kernel[key_str](audio) |
|
|
|
|
|
|
|
|
if audio_res.size(-1) < 400: |
|
|
audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1))) |
|
|
units = self.model(audio_res) |
|
|
|
|
|
|
|
|
n_frames = audio.size(-1) // hop_size + 1 |
|
|
ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate) |
|
|
index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1) |
|
|
units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)])) |
|
|
return units_aligned |
|
|
|
|
|
def batch_encode(self, |
|
|
audio, |
|
|
sample_rate, |
|
|
hop_size): |
|
|
units_aligned_batch = [] |
|
|
for i in range(audio.size(0)): |
|
|
audio |
|
|
|
|
|
if sample_rate == self.encoder_sample_rate: |
|
|
audio_res = audio[i] |
|
|
else: |
|
|
key_str = str(sample_rate) |
|
|
if key_str not in self.resample_kernel: |
|
|
self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device) |
|
|
audio_res = self.resample_kernel[key_str](audio[i]) |
|
|
|
|
|
|
|
|
if audio_res.size(-1) < 400: |
|
|
audio_res = torch.nn.functional.pad(audio[i], (0, 400 - audio_res.size(-1))) |
|
|
units = self.model(audio_res) |
|
|
|
|
|
|
|
|
n_frames = audio.size(-1) // hop_size + 1 |
|
|
ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate) |
|
|
index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1) |
|
|
units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)])) |
|
|
units_aligned_batch.append(units_aligned.squeeze(0)) |
|
|
return torch.stack(units_aligned_batch, 0) |
|
|
|
|
|
|
|
|
class Audio2HubertSoft(torch.nn.Module): |
|
|
def __init__(self, path, h_sample_rate = 16000, h_hop_size = 320): |
|
|
super().__init__() |
|
|
print(' [Encoder Model] HuBERT Soft') |
|
|
self.hubert = HubertSoft() |
|
|
print(' [Loading] ' + path) |
|
|
checkpoint = torch.load(path) |
|
|
consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
|
|
self.hubert.load_state_dict(checkpoint) |
|
|
self.hubert.eval() |
|
|
|
|
|
def forward(self, |
|
|
audio): |
|
|
with torch.inference_mode(): |
|
|
units = self.hubert.units(audio.unsqueeze(1)) |
|
|
return units |
|
|
|
|
|
|
|
|
class Audio2ContentVec(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] Content Vec') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
|
|
|
wav_tensor = audio |
|
|
feats = wav_tensor.view(1, -1) |
|
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": feats.to(wav_tensor.device), |
|
|
"padding_mask": padding_mask.to(wav_tensor.device), |
|
|
"output_layer": 9, |
|
|
} |
|
|
with torch.no_grad(): |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
feats = self.hubert.final_proj(logits[0]) |
|
|
units = feats |
|
|
return units |
|
|
|
|
|
|
|
|
class Audio2ContentVec768(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] Content Vec') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
|
|
|
wav_tensor = audio |
|
|
print('wav_tensor.shape: ', wav_tensor.shape) |
|
|
feats = wav_tensor.view(1, -1) |
|
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": feats.to(wav_tensor.device), |
|
|
"padding_mask": padding_mask.to(wav_tensor.device), |
|
|
"output_layer": 9, |
|
|
} |
|
|
with torch.no_grad(): |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
feats = logits[0] |
|
|
units = feats |
|
|
return units |
|
|
|
|
|
|
|
|
class Audio2ContentVec768L12(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] Content Vec') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
|
|
|
wav_tensor = audio |
|
|
feats = wav_tensor.view(1, -1) |
|
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": feats.to(wav_tensor.device), |
|
|
"padding_mask": padding_mask.to(wav_tensor.device), |
|
|
"output_layer": 12, |
|
|
} |
|
|
with torch.no_grad(): |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
feats = logits[0] |
|
|
units = feats |
|
|
return units |
|
|
|
|
|
class Audio2Spin(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] Spin') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
|
|
|
wav_tensor = audio |
|
|
feats = wav_tensor.view(1, -1) |
|
|
padding_mask = torch.BoolTensor(feats.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": feats.to(wav_tensor.device), |
|
|
"padding_mask": padding_mask.to(wav_tensor.device), |
|
|
"output_layer": 12, |
|
|
} |
|
|
with torch.no_grad(): |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
feats = logits[0] |
|
|
units = feats |
|
|
return units |
|
|
|
|
|
class CNHubertSoftFish(torch.nn.Module): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu', gate_size=10): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.gate_size = gate_size |
|
|
|
|
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
|
"./pretrain/TencentGameMate/chinese-hubert-base") |
|
|
self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base") |
|
|
self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256)) |
|
|
|
|
|
|
|
|
state_dict = torch.load(path, map_location=device) |
|
|
self.load_state_dict(state_dict) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, audio): |
|
|
input_values = self.feature_extractor( |
|
|
audio, sampling_rate=16000, return_tensors="pt" |
|
|
).input_values |
|
|
input_values = input_values.to(self.model.device) |
|
|
|
|
|
return self._forward(input_values[0]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _forward(self, input_values): |
|
|
features = self.model(input_values) |
|
|
features = self.proj(features.last_hidden_state) |
|
|
|
|
|
|
|
|
topk, indices = torch.topk(features, self.gate_size, dim=2) |
|
|
features = torch.zeros_like(features).scatter(2, indices, topk) |
|
|
features = features / features.sum(2, keepdim=True) |
|
|
|
|
|
return features.to(self.device) |
|
|
|
|
|
|
|
|
class Audio2HubertBase(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] HuBERT Base') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert = self.hubert.float() |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
with torch.no_grad(): |
|
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": audio.to(self.device), |
|
|
"padding_mask": padding_mask.to(self.device), |
|
|
"output_layer": 9, |
|
|
} |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
units = self.hubert.final_proj(logits[0]) |
|
|
return units |
|
|
|
|
|
|
|
|
class Audio2HubertBase768(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] HuBERT Base') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert = self.hubert.float() |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
with torch.no_grad(): |
|
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": audio.to(self.device), |
|
|
"padding_mask": padding_mask.to(self.device), |
|
|
"output_layer": 9, |
|
|
} |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
units = logits[0] |
|
|
return units |
|
|
|
|
|
|
|
|
class Audio2HubertBase768L12(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] HuBERT Base') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert = self.hubert.float() |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
with torch.no_grad(): |
|
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": audio.to(self.device), |
|
|
"padding_mask": padding_mask.to(self.device), |
|
|
"output_layer": 12, |
|
|
} |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
units = logits[0] |
|
|
return units |
|
|
|
|
|
|
|
|
class Audio2HubertLarge1024L24(): |
|
|
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): |
|
|
self.device = device |
|
|
print(' [Encoder Model] HuBERT Base') |
|
|
print(' [Loading] ' + path) |
|
|
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) |
|
|
self.hubert = self.models[0] |
|
|
self.hubert = self.hubert.to(self.device) |
|
|
self.hubert = self.hubert.float() |
|
|
self.hubert.eval() |
|
|
|
|
|
def __call__(self, |
|
|
audio): |
|
|
with torch.no_grad(): |
|
|
padding_mask = torch.BoolTensor(audio.shape).fill_(False) |
|
|
inputs = { |
|
|
"source": audio.to(self.device), |
|
|
"padding_mask": padding_mask.to(self.device), |
|
|
"output_layer": 24, |
|
|
} |
|
|
logits = self.hubert.extract_features(**inputs) |
|
|
units = logits[0] |
|
|
return units |
|
|
|
|
|
|
|
|
|