| import torch, torchaudio |
| from .hubert.hubert import HubertSoft |
| from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
| import librosa |
|
|
|
|
| def get_soft_model(model_path): |
| hubert = HubertSoft() |
| |
| |
| checkpoint = torch.load(model_path) |
| consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.") |
| hubert.load_state_dict(checkpoint["hubert"]) |
| hubert.eval() |
| return hubert |
|
|
|
|
| @torch.no_grad() |
| def get_hubert_soft_content(hmodel, wav_16k_tensor, device='cuda'): |
| wav_16k_tensor = wav_16k_tensor.to(device).unsqueeze(1) |
| |
| units = hmodel.units(wav_16k_tensor) |
| |
| return units.cpu() |