|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Speaker diarization pipelines""" |
|
|
|
|
|
import functools |
|
|
import itertools |
|
|
import math |
|
|
import textwrap |
|
|
import warnings |
|
|
import numpy as np |
|
|
|
|
|
from typing import Callable, Optional, Text, Union, Mapping |
|
|
from pathlib import Path |
|
|
|
|
|
from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature |
|
|
from pyannote_audio_utils.pipeline.parameter import ParamDict, Uniform |
|
|
from pyannote_audio_utils.audio import Audio, Inference, Pipeline |
|
|
from pyannote_audio_utils.audio.core.io import AudioFile |
|
|
from pyannote_audio_utils.audio.pipelines.clustering import Clustering |
|
|
from pyannote_audio_utils.audio.pipelines.speaker_verification import ONNXWeSpeakerPretrainedSpeakerEmbedding |
|
|
from pyannote_audio_utils.audio.pipelines.utils import SpeakerDiarizationMixin |
|
|
|
|
|
AudioFile = Union[Text, Path, Mapping] |
|
|
PipelineModel = Union[Text, Mapping] |
|
|
|
|
|
def batchify(iterable, batch_size: int = 32, fillvalue=None): |
|
|
"""Batchify iterable""" |
|
|
|
|
|
args = [iter(iterable)] * batch_size |
|
|
return itertools.zip_longest(*args, fillvalue=fillvalue) |
|
|
|
|
|
|
|
|
class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline): |
|
|
"""Speaker diarization pipeline |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
segmentation : Model, str, or dict, optional |
|
|
Pretrained segmentation model. Defaults to "pyannote_audio_utils/segmentation@2022.07". |
|
|
See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format. |
|
|
segmentation_step: float, optional |
|
|
The segmentation model is applied on a window sliding over the whole audio file. |
|
|
`segmentation_step` controls the step of this window, provided as a ratio of its |
|
|
duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). |
|
|
embedding : Model, str, or dict, optional |
|
|
Pretrained embedding model. Defaults to "pyannote_audio_utils/embedding@2022.07". |
|
|
See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format. |
|
|
embedding_exclude_overlap : bool, optional |
|
|
Exclude overlapping speech regions when extracting embeddings. |
|
|
Defaults (False) to use the whole speech. |
|
|
clustering : str, optional |
|
|
Clustering algorithm. See pyannote_audio_utils.audio.pipelines.clustering.Clustering |
|
|
for available options. Defaults to "AgglomerativeClustering". |
|
|
segmentation_batch_size : int, optional |
|
|
Batch size used for speaker segmentation. Defaults to 1. |
|
|
embedding_batch_size : int, optional |
|
|
Batch size used for speaker embedding. Defaults to 1. |
|
|
der_variant : dict, optional |
|
|
Optimize for a variant of diarization error rate. |
|
|
Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric` |
|
|
when instantiating the metric: GreedyDiarizationErrorRate(**der_variant). |
|
|
use_auth_token : str, optional |
|
|
When loading private huggingface.co models, set `use_auth_token` |
|
|
to True or to a string containing your hugginface.co authentication |
|
|
token that can be obtained by running `huggingface-cli login` |
|
|
|
|
|
Usage |
|
|
----- |
|
|
# perform (unconstrained) diarization |
|
|
>>> diarization = pipeline("/path/to/audio.wav") |
|
|
|
|
|
# perform diarization, targetting exactly 4 speakers |
|
|
>>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) |
|
|
|
|
|
# perform diarization, with at least 2 speakers and at most 10 speakers |
|
|
>>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) |
|
|
|
|
|
# perform diarization and get one representative embedding per speaker |
|
|
>>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) |
|
|
>>> for s, speaker in enumerate(diarization.labels()): |
|
|
... # embeddings[s] is the embedding of speaker `speaker` |
|
|
|
|
|
Hyper-parameters |
|
|
---------------- |
|
|
segmentation.threshold |
|
|
segmentation.min_duration_off |
|
|
clustering.??? |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
segmentation: PipelineModel = "pyannote_audio_utils/segmentation@2022.07", |
|
|
segmentation_step: float = 0.1, |
|
|
embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e", |
|
|
embedding_exclude_overlap: bool = False, |
|
|
clustering: str = "AgglomerativeClustering", |
|
|
embedding_batch_size: int = 1, |
|
|
segmentation_batch_size: int = 1, |
|
|
args = None, |
|
|
seg_path = None, |
|
|
emb_path = None, |
|
|
der_variant: dict = None, |
|
|
use_auth_token: Union[Text, None] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
model = segmentation |
|
|
self.segmentation_step = segmentation_step |
|
|
self.embedding = embedding |
|
|
self.embedding_batch_size = embedding_batch_size |
|
|
self.embedding_exclude_overlap = embedding_exclude_overlap |
|
|
self.klustering = clustering |
|
|
self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False} |
|
|
|
|
|
segmentation_duration = 10.0 |
|
|
|
|
|
self._segmentation = Inference( |
|
|
model, |
|
|
duration=segmentation_duration, |
|
|
step=self.segmentation_step * segmentation_duration, |
|
|
skip_aggregation=True, |
|
|
batch_size=segmentation_batch_size, |
|
|
args=args, |
|
|
seg_path=seg_path |
|
|
) |
|
|
|
|
|
self._frames: SlidingWindow = self._segmentation.example_output.frames |
|
|
|
|
|
self.segmentation = ParamDict( |
|
|
min_duration_off=Uniform(0.0, 1.0), |
|
|
) |
|
|
|
|
|
self._embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding( |
|
|
self.embedding, |
|
|
args=args, |
|
|
emb_path=emb_path |
|
|
) |
|
|
self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix") |
|
|
|
|
|
metric = self._embedding.metric |
|
|
Klustering = Clustering[clustering] |
|
|
|
|
|
self.clustering = Klustering.value(metric=metric) |
|
|
|
|
|
|
|
|
def get_segmentations(self, file, hook=None) -> SlidingWindowFeature: |
|
|
"""Apply segmentation model |
|
|
|
|
|
Parameter |
|
|
--------- |
|
|
file : AudioFile |
|
|
hook : Optional[Callable] |
|
|
|
|
|
Returns |
|
|
------- |
|
|
segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature |
|
|
""" |
|
|
|
|
|
if hook is not None: |
|
|
hook = functools.partial(hook, "segmentation", None) |
|
|
segmentations: SlidingWindowFeature = self._segmentation(file, hook=hook) |
|
|
|
|
|
return segmentations |
|
|
|
|
|
def get_embeddings( |
|
|
self, |
|
|
file, |
|
|
binary_segmentations: SlidingWindowFeature, |
|
|
exclude_overlap: bool = False, |
|
|
hook: Optional[Callable] = None, |
|
|
): |
|
|
"""Extract embeddings for each (chunk, speaker) pair |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
file : AudioFile |
|
|
binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature |
|
|
Binarized segmentation. |
|
|
exclude_overlap : bool, optional |
|
|
Exclude overlapping speech regions when extracting embeddings. |
|
|
In case non-overlapping speech is too short, use the whole speech. |
|
|
hook: Optional[Callable] |
|
|
Called during embeddings after every batch to report the progress |
|
|
|
|
|
Returns |
|
|
------- |
|
|
embeddings : (num_chunks, num_speakers, dimension) array |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
duration = binary_segmentations.sliding_window.duration |
|
|
num_chunks, num_frames, num_speakers = binary_segmentations.data.shape |
|
|
|
|
|
if exclude_overlap: |
|
|
|
|
|
|
|
|
|
|
|
min_num_samples = self._embedding.min_num_samples |
|
|
|
|
|
|
|
|
num_samples = duration * self._embedding.sample_rate |
|
|
min_num_frames = math.ceil(num_frames * min_num_samples / num_samples) |
|
|
|
|
|
|
|
|
clean_frames = 1.0 * ( |
|
|
np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2 |
|
|
) |
|
|
clean_segmentations = SlidingWindowFeature( |
|
|
binary_segmentations.data * clean_frames, |
|
|
binary_segmentations.sliding_window, |
|
|
) |
|
|
|
|
|
else: |
|
|
min_num_frames = -1 |
|
|
clean_segmentations = SlidingWindowFeature( |
|
|
binary_segmentations.data, binary_segmentations.sliding_window |
|
|
) |
|
|
|
|
|
def iter_waveform_and_mask(): |
|
|
for (chunk, masks), (_, clean_masks) in zip(binary_segmentations, clean_segmentations): |
|
|
|
|
|
|
|
|
|
|
|
waveform, _ = self._audio.crop( |
|
|
file, |
|
|
chunk, |
|
|
duration=duration, |
|
|
mode="pad", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
masks = np.nan_to_num(masks, nan=0.0).astype(np.float32) |
|
|
clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32) |
|
|
|
|
|
for mask, clean_mask in zip(masks.T, clean_masks.T): |
|
|
|
|
|
|
|
|
if np.sum(clean_mask) > min_num_frames: |
|
|
used_mask = clean_mask |
|
|
else: |
|
|
used_mask = mask |
|
|
|
|
|
|
|
|
yield waveform[None], used_mask[None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batches = batchify( |
|
|
iter_waveform_and_mask(), |
|
|
batch_size=self.embedding_batch_size, |
|
|
fillvalue=(None, None), |
|
|
) |
|
|
|
|
|
|
|
|
batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size) |
|
|
|
|
|
embedding_batches = [] |
|
|
|
|
|
if hook is not None: |
|
|
hook("embeddings", None, total=batch_count, completed=0) |
|
|
|
|
|
for i, batch in enumerate(batches, 1): |
|
|
waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch)) |
|
|
|
|
|
waveform_batch = np.vstack(waveforms) |
|
|
|
|
|
|
|
|
mask_batch = np.vstack(masks) |
|
|
|
|
|
|
|
|
embedding_batch: np.ndarray = self._embedding( |
|
|
waveform_batch, masks=mask_batch |
|
|
) |
|
|
|
|
|
|
|
|
embedding_batches.append(embedding_batch) |
|
|
|
|
|
if hook is not None: |
|
|
hook("embeddings", embedding_batch, total=batch_count, completed=i) |
|
|
|
|
|
embedding_batches = np.vstack(embedding_batches) |
|
|
embeddings = embedding_batches.reshape([num_chunks, -1 , embedding_batches.shape[-1]]) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def reconstruct( |
|
|
self, |
|
|
segmentations: SlidingWindowFeature, |
|
|
hard_clusters: np.ndarray, |
|
|
count: SlidingWindowFeature, |
|
|
) -> SlidingWindowFeature: |
|
|
"""Build final discrete diarization out of clustered segmentation |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature |
|
|
Raw speaker segmentation. |
|
|
hard_clusters : (num_chunks, num_speakers) array |
|
|
Output of clustering step. |
|
|
count : (total_num_frames, 1) SlidingWindowFeature |
|
|
Instantaneous number of active speakers. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
discrete_diarization : SlidingWindowFeature |
|
|
Discrete (0s and 1s) diarization. |
|
|
""" |
|
|
|
|
|
num_chunks, num_frames, local_num_speakers = segmentations.data.shape |
|
|
|
|
|
num_clusters = np.max(hard_clusters) + 1 |
|
|
clustered_segmentations = np.NAN * np.zeros( |
|
|
(num_chunks, num_frames, num_clusters) |
|
|
) |
|
|
|
|
|
for c, (cluster, (chunk, segmentation)) in enumerate( |
|
|
zip(hard_clusters, segmentations) |
|
|
): |
|
|
|
|
|
|
|
|
for k in np.unique(cluster): |
|
|
if k == -2: |
|
|
continue |
|
|
|
|
|
|
|
|
clustered_segmentations[c, :, k] = np.max( |
|
|
segmentation[:, cluster == k], axis=1 |
|
|
) |
|
|
|
|
|
clustered_segmentations = SlidingWindowFeature( |
|
|
clustered_segmentations, segmentations.sliding_window |
|
|
) |
|
|
|
|
|
return self.to_diarization(clustered_segmentations, count) |
|
|
|
|
|
def apply( |
|
|
self, |
|
|
file: AudioFile, |
|
|
num_speakers: int = None, |
|
|
min_speakers: int = None, |
|
|
max_speakers: int = None, |
|
|
return_embeddings: bool = False, |
|
|
hook: Optional[Callable] = None, |
|
|
) -> Annotation: |
|
|
"""Apply speaker diarization |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
file : AudioFile |
|
|
Processed file. |
|
|
num_speakers : int, optional |
|
|
Number of speakers, when known. |
|
|
min_speakers : int, optional |
|
|
Minimum number of speakers. Has no effect when `num_speakers` is provided. |
|
|
max_speakers : int, optional |
|
|
Maximum number of speakers. Has no effect when `num_speakers` is provided. |
|
|
return_embeddings : bool, optional |
|
|
Return representative speaker embeddings. |
|
|
hook : callable, optional |
|
|
Callback called after each major steps of the pipeline as follows: |
|
|
hook(step_name, # human-readable name of current step |
|
|
step_artefact, # artifact generated by current step |
|
|
file=file) # file being processed |
|
|
Time-consuming steps call `hook` multiple times with the same `step_name` |
|
|
and additional `completed` and `total` keyword arguments usable to track |
|
|
progress of current step. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
diarization : Annotation |
|
|
Speaker diarization |
|
|
embeddings : np.array, optional |
|
|
Representative speaker embeddings such that `embeddings[i]` is the |
|
|
speaker embedding for i-th speaker in diarization.labels(). |
|
|
Only returned when `return_embeddings` is True. |
|
|
""" |
|
|
|
|
|
|
|
|
hook = self.setup_hook(file, hook=hook) |
|
|
|
|
|
num_speakers, min_speakers, max_speakers = self.set_num_speakers( |
|
|
num_speakers=num_speakers, |
|
|
min_speakers=min_speakers, |
|
|
max_speakers=max_speakers, |
|
|
) |
|
|
|
|
|
segmentations = self.get_segmentations(file, hook=hook) |
|
|
hook("segmentation", segmentations) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
binarized_segmentations = segmentations |
|
|
|
|
|
count = self.speaker_count( |
|
|
binarized_segmentations, |
|
|
frames=self._frames, |
|
|
warm_up=(0.0, 0.0), |
|
|
) |
|
|
hook("speaker_counting", count) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if np.nanmax(count.data) == 0.0: |
|
|
diarization = Annotation(uri=file["uri"]) |
|
|
if return_embeddings: |
|
|
return diarization, np.zeros((0, self._embedding.dimension)) |
|
|
|
|
|
return diarization |
|
|
|
|
|
embeddings = self.get_embeddings( |
|
|
file, |
|
|
binarized_segmentations, |
|
|
exclude_overlap=self.embedding_exclude_overlap, |
|
|
hook=hook, |
|
|
) |
|
|
|
|
|
hook("embeddings", embeddings) |
|
|
|
|
|
|
|
|
hard_clusters, _, centroids = self.clustering( |
|
|
embeddings=embeddings, |
|
|
segmentations=binarized_segmentations, |
|
|
num_clusters=num_speakers, |
|
|
min_clusters=min_speakers, |
|
|
max_clusters=max_speakers, |
|
|
file=file, |
|
|
frames=self._frames, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_different_speakers = np.max(hard_clusters) + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_different_speakers < min_speakers or num_different_speakers > max_speakers: |
|
|
warnings.warn(textwrap.dedent( |
|
|
f""" |
|
|
The detected number of speakers ({num_different_speakers}) is outside |
|
|
the given bounds [{min_speakers}, {max_speakers}]. This can happen if the |
|
|
given audio file is too short to contain {min_speakers} or more speakers. |
|
|
Try to lower the desired minimal number of speakers. |
|
|
""" |
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
count.data = np.minimum(count.data, max_speakers).astype(np.int8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 |
|
|
|
|
|
|
|
|
hard_clusters[inactive_speakers] = -2 |
|
|
discrete_diarization = self.reconstruct( |
|
|
segmentations, |
|
|
hard_clusters, |
|
|
count, |
|
|
) |
|
|
hook("discrete_diarization", discrete_diarization) |
|
|
|
|
|
|
|
|
diarization = self.to_annotation( |
|
|
discrete_diarization, |
|
|
min_duration_on=0.0, |
|
|
min_duration_off=self.segmentation.min_duration_off, |
|
|
) |
|
|
diarization.uri = file["uri"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "annotation" in file and file["annotation"]: |
|
|
|
|
|
|
|
|
|
|
|
_, mapping = self.optimal_mapping( |
|
|
file["annotation"], diarization, return_mapping=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mapping = {key: mapping.get(key, key) for key in diarization.labels()} |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
mapping = { |
|
|
label: expected_label |
|
|
for label, expected_label in zip(diarization.labels(), self.classes()) |
|
|
} |
|
|
|
|
|
diarization = diarization.rename_labels(mapping=mapping) |
|
|
|
|
|
|
|
|
|
|
|
if not return_embeddings: |
|
|
return diarization |
|
|
|
|
|
|
|
|
if centroids is None: |
|
|
return diarization, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(diarization.labels()) > centroids.shape[0]: |
|
|
centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0))) |
|
|
|
|
|
|
|
|
|
|
|
inverse_mapping = {label: index for index, label in mapping.items()} |
|
|
centroids = centroids[ |
|
|
[inverse_mapping[label] for label in diarization.labels()] |
|
|
] |
|
|
|
|
|
return diarization, centroids |
|
|
|