import os import numpy as np import yaml import torch import torch.nn.functional as F # import pyworld as pw # import parselmouth # import torchcrepe # import resampy from transformers import HubertModel, Wav2Vec2FeatureExtractor from fairseq import checkpoint_utils # from encoder.hubert.model import HubertSoft from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from torchaudio.transforms import Resample from .unit2control import Unit2ControlFac from .core import frequency_filter, upsample, remove_above_fmax # from .core import MaskedAvgPool1d, MedianPool1d # import time # import librosa CREPE_RESAMPLE_KERNEL = {} F0_KERNEL = {} class F0_Extractor: def __init__(self, f0_extractor, sample_rate = 44100, hop_size = 512, f0_min = 65, f0_max = 800): self.f0_extractor = f0_extractor self.sample_rate = sample_rate self.hop_size = hop_size self.f0_min = f0_min self.f0_max = f0_max if f0_extractor == 'crepe': key_str = str(sample_rate) if key_str not in CREPE_RESAMPLE_KERNEL: CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width = 128) self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str] if f0_extractor == 'rmvpe': if 'rmvpe' not in F0_KERNEL : from rmvpe import RMVPE F0_KERNEL['rmvpe'] = RMVPE('utils/pretrain/rmvpe/model.pt', hop_length=160) self.rmvpe = F0_KERNEL['rmvpe'] def extract(self, audio, uv_interp = False, device = None, silence_front = 0): # audio: 1d numpy array # extractor start time n_frames = int(len(audio) // self.hop_size) + 1 start_frame = int(silence_front * self.sample_rate / self.hop_size) real_silence_front = start_frame * self.hop_size / self.sample_rate audio = audio[int(np.round(real_silence_front * self.sample_rate)) : ] # # extract f0 using parselmouth # if self.f0_extractor == 'parselmouth': # f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac( # time_step = self.hop_size / self.sample_rate, # voicing_threshold = 0.6, # pitch_floor = self.f0_min, # pitch_ceiling = self.f0_max).selected_array['frequency'] # pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2 # f0 = np.pad(f0,(pad_size, n_frames - len(f0) - pad_size)) # extract f0 using rmvpe if self.f0_extractor == "rmvpe": f0 = self.rmvpe.infer_from_audio(audio, self.sample_rate, device=device, thred=0.03, use_viterbi=False) uv = f0 == 0 if len(f0[~uv]) > 0: f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) origin_time = 0.01 * np.arange(len(f0)) target_time = self.hop_size / self.sample_rate * np.arange(n_frames - start_frame) f0 = np.interp(target_time, origin_time, f0) uv = np.interp(target_time, origin_time, uv.astype(float)) > 0.5 f0[uv] = 0 f0 = np.pad(f0, (start_frame, 0)) else: raise ValueError(f" [x] Unknown f0 extractor: {self.f0_extractor}") # interpolate the unvoiced f0 if uv_interp: uv = f0 == 0 # unvoiced frames bool, e.g. [True, False, False, True, False, True] if len(f0[~uv]) > 0: # if there are voiced frames f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) f0[f0 < self.f0_min] = self.f0_min return f0 def batch_extract(self, audios, uv_interp=False, device=None, silence_front=0): processed_f0s = [] for audio in audios: # Extract f0 using rmvpe if self.f0_extractor == "rmvpe": f0 = self.rmvpe.infer_from_audio(audio, self.sample_rate, device=device, thred=0.03, use_viterbi=False) f0 = torch.tensor(f0, dtype=torch.float32, device=device) # Convert to torch tensor n_frames = int(len(audio) // self.hop_size) + 1 start_frame = int(silence_front * self.sample_rate / self.hop_size) real_silence_front = start_frame * self.hop_size / self.sample_rate audio = audio[int(np.round(real_silence_front * self.sample_rate)):] target_time = self.hop_size / self.sample_rate * torch.arange(n_frames - start_frame, device=device) f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), size=n_frames - start_frame, mode='linear').squeeze() else: raise ValueError(f"Unknown f0 extractor: {self.f0_extractor}") processed_f0s.append(f0) processed_f0s = torch.stack(processed_f0s, 0) # Convert list of tensors to tensor return processed_f0s class Volume_Extractor: def __init__(self, hop_size = 512): self.hop_size = hop_size def extract(self, audio): # audio: 1d numpy array n_frames = int(len(audio) // self.hop_size) + 1 audio2 = audio ** 2 audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect') volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)]) volume = np.sqrt(volume) return volume class DotDict(dict): def __getattr__(*args): val = dict.get(*args) return DotDict(val) if type(val) is dict else val __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def load_model( model_path, device='cpu'): config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') with open(config_file, "r") as config: args = yaml.safe_load(config) args = DotDict(args) # load model model = None if args.model.type == 'Sins': model = Sins( sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_harmonics=args.model.n_harmonics, n_mag_allpass=args.model.n_mag_allpass, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk) elif args.model.type == 'CombSub': model = CombSub( sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_mag_allpass=args.model.n_mag_allpass, n_mag_harmonic=args.model.n_mag_harmonic, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk) elif args.model.type == 'CombSubFast': model = CombSubFast( sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk) else: raise ValueError(f" [x] Unknown Model: {args.model.type}") print(' [Loading] ' + model_path) ckpt = torch.load(model_path, map_location=torch.device(device)) model.to(device) model.load_state_dict(ckpt['model']) model.eval() return model, args # class Sins(torch.nn.Module): # def __init__(self, # sampling_rate, # block_size, # n_harmonics, # n_mag_allpass, # n_mag_noise, # n_unit=256, # n_spk=1): # super().__init__() # print(' [DDSP Model] Sinusoids Additive Synthesiser') # # params # self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) # self.register_buffer("block_size", torch.tensor(block_size)) # # Unit2Control # split_map = { # 'amplitudes': n_harmonics, # 'group_delay': n_mag_allpass, # 'noise_magnitude': n_mag_noise, # } # self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map) # def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32): # ''' # units_frames: B x n_frames x n_unit # f0_frames: B x n_frames x 1 # volume_frames: B x n_frames x 1 # spk_id: B x 1 # ''' # # exciter phase # f0 = upsample(f0_frames, self.block_size) # if infer: # x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) # else: # x = torch.cumsum(f0 / self.sampling_rate, axis=1) # if initial_phase is not None: # x += initial_phase.to(x) / 2 / np.pi # x = x - torch.round(x) # x = x.to(f0) # phase = 2 * np.pi * x # phase_frames = phase[:, ::self.block_size, :] # # parameter prediction # ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict) # amplitudes_frames = torch.exp(ctrls['amplitudes'])/ 128 # group_delay = np.pi * torch.tanh(ctrls['group_delay']) # noise_param = torch.exp(ctrls['noise_magnitude']) / 128 # # sinusoids exciter signal # amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start = 1) # n_harmonic = amplitudes_frames.shape[-1] # level_harmonic = torch.arange(1, n_harmonic + 1).to(phase) # sinusoids = 0. # for n in range(( n_harmonic - 1) // max_upsample_dim + 1): # start = n * max_upsample_dim # end = (n + 1) * max_upsample_dim # phases = phase * level_harmonic[start:end] # amplitudes = upsample(amplitudes_frames[:,:,start:end], self.block_size) # sinusoids += (torch.sin(phases) * amplitudes).sum(-1) # # harmonic part filter (apply group-delay) # harmonic = frequency_filter( # sinusoids, # torch.exp(1.j * torch.cumsum(group_delay, axis = -1)), # hann_window = False) # # noise part filter # noise = torch.rand_like(harmonic) * 2 - 1 # noise = frequency_filter( # noise, # torch.complex(noise_param, torch.zeros_like(noise_param)), # hann_window = True) # signal = harmonic + noise # return signal, hidden, (harmonic, noise) #, (noise_param, noise_param) # class CombSubFast(torch.nn.Module): # def __init__(self, # sampling_rate, # block_size, # n_unit=256, # n_spk=1, # use_pitch_aug=False, # pcmer_norm=False): # super().__init__() # print(' [DDSP Model] Combtooth Subtractive Synthesiser') # # params # self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) # self.register_buffer("block_size", torch.tensor(block_size)) # self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size))) # #Unit2Control # split_map = { # 'harmonic_magnitude': block_size + 1, # 'harmonic_phase': block_size + 1, # 'noise_magnitude': block_size + 1 # } # self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map, use_pitch_aug=use_pitch_aug, pcmer_norm=pcmer_norm) # def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, aug_shift=None, initial_phase=None, infer=True, **kwargs): # ''' # units_frames: B x n_frames x n_unit # f0_frames: B x n_frames x 1 # volume_frames: B x n_frames x 1 # spk_id: B x 1 # ''' # # exciter phase # f0 = upsample(f0_frames, self.block_size) # if infer: # x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) # else: # x = torch.cumsum(f0 / self.sampling_rate, axis=1) # if initial_phase is not None: # x += initial_phase.to(x) / 2 / np.pi # x = x - torch.round(x) # x = x.to(f0) # phase_frames = 2 * np.pi * x[:, ::self.block_size, :] # # parameter prediction # ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift) # src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase']) # src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1) # noise_filter= torch.exp(ctrls['noise_magnitude']) / 128 # noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1) # # combtooth exciter signal # combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) # combtooth = combtooth.squeeze(-1) # combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) # combtooth_frames = combtooth_frames * self.window # combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size) # # noise exciter signal # noise = torch.rand_like(combtooth) * 2 - 1 # noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) # noise_frames = noise_frames * self.window # noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size) # # apply the filters # signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter # # take the ifft to resynthesize audio. # signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window # # overlap add # fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size)) # signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size] # return signal, hidden, (signal, signal) class CombSubFastFac(torch.nn.Module): def __init__(self, sampling_rate, block_size, n_unit=256, n_spk=1, use_pitch_aug=False, pcmer_norm=False): super().__init__() print(' [DDSP Model] Combtooth Subtractive Synthesiser') # params self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) self.register_buffer("block_size", torch.tensor(block_size)) self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size))) #Unit2Control split_map = { 'harmonic_magnitude': block_size + 1, 'harmonic_phase': block_size + 1, 'noise_magnitude': block_size + 1 } self.unit2ctrl = Unit2ControlFac(n_unit, split_map, use_pitch_aug=use_pitch_aug, pcmer_norm=pcmer_norm) def forward(self, units_frames, f0_frames, volume_frames, spk, aug_shift=None, initial_phase=None, infer=True, **kwargs): # ''' # units_frames: B x n_frames x n_unit # f0_frames: B x n_frames x 1 # volume_frames: B x n_frames x 1 # spk: B x 256 # ''' # exciter phase # reshape f0_frames = f0_frames.unsqueeze(2) volume_frames = volume_frames.unsqueeze(2) f0 = upsample(f0_frames, self.block_size) if infer: x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) else: x = torch.cumsum(f0 / self.sampling_rate, axis=1) if initial_phase is not None: x += initial_phase.to(x) / 2 / np.pi x = x - torch.round(x) x = x.to(f0) phase_frames = 2 * np.pi * x[:, ::self.block_size, :] # parameter prediction ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk, aug_shift=aug_shift) src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase']) src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1) noise_filter= torch.exp(ctrls['noise_magnitude']) / 128 noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1) # combtooth exciter signal combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) combtooth = combtooth.squeeze(-1) combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) combtooth_frames = combtooth_frames * self.window combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size) # noise exciter signal noise = torch.rand_like(combtooth) * 2 - 1 noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) noise_frames = noise_frames * self.window noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size) # apply the filters signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter # take the ifft to resynthesize audio. signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window # overlap add fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size)) signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size] return signal, hidden, (signal, signal) class CombSubFast_SingingEnhance(torch.nn.Module): def __init__(self, sampling_rate, block_size, n_unit=256, n_spk=1, use_pitch_aug=False, pcmer_norm=False): super().__init__() print(' [DDSP Model] Combtooth Subtractive Synthesiser') # params self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) self.register_buffer("block_size", torch.tensor(block_size)) self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size))) #Unit2Control split_map = { 'harmonic_magnitude': block_size + 1, 'harmonic_phase': block_size + 1, 'noise_magnitude': block_size + 1 } self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map, use_pitch_aug=use_pitch_aug, pcmer_norm=pcmer_norm) def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, aug_shift=None, initial_phase=None, infer=True, **kwargs): ''' units_frames: B x n_frames x n_unit f0_frames: B x n_frames x 1 volume_frames: B x n_frames x 1 spk_id: B x 1 ''' # exciter phase f0 = upsample(f0_frames, self.block_size) if infer: x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) else: x = torch.cumsum(f0 / self.sampling_rate, axis=1) if initial_phase is not None: x += initial_phase.to(x) / 2 / np.pi x = x - torch.round(x) x = x.to(f0) phase_frames = 2 * np.pi * x[:, ::self.block_size, :] # parameter prediction ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict, aug_shift=aug_shift) src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase']) src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1) noise_filter= torch.exp(ctrls['noise_magnitude']) / 128 noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1) # combtooth exciter signal combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) combtooth = combtooth.squeeze(-1) combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) combtooth_frames = combtooth_frames * self.window combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size) # noise exciter signal noise = torch.rand_like(combtooth) * 2 - 1 noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size) noise_frames = noise_frames * self.window noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size) # apply the filters signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter # take the ifft to resynthesize audio. signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window # overlap add fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size)) signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size] return signal, hidden, (signal, signal) class CombSub(torch.nn.Module): def __init__(self, sampling_rate, block_size, n_mag_allpass, n_mag_harmonic, n_mag_noise, n_unit=256, n_spk=1): super().__init__() print(' [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)') # params self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) self.register_buffer("block_size", torch.tensor(block_size)) #Unit2Control split_map = { 'group_delay': n_mag_allpass, 'harmonic_magnitude': n_mag_harmonic, 'noise_magnitude': n_mag_noise } self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map) def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs): ''' units_frames: B x n_frames x n_unit f0_frames: B x n_frames x 1 volume_frames: B x n_frames x 1 spk_id: B x 1 ''' # exciter phase f0 = upsample(f0_frames, self.block_size) if infer: x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) else: x = torch.cumsum(f0 / self.sampling_rate, axis=1) if initial_phase is not None: x += initial_phase.to(x) / 2 / np.pi x = x - torch.round(x) x = x.to(f0) phase_frames = 2 * np.pi * x[:, ::self.block_size, :] # parameter prediction ctrls, hidden = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict) group_delay = np.pi * torch.tanh(ctrls['group_delay']) src_param = torch.exp(ctrls['harmonic_magnitude']) noise_param = torch.exp(ctrls['noise_magnitude']) / 128 # combtooth exciter signal combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) combtooth = combtooth.squeeze(-1) # harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction) harmonic = frequency_filter( combtooth, torch.exp(1.j * torch.cumsum(group_delay, axis = -1)), hann_window = False) harmonic = frequency_filter( harmonic, torch.complex(src_param, torch.zeros_like(src_param)), hann_window = True, half_width_frames = 1.5 * self.sampling_rate / (f0_frames + 1e-3)) # noise part filter (using constant-windowed LTV-FIR, without group-delay) noise = torch.rand_like(harmonic) * 2 - 1 noise = frequency_filter( noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window = True) signal = harmonic + noise return signal, hidden, (harmonic, noise) class Units_Encoder: def __init__(self, encoder, encoder_ckpt, encoder_sample_rate = 16000, encoder_hop_size = 320, device = None, cnhubertsoft_gate=10): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device is_loaded_encoder = False if encoder == 'hubertsoft': self.model = Audio2HubertSoft(encoder_ckpt).to(device) is_loaded_encoder = True if encoder == 'hubertbase': self.model = Audio2HubertBase(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'hubertbase768': self.model = Audio2HubertBase768(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'hubertbase768l12': self.model = Audio2HubertBase768L12(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'hubertlarge1024l24': self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'contentvec': self.model = Audio2ContentVec(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'contentvec768': self.model = Audio2ContentVec768(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'spin': self.model = Audio2Spin(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'contentvec768l12': self.model = Audio2ContentVec768L12(encoder_ckpt, device=device) is_loaded_encoder = True if encoder == 'cnhubertsoftfish': self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate) is_loaded_encoder = True if not is_loaded_encoder: raise ValueError(f" [x] Unknown units encoder: {encoder}") self.resample_kernel = {} self.encoder_sample_rate = encoder_sample_rate self.encoder_hop_size = encoder_hop_size def encode(self, audio, # B, T sample_rate, hop_size): # resample if sample_rate == self.encoder_sample_rate: audio_res = audio else: key_str = str(sample_rate) if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device) audio_res = self.resample_kernel[key_str](audio) # encode if audio_res.size(-1) < 400: audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1))) units = self.model(audio_res) # alignment n_frames = audio.size(-1) // hop_size + 1 ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate) index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1) units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)])) return units_aligned def batch_encode(self, audio, # B, T sample_rate, hop_size): units_aligned_batch = [] for i in range(audio.size(0)): audio # resample if sample_rate == self.encoder_sample_rate: audio_res = audio[i] else: key_str = str(sample_rate) if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device) audio_res = self.resample_kernel[key_str](audio[i]) # encode if audio_res.size(-1) < 400: audio_res = torch.nn.functional.pad(audio[i], (0, 400 - audio_res.size(-1))) units = self.model(audio_res) # alignment n_frames = audio.size(-1) // hop_size + 1 ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate) index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1) units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)])) units_aligned_batch.append(units_aligned.squeeze(0)) return torch.stack(units_aligned_batch, 0) # from list of tensor to tensor class Audio2HubertSoft(torch.nn.Module): def __init__(self, path, h_sample_rate = 16000, h_hop_size = 320): super().__init__() print(' [Encoder Model] HuBERT Soft') self.hubert = HubertSoft() print(' [Loading] ' + path) checkpoint = torch.load(path) consume_prefix_in_state_dict_if_present(checkpoint, "module.") self.hubert.load_state_dict(checkpoint) self.hubert.eval() def forward(self, audio): # B, T with torch.inference_mode(): units = self.hubert.units(audio.unsqueeze(1)) return units class Audio2ContentVec(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] Content Vec') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert.eval() def __call__(self, audio): # B, T # wav_tensor = torch.from_numpy(audio).to(self.device) wav_tensor = audio feats = wav_tensor.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) inputs = { "source": feats.to(wav_tensor.device), "padding_mask": padding_mask.to(wav_tensor.device), "output_layer": 9, # layer 9 } with torch.no_grad(): logits = self.hubert.extract_features(**inputs) feats = self.hubert.final_proj(logits[0]) units = feats # .transpose(2, 1) return units class Audio2ContentVec768(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] Content Vec') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert.eval() def __call__(self, audio): # B, T # wav_tensor = torch.from_numpy(audio).to(self.device) wav_tensor = audio print('wav_tensor.shape: ', wav_tensor.shape) feats = wav_tensor.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) inputs = { "source": feats.to(wav_tensor.device), "padding_mask": padding_mask.to(wav_tensor.device), "output_layer": 9, # layer 9 } with torch.no_grad(): logits = self.hubert.extract_features(**inputs) feats = logits[0] units = feats # .transpose(2, 1) return units class Audio2ContentVec768L12(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] Content Vec') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert.eval() def __call__(self, audio): # B, T # wav_tensor = torch.from_numpy(audio).to(self.device) wav_tensor = audio feats = wav_tensor.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) inputs = { "source": feats.to(wav_tensor.device), "padding_mask": padding_mask.to(wav_tensor.device), "output_layer": 12, # layer 12 } with torch.no_grad(): logits = self.hubert.extract_features(**inputs) feats = logits[0] units = feats # .transpose(2, 1) return units class Audio2Spin(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] Spin') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert.eval() def __call__(self, audio): # B, T # wav_tensor = torch.from_numpy(audio).to(self.device) wav_tensor = audio feats = wav_tensor.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) inputs = { "source": feats.to(wav_tensor.device), "padding_mask": padding_mask.to(wav_tensor.device), "output_layer": 12, # layer 12 } with torch.no_grad(): logits = self.hubert.extract_features(**inputs) feats = logits[0] units = feats # .transpose(2, 1) return units class CNHubertSoftFish(torch.nn.Module): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu', gate_size=10): super().__init__() self.device = device self.gate_size = gate_size self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( "./pretrain/TencentGameMate/chinese-hubert-base") self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base") self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256)) # self.label_embedding = nn.Embedding(128, 256) state_dict = torch.load(path, map_location=device) self.load_state_dict(state_dict) @torch.no_grad() def forward(self, audio): input_values = self.feature_extractor( audio, sampling_rate=16000, return_tensors="pt" ).input_values input_values = input_values.to(self.model.device) return self._forward(input_values[0]) @torch.no_grad() def _forward(self, input_values): features = self.model(input_values) features = self.proj(features.last_hidden_state) # Top-k gating topk, indices = torch.topk(features, self.gate_size, dim=2) features = torch.zeros_like(features).scatter(2, indices, topk) features = features / features.sum(2, keepdim=True) return features.to(self.device) # .transpose(1, 2) class Audio2HubertBase(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] HuBERT Base') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert = self.hubert.float() self.hubert.eval() def __call__(self, audio): # B, T with torch.no_grad(): padding_mask = torch.BoolTensor(audio.shape).fill_(False) inputs = { "source": audio.to(self.device), "padding_mask": padding_mask.to(self.device), "output_layer": 9, # layer 9 } logits = self.hubert.extract_features(**inputs) units = self.hubert.final_proj(logits[0]) return units class Audio2HubertBase768(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] HuBERT Base') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert = self.hubert.float() self.hubert.eval() def __call__(self, audio): # B, T with torch.no_grad(): padding_mask = torch.BoolTensor(audio.shape).fill_(False) inputs = { "source": audio.to(self.device), "padding_mask": padding_mask.to(self.device), "output_layer": 9, # layer 9 } logits = self.hubert.extract_features(**inputs) units = logits[0] return units class Audio2HubertBase768L12(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] HuBERT Base') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert = self.hubert.float() self.hubert.eval() def __call__(self, audio): # B, T with torch.no_grad(): padding_mask = torch.BoolTensor(audio.shape).fill_(False) inputs = { "source": audio.to(self.device), "padding_mask": padding_mask.to(self.device), "output_layer": 12, # layer 12 } logits = self.hubert.extract_features(**inputs) units = logits[0] return units class Audio2HubertLarge1024L24(): def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): self.device = device print(' [Encoder Model] HuBERT Base') print(' [Loading] ' + path) self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) self.hubert = self.models[0] self.hubert = self.hubert.to(self.device) self.hubert = self.hubert.float() self.hubert.eval() def __call__(self, audio): # B, T with torch.no_grad(): padding_mask = torch.BoolTensor(audio.shape).fill_(False) inputs = { "source": audio.to(self.device), "padding_mask": padding_mask.to(self.device), "output_layer": 24, # layer 24 } logits = self.hubert.extract_features(**inputs) units = logits[0] return units