# The MIT License (MIT) # # Copyright (c) 2021- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """Speaker diarization pipelines""" import functools import itertools import math import textwrap import warnings import numpy as np from typing import Callable, Optional, Text, Union, Mapping from pathlib import Path from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature from pyannote_audio_utils.pipeline.parameter import ParamDict, Uniform from pyannote_audio_utils.audio import Audio, Inference, Pipeline from pyannote_audio_utils.audio.core.io import AudioFile from pyannote_audio_utils.audio.pipelines.clustering import Clustering from pyannote_audio_utils.audio.pipelines.speaker_verification import ONNXWeSpeakerPretrainedSpeakerEmbedding from pyannote_audio_utils.audio.pipelines.utils import SpeakerDiarizationMixin AudioFile = Union[Text, Path, Mapping] PipelineModel = Union[Text, Mapping] def batchify(iterable, batch_size: int = 32, fillvalue=None): """Batchify iterable""" # batchify('ABCDEFG', 3) --> ['A', 'B', 'C'] ['D', 'E', 'F'] [G, ] args = [iter(iterable)] * batch_size return itertools.zip_longest(*args, fillvalue=fillvalue) class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline): """Speaker diarization pipeline Parameters ---------- segmentation : Model, str, or dict, optional Pretrained segmentation model. Defaults to "pyannote_audio_utils/segmentation@2022.07". See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format. segmentation_step: float, optional The segmentation model is applied on a window sliding over the whole audio file. `segmentation_step` controls the step of this window, provided as a ratio of its duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). embedding : Model, str, or dict, optional Pretrained embedding model. Defaults to "pyannote_audio_utils/embedding@2022.07". See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format. embedding_exclude_overlap : bool, optional Exclude overlapping speech regions when extracting embeddings. Defaults (False) to use the whole speech. clustering : str, optional Clustering algorithm. See pyannote_audio_utils.audio.pipelines.clustering.Clustering for available options. Defaults to "AgglomerativeClustering". segmentation_batch_size : int, optional Batch size used for speaker segmentation. Defaults to 1. embedding_batch_size : int, optional Batch size used for speaker embedding. Defaults to 1. der_variant : dict, optional Optimize for a variant of diarization error rate. Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric` when instantiating the metric: GreedyDiarizationErrorRate(**der_variant). use_auth_token : str, optional When loading private huggingface.co models, set `use_auth_token` to True or to a string containing your hugginface.co authentication token that can be obtained by running `huggingface-cli login` Usage ----- # perform (unconstrained) diarization >>> diarization = pipeline("/path/to/audio.wav") # perform diarization, targetting exactly 4 speakers >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) # perform diarization, with at least 2 speakers and at most 10 speakers >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) # perform diarization and get one representative embedding per speaker >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) >>> for s, speaker in enumerate(diarization.labels()): ... # embeddings[s] is the embedding of speaker `speaker` Hyper-parameters ---------------- segmentation.threshold segmentation.min_duration_off clustering.??? """ def __init__( self, segmentation: PipelineModel = "pyannote_audio_utils/segmentation@2022.07", segmentation_step: float = 0.1, embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e", embedding_exclude_overlap: bool = False, clustering: str = "AgglomerativeClustering", embedding_batch_size: int = 1, segmentation_batch_size: int = 1, args = None, seg_path = None, emb_path = None, der_variant: dict = None, use_auth_token: Union[Text, None] = None, ): super().__init__() model = segmentation self.segmentation_step = segmentation_step self.embedding = embedding self.embedding_batch_size = embedding_batch_size self.embedding_exclude_overlap = embedding_exclude_overlap self.klustering = clustering self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False} segmentation_duration = 10.0 self._segmentation = Inference( model, duration=segmentation_duration, step=self.segmentation_step * segmentation_duration, skip_aggregation=True, batch_size=segmentation_batch_size, args=args, seg_path=seg_path ) self._frames: SlidingWindow = self._segmentation.example_output.frames self.segmentation = ParamDict( min_duration_off=Uniform(0.0, 1.0), ) self._embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding( self.embedding, args=args, emb_path=emb_path ) self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix") metric = self._embedding.metric Klustering = Clustering[clustering] self.clustering = Klustering.value(metric=metric) def get_segmentations(self, file, hook=None) -> SlidingWindowFeature: """Apply segmentation model Parameter --------- file : AudioFile hook : Optional[Callable] Returns ------- segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature """ if hook is not None: hook = functools.partial(hook, "segmentation", None) segmentations: SlidingWindowFeature = self._segmentation(file, hook=hook) return segmentations def get_embeddings( self, file, binary_segmentations: SlidingWindowFeature, exclude_overlap: bool = False, hook: Optional[Callable] = None, ): """Extract embeddings for each (chunk, speaker) pair Parameters ---------- file : AudioFile binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature Binarized segmentation. exclude_overlap : bool, optional Exclude overlapping speech regions when extracting embeddings. In case non-overlapping speech is too short, use the whole speech. hook: Optional[Callable] Called during embeddings after every batch to report the progress Returns ------- embeddings : (num_chunks, num_speakers, dimension) array """ # when optimizing the hyper-parameters of this pipeline with frozen # "segmentation.threshold", one can reuse the embeddings from the first trial, # bringing a massive speed up to the optimization process (and hence allowing to use # a larger search space). duration = binary_segmentations.sliding_window.duration num_chunks, num_frames, num_speakers = binary_segmentations.data.shape if exclude_overlap: # minimum number of samples needed to extract an embedding # (a lower number of samples would result in an error) min_num_samples = self._embedding.min_num_samples # corresponding minimum number of frames num_samples = duration * self._embedding.sample_rate min_num_frames = math.ceil(num_frames * min_num_samples / num_samples) # zero-out frames with overlapping speech clean_frames = 1.0 * ( np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2 ) clean_segmentations = SlidingWindowFeature( binary_segmentations.data * clean_frames, binary_segmentations.sliding_window, ) else: min_num_frames = -1 clean_segmentations = SlidingWindowFeature( binary_segmentations.data, binary_segmentations.sliding_window ) def iter_waveform_and_mask(): for (chunk, masks), (_, clean_masks) in zip(binary_segmentations, clean_segmentations): # chunk: Segment(t, t + duration) # masks: (num_frames, local_num_speakers) np.ndarray waveform, _ = self._audio.crop( file, chunk, duration=duration, mode="pad", ) # waveform: (1, num_samples) torch.Tensor # mask may contain NaN (in case of partial stitching) masks = np.nan_to_num(masks, nan=0.0).astype(np.float32) clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32) for mask, clean_mask in zip(masks.T, clean_masks.T): # mask: (num_frames, ) np.ndarray if np.sum(clean_mask) > min_num_frames: used_mask = clean_mask else: used_mask = mask # yield waveform[None], torch.from_numpy(used_mask)[None] yield waveform[None], used_mask[None] # w: (1, 1, num_samples) torch.Tensor # m: (1, num_frames) torch.Tensor batches = batchify( iter_waveform_and_mask(), batch_size=self.embedding_batch_size, fillvalue=(None, None), ) batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size) embedding_batches = [] if hook is not None: hook("embeddings", None, total=batch_count, completed=0) for i, batch in enumerate(batches, 1): waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch)) waveform_batch = np.vstack(waveforms) # (batch_size, 1, num_samples) torch.Tensor mask_batch = np.vstack(masks) # (batch_size, num_frames) torch.Tensor embedding_batch: np.ndarray = self._embedding( waveform_batch, masks=mask_batch ) # (batch_size, dimension) np.ndarray embedding_batches.append(embedding_batch) if hook is not None: hook("embeddings", embedding_batch, total=batch_count, completed=i) embedding_batches = np.vstack(embedding_batches) embeddings = embedding_batches.reshape([num_chunks, -1 , embedding_batches.shape[-1]]) return embeddings def reconstruct( self, segmentations: SlidingWindowFeature, hard_clusters: np.ndarray, count: SlidingWindowFeature, ) -> SlidingWindowFeature: """Build final discrete diarization out of clustered segmentation Parameters ---------- segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature Raw speaker segmentation. hard_clusters : (num_chunks, num_speakers) array Output of clustering step. count : (total_num_frames, 1) SlidingWindowFeature Instantaneous number of active speakers. Returns ------- discrete_diarization : SlidingWindowFeature Discrete (0s and 1s) diarization. """ num_chunks, num_frames, local_num_speakers = segmentations.data.shape num_clusters = np.max(hard_clusters) + 1 clustered_segmentations = np.NAN * np.zeros( (num_chunks, num_frames, num_clusters) ) for c, (cluster, (chunk, segmentation)) in enumerate( zip(hard_clusters, segmentations) ): # cluster is (local_num_speakers, )-shaped # segmentation is (num_frames, local_num_speakers)-shaped for k in np.unique(cluster): if k == -2: continue # TODO: can we do better than this max here? clustered_segmentations[c, :, k] = np.max( segmentation[:, cluster == k], axis=1 ) clustered_segmentations = SlidingWindowFeature( clustered_segmentations, segmentations.sliding_window ) return self.to_diarization(clustered_segmentations, count) def apply( self, file: AudioFile, num_speakers: int = None, min_speakers: int = None, max_speakers: int = None, return_embeddings: bool = False, hook: Optional[Callable] = None, ) -> Annotation: """Apply speaker diarization Parameters ---------- file : AudioFile Processed file. num_speakers : int, optional Number of speakers, when known. min_speakers : int, optional Minimum number of speakers. Has no effect when `num_speakers` is provided. max_speakers : int, optional Maximum number of speakers. Has no effect when `num_speakers` is provided. return_embeddings : bool, optional Return representative speaker embeddings. hook : callable, optional Callback called after each major steps of the pipeline as follows: hook(step_name, # human-readable name of current step step_artefact, # artifact generated by current step file=file) # file being processed Time-consuming steps call `hook` multiple times with the same `step_name` and additional `completed` and `total` keyword arguments usable to track progress of current step. Returns ------- diarization : Annotation Speaker diarization embeddings : np.array, optional Representative speaker embeddings such that `embeddings[i]` is the speaker embedding for i-th speaker in diarization.labels(). Only returned when `return_embeddings` is True. """ # setup hook (e.g. for debugging purposes) hook = self.setup_hook(file, hook=hook) num_speakers, min_speakers, max_speakers = self.set_num_speakers( num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, ) segmentations = self.get_segmentations(file, hook=hook) hook("segmentation", segmentations) # shape: (num_chunks, num_frames, local_num_speakers) # binarize segmentation binarized_segmentations = segmentations # estimate frame-level number of instantaneous speakers count = self.speaker_count( binarized_segmentations, frames=self._frames, warm_up=(0.0, 0.0), ) hook("speaker_counting", count) # shape: (num_frames, 1) # dtype: int # exit early when no speaker is ever active if np.nanmax(count.data) == 0.0: diarization = Annotation(uri=file["uri"]) if return_embeddings: return diarization, np.zeros((0, self._embedding.dimension)) return diarization embeddings = self.get_embeddings( file, binarized_segmentations, exclude_overlap=self.embedding_exclude_overlap, hook=hook, ) hook("embeddings", embeddings) # shape: (num_chunks, local_num_speakers, dimension) hard_clusters, _, centroids = self.clustering( embeddings=embeddings, segmentations=binarized_segmentations, num_clusters=num_speakers, min_clusters=min_speakers, max_clusters=max_speakers, file=file, # <== for oracle clustering frames=self._frames, # <== for oracle clustering ) # hard_clusters: (num_chunks, num_speakers) # centroids: (num_speakers, dimension) # number of detected clusters is the number of different speakers num_different_speakers = np.max(hard_clusters) + 1 # detected number of speakers can still be out of bounds # (specifically, lower than `min_speakers`), since there could be too few embeddings # to make enough clusters with a given minimum cluster size. if num_different_speakers < min_speakers or num_different_speakers > max_speakers: warnings.warn(textwrap.dedent( f""" The detected number of speakers ({num_different_speakers}) is outside the given bounds [{min_speakers}, {max_speakers}]. This can happen if the given audio file is too short to contain {min_speakers} or more speakers. Try to lower the desired minimal number of speakers. """ )) # during counting, we could possibly overcount the number of instantaneous # speakers due to segmentation errors, so we cap the maximum instantaneous number # of speakers by the `max_speakers` value count.data = np.minimum(count.data, max_speakers).astype(np.int8) # reconstruct discrete diarization from raw hard clusters # keep track of inactive speakers inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 # shape: (num_chunks, num_speakers) hard_clusters[inactive_speakers] = -2 discrete_diarization = self.reconstruct( segmentations, hard_clusters, count, ) hook("discrete_diarization", discrete_diarization) # convert to continuous diarization diarization = self.to_annotation( discrete_diarization, min_duration_on=0.0, min_duration_off=self.segmentation.min_duration_off, ) diarization.uri = file["uri"] # at this point, `diarization` speaker labels are integers # from 0 to `num_speakers - 1`, aligned with `centroids` rows. if "annotation" in file and file["annotation"]: # when reference is available, use it to map hypothesized speakers # to reference speakers (this makes later error analysis easier # but does not modify the actual output of the diarization pipeline) _, mapping = self.optimal_mapping( file["annotation"], diarization, return_mapping=True ) # in case there are more speakers in the hypothesis than in # the reference, those extra speakers are missing from `mapping`. # we add them back here mapping = {key: mapping.get(key, key) for key in diarization.labels()} else: # when reference is not available, rename hypothesized speakers # to human-readable SPEAKER_00, SPEAKER_01, ... mapping = { label: expected_label for label, expected_label in zip(diarization.labels(), self.classes()) } diarization = diarization.rename_labels(mapping=mapping) # at this point, `diarization` speaker labels are strings (or mix of # strings and integers when reference is available and some hypothesis # speakers are not present in the reference) if not return_embeddings: return diarization # this can happen when we use OracleClustering if centroids is None: return diarization, None # The number of centroids may be smaller than the number of speakers # in the annotation. This can happen if the number of active speakers # obtained from `speaker_count` for some frames is larger than the number # of clusters obtained from `clustering`. In this case, we append zero embeddings # for extra speakers if len(diarization.labels()) > centroids.shape[0]: centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0))) # re-order centroids so that they match # the order given by diarization.labels() inverse_mapping = {label: index for index, label in mapping.items()} centroids = centroids[ [inverse_mapping[label] for label in diarization.labels()] ] return diarization, centroids