Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| from .hubert import HuBERT | |
| from .semantic_tokenizer import SemanticVQVAE | |
| class SemanticTokenizer: | |
| def __init__(self, config, path): | |
| self.model = SemanticVQVAE(**config) | |
| self.model.load_state_dict( | |
| torch.load(os.path.join(path, "codec.bin"), map_location="cpu"), strict=True | |
| ) | |
| hubert = HuBERT(os.path.join(path, "hubert.pt")) | |
| for name, param in hubert.named_parameters(): | |
| param.requires_grad = False | |
| self.model.ssl_extractor = hubert | |
| if torch.cuda.is_available(): | |
| self.model = self.model.cuda() | |
| self.model.eval() | |
| def __call__(self, wavs, wav_lengths): | |
| tokens, token_lengths, spk_embeddings = self.extract(wavs, wav_lengths) | |
| return tokens, token_lengths, spk_embeddings | |
| def extract(self, wavs, wav_lengths): | |
| saved_features = self.model.extract_speech_tokens(wavs, wav_lengths) | |
| tokens = saved_features["token"] | |
| token_lengths = saved_features["token_length"] | |
| spk_embeddings = saved_features["spk"] | |
| return tokens, token_lengths, spk_embeddings | |