Spaces:
Running
Running
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from functools import cached_property | |
| from torch.nn.utils.rnn import pad_sequence | |
| sys.path.append(os.getcwd()) | |
| from main.library.speaker_diarization.speechbrain import EncoderClassifier | |
| class BaseInference: | |
| pass | |
| class SpeechBrainPretrainedSpeakerEmbedding(BaseInference): | |
| def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None): | |
| super().__init__() | |
| self.embedding = embedding | |
| self.device = device or torch.device("cpu") | |
| self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device}) | |
| def to(self, device): | |
| if not isinstance(device, torch.device): raise TypeError | |
| self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device}) | |
| self.device = device | |
| return self | |
| def sample_rate(self): | |
| return self.classifier_.audio_normalizer.sample_rate | |
| def dimension(self): | |
| *_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape | |
| return dimension | |
| def metric(self): | |
| return "cosine" | |
| def min_num_samples(self): | |
| with torch.inference_mode(): | |
| lower, upper = 2, round(0.5 * self.sample_rate) | |
| middle = (lower + upper) // 2 | |
| while lower + 1 < upper: | |
| try: | |
| _ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device)) | |
| upper = middle | |
| except RuntimeError: | |
| lower = middle | |
| middle = (lower + upper) // 2 | |
| return upper | |
| def __call__(self, waveforms, masks = None): | |
| batch_size, num_channels, num_samples = waveforms.shape | |
| assert num_channels == 1 | |
| waveforms = waveforms.squeeze(dim=1) | |
| if masks is None: | |
| signals = waveforms.squeeze(dim=1) | |
| wav_lens = signals.shape[1] * torch.ones(batch_size) | |
| else: | |
| batch_size_masks, _ = masks.shape | |
| assert batch_size == batch_size_masks | |
| imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5 | |
| signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True) | |
| wav_lens = imasks.sum(dim=1) | |
| max_len = wav_lens.max() | |
| if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension)) | |
| too_short = wav_lens < self.min_num_samples | |
| wav_lens = wav_lens / max_len | |
| wav_lens[too_short] = 1.0 | |
| embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy()) | |
| embeddings[too_short.cpu().numpy()] = np.nan | |
| return embeddings |