Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,140 Bytes
faadabf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import os
import torch
from .hubert import HuBERT
from .semantic_tokenizer import SemanticVQVAE
class SemanticTokenizer:
def __init__(self, config, path):
self.model = SemanticVQVAE(**config)
self.model.load_state_dict(
torch.load(os.path.join(path, "codec.bin"), map_location="cpu"), strict=True
)
hubert = HuBERT(os.path.join(path, "hubert.pt"))
for name, param in hubert.named_parameters():
param.requires_grad = False
self.model.ssl_extractor = hubert
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model.eval()
def __call__(self, wavs, wav_lengths):
tokens, token_lengths, spk_embeddings = self.extract(wavs, wav_lengths)
return tokens, token_lengths, spk_embeddings
def extract(self, wavs, wav_lengths):
saved_features = self.model.extract_speech_tokens(wavs, wav_lengths)
tokens = saved_features["token"]
token_lengths = saved_features["token_length"]
spk_embeddings = saved_features["spk"]
return tokens, token_lengths, spk_embeddings
|