diff --git a/.gitattributes b/.gitattributes index 586fa25ac6b19cc985c2bbc9ffb07b3de1b97881..9f1b7bf83ea39df24426e4e8e5314b67a3b33e8a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -37,3 +37,5 @@ brouhaha/brouhaha.gif filter=lfs diff=lfs merge=lfs -text separation-ami-1.0/model.png filter=lfs diff=lfs merge=lfs -text speaker-diarization/technical_report_2.1.pdf filter=lfs diff=lfs merge=lfs -text speech-separation-ami-1.0/pipeline.png filter=lfs diff=lfs merge=lfs -text +ailia-models/code/data/sample.wav filter=lfs diff=lfs merge=lfs -text +speaker-diarization-community-1/diarization.gif filter=lfs diff=lfs merge=lfs -text diff --git a/ailia-models/code/LICENSE b/ailia-models/code/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9712e8a75f3b007849124dfb5653863925fdaaf9 --- /dev/null +++ b/ailia-models/code/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 CNRS + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ailia-models/code/README.md b/ailia-models/code/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ac75830eea6583124e9bf03e02b126a7f7d7c2f9 --- /dev/null +++ b/ailia-models/code/README.md @@ -0,0 +1,135 @@ +# Pyannote-audio : Speaker Diarization + +## Input + +Audio file (.wav format). +``` +Example +input: data/demo.wav +``` +(Wav file from https://github.com/pyannote/pyannote-audio/tree/develop/pyannote/audio/sample) + +## Output + +When and who spoke. +![Output](output.png) + +``` +[ 00:00:06.714 --> 00:00:07.003] A speaker91 +[ 00:00:07.003 --> 00:00:07.173] B speaker90 +[ 00:00:07.580 --> 00:00:08.310] C speaker91 +[ 00:00:08.310 --> 00:00:09.923] D speaker90 +[ 00:00:09.923 --> 00:00:10.976] E speaker91 +[ 00:00:10.466 --> 00:00:14.745] F speaker90 +[ 00:00:14.303 --> 00:00:17.886] G speaker91 +[ 00:00:18.022 --> 00:00:21.502] H speaker90 +[ 00:00:18.157 --> 00:00:18.446] I speaker91 +[ 00:00:21.774 --> 00:00:28.531] J speaker91 +[ 00:00:27.886 --> 00:00:29.991] K speaker90 +``` + +## Requirements + +This model recommends additional module. +```bash +$ pip3 install -r requirements.txt +``` + +## Usage + +Automatically downloads the onnx and prototxt files on the first run. +It is necessary to be connected to the Internet while downloading. + +For the sample +```bash +$ python pyannote-audio.py -i ./data/sample.wav +``` + +For the sample with plot +```bash +$ python pyannote-audio.py -i ./data/sample.wav --plt +``` + +For the sample with verification +```bash +$ python pyannote-audio.py -i ./data/sample.wav -g ./data/sample.rttm +``` + +If you want to specify the audio, put the file path after the `--i` or `-input` option. + +```bash +$ python pyannote-audio.py --i FILE_PATH +``` + +If you want to specify the ground truth, put the file path after the `--ig` or `-input_ground` option. + +```bash +$ python pyannote-audio.py --ig FILE_PATH +``` + +If you want to specify the output file, put the file path after the `--o` or `-output` option. + +```bash +$ python pyannote-audio.py --o FILE_PATH +``` + +If you want to specify the output ground truth file, put the file path after the `--og` or `-output_ground` option. + +```bash +$ python pyannote-audio.py --og FILE_PATH +``` + +If you know the number of speakers, put the numper `--num` or `-num_speaker` option. +```bash +$ python pyannote-audio.py --num 2 +``` + +If you know the maxisimum number of speakers, put the numper `--max` or `-max_speaker` option. +```bash +$ python pyannote-audio.py --max 4 +``` + +If you know the minimum number of speakers, put the numper `--min` or `-min_speaker` option. +```bash +$ python pyannote-audio.py --min 2 +``` + +By giving the `--e` or `-error` option, you can get diarization error rate. +```bash +$ python pyannote-audio.py --use_onnx +``` + +By giving the `--plt` option, you can visualize results. +```bash +$ python pyannote-audio.py --use_onnx +``` + +By giving the `--use_onnx` option, you can use onnx. +```bash +$ python pyannote-audio.py --use_onnx +``` + +By giving the `--embed` option, you can get embedding vector in the input file. +```bash +$ python pyannote-audio.py --embed +``` + +## Reference + +- [Pyannote-audio](https://github.com/pyannote/pyannote-audio) +- [Hugging Face - pyannote in speaker-diariazation](https://huggingface.co/pyannote/speaker-diarization-3.1) +- [Hugging Face - hdbrain in wespeaker-voxceleb-resnet34-LM](https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main) +- [KaldiFeat](https://github.com/yuyq96/kaldifeat) + +## Framework + +Pytorch + +## Model Format + +ONNX opset=14,17 + +## Netron + +- [segmentation.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pyannote-audio/segmentation.onnx.prototxt) +- [speaker-embedding.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pyannote-audio/speaker-embedding.onnx.prototxt) diff --git a/ailia-models/code/config.yaml b/ailia-models/code/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcf9752d17a120cf697b15f7bfbc245c0b2d7c6e --- /dev/null +++ b/ailia-models/code/config.yaml @@ -0,0 +1,17 @@ +params: + clustering: + method: centroid + min_cluster_size: 12 + threshold: 0.7045654963945799 + segmentation: + min_duration_off: 0.0 +pipeline: + name: pyannote.audio.pipelines.SpeakerDiarization + params: + clustering: AgglomerativeClustering + embedding: speaker-embedding.onnx + embedding_batch_size: 32 + embedding_exclude_overlap: true + segmentation: segmentation.onnx + segmentation_batch_size: 32 +version: 3.1.0 diff --git a/ailia-models/code/data/sample.rttm b/ailia-models/code/data/sample.rttm new file mode 100644 index 0000000000000000000000000000000000000000..7c6b378febff6a4a0bbd4578669b62c2028ad0a2 --- /dev/null +++ b/ailia-models/code/data/sample.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.690 0.430 speaker90 +SPEAKER sample 1 7.550 0.800 speaker91 +SPEAKER sample 1 8.320 1.700 speaker90 +SPEAKER sample 1 9.920 1.110 speaker91 +SPEAKER sample 1 10.570 4.130 speaker90 +SPEAKER sample 1 14.490 3.430 speaker91 +SPEAKER sample 1 18.050 3.440 speaker90 +SPEAKER sample 1 18.150 0.440 speaker91 +SPEAKER sample 1 21.780 6.720 speaker91 +SPEAKER sample 1 27.850 2.150 speaker90 diff --git a/ailia-models/code/data/sample.wav b/ailia-models/code/data/sample.wav new file mode 100644 index 0000000000000000000000000000000000000000..ffc5ebbf5dff03a587e044cac5443cf5de2a04b3 --- /dev/null +++ b/ailia-models/code/data/sample.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c319b4abca767b124e41432d364fd7df006cb26bb79d09326c487d606a134e6e +size 960104 diff --git a/ailia-models/code/output.png b/ailia-models/code/output.png new file mode 100644 index 0000000000000000000000000000000000000000..92a0b3639936b2c06bbeab129b290bb6987b3759 Binary files /dev/null and b/ailia-models/code/output.png differ diff --git a/ailia-models/code/output_ground.png b/ailia-models/code/output_ground.png new file mode 100644 index 0000000000000000000000000000000000000000..5f848b362cf8a4ab381468c21584d0e983e2f797 Binary files /dev/null and b/ailia-models/code/output_ground.png differ diff --git a/ailia-models/code/pyannote-audio.py b/ailia-models/code/pyannote-audio.py new file mode 100644 index 0000000000000000000000000000000000000000..eee42fc9fb79f7f60d419f78208af3c777729009 --- /dev/null +++ b/ailia-models/code/pyannote-audio.py @@ -0,0 +1,181 @@ +import yaml +import sys +import matplotlib.pyplot as plt +import time + +from pyannote_audio_utils.audio.pipelines.speaker_diarization import SpeakerDiarization +from pyannote_audio_utils.core import Segment, Annotation +from pyannote_audio_utils.core.notebook import Notebook +from pyannote_audio_utils.database.util import load_rttm +from pyannote_audio_utils.metrics.diarization import DiarizationErrorRate + +sys.path.append('../../util') +from arg_utils import get_base_parser, update_parser # noqa: E402 +from model_utils import check_and_download_models # noqa: E402 +from logging import getLogger # noqa: E402 +logger = getLogger(__name__) + +WEIGHT_SEGMENTATION_PATH = 'segmentation.onnx' +MODEL_SEGMENTATION_PATH = 'segmentation.onnx.prototxt' +WEIGHT_EMBEDDING_PATH = 'speaker-embedding.onnx' +MODEL_EMBEDDING_PATH = 'speaker-embedding.onnx.prototxt' +REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/pyannote-audio/' +YAML_PATH = 'config.yaml' +OUT_PATH = 'output.png' + +parser = get_base_parser( + 'Pyannote-audio', './data/sample.wav', None, input_ftype='audio' +) + +parser.add_argument( + '--num', '-num_speaker', default=0, type=int, + help='If the number of speakers is fixed', +) +parser.add_argument( + '--max', '-max_speaker', default=0, type=int, + help='If the maximum number of speakers is fixed', +) +parser.add_argument( + '--min', '-min_speaker', default=0, type=int, + help='If the minimum number of speakers is fixed', +) +parser.add_argument( + '--ig', '-ground', default=None, + help='Specify a wav file as ground truth. If you need diarization error rate, you need this file' +) +parser.add_argument( + '--o', '-output', default='output.png', + help='Specify an output file' +) +parser.add_argument( + '--og', '-output_ground', default='output_ground.png', + help='Specify an output ground truth file' +) +parser.add_argument( + '--e', '-error', + action='store_true', + help='If you need diarization error rate' +) +parser.add_argument( + '--plt', + action='store_true', + help='If you want to visualize result' +) +parser.add_argument( + '--embed', + action='store_true', + help='If you need embedding vector', +) +parser.add_argument( + '--onnx', + action='store_true', + help='execute onnxruntime version' +) + +args = update_parser(parser) + +def repr_annotation(args, annotation: Annotation, notebook:Notebook, ground:bool = False): + """Get `png` data for `annotation`""" + figsize = plt.rcParams["figure.figsize"] + plt.rcParams["figure.figsize"] = (notebook.width, 2) + fig, ax = plt.subplots() + notebook.plot_annotation(annotation, ax=ax) + if ground: + plt.savefig(args.og) + else: + plt.savefig(args.o) + plt.close(fig) + plt.rcParams["figure.figsize"] = figsize + return + +def main(args): + check_and_download_models(WEIGHT_SEGMENTATION_PATH, MODEL_SEGMENTATION_PATH, remote_path=REMOTE_PATH) + check_and_download_models(WEIGHT_EMBEDDING_PATH, MODEL_EMBEDDING_PATH, remote_path=REMOTE_PATH) + + if args.benchmark: + start = int(round(time.time() * 1000)) + + with open(YAML_PATH, 'r') as yml: + config = yaml.safe_load(yml) + + config["pipeline"]["params"]["segmentation"] = WEIGHT_SEGMENTATION_PATH + config["pipeline"]["params"]["embedding"] = WEIGHT_EMBEDDING_PATH + with open(YAML_PATH, 'w') as f: + yaml.dump(config, f) + + audio_file = args.input[0] + checkpoint_path = YAML_PATH + config_yml = checkpoint_path + + with open(config_yml, "r") as fp: + config = yaml.load(fp, Loader=yaml.SafeLoader) + + params = config["pipeline"].get("params", {}) + pipeline = SpeakerDiarization( + **params, + args=args, + seg_path=MODEL_SEGMENTATION_PATH, + emb_path=MODEL_EMBEDDING_PATH, + ) + + if "params" in config: + pipeline.instantiate(config["params"]) + + if args.embed: + if args.num > 0: + diarization, embeddings = pipeline(audio_file, return_embeddings=True, num_speakers=args.num) + for s, speaker in enumerate(diarization.labels()): + print(speaker, embeddings[s].shape) + elif args.max > 0 or args.min > 0: + diarization, embeddings = pipeline(audio_file, return_embeddings=True, min_speakers=args.min, max_speaker=args.max) + for s, speaker in enumerate(diarization.labels()): + print(speaker, embeddings[s].shape) + else: + diarization, embeddings = pipeline(audio_file, return_embeddings=True) + for s, speaker in enumerate(diarization.labels()): + print(speaker, embeddings[s].shape) + else: + if args.num > 0: + diarization = pipeline(audio_file, num_speakers=args.num) + elif args.max > 0 or args.min > 0: + diarization = pipeline(audio_file, min_speakers=args.min, max_speaker=args.max) + else: + diarization = pipeline(audio_file) + + if args.benchmark: + end = int(round(time.time() * 1000)) + print(f'\tailia processing time {end - start} ms') + + if args.ig: + _, groundtruth = load_rttm(args.ig).popitem() + metric = DiarizationErrorRate() + result = metric(groundtruth, diarization, detailed=False) + + mapping = metric.optimal_mapping(groundtruth, diarization) + diarization = diarization.rename_labels(mapping=mapping) + + print(diarization) + if args.e: + print(f'diarization error rate = {100 * result:.1f}%') + + if args.plt: + EXCERPT = Segment(0, 30) + notebook = Notebook() + notebook.crop = EXCERPT + repr_annotation(args, diarization, notebook) + repr_annotation(args, groundtruth, notebook, ground=True) + return + + else: + print(diarization) + + if args.plt: + EXCERPT = Segment(0, 30) + notebook = Notebook() + notebook.crop = EXCERPT + repr_annotation(args, diarization, notebook) + return + + +if __name__ == "__main__": + main(args) diff --git a/ailia-models/code/pyannote_audio_utils/__init__.py b/ailia-models/code/pyannote_audio_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0898ad3b6a02c823145a9c1eefc1c3b0ab7bc3ad --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/__init__.py @@ -0,0 +1,23 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +__import__("pkg_resources").declare_namespace(__name__) diff --git a/ailia-models/code/pyannote_audio_utils/audio/__init__.py b/ailia-models/code/pyannote_audio_utils/audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..078e3d7d3a4a5d3462f9133cfb2ea1497fb93c4f --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/__init__.py @@ -0,0 +1,33 @@ +# MIT License +# +# Copyright (c) 2020-2021 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +try: + from .version import __version__, git_version # noqa: F401 +except ImportError: + pass + + +from .core.inference import Inference +from .core.io import Audio +from .core.pipeline import Pipeline + +__all__ = ["Audio", "Inference", "Pipeline"] diff --git a/ailia-models/code/pyannote_audio_utils/audio/core/inference.py b/ailia-models/code/pyannote_audio_utils/audio/core/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2a408ed0dc4af29d0ca3e1527e8d7e301aef691a --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/core/inference.py @@ -0,0 +1,596 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import ailia +import warnings +from pathlib import Path +from typing import Callable, List, Optional, Text, Tuple, Union +from functools import cached_property +from dataclasses import dataclass +import numpy as np + +from pyannote_audio_utils.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote_audio_utils.audio.core.io import AudioFile, Audio +from pyannote_audio_utils.audio.core.task import Resolution, Specifications, Problem +from pyannote_audio_utils.audio.utils.multi_task import map_with_specifications +from pyannote_audio_utils.audio.utils.powerset import Powerset + + +class BaseInference: + pass + +@dataclass +class Output: + num_frames: int + dimension: int + frames: SlidingWindow + +class Inference(BaseInference): + """Inference + + Parameters + ---------- + model : Model + Model. Will be automatically set to eval() mode and moved to `device` when provided. + window : {"sliding", "whole"}, optional + Use a "sliding" window and aggregate the corresponding outputs (default) + or just one (potentially long) window covering the "whole" file or chunk. + duration : float, optional + Chunk duration, in seconds. Defaults to duration used for training the model. + Has no effect when `window` is "whole". + step : float, optional + Step between consecutive chunks, in seconds. Defaults to warm-up duration when + greater than 0s, otherwise 10% of duration. Has no effect when `window` is "whole". + pre_aggregation_hook : callable, optional + When a callable is provided, it is applied to the model output, just before aggregation. + Takes a (num_chunks, num_frames, dimension) numpy array as input and returns a modified + (num_chunks, num_frames, other_dimension) numpy array passed to overlap-add aggregation. + skip_aggregation : bool, optional + Do not aggregate outputs when using "sliding" window. Defaults to False. + skip_conversion: bool, optional + In case a task has been trained with `powerset` mode, output is automatically + converted to `multi-label`, unless `skip_conversion` is set to True. + batch_size : int, optional + Batch size. Larger values (should) make inference faster. Defaults to 32. + device : torch.device, optional + Device used for inference. Defaults to `model.device`. + In case `device` and `model.device` are different, model is sent to device. + use_auth_token : str, optional + When loading a private huggingface.co model, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + """ + + def __init__( + self, + model: Union[Text, Path], + window: Text = "sliding", + duration: float = None, + step: float = None, + pre_aggregation_hook: Callable[[np.ndarray], np.ndarray] = None, + skip_aggregation: bool = False, + skip_conversion: bool = False, + batch_size: int = 32, + use_auth_token: Union[Text, None] = None, + args = None, + seg_path = None, + ): + # ~~~~ model ~~~~~ + + + if args.onnx: + #print("use onnx runtime") + import onnxruntime + model = onnxruntime.InferenceSession(model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + else: + #print("use ailia") + model = ailia.Net(seg_path, weight=model, env_id=args.env_id) + + self.model = model + self.args = args + + specifications = Specifications(problem=Problem.MONO_LABEL_CLASSIFICATION, resolution = Resolution.FRAME, classes=['speaker#1', 'speaker#2', 'speaker#3']) + + self.specifications = specifications + self.audio = Audio(sample_rate=16000, mono="downmix") + # ~~~~ sliding window ~~~~~ + + if window not in ["sliding", "whole"]: + raise ValueError('`window` must be "sliding" or "whole".') + + if window == "whole" and any( + s.resolution == Resolution.FRAME for s in specifications + ): + warnings.warn( + 'Using "whole" `window` inference with a frame-based model might lead to bad results ' + 'and huge memory consumption: it is recommended to set `window` to "sliding".' + ) + self.window = window + + training_duration = next(iter(specifications)).duration + duration = duration or training_duration + if training_duration != duration: + warnings.warn( + f"Model was trained with {training_duration:g}s chunks, and you requested " + f"{duration:g}s chunks for inference: this might lead to suboptimal results." + ) + self.duration = duration + + # ~~~~ powerset to multilabel conversion ~~~~ + + self.skip_conversion = skip_conversion + + conversion = list() + for s in specifications: + if s.powerset and not skip_conversion: + c = Powerset(len(s.classes), s.powerset_max_classes) + conversion.append(c) + + + if isinstance(specifications, Specifications): + self.conversion = conversion[0] + + # ~~~~ overlap-add aggregation ~~~~~ + + self.skip_aggregation = skip_aggregation + self.pre_aggregation_hook = pre_aggregation_hook + + + self.warm_up = next(iter(specifications)).warm_up + # Use that many seconds on the left- and rightmost parts of each chunk + # to warm up the model. While the model does process those left- and right-most + # parts, only the remaining central part of each chunk is used for aggregating + # scores during inference. + + # step between consecutive chunks + step = step or ( + 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + ) + if step > self.duration: + raise ValueError( + f"Step between consecutive chunks is set to {step:g}s, while chunks are " + f"only {self.duration:g}s long, leading to gaps between consecutive chunks. " + f"Either decrease step or increase duration." + ) + self.step = step + + self.batch_size = batch_size + + def infer(self, chunks: np.ndarray) -> Union[np.ndarray, Tuple[np.ndarray]]: + """Forward pass + + Takes care of sending chunks to right device and outputs back to CPU + + Parameters + ---------- + chunks : (batch_size, num_channels, num_samples) torch.Tensor + Batch of audio chunks. + + Returns + ------- + outputs : (tuple of) (batch_size, ...) np.ndarray + Model output. + """ + chunks = chunks.astype(np.float32) + if self.args.onnx: + outputs = self.model.run(None, {"input": chunks})[0] + else: + outputs = self.model.predict([chunks])[0] + + + def __convert(output: np.ndarray, conversion, **kwargs): + return conversion(output) + + return map_with_specifications(self.specifications, __convert, outputs, self.conversion) + + @cached_property + def example_output(self) -> Union[Output, Tuple[Output]]: + """Example output""" + example_input_array = np.random.randn(1, 1, self.audio.get_num_samples(self.specifications.duration)).astype(np.float32) + + example_outputs = self.infer(example_input_array) + + def __example_output( + example_output: np.ndarray, + specifications: Specifications = None, + ) -> Output: + _, num_frames, dimension = example_output.shape + + if specifications.resolution == Resolution.FRAME: + frame_duration = specifications.duration / num_frames + frames = SlidingWindow(step=frame_duration, duration=frame_duration) + else: + frames = None + + return Output( + num_frames=num_frames, + dimension=dimension, + frames=frames, + ) + + return map_with_specifications( + self.specifications, __example_output, example_outputs + ) + + def slide( + self, + waveform: np.ndarray, + sample_rate: int, + hook: Optional[Callable], + ) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]: + """Slide model on a waveform + + Parameters + ---------- + waveform: (num_channels, num_samples) torch.Tensor + Waveform. + sample_rate : int + Sample rate. + hook: Optional[Callable] + When a callable is provided, it is called everytime a batch is + processed with two keyword arguments: + - `completed`: the number of chunks that have been processed so far + - `total`: the total number of chunks + + Returns + ------- + output : (tuple of) SlidingWindowFeature + Model output. Shape is (num_chunks, dimension) for chunk-level tasks, + and (num_frames, dimension) for frame-level tasks. + """ + + window_size: int = self.audio.get_num_samples(self.duration) + step_size: int = round(self.step * sample_rate) + + _, num_samples = waveform.shape + + def __frames( + example_output, specifications: Optional[Specifications] = None + ) -> SlidingWindow: + if specifications.resolution == Resolution.CHUNK: + return SlidingWindow(start=0.0, duration=self.duration, step=self.step) + + return example_output.frames + + frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( + self.specifications, + __frames, + self.example_output, + ) + + + # prepare complete chunks + def unfold_numpy(waveform, window_size, step_size): + batch_size, waveform_size = waveform.shape + num_windows = (waveform_size - window_size) // step_size + 1 + shape = (batch_size, num_windows, window_size) + strides = ( + waveform.strides[0], + step_size * waveform.strides[1], + waveform.strides[1], + ) + + return np.lib.stride_tricks.as_strided(waveform, shape=shape, strides=strides) + + if num_samples >= window_size: + chunks: np.ndarray = (unfold_numpy(waveform, window_size, step_size)).transpose(1, 0, 2) + num_chunks = chunks.shape[0] + + else: + num_chunks = 0 + + # prepare last incomplete chunk + + has_last_chunk = (num_samples < window_size) or (num_samples - window_size) % step_size > 0 + + if has_last_chunk: + # pad last chunk with zeros + last_chunk: np.ndarray = waveform[:, num_chunks * step_size :] + _, last_window_size = last_chunk.shape + last_pad = window_size - last_window_size + last_chunk = np.pad(last_chunk, ((0, 0), (0, last_pad))) + + def __empty_list(**kwargs): + return list() + + outputs: Union[ + List[np.ndarray], Tuple[List[np.ndarray]] + ] = map_with_specifications(self.specifications, __empty_list) + + if hook is not None: + hook(completed=0, total=num_chunks + has_last_chunk) + + def __append_batch(output, batch_output, **kwargs) -> None: + output.append(batch_output) + return + + + # slide over audio chunks in batch + for c in np.arange(0, num_chunks, self.batch_size): + batch: np.ndarray = chunks[c : c + self.batch_size] + batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch) + + _ = map_with_specifications( + self.specifications, __append_batch, outputs, batch_outputs + ) + + if hook is not None: + hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk) + + + # process orphan last chunk + if has_last_chunk: + last_outputs = self.infer(last_chunk[None]) + + _ = map_with_specifications( + self.specifications, __append_batch, outputs, last_outputs + ) + + if hook is not None: + hook( + completed=num_chunks + has_last_chunk, + total=num_chunks + has_last_chunk, + ) + + def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray: + return np.vstack(output) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications( + self.specifications, __vstack, outputs + ) + + def __aggregate( + outputs: np.ndarray, + frames: SlidingWindow, + specifications: Optional[Specifications] = None, + ) -> SlidingWindowFeature: + # skip aggregation when requested, + # or when model outputs just one vector per chunk + # or when model is permutation-invariant (and not post-processed) + + if ( + self.skip_aggregation + or specifications.resolution == Resolution.CHUNK + or ( + specifications.permutation_invariant + and self.pre_aggregation_hook is None + ) + ): + frames = SlidingWindow( + start=0.0, duration=self.duration, step=self.step + ) + + return SlidingWindowFeature(outputs, frames) + + return map_with_specifications( + self.specifications, __aggregate, outputs, frames + ) + + def __call__( + self, file: AudioFile, hook: Optional[Callable] = None + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: + """Run inference on a whole file + + Parameters + ---------- + file : AudioFile + Audio file. + hook : callable, optional + When a callable is provided, it is called everytime a batch is processed + with two keyword arguments: + - `completed`: the number of chunks that have been processed so far + - `total`: the total number of chunks + + Returns + ------- + output : (tuple of) SlidingWindowFeature or np.ndarray + Model output, as `SlidingWindowFeature` if `window` is set to "sliding" + and `np.ndarray` if is set to "whole". + + """ + + waveform, sample_rate = self.audio(file) + + if self.window == "sliding": + return self.slide(waveform, sample_rate, hook=hook) + + + @staticmethod + def aggregate( + scores: SlidingWindowFeature, + frames: SlidingWindow = None, + warm_up: Tuple[float, float] = (0.0, 0.0), + epsilon: float = 1e-12, + hamming: bool = False, + missing: float = np.NaN, + skip_average: bool = False, + ) -> SlidingWindowFeature: + """Aggregation + + Parameters + ---------- + scores : SlidingWindowFeature + Raw (unaggregated) scores. Shape is (num_chunks, num_frames_per_chunk, num_classes). + frames : SlidingWindow, optional + Frames resolution. Defaults to estimate it automatically based on `scores` shape + and chunk size. Providing the exact frame resolution (when known) leads to better + temporal precision. + warm_up : (float, float) tuple, optional + Left/right warm up duration (in seconds). + missing : float, optional + Value used to replace missing (ie all NaNs) values. + skip_average : bool, optional + Skip final averaging step. + + Returns + ------- + aggregated_scores : SlidingWindowFeature + Aggregated scores. Shape is (num_frames, num_classes) + """ + + num_chunks, num_frames_per_chunk, num_classes = scores.data.shape + + chunks = scores.sliding_window + if frames is None: + duration = step = chunks.duration / num_frames_per_chunk + frames = SlidingWindow(start=chunks.start, duration=duration, step=step) + else: + frames = SlidingWindow( + start=chunks.start, + duration=frames.duration, + step=frames.step, + ) + + masks = 1 - np.isnan(scores) + scores.data = np.nan_to_num(scores.data, copy=True, nan=0.0) + + # Hamming window used for overlap-add aggregation + hamming_window = ( + np.hamming(num_frames_per_chunk).reshape(-1, 1) + if hamming + else np.ones((num_frames_per_chunk, 1)) + ) + + # anything before warm_up_left (and after num_frames_per_chunk - warm_up_right) + # will not be used in the final aggregation + + # warm-up windows used for overlap-add aggregation + warm_up_window = np.ones((num_frames_per_chunk, 1)) + # anything before warm_up_left will not contribute to aggregation + warm_up_left = round( + warm_up[0] / scores.sliding_window.duration * num_frames_per_chunk + ) + warm_up_window[:warm_up_left] = epsilon + # anything after num_frames_per_chunk - warm_up_right either + warm_up_right = round( + warm_up[1] / scores.sliding_window.duration * num_frames_per_chunk + ) + warm_up_window[num_frames_per_chunk - warm_up_right :] = epsilon + + # aggregated_output[i] will be used to store the sum of all predictions + # for frame #i + num_frames = ( + frames.closest_frame( + scores.sliding_window.start + + scores.sliding_window.duration + + (num_chunks - 1) * scores.sliding_window.step + ) + + 1 + ) + aggregated_output: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + + # overlapping_chunk_count[i] will be used to store the number of chunks + # that contributed to frame #i + overlapping_chunk_count: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + + # aggregated_mask[i] will be used to indicate whether + # at least one non-NAN frame contributed to frame #i + aggregated_mask: np.ndarray = np.zeros( + (num_frames, num_classes), dtype=np.float32 + ) + + # loop on the scores of sliding chunks + for (chunk, score), (_, mask) in zip(scores, masks): + # chunk ~ Segment + # score ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray + # mask ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray + + start_frame = frames.closest_frame(chunk.start) + aggregated_output[start_frame : start_frame + num_frames_per_chunk] += ( + score * mask * hamming_window * warm_up_window + ) + + overlapping_chunk_count[ + start_frame : start_frame + num_frames_per_chunk + ] += (mask * hamming_window * warm_up_window) + + aggregated_mask[ + start_frame : start_frame + num_frames_per_chunk + ] = np.maximum( + aggregated_mask[start_frame : start_frame + num_frames_per_chunk], + mask, + ) + + if skip_average: + average = aggregated_output + else: + average = aggregated_output / np.maximum(overlapping_chunk_count, epsilon) + + average[aggregated_mask == 0.0] = missing + + return SlidingWindowFeature(average, frames) + + @staticmethod + def trim( + scores: SlidingWindowFeature, + warm_up: Tuple[float, float] = (0.1, 0.1), + ) -> SlidingWindowFeature: + """Trim left and right warm-up regions + + Parameters + ---------- + scores : SlidingWindowFeature + (num_chunks, num_frames, num_classes)-shaped scores. + warm_up : (float, float) tuple + Left/right warm up ratio of chunk duration. + Defaults to (0.1, 0.1), i.e. 10% on both sides. + + Returns + ------- + trimmed : SlidingWindowFeature + (num_chunks, trimmed_num_frames, num_speakers)-shaped scores + """ + + + assert ( + scores.data.ndim == 3 + ), "Inference.trim expects (num_chunks, num_frames, num_classes)-shaped `scores`" + _, num_frames, _ = scores.data.shape + + chunks = scores.sliding_window + + num_frames_left = round(num_frames * warm_up[0]) + num_frames_right = round(num_frames * warm_up[1]) + + num_frames_step = round(num_frames * chunks.step / chunks.duration) + if num_frames - num_frames_left - num_frames_right < num_frames_step: + warnings.warn( + f"Total `warm_up` is so large ({sum(warm_up) * 100:g}% of each chunk) " + f"that resulting trimmed scores does not cover a whole step ({chunks.step:g}s)" + ) + new_data = scores.data[:, num_frames_left : num_frames - num_frames_right] + + new_chunks = SlidingWindow( + start=chunks.start + warm_up[0] * chunks.duration, + step=chunks.step, + duration=(1 - warm_up[0] - warm_up[1]) * chunks.duration, + ) + + return SlidingWindowFeature(new_data, new_chunks) diff --git a/ailia-models/code/pyannote_audio_utils/audio/core/io.py b/ailia-models/code/pyannote_audio_utils/audio/core/io.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc1f660e479b8b1971041f071f3d3ca747b4478 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/core/io.py @@ -0,0 +1,352 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +# Audio IO + +pyannote.audio relies on torchaudio for reading and resampling. + +""" + +import math +import random +import warnings + +from pathlib import Path +from typing import Mapping, Optional, Text, Tuple, Union +import numpy as np + +from pyannote_audio_utils.core import Segment + +import ailia.audio +import soundfile + +AudioFile = Union[Text, Path, Mapping] + +AudioFileDocString = """ +Audio files can be provided to the Audio class using different types: + - a "str" or "Path" instance: "audio.wav" or Path("audio.wav") + - a "IOBase" instance with "read" and "seek" support: open("audio.wav", "rb") + - a "Mapping" with any of the above as "audio" key: {"audio": ...} + - a "Mapping" with both "waveform" and "sample_rate" key: + {"waveform": (channel, time) numpy.ndarray or torch.Tensor, "sample_rate": 44100} + +For last two options, an additional "channel" key can be provided as a zero-indexed +integer to load a specific channel: {"audio": "stereo.wav", "channel": 0} +""" + + +class Audio: + """Audio IO + + Parameters + ---------- + sample_rate: int, optional + Target sampling rate. Defaults to using native sampling rate. + mono : {'random', 'downmix'}, optional + In case of multi-channel audio, convert to single-channel audio + using one of the following strategies: select one channel at + 'random' or 'downmix' by averaging all channels. + + Usage + ----- + >>> audio = Audio(sample_rate=16000, mono='downmix') + >>> waveform, sample_rate = audio({"audio": "/path/to/audio.wav"}) + >>> assert sample_rate == 16000 + >>> sample_rate = 44100 + >>> two_seconds_stereo = torch.rand(2, 2 * sample_rate) + >>> waveform, sample_rate = audio({"waveform": two_seconds_stereo, "sample_rate": sample_rate}) + >>> assert sample_rate == 16000 + >>> assert waveform.shape[0] == 1 + """ + + PRECISION = 0.001 + + + @staticmethod + def validate_file(file: AudioFile) -> Mapping: + """Validate file for use with the other Audio methods + + Parameter + --------- + file: AudioFile + + Returns + ------- + validated_file : Mapping + {"audio": str, "uri": str, ...} + {"waveform": array or tensor, "sample_rate": int, "uri": str, ...} + {"audio": file, "uri": "stream"} if `file` is an IOBase instance + + Raises + ------ + ValueError if file format is not valid or file does not exist. + + """ + + + if isinstance(file, Mapping): + pass + elif isinstance(file, (str, Path)): + file = {"audio": str(file), "uri": Path(file).stem} + + # elif isinstance(file, IOBase): + # return {"audio": file, "uri": "stream"} + + # else: + # raise ValueError(AudioFileDocString) + + if "waveform" in file: + waveform: np.ndarray = file["waveform"] + if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: + raise ValueError( + "'waveform' must be provided as a (channel, time) torch Tensor." + ) + + sample_rate: int = file.get("sample_rate", None) + if sample_rate is None: + raise ValueError( + "'waveform' must be provided with their 'sample_rate'." + ) + + file.setdefault("uri", "waveform") + + elif "audio" in file: + # if isinstance(file["audio"], IOBase): + # return file + + path = Path(file["audio"]) + if not path.is_file(): + raise ValueError(f"File {path} does not exist") + + file.setdefault("uri", path.stem) + + else: + raise ValueError( + "Neither 'waveform' nor 'audio' is available for this file." + ) + + return file + + def __init__(self, sample_rate=None, mono=None): + super().__init__() + self.sample_rate = sample_rate + self.mono = mono + + def downmix_and_resample(self, waveform: np.ndarray, sample_rate: int) -> np.ndarray: + """Downmix and resample + + Parameters + ---------- + waveform : (channel, time) Tensor + Waveform. + sample_rate : int + Sample rate. + + Returns + ------- + waveform : (channel, time) Tensor + Remixed and resampled waveform + sample_rate : int + New sample rate + """ + + # downmix to mono + + num_channels = waveform.shape[0] + if num_channels > 1: + if self.mono == "random": + channel = random.randint(0, num_channels - 1) + waveform = waveform[channel : channel + 1] + elif self.mono == "downmix": + waveform = np.mean(waveform, axis=0, keepdims=True) + + + + ######## ここでずれる ########## + if (self.sample_rate is not None) and (self.sample_rate != sample_rate): + waveform = ailia.audio.resample( + waveform, org_sr=sample_rate, target_sr=self.sample_rate) + + sample_rate = self.sample_rate + + return waveform, sample_rate + + + def get_num_samples( + self, duration: float, sample_rate: Optional[int] = None + ) -> int: + """Deterministic number of samples from duration and sample rate""" + + sample_rate = sample_rate or self.sample_rate + + if sample_rate is None: + raise ValueError( + "`sample_rate` must be provided to compute number of samples." + ) + + return math.floor(duration * sample_rate) + + def __call__(self, file: AudioFile) -> Tuple[np.ndarray, int]: + """Obtain waveform + + Parameters + ---------- + file : AudioFile + + Returns + ------- + waveform : (channel, time) torch.Tensor + Waveform + sample_rate : int + Sample rate + + See also + -------- + AudioFile + """ + + file = self.validate_file(file) + + if "waveform" in file: + waveform = file["waveform"] + sample_rate = file["sample_rate"] + + waveform, sample_rate = soundfile.read(file["audio"]) + + if waveform.ndim == 1: + waveform = np.expand_dims(waveform,axis=0) + else: + waveform = waveform.T + + channel = file.get("channel", None) + + if channel is not None: + waveform = waveform[channel : channel + 1] + + return self.downmix_and_resample(waveform, sample_rate) + + def crop( + self, + file: AudioFile, + segment: Segment, + duration: Optional[float] = None, + mode="raise", + ) -> Tuple[np.ndarray, int]: + """Fast version of self(file).crop(segment, **kwargs) + + Parameters + ---------- + file : AudioFile + Audio file. + segment : `pyannote.core.Segment` + Temporal segment to load. + duration : float, optional + Overrides `Segment` 'focus' duration and ensures that the number of + returned frames is fixed (which might otherwise not be the case + because of rounding errors). + mode : {'raise', 'pad'}, optional + Specifies how out-of-bounds segments will behave. + * 'raise' -- raise an error (default) + * 'pad' -- zero pad + + Returns + ------- + waveform : (channel, time) torch.Tensor + Waveform + sample_rate : int + Sample rate + + """ + + file = self.validate_file(file) + + if "waveform" in file: + waveform = file["waveform"] + frames = waveform.shape[1] + sample_rate = file["sample_rate"] + + elif "torchaudio.info" in file: + info = file["torchaudio.info"] + frames = info.num_frames + sample_rate = info.sample_rate + + else: + info = soundfile.read(file["audio"]) + frames = info[0].shape[0] + sample_rate = info[1] + + channel = file.get("channel", None) + + # infer which samples to load from sample rate and requested chunk + start_frame = math.floor(segment.start * sample_rate) + + if duration: + num_frames = math.floor(duration * sample_rate) + end_frame = start_frame + num_frames + + else: + end_frame = math.floor(segment.end * sample_rate) + num_frames = end_frame - start_frame + + if mode == "pad": + pad_start = -min(0, start_frame) + pad_end = max(end_frame, frames) - frames + start_frame = max(0, start_frame) + end_frame = min(end_frame, frames) + num_frames = end_frame - start_frame + + if "waveform" in file: + data = file["waveform"][:, start_frame:end_frame] + + else: + try: + data, _ = soundfile.read(file["audio"], start=start_frame, frames=num_frames) + if data.ndim == 1: + data = np.expand_dims(data, axis=0) + else: + data = data.T + + except RuntimeError: + msg = ( + f"torchaudio failed to seek-and-read in {file['audio']}: " + f"loading the whole file instead." + ) + + warnings.warn(msg) + waveform, sample_rate = self.__call__(file) + data = waveform[:, start_frame:end_frame] + + # storing waveform and sample_rate for next time + # as it is very likely that seek-and-read will + # fail again for this particular file + file["waveform"] = waveform + file["sample_rate"] = sample_rate + + if channel is not None: + data = data[channel : channel + 1, :] + + if mode == "pad": + data = np.pad(data, ((0, 0), (pad_start, pad_end))) + + + return self.downmix_and_resample(data, sample_rate) diff --git a/ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py b/ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..70262acfaded805b2d89831970cc080f48c29d4c --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py @@ -0,0 +1,218 @@ +# MIT License +# +# Copyright (c) 2021 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +import warnings +from collections import OrderedDict +from collections.abc import Iterator +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Text, Union + + +import yaml +from importlib import import_module +from pyannote_audio_utils.database import ProtocolFile +from pyannote_audio_utils.pipeline import Pipeline as _Pipeline + +from pyannote_audio_utils.audio import Audio, __version__ +from pyannote_audio_utils.audio.core.inference import BaseInference +from pyannote_audio_utils.audio.core.io import AudioFile + +PIPELINE_PARAMS_NAME = "config.yaml" + + +class Pipeline(_Pipeline): + @classmethod + def from_pretrained( + cls, + checkpoint_path: Union[Text, Path], + hparams_file: Union[Text, Path] = None, + use_auth_token: Union[Text, None] = None, + ) -> "Pipeline": + """Load pretrained pipeline + + Parameters + ---------- + checkpoint_path : Path or str + Path to pipeline checkpoint, or a remote URL, + or a pipeline identifier from the huggingface.co model hub. + hparams_file: Path or str, optional + use_auth_token : str, optional + When loading a private huggingface.co pipeline, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + cache_dir: Path or str, optional + Path to model cache directory. Defauorch/pyannote_audio_utils" when unset. + """ + + checkpoint_path = str(checkpoint_path) + config_yml = checkpoint_path + + with open(config_yml, "r") as fp: + config = yaml.load(fp, Loader=yaml.SafeLoader) + + # initialize pipeline + pipeline_name = config["pipeline"]["name"] + tokens = pipeline_name.split('.') + module_name = '.'.join(tokens[:-1]) + class_name = tokens[-1] + Klass = getattr(import_module(module_name), class_name) + params = config["pipeline"].get("params", {}) + pipeline = Klass(**params) + + # freeze parameters + if "params" in config: + pipeline.instantiate(config["params"]) + + return pipeline + + def __init__(self): + super().__init__() + self._models: Dict[str] = OrderedDict() + self._inferences: Dict[str, BaseInference] = OrderedDict() + + def __getattr__(self, name): + """(Advanced) attribute getter + + Adds support for Model and Inference attributes, + which are iterated over by Pipeline.to() method. + + See pyannote_audio_utils.pipeline.Pipeline.__getattr__. + """ + + if "_models" in self.__dict__: + _models = self.__dict__["_models"] + if name in _models: + return _models[name] + + if "_inferences" in self.__dict__: + _inferences = self.__dict__["_inferences"] + if name in _inferences: + return _inferences[name] + + return super().__getattr__(name) + + def __setattr__(self, name, value): + """(Advanced) attribute setter + + Adds support for Model and Inference attributes, + which are iterated over by Pipeline.to() method. + + See pyannote_audio_utils.pipeline.Pipeline.__setattr__. + """ + + def remove_from(*dicts): + for d in dicts: + if name in d: + del d[name] + + _parameters = self.__dict__.get("_parameters") + _instantiated = self.__dict__.get("_instantiated") + _pipelines = self.__dict__.get("_pipelines") + _models = self.__dict__.get("_models") + _inferences = self.__dict__.get("_inferences") + + + + if isinstance(value, BaseInference): + if _inferences is None: + msg = "cannot assign inferences before Pipeline.__init__() call" + raise AttributeError(msg) + remove_from(self.__dict__, _models, _parameters, _instantiated, _pipelines) + _inferences[name] = value + return + + super().__setattr__(name, value) + + def __delattr__(self, name): + if name in self._models: + del self._models[name] + + elif name in self._inferences: + del self._inferences[name] + + else: + super().__delattr__(name) + + @staticmethod + def setup_hook(file: AudioFile, hook: Optional[Callable] = None) -> Callable: + def noop(*args, **kwargs): + return + + return partial(hook or noop, file=file) + + def default_parameters(self): + raise NotImplementedError() + + def classes(self) -> Union[List, Iterator]: + """Classes returned by the pipeline + + Returns + ------- + classes : list of string or string iterator + Finite list of strings when classes are known in advance + (e.g. ["MALE", "FEMALE"] for gender classification), or + infinite string iterator when they depend on the file + (e.g. "SPEAKER_00", "SPEAKER_01", ... for speaker diarization) + + Usage + ----- + >>> from collections.abc import Iterator + >>> classes = pipeline.classes() + >>> if isinstance(classes, Iterator): # classes depend on the input file + >>> if isinstance(classes, list): # classes are known in advance + + """ + raise NotImplementedError() + + def __call__(self, file: AudioFile, **kwargs): + # breakpoint() + # fix_reproducibility(getattr(self, "device", torch.device("cpu"))) + + if not self.instantiated: + # instantiate with default parameters when available + try: + default_parameters = self.default_parameters() + except NotImplementedError: + raise RuntimeError( + "A pipeline must be instantiated with `pipeline.instantiate(parameters)` before it can be applied." + ) + + try: + self.instantiate(default_parameters) + except ValueError: + raise RuntimeError( + "A pipeline must be instantiated with `pipeline.instantiate(paramaters)` before it can be applied. " + "Tried to use parameters provided by `pipeline.default_parameters()` but those are not compatible. " + ) + + warnings.warn( + f"The pipeline has been automatically instantiated with {default_parameters}." + ) + + file = Audio.validate_file(file) + + if hasattr(self, "preprocessors"): + file = ProtocolFile(file, lazy=self.preprocessors) + + return self.apply(file, **kwargs) diff --git a/ailia-models/code/pyannote_audio_utils/audio/core/task.py b/ailia-models/code/pyannote_audio_utils/audio/core/task.py new file mode 100644 index 0000000000000000000000000000000000000000..bea954f4a39dfc7c219919d5096377b74efd9314 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/core/task.py @@ -0,0 +1,125 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from functools import cached_property, partial +from typing import Dict, List, Literal, Optional, Sequence, Text, Tuple, Union + +import scipy.special + +from pyannote_audio_utils.database.protocol.protocol import Scope, Subset + + +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) + + +# Type of machine learning problem +class Problem(Enum): + BINARY_CLASSIFICATION = 0 + MONO_LABEL_CLASSIFICATION = 1 + MULTI_LABEL_CLASSIFICATION = 2 + REPRESENTATION = 3 + REGRESSION = 4 + # any other we could think of? + + +# A task takes an audio chunk as input and returns +# either a temporal sequence of predictions +# or just one prediction for the whole audio chunk +class Resolution(Enum): + FRAME = 1 # model outputs a sequence of frames + CHUNK = 2 # model outputs just one vector for the whole chunk + + +class UnknownSpecificationsError(Exception): + pass + + +@dataclass +class Specifications: + problem: Problem + resolution: Resolution + + # (maximum) chunk duration in seconds + # duration: float + duration: float = 10.0 + + # (for variable-duration tasks only) minimum chunk duration in seconds + min_duration: Optional[float] = None + + # use that many seconds on the left- and rightmost parts of each chunk + # to warm up the model. This is mostly useful for segmentation tasks. + # While the model does process those left- and right-most parts, only + # the remaining central part of each chunk is used for computing the + # loss during training, and for aggregating scores during inference. + # Defaults to 0. (i.e. no warm-up). + warm_up: Optional[Tuple[float, float]] = (0.0, 0.0) + + # (for classification tasks only) list of classes + classes: Optional[List[Text]] = None + # classes: Optional[List[Text]] = ['speaker#1', 'speaker#2', 'speaker#3'] + + # (for powerset only) max number of simultaneous classes + # (n choose k with k <= powerset_max_classes) + # powerset_max_classes: Optional[int] = None + powerset_max_classes: Optional[int] = 2 + + # whether classes are permutation-invariant (e.g. diarization) + # permutation_invariant: bool = False + permutation_invariant: bool = True + + + @cached_property + def powerset(self) -> bool: + if self.powerset_max_classes is None: + return False + + if self.problem != Problem.MONO_LABEL_CLASSIFICATION: + raise ValueError( + "`powerset_max_classes` only makes sense with multi-class classification problems." + ) + + return True + + @cached_property + def num_powerset_classes(self) -> int: + # compute number of subsets of size at most "powerset_max_classes" + # e.g. with len(classes) = 3 and powerset_max_classes = 2: + # {}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2} + return int( + sum( + scipy.special.binom(len(self.classes), i) + for i in range(0, self.powerset_max_classes + 1) + ) + ) + + def __len__(self): + return 1 + + def __iter__(self): + yield self + diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b79e28071beece8d0e8098f335f493da2a19969 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py @@ -0,0 +1,35 @@ +# MIT License +# +# Copyright (c) 2020-2022 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# from .multilabel import MultiLabelSegmentation +# from .overlapped_speech_detection import OverlappedSpeechDetection +# from .resegmentation import Resegmentation +from .speaker_diarization import SpeakerDiarization +# from .voice_activity_detection import VoiceActivityDetection + +__all__ = [ + # "VoiceActivityDetection", + # "OverlappedSpeechDetection", + "SpeakerDiarization", + # "Resegmentation", + # "MultiLabelSegmentation", +] diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec54bc99b8264c25afd69b97f58c7eeae9e1578 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py @@ -0,0 +1,468 @@ +# The MIT License (MIT) +# +# Copyright (c) 2021- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Clustering pipelines""" + + +import random +from enum import Enum +from typing import Optional, Tuple + +import numpy as np +from pyannote_audio_utils.core import SlidingWindow, SlidingWindowFeature +from pyannote_audio_utils.pipeline import Pipeline +from pyannote_audio_utils.pipeline.parameter import Categorical, Integer, Uniform +from scipy.cluster.hierarchy import fcluster, linkage +from scipy.optimize import linear_sum_assignment +from scipy.spatial.distance import cdist + + +class BaseClustering(Pipeline): + def __init__( + self, + metric: str = "cosine", + max_num_embeddings: int = 1000, + constrained_assignment: bool = False, + ): + super().__init__() + self.metric = metric + self.max_num_embeddings = max_num_embeddings + self.constrained_assignment = constrained_assignment + + def set_num_clusters( + self, + num_embeddings: int, + num_clusters: Optional[int] = None, + min_clusters: Optional[int] = None, + max_clusters: Optional[int] = None, + ): + min_clusters = num_clusters or min_clusters or 1 + min_clusters = max(1, min(num_embeddings, min_clusters)) + max_clusters = num_clusters or max_clusters or num_embeddings + max_clusters = max(1, min(num_embeddings, max_clusters)) + + if min_clusters > max_clusters: + raise ValueError( + f"min_clusters must be smaller than (or equal to) max_clusters " + f"(here: min_clusters={min_clusters:g} and max_clusters={max_clusters:g})." + ) + + if min_clusters == max_clusters: + num_clusters = min_clusters + + return num_clusters, min_clusters, max_clusters + + def filter_embeddings( + self, + embeddings: np.ndarray, + segmentations: Optional[SlidingWindowFeature] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Filter NaN embeddings and downsample embeddings + + Parameters + ---------- + embeddings : (num_chunks, num_speakers, dimension) array + Sequence of embeddings. + segmentations : (num_chunks, num_frames, num_speakers) array + Binary segmentations. + + Returns + ------- + filtered_embeddings : (num_embeddings, dimension) array + chunk_idx : (num_embeddings, ) array + speaker_idx : (num_embeddings, ) array + """ + + # whether speaker is active + active = np.sum(segmentations.data, axis=1) > 0 + # whether speaker embedding extraction went fine + valid = ~np.any(np.isnan(embeddings), axis=2) + + # indices of embeddings that are both active and valid + chunk_idx, speaker_idx = np.where(active * valid) + + # sample max_num_embeddings embeddings + num_embeddings = len(chunk_idx) + if num_embeddings > self.max_num_embeddings: + indices = list(range(num_embeddings)) + random.shuffle(indices) + indices = sorted(indices[: self.max_num_embeddings]) + chunk_idx = chunk_idx[indices] + speaker_idx = speaker_idx[indices] + + return embeddings[chunk_idx, speaker_idx], chunk_idx, speaker_idx + + def constrained_argmax(self, soft_clusters: np.ndarray) -> np.ndarray: + soft_clusters = np.nan_to_num(soft_clusters, nan=np.nanmin(soft_clusters)) + num_chunks, num_speakers, num_clusters = soft_clusters.shape + # num_chunks, num_speakers, num_clusters + + hard_clusters = -2 * np.ones((num_chunks, num_speakers), dtype=np.int8) + + for c, cost in enumerate(soft_clusters): + speakers, clusters = linear_sum_assignment(cost, maximize=True) + for s, k in zip(speakers, clusters): + hard_clusters[c, s] = k + + return hard_clusters + + def assign_embeddings( + self, + embeddings: np.ndarray, + train_chunk_idx: np.ndarray, + train_speaker_idx: np.ndarray, + train_clusters: np.ndarray, + constrained: bool = False, + ): + """Assign embeddings to the closest centroid + + Cluster centroids are computed as the average of the train embeddings + previously assigned to them. + + Parameters + ---------- + embeddings : (num_chunks, num_speakers, dimension)-shaped array + Complete set of embeddings. + train_chunk_idx : (num_embeddings,)-shaped array + train_speaker_idx : (num_embeddings,)-shaped array + Indices of subset of embeddings used for "training". + train_clusters : (num_embedding,)-shaped array + Clusters of the above subset + constrained : bool, optional + Use constrained_argmax, instead of (default) argmax. + + Returns + ------- + soft_clusters : (num_chunks, num_speakers, num_clusters)-shaped array + hard_clusters : (num_chunks, num_speakers)-shaped array + centroids : (num_clusters, dimension)-shaped array + Clusters centroids + """ + + # TODO: option to add a new (dummy) cluster in case num_clusters < max(frame_speaker_count) + + num_clusters = np.max(train_clusters) + 1 + num_chunks, num_speakers, dimension = embeddings.shape + + train_embeddings = embeddings[train_chunk_idx, train_speaker_idx] + + centroids = np.vstack( + [ + np.mean(train_embeddings[train_clusters == k], axis=0) + for k in range(num_clusters) + ] + ) + + e2k_distance = cdist( + embeddings.reshape([-1, dimension]), + centroids, + metric=self.metric + ).reshape([num_chunks, num_speakers, -1]) + + soft_clusters = 2 - e2k_distance + + # assign each embedding to the cluster with the most similar centroid + if constrained: + hard_clusters = self.constrained_argmax(soft_clusters) + else: + hard_clusters = np.argmax(soft_clusters, axis=2) + + # NOTE: train_embeddings might be reassigned to a different cluster + # in the process. based on experiments, this seems to lead to better + # results than sticking to the original assignment. + + return hard_clusters, soft_clusters, centroids + + def __call__( + self, + embeddings: np.ndarray, + segmentations: Optional[SlidingWindowFeature] = None, + num_clusters: Optional[int] = None, + min_clusters: Optional[int] = None, + max_clusters: Optional[int] = None, + **kwargs, + ) -> np.ndarray: + """Apply clustering + + Parameters + ---------- + embeddings : (num_chunks, num_speakers, dimension) array + Sequence of embeddings. + segmentations : (num_chunks, num_frames, num_speakers) array + Binary segmentations. + num_clusters : int, optional + Number of clusters, when known. Default behavior is to use + internal threshold hyper-parameter to decide on the number + of clusters. + min_clusters : int, optional + Minimum number of clusters. Has no effect when `num_clusters` is provided. + max_clusters : int, optional + Maximum number of clusters. Has no effect when `num_clusters` is provided. + + Returns + ------- + hard_clusters : (num_chunks, num_speakers) array + Hard cluster assignment (hard_clusters[c, s] = k means that sth speaker + of cth chunk is assigned to kth cluster) + soft_clusters : (num_chunks, num_speakers, num_clusters) array + Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely + the sth speaker of cth chunk belongs to kth cluster) + centroids : (num_clusters, dimension) array + Centroid vectors of each cluster + """ + + train_embeddings, train_chunk_idx, train_speaker_idx = self.filter_embeddings( + embeddings, + segmentations=segmentations, + ) + + num_embeddings, _ = train_embeddings.shape + + num_clusters, min_clusters, max_clusters = self.set_num_clusters( + num_embeddings, + num_clusters=num_clusters, + min_clusters=min_clusters, + max_clusters=max_clusters, + ) + + if max_clusters < 2: + # do NOT apply clustering when min_clusters = max_clusters = 1 + num_chunks, num_speakers, _ = embeddings.shape + hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8) + soft_clusters = np.ones((num_chunks, num_speakers, 1)) + centroids = np.mean(train_embeddings, axis=0, keepdims=True) + return hard_clusters, soft_clusters, centroids + + train_clusters = self.cluster( + train_embeddings, + min_clusters, + max_clusters, + num_clusters=num_clusters, + ) + + hard_clusters, soft_clusters, centroids = self.assign_embeddings( + embeddings, + train_chunk_idx, + train_speaker_idx, + train_clusters, + constrained=self.constrained_assignment, + ) + + return hard_clusters, soft_clusters, centroids + + +class AgglomerativeClustering(BaseClustering): + """Agglomerative clustering + + Parameters + ---------- + metric : {"cosine", "euclidean", ...}, optional + Distance metric to use. Defaults to "cosine". + + Hyper-parameters + ---------------- + method : {"average", "centroid", "complete", "median", "single", "ward"} + Linkage method. + threshold : float in range [0.0, 2.0] + Clustering threshold. + min_cluster_size : int in range [1, 20] + Minimum cluster size + """ + + def __init__( + self, + metric: str = "cosine", + max_num_embeddings: int = np.inf, + constrained_assignment: bool = False, + ): + super().__init__( + metric=metric, + max_num_embeddings=max_num_embeddings, + constrained_assignment=constrained_assignment, + ) + + self.threshold = Uniform(0.0, 2.0) # assume unit-normalized embeddings + self.method = Categorical( + ["average", "centroid", "complete", "median", "single", "ward", "weighted"] + ) + + # minimum cluster size + self.min_cluster_size = Integer(1, 20) + + + def cluster( + self, + embeddings: np.ndarray, + min_clusters: int, + max_clusters: int, + num_clusters: Optional[int] = None, + ): + """ + + Parameters + ---------- + embeddings : (num_embeddings, dimension) array + Embeddings + min_clusters : int + Minimum number of clusters + max_clusters : int + Maximum number of clusters + num_clusters : int, optional + Actual number of clusters. Default behavior is to estimate it based + on values provided for `min_clusters`, `max_clusters`, and `threshold`. + + Returns + ------- + clusters : (num_embeddings, ) array + 0-indexed cluster indices. + """ + + num_embeddings, _ = embeddings.shape + + # heuristic to reduce self.min_cluster_size when num_embeddings is very small + # (0.1 value is kind of arbitrary, though) + min_cluster_size = min( + self.min_cluster_size, max(1, round(0.1 * num_embeddings)) + ) + + + # linkage function will complain when there is just one embedding to cluster + if num_embeddings == 1: + return np.zeros((1,), dtype=np.uint8) + + # centroid, median, and Ward method only support "euclidean" metric + # therefore we unit-normalize embeddings to somehow make them "euclidean" + if self.metric == "cosine" and self.method in ["centroid", "median", "ward"]: + with np.errstate(divide="ignore", invalid="ignore"): + embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True) + dendrogram: np.ndarray = linkage( + embeddings, method=self.method, metric="euclidean" + ) + + # other methods work just fine with any metric + else: + dendrogram: np.ndarray = linkage( + embeddings, method=self.method, metric=self.metric + ) + + # apply the predefined threshold + clusters = fcluster(dendrogram, self.threshold, criterion="distance") - 1 + + # split clusters into two categories based on their number of items: + # large clusters vs. small clusters + cluster_unique, cluster_counts = np.unique( + clusters, + return_counts=True, + ) + large_clusters = cluster_unique[cluster_counts >= min_cluster_size] + num_large_clusters = len(large_clusters) + + # force num_clusters to min_clusters in case the actual number is too small + if num_large_clusters < min_clusters: + num_clusters = min_clusters + + # force num_clusters to max_clusters in case the actual number is too large + elif num_large_clusters > max_clusters: + num_clusters = max_clusters + + # look for perfect candidate if necessary + if num_clusters is not None and num_large_clusters != num_clusters: + # switch stopping criterion from "inter-cluster distance" stopping to "iteration index" + _dendrogram = np.copy(dendrogram) + _dendrogram[:, 2] = np.arange(num_embeddings - 1) + + best_iteration = num_embeddings - 1 + best_num_large_clusters = 1 + + # traverse the dendrogram by going further and further away + # from the "optimal" threshold + + for iteration in np.argsort(np.abs(dendrogram[:, 2] - self.threshold)): + # only consider iterations that might have resulted + # in changing the number of (large) clusters + new_cluster_size = _dendrogram[iteration, 3] + if new_cluster_size < min_cluster_size: + continue + + # estimate number of large clusters at considered iteration + clusters = fcluster(_dendrogram, iteration, criterion="distance") - 1 + cluster_unique, cluster_counts = np.unique(clusters, return_counts=True) + large_clusters = cluster_unique[cluster_counts >= min_cluster_size] + num_large_clusters = len(large_clusters) + + # keep track of iteration that leads to the number of large clusters + # as close as possible to the target number of clusters. + if abs(num_large_clusters - num_clusters) < abs( + best_num_large_clusters - num_clusters + ): + best_iteration = iteration + best_num_large_clusters = num_large_clusters + + # stop traversing the dendrogram as soon as we found a good candidate + if num_large_clusters == num_clusters: + break + + # re-apply best iteration in case we did not find a perfect candidate + if best_num_large_clusters != num_clusters: + clusters = ( + fcluster(_dendrogram, best_iteration, criterion="distance") - 1 + ) + cluster_unique, cluster_counts = np.unique(clusters, return_counts=True) + large_clusters = cluster_unique[cluster_counts >= min_cluster_size] + num_large_clusters = len(large_clusters) + print( + f"Found only {num_large_clusters} clusters. Using a smaller value than {min_cluster_size} for `min_cluster_size` might help." + ) + + if num_large_clusters == 0: + clusters[:] = 0 + return clusters + + small_clusters = cluster_unique[cluster_counts < min_cluster_size] + if len(small_clusters) == 0: + return clusters + + # re-assign each small cluster to the most similar large cluster based on their respective centroids + large_centroids = np.vstack( + [ + np.mean(embeddings[clusters == large_k], axis=0) + for large_k in large_clusters + ] + ) + small_centroids = np.vstack( + [ + np.mean(embeddings[clusters == small_k], axis=0) + for small_k in small_clusters + ] + ) + centroids_cdist = cdist(large_centroids, small_centroids, metric=self.metric) + for small_k, large_k in enumerate(np.argmin(centroids_cdist, axis=0)): + clusters[clusters == small_clusters[small_k]] = large_clusters[large_k] + + # re-number clusters from 0 to num_large_clusters + _, clusters = np.unique(clusters, return_inverse=True) + return clusters + + +class Clustering(Enum): + AgglomerativeClustering = AgglomerativeClustering + diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py new file mode 100644 index 0000000000000000000000000000000000000000..7320a62d2d444bf5e24cfcfa21d89100897fffeb --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py @@ -0,0 +1,553 @@ +# The MIT License (MIT) +# +# Copyright (c) 2021- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Speaker diarization pipelines""" + +import functools +import itertools +import math +import textwrap +import warnings +import numpy as np + +from typing import Callable, Optional, Text, Union, Mapping +from pathlib import Path + +from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature +from pyannote_audio_utils.pipeline.parameter import ParamDict, Uniform +from pyannote_audio_utils.audio import Audio, Inference, Pipeline +from pyannote_audio_utils.audio.core.io import AudioFile +from pyannote_audio_utils.audio.pipelines.clustering import Clustering +from pyannote_audio_utils.audio.pipelines.speaker_verification import ONNXWeSpeakerPretrainedSpeakerEmbedding +from pyannote_audio_utils.audio.pipelines.utils import SpeakerDiarizationMixin + +AudioFile = Union[Text, Path, Mapping] +PipelineModel = Union[Text, Mapping] + +def batchify(iterable, batch_size: int = 32, fillvalue=None): + """Batchify iterable""" + # batchify('ABCDEFG', 3) --> ['A', 'B', 'C'] ['D', 'E', 'F'] [G, ] + args = [iter(iterable)] * batch_size + return itertools.zip_longest(*args, fillvalue=fillvalue) + + +class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline): + """Speaker diarization pipeline + + Parameters + ---------- + segmentation : Model, str, or dict, optional + Pretrained segmentation model. Defaults to "pyannote_audio_utils/segmentation@2022.07". + See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format. + segmentation_step: float, optional + The segmentation model is applied on a window sliding over the whole audio file. + `segmentation_step` controls the step of this window, provided as a ratio of its + duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). + embedding : Model, str, or dict, optional + Pretrained embedding model. Defaults to "pyannote_audio_utils/embedding@2022.07". + See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format. + embedding_exclude_overlap : bool, optional + Exclude overlapping speech regions when extracting embeddings. + Defaults (False) to use the whole speech. + clustering : str, optional + Clustering algorithm. See pyannote_audio_utils.audio.pipelines.clustering.Clustering + for available options. Defaults to "AgglomerativeClustering". + segmentation_batch_size : int, optional + Batch size used for speaker segmentation. Defaults to 1. + embedding_batch_size : int, optional + Batch size used for speaker embedding. Defaults to 1. + der_variant : dict, optional + Optimize for a variant of diarization error rate. + Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric` + when instantiating the metric: GreedyDiarizationErrorRate(**der_variant). + use_auth_token : str, optional + When loading private huggingface.co models, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + + Usage + ----- + # perform (unconstrained) diarization + >>> diarization = pipeline("/path/to/audio.wav") + + # perform diarization, targetting exactly 4 speakers + >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) + + # perform diarization, with at least 2 speakers and at most 10 speakers + >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + + # perform diarization and get one representative embedding per speaker + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) + >>> for s, speaker in enumerate(diarization.labels()): + ... # embeddings[s] is the embedding of speaker `speaker` + + Hyper-parameters + ---------------- + segmentation.threshold + segmentation.min_duration_off + clustering.??? + """ + + def __init__( + self, + segmentation: PipelineModel = "pyannote_audio_utils/segmentation@2022.07", + segmentation_step: float = 0.1, + embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e", + embedding_exclude_overlap: bool = False, + clustering: str = "AgglomerativeClustering", + embedding_batch_size: int = 1, + segmentation_batch_size: int = 1, + args = None, + seg_path = None, + emb_path = None, + der_variant: dict = None, + use_auth_token: Union[Text, None] = None, + ): + super().__init__() + + model = segmentation + self.segmentation_step = segmentation_step + self.embedding = embedding + self.embedding_batch_size = embedding_batch_size + self.embedding_exclude_overlap = embedding_exclude_overlap + self.klustering = clustering + self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False} + + segmentation_duration = 10.0 + + self._segmentation = Inference( + model, + duration=segmentation_duration, + step=self.segmentation_step * segmentation_duration, + skip_aggregation=True, + batch_size=segmentation_batch_size, + args=args, + seg_path=seg_path + ) + + self._frames: SlidingWindow = self._segmentation.example_output.frames + + self.segmentation = ParamDict( + min_duration_off=Uniform(0.0, 1.0), + ) + + self._embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding( + self.embedding, + args=args, + emb_path=emb_path + ) + self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix") + + metric = self._embedding.metric + Klustering = Clustering[clustering] + + self.clustering = Klustering.value(metric=metric) + + + def get_segmentations(self, file, hook=None) -> SlidingWindowFeature: + """Apply segmentation model + + Parameter + --------- + file : AudioFile + hook : Optional[Callable] + + Returns + ------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + """ + + if hook is not None: + hook = functools.partial(hook, "segmentation", None) + segmentations: SlidingWindowFeature = self._segmentation(file, hook=hook) + + return segmentations + + def get_embeddings( + self, + file, + binary_segmentations: SlidingWindowFeature, + exclude_overlap: bool = False, + hook: Optional[Callable] = None, + ): + """Extract embeddings for each (chunk, speaker) pair + + Parameters + ---------- + file : AudioFile + binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Binarized segmentation. + exclude_overlap : bool, optional + Exclude overlapping speech regions when extracting embeddings. + In case non-overlapping speech is too short, use the whole speech. + hook: Optional[Callable] + Called during embeddings after every batch to report the progress + + Returns + ------- + embeddings : (num_chunks, num_speakers, dimension) array + """ + + # when optimizing the hyper-parameters of this pipeline with frozen + # "segmentation.threshold", one can reuse the embeddings from the first trial, + # bringing a massive speed up to the optimization process (and hence allowing to use + # a larger search space). + + duration = binary_segmentations.sliding_window.duration + num_chunks, num_frames, num_speakers = binary_segmentations.data.shape + + if exclude_overlap: + + # minimum number of samples needed to extract an embedding + # (a lower number of samples would result in an error) + min_num_samples = self._embedding.min_num_samples + + # corresponding minimum number of frames + num_samples = duration * self._embedding.sample_rate + min_num_frames = math.ceil(num_frames * min_num_samples / num_samples) + + # zero-out frames with overlapping speech + clean_frames = 1.0 * ( + np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2 + ) + clean_segmentations = SlidingWindowFeature( + binary_segmentations.data * clean_frames, + binary_segmentations.sliding_window, + ) + + else: + min_num_frames = -1 + clean_segmentations = SlidingWindowFeature( + binary_segmentations.data, binary_segmentations.sliding_window + ) + + def iter_waveform_and_mask(): + for (chunk, masks), (_, clean_masks) in zip(binary_segmentations, clean_segmentations): + # chunk: Segment(t, t + duration) + # masks: (num_frames, local_num_speakers) np.ndarray + + waveform, _ = self._audio.crop( + file, + chunk, + duration=duration, + mode="pad", + ) + # waveform: (1, num_samples) torch.Tensor + + # mask may contain NaN (in case of partial stitching) + masks = np.nan_to_num(masks, nan=0.0).astype(np.float32) + clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32) + + for mask, clean_mask in zip(masks.T, clean_masks.T): + # mask: (num_frames, ) np.ndarray + + if np.sum(clean_mask) > min_num_frames: + used_mask = clean_mask + else: + used_mask = mask + + # yield waveform[None], torch.from_numpy(used_mask)[None] + yield waveform[None], used_mask[None] + + # w: (1, 1, num_samples) torch.Tensor + # m: (1, num_frames) torch.Tensor + + batches = batchify( + iter_waveform_and_mask(), + batch_size=self.embedding_batch_size, + fillvalue=(None, None), + ) + + + batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size) + + embedding_batches = [] + + if hook is not None: + hook("embeddings", None, total=batch_count, completed=0) + + for i, batch in enumerate(batches, 1): + waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch)) + + waveform_batch = np.vstack(waveforms) + # (batch_size, 1, num_samples) torch.Tensor + + mask_batch = np.vstack(masks) + # (batch_size, num_frames) torch.Tensor + + embedding_batch: np.ndarray = self._embedding( + waveform_batch, masks=mask_batch + ) + # (batch_size, dimension) np.ndarray + + embedding_batches.append(embedding_batch) + + if hook is not None: + hook("embeddings", embedding_batch, total=batch_count, completed=i) + + embedding_batches = np.vstack(embedding_batches) + embeddings = embedding_batches.reshape([num_chunks, -1 , embedding_batches.shape[-1]]) + + return embeddings + + def reconstruct( + self, + segmentations: SlidingWindowFeature, + hard_clusters: np.ndarray, + count: SlidingWindowFeature, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, local_num_speakers = segmentations.data.shape + + num_clusters = np.max(hard_clusters) + 1 + clustered_segmentations = np.NAN * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(hard_clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + + return self.to_diarization(clustered_segmentations, count) + + def apply( + self, + file: AudioFile, + num_speakers: int = None, + min_speakers: int = None, + max_speakers: int = None, + return_embeddings: bool = False, + hook: Optional[Callable] = None, + ) -> Annotation: + """Apply speaker diarization + + Parameters + ---------- + file : AudioFile + Processed file. + num_speakers : int, optional + Number of speakers, when known. + min_speakers : int, optional + Minimum number of speakers. Has no effect when `num_speakers` is provided. + max_speakers : int, optional + Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. + hook : callable, optional + Callback called after each major steps of the pipeline as follows: + hook(step_name, # human-readable name of current step + step_artefact, # artifact generated by current step + file=file) # file being processed + Time-consuming steps call `hook` multiple times with the same `step_name` + and additional `completed` and `total` keyword arguments usable to track + progress of current step. + + Returns + ------- + diarization : Annotation + Speaker diarization + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + segmentations = self.get_segmentations(file, hook=hook) + hook("segmentation", segmentations) + # shape: (num_chunks, num_frames, local_num_speakers) + + # binarize segmentation + + binarized_segmentations = segmentations + # estimate frame-level number of instantaneous speakers + count = self.speaker_count( + binarized_segmentations, + frames=self._frames, + warm_up=(0.0, 0.0), + ) + hook("speaker_counting", count) + # shape: (num_frames, 1) + # dtype: int + + # exit early when no speaker is ever active + if np.nanmax(count.data) == 0.0: + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, np.zeros((0, self._embedding.dimension)) + + return diarization + + embeddings = self.get_embeddings( + file, + binarized_segmentations, + exclude_overlap=self.embedding_exclude_overlap, + hook=hook, + ) + + hook("embeddings", embeddings) + # shape: (num_chunks, local_num_speakers, dimension) + + hard_clusters, _, centroids = self.clustering( + embeddings=embeddings, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + min_clusters=min_speakers, + max_clusters=max_speakers, + file=file, # <== for oracle clustering + frames=self._frames, # <== for oracle clustering + ) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) + + # number of detected clusters is the number of different speakers + num_different_speakers = np.max(hard_clusters) + 1 + + # detected number of speakers can still be out of bounds + # (specifically, lower than `min_speakers`), since there could be too few embeddings + # to make enough clusters with a given minimum cluster size. + if num_different_speakers < min_speakers or num_different_speakers > max_speakers: + warnings.warn(textwrap.dedent( + f""" + The detected number of speakers ({num_different_speakers}) is outside + the given bounds [{min_speakers}, {max_speakers}]. This can happen if the + given audio file is too short to contain {min_speakers} or more speakers. + Try to lower the desired minimal number of speakers. + """ + )) + + # during counting, we could possibly overcount the number of instantaneous + # speakers due to segmentation errors, so we cap the maximum instantaneous number + # of speakers by the `max_speakers` value + count.data = np.minimum(count.data, max_speakers).astype(np.int8) + + # reconstruct discrete diarization from raw hard clusters + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + + hard_clusters[inactive_speakers] = -2 + discrete_diarization = self.reconstruct( + segmentations, + hard_clusters, + count, + ) + hook("discrete_diarization", discrete_diarization) + + # convert to continuous diarization + diarization = self.to_annotation( + discrete_diarization, + min_duration_on=0.0, + min_duration_off=self.segmentation.min_duration_off, + ) + diarization.uri = file["uri"] + + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + + if "annotation" in file and file["annotation"]: + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} + + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { + label: expected_label + for label, expected_label in zip(diarization.labels(), self.classes()) + } + + diarization = diarization.rename_labels(mapping=mapping) + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + if not return_embeddings: + return diarization + + # this can happen when we use OracleClustering + if centroids is None: + return diarization, None + + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0))) + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + return diarization, centroids diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py new file mode 100644 index 0000000000000000000000000000000000000000..35bbcfb2602e6edadaa6a0d365a5f9c4effa8e88 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py @@ -0,0 +1,249 @@ +# MIT License +# +# Copyright (c) 2021 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from functools import cached_property +from typing import Optional, Text, Union, Mapping + +import numpy as np +import ailia + +from pyannote_audio_utils.audio.pipelines.utils.kaldifeat import compute_fbank_feats +from pyannote_audio_utils.audio.core.inference import BaseInference + +PipelineModel = Union[Text, Mapping] + +class ONNXWeSpeakerPretrainedSpeakerEmbedding(BaseInference): + """Pretrained WeSpeaker speaker embedding + + Parameters + ---------- + embedding : str + Path to WeSpeaker pretrained speaker embedding + device : torch.device, optional + Device + + Usage + ----- + >>> get_embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM") + >>> assert waveforms.ndim == 3 + >>> batch_size, num_channels, num_samples = waveforms.shape + >>> assert num_channels == 1 + >>> embeddings = get_embedding(waveforms) + >>> assert embeddings.ndim == 2 + >>> assert embeddings.shape[0] == batch_size + + >>> assert binary_masks.ndim == 1 + >>> assert binary_masks.shape[0] == batch_size + >>> embeddings = get_embedding(waveforms, masks=binary_masks) + """ + + def __init__( + self, + embedding: Text = "hbredin/wespeaker-voxceleb-resnet34-LM", + # device: Optional[torch.device] = None, + args = None, + emb_path = None + ): + # if not ONNX_IS_AVAILABLE: + # raise ImportError( + # f"'onnxruntime' must be installed to use '{embedding}' embeddings." + # ) + + super().__init__() + + # if not Path(embedding).exists(): + # try: + # embedding = hf_hub_download( + # repo_id=embedding, + # filename="speaker-embedding.onnx", + # ) + # except RepositoryNotFoundError: + # raise ValueError( + # f"Could not find '{embedding}' on huggingface.co nor on local disk." + # ) + + self.embedding = embedding + + if args.onnx: + import onnxruntime as ort + #print("use onnx runtime") + providers = ["CPUExecutionProvider", ("CUDAExecutionProvider",{"cudnn_conv_algo_search": "DEFAULT"})] + + sess_options = ort.SessionOptions() + sess_options.inter_op_num_threads = 1 + sess_options.intra_op_num_threads = 1 + self.session_ = ort.InferenceSession( + embedding, sess_options=sess_options, providers=providers + ) + else: + #print("use ailia") + + self.session_ = ailia.Net(emb_path, weight=embedding, env_id=args.env_id) + + self.args = args + + @cached_property + def sample_rate(self) -> int: + return 16000 + + @cached_property + def dimension(self) -> int: + dummy_waveforms = np.random.rand(1, 1, 16000) + features = self.compute_fbank(dummy_waveforms) + + if self.args.onnx: + embeddings = self.session_.run(output_names=["embs"], input_feed={"feats": features} + )[0] + else: + embeddings = self.session_.predict([features])[0] + + _, dimension = embeddings.shape + return dimension + + @cached_property + def metric(self) -> str: + return "cosine" + + @cached_property + def min_num_samples(self) -> int: + lower, upper = 2, round(0.5 * self.sample_rate) + middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + features = self.compute_fbank(np.random.randn(1, 1, middle)) + + except AssertionError: + lower = middle + middle = (lower + upper) // 2 + continue + + if self.args.onnx: + embeddings = self.session_.run(output_names=["embs"], input_feed={"feats": features})[0] + else: + embeddings = self.session_.predict([features])[0] + + if np.any(np.isnan(embeddings)): + lower = middle + else: + upper = middle + middle = (lower + upper) // 2 + + return upper + + @cached_property + def min_num_frames(self) -> int: + return self.compute_fbank(np.random.randn(1, 1, self.min_num_samples)).shape[1] + + def compute_fbank( + self, + waveforms: np.ndarray, + num_mel_bins: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + dither: float = 0.0, + ) -> np.ndarray: + """Extract fbank features + + Parameters + ---------- + waveforms : (batch_size, num_channels, num_samples) + + Returns + ------- + fbank : (batch_size, num_frames, num_mel_bins) + + Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50 + """ + + waveforms = waveforms * (1 << 15) + + ### ここで少しずれる ### + features_numpy = np.stack([compute_fbank_feats( + waveform=waveform[0], + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + sample_frequency=self.sample_rate, + window_type="hamming", + use_energy=False, + )for waveform in waveforms]) + ### ここで少しずれる ### + + features = features_numpy.astype(np.float32) + + return features - np.mean(features, axis=1, keepdims=True) + + def __call__( + self, waveforms: np.ndarray, masks: Optional[np.ndarray] = None + ) -> np.ndarray: + """ + + Parameters + ---------- + waveforms : (batch_size, num_channels, num_samples) + Only num_channels == 1 is supported. + masks : (batch_size, num_samples), optional + + Returns + ------- + embeddings : (batch_size, dimension) + + """ + + batch_size, num_channels, num_samples = waveforms.shape + assert num_channels == 1 + + features = self.compute_fbank(waveforms) + _, num_frames, _ = features.shape + + batch_size_masks, _ = masks.shape + assert batch_size == batch_size_masks + + def interpolate_numpy(input_array, size): + output_array = np.zeros((input_array.shape[0],size)) + + for i in range(output_array.shape[0]): + for j in range(output_array.shape[1]): + ii = int(np.floor(i * input_array.shape[0] / output_array.shape[0])) + jj = int(np.floor(j * input_array.shape[1] / output_array.shape[1])) + output_array[i, j] = input_array[ii, jj] + return output_array + + imasks = interpolate_numpy(masks,size=num_frames) + imasks = imasks > 0.5 + + embeddings = np.NAN * np.zeros((batch_size, self.dimension)) + + for f, (feature, imask) in enumerate(zip(features, imasks)): + masked_feature = feature[imask] + if masked_feature.shape[0] < self.min_num_frames: + continue + + if self.args.onnx: + embeddings[f] = self.session_.run(output_names=["embs"],input_feed={"feats": masked_feature[None]},)[0][0] + else: + embeddings[f] = self.session_.predict([masked_feature[None]])[0][0] + + return embeddings + diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c2456658e58591b90ae17f740659b7d60baade --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py @@ -0,0 +1,37 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .diarization import SpeakerDiarizationMixin +# from .getter import ( + # PipelineAugmentation, + # PipelineInference, + # PipelineModel, + # get_augmentation, + # get_devices, + # get_inference, + # get_model, +# ) +# from .oracle import oracle_segmentation + +__all__ = [ + "SpeakerDiarizationMixin", +] diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py new file mode 100644 index 0000000000000000000000000000000000000000..565da036483fd799ee82a2ce930d2b6fc8231d6b --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py @@ -0,0 +1,248 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import Dict, Mapping, Optional, Tuple, Union + +import numpy as np +from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature +from pyannote_audio_utils.core.utils.types import Label +from pyannote_audio_utils.metrics.diarization import DiarizationErrorRate + +from pyannote_audio_utils.audio.core.inference import Inference +from pyannote_audio_utils.audio.utils.signal import Binarize + + +# TODO: move to dedicated module +class SpeakerDiarizationMixin: + """Defines a bunch of methods common to speaker diarization pipelines""" + + @staticmethod + def set_num_speakers( + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + ): + """Validate number of speakers + + Parameters + ---------- + num_speakers : int, optional + Number of speakers. + min_speakers : int, optional + Minimum number of speakers. + max_speakers : int, optional + Maximum number of speakers. + + Returns + ------- + num_speakers : int or None + min_speakers : int + max_speakers : int or np.inf + """ + + # override {min|max}_num_speakers by num_speakers when available + min_speakers = num_speakers or min_speakers or 1 + max_speakers = num_speakers or max_speakers or np.inf + + if min_speakers > max_speakers: + raise ValueError( + f"min_speakers must be smaller than (or equal to) max_speakers " + f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})." + ) + if min_speakers == max_speakers: + num_speakers = min_speakers + + return num_speakers, min_speakers, max_speakers + + @staticmethod + def optimal_mapping( + reference: Union[Mapping, Annotation], + hypothesis: Annotation, + return_mapping: bool = False, + ) -> Union[Annotation, Tuple[Annotation, Dict[Label, Label]]]: + """Find the optimal bijective mapping between reference and hypothesis labels + + Parameters + ---------- + reference : Annotation or Mapping + Reference annotation. Can be an Annotation instance or + a mapping with an "annotation" key. + hypothesis : Annotation + Hypothesized annotation. + return_mapping : bool, optional + Return the label mapping itself along with the mapped annotation. Defaults to False. + + Returns + ------- + mapped : Annotation + Hypothesis mapped to reference speakers. + mapping : dict, optional + Mapping between hypothesis (key) and reference (value) labels + Only returned if `return_mapping` is True. + """ + + if isinstance(reference, Mapping): + reference = reference["annotation"] + annotated = reference["annotated"] if "annotated" in reference else None + else: + annotated = None + + mapping = DiarizationErrorRate().optimal_mapping( + reference, hypothesis, uem=annotated + ) + mapped_hypothesis = hypothesis.rename_labels(mapping=mapping) + + if return_mapping: + return mapped_hypothesis, mapping + + else: + return mapped_hypothesis + + # TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count) + @staticmethod + def speaker_count( + binarized_segmentations: SlidingWindowFeature, + frames: SlidingWindow, + warm_up: Tuple[float, float] = (0.1, 0.1), + ) -> SlidingWindowFeature: + """Estimate frame-level number of instantaneous speakers + + Parameters + ---------- + binarized_segmentations : SlidingWindowFeature + (num_chunks, num_frames, num_classes)-shaped binarized scores. + warm_up : (float, float) tuple, optional + Left/right warm up ratio of chunk duration. + Defaults to (0.1, 0.1), i.e. 10% on both sides. + frames : SlidingWindow + Frames resolution. Defaults to estimate it automatically based on + `segmentations` shape and chunk size. Providing the exact frame + resolution (when known) leads to better temporal precision. + + Returns + ------- + count : SlidingWindowFeature + (num_frames, 1)-shaped instantaneous speaker count + """ + + trimmed = Inference.trim(binarized_segmentations, warm_up=warm_up) + + count = Inference.aggregate( + np.sum(trimmed, axis=-1, keepdims=True), + frames, + hamming=False, + missing=0.0, + skip_average=False, + ) + + count.data = np.rint(count.data).astype(np.uint8) + + return count + + @staticmethod + def to_annotation( + discrete_diarization: SlidingWindowFeature, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + ) -> Annotation: + """ + + Parameters + ---------- + discrete_diarization : SlidingWindowFeature + (num_frames, num_speakers)-shaped discrete diarization + min_duration_on : float, optional + Defaults to 0. + min_duration_off : float, optional + Defaults to 0. + + Returns + ------- + continuous_diarization : Annotation + Continuous diarization, with speaker labels as integers, + corresponding to the speaker indices in the discrete diarization. + """ + + binarize = Binarize( + onset=0.5, + offset=0.5, + min_duration_on=min_duration_on, + min_duration_off=min_duration_off, + ) + + return binarize(discrete_diarization).rename_tracks(generator="string") + + @staticmethod + def to_diarization( + segmentations: SlidingWindowFeature, + count: SlidingWindowFeature, + ) -> SlidingWindowFeature: + """Build diarization out of preprocessed segmentation and precomputed speaker count + + Parameters + ---------- + segmentations : SlidingWindowFeature + (num_chunks, num_frames, num_speakers)-shaped segmentations + count : SlidingWindow_feature + (num_frames, 1)-shaped speaker count + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + # TODO: investigate alternative aggregation + activations = Inference.aggregate( + segmentations, + count.sliding_window, + hamming=False, + missing=0.0, + skip_average=True, + ) + # shape is (num_frames, num_speakers) + + _, num_speakers = activations.data.shape + max_speakers_per_frame = np.max(count.data) + if num_speakers < max_speakers_per_frame: + activations.data = np.pad( + activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers)) + ) + + extent = activations.extent & count.extent + activations = activations.crop(extent, return_data=False) + count = count.crop(extent, return_data=False) + + sorted_speakers = np.argsort(-activations, axis=-1) + binary = np.zeros_like(activations.data) + + for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)): + for i in range(c.item()): + binary[t, speakers[i]] = 1.0 + + return SlidingWindowFeature(binary, activations.sliding_window) + + def classes(self): + speaker = 0 + while True: + yield f"SPEAKER_{speaker:02d}" + speaker += 1 diff --git a/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py b/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7adee97eafa999433b064f4929567cc9c7b247e2 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py @@ -0,0 +1,291 @@ +# https://github.com/yuyq96/kaldifeat + +import numpy as np + + +# ---------- feature-window ---------- + +def sliding_window(x, window_size, window_shift): + shape = x.shape[:-1] + (x.shape[-1] - window_size + 1, window_size) + strides = x.strides + (x.strides[-1],) + return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)[::window_shift] + + +def func_num_frames(num_samples, window_size, window_shift, snip_edges): + if snip_edges: + if num_samples < window_size: + return 0 + else: + return 1 + ((num_samples - window_size) // window_shift) + else: + return (num_samples + (window_shift // 2)) // window_shift + + +def func_dither(waveform, dither_value): + if dither_value == 0.0: + return waveform + waveform += np.random.normal(size=waveform.shape).astype(waveform.dtype) * dither_value + return waveform + + +def func_remove_dc_offset(waveform): + return waveform - np.mean(waveform) + + +def func_log_energy(waveform): + return np.log(np.dot(waveform, waveform).clip(min=np.finfo(waveform.dtype).eps)) + + +def func_preemphasis(waveform, preemph_coeff): + if preemph_coeff == 0.0: + return waveform + assert 0 < preemph_coeff <= 1 + waveform[1:] -= preemph_coeff * waveform[:-1] + waveform[0] -= preemph_coeff * waveform[0] + return waveform + + +def sine(M): + if M < 1: + return np.array([]) + if M == 1: + return np.ones(1, float) + n = np.arange(0, M) + return np.sin(np.pi * n / (M - 1)) + + +def povey(M): + if M < 1: + return np.array([]) + if M == 1: + return np.ones(1, float) + n = np.arange(0, M) + return (0.5 - 0.5 * np.cos(2.0 * np.pi * n / (M - 1))) ** 0.85 + + +def feature_window_function(window_type, window_size, blackman_coeff): + assert window_size > 0 + if window_type == 'hanning': + return np.hanning(window_size) + elif window_type == 'sine': + return sine(window_size) + elif window_type == 'hamming': + return np.hamming(window_size) + elif window_type == 'povey': + return povey(window_size) + elif window_type == 'rectangular': + return np.ones(window_size) + elif window_type == 'blackman': + window_func = np.blackman(window_size) + if blackman_coeff == 0.42: + return window_func + else: + return window_func - 0.42 + blackman_coeff + else: + raise ValueError('Invalid window type {}'.format(window_type)) + + +def process_window(window, dither, remove_dc_offset, preemphasis_coefficient, window_function, raw_energy): + if dither != 0.0: + window = func_dither(window, dither) + if remove_dc_offset: + window = func_remove_dc_offset(window) + if raw_energy: + log_energy = func_log_energy(window) + if preemphasis_coefficient != 0.0: + window = func_preemphasis(window, preemphasis_coefficient) + window *= window_function + if not raw_energy: + log_energy = func_log_energy(window) + return window, log_energy + + +def extract_window(waveform, blackman_coeff, dither, window_size, window_shift, + preemphasis_coefficient, raw_energy, remove_dc_offset, + snip_edges, window_type, dtype): + num_samples = len(waveform) + num_frames = func_num_frames(num_samples, window_size, window_shift, snip_edges) + num_samples_ = (num_frames - 1) * window_shift + window_size + if snip_edges: + waveform = waveform[:num_samples_] + else: + offset = window_shift // 2 - window_size // 2 + waveform = np.concatenate([ + waveform[-offset - 1::-1], + waveform, + waveform[:-(offset + num_samples_ - num_samples + 1):-1] + ]) + frames = sliding_window(waveform, window_size=window_size, window_shift=window_shift) + frames = frames.astype(dtype) + log_enery = np.empty(frames.shape[0], dtype=dtype) + for i in range(frames.shape[0]): + frames[i], log_enery[i] = process_window( + window=frames[i], + dither=dither, + remove_dc_offset=remove_dc_offset, + preemphasis_coefficient=preemphasis_coefficient, + window_function=feature_window_function( + window_type=window_type, + window_size=window_size, + blackman_coeff=blackman_coeff + ).astype(dtype), + raw_energy=raw_energy + ) + return frames, log_enery + + +# ---------- feature-window ---------- + + +# ---------- feature-functions ---------- + +def compute_spectrum(frames, n): + complex_spec = np.fft.rfft(frames, n) + return np.absolute(complex_spec) + + +def compute_power_spectrum(frames, n): + return np.square(compute_spectrum(frames, n)) + + +# ---------- feature-functions ---------- + + +# ---------- mel-computations ---------- + + +def mel_scale(freq): + return 1127.0 * np.log(1.0 + freq / 700.0) + + +def compute_mel_banks(num_bins, sample_frequency, low_freq, high_freq, n): + """ Compute Mel banks. + + :param num_bins: Number of triangular mel-frequency bins + :param sample_frequency: Waveform data sample frequency + :param low_freq: Low cutoff frequency for mel bins + :param high_freq: High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + :param n: Window size + :return: Mel banks. + """ + assert num_bins >= 3, 'Must have at least 3 mel bins' + num_fft_bins = n // 2 + + nyquist = 0.5 * sample_frequency + if high_freq <= 0: + high_freq = nyquist + high_freq + assert 0 <= low_freq < high_freq <= nyquist + + fft_bin_width = sample_frequency / n + + mel_low_freq = mel_scale(low_freq) + mel_high_freq = mel_scale(high_freq) + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + mel_banks = np.zeros([num_bins, num_fft_bins + 1]) + for i in range(num_bins): + left_mel = mel_low_freq + mel_freq_delta * i + center_mel = left_mel + mel_freq_delta + right_mel = center_mel + mel_freq_delta + for j in range(num_fft_bins): + mel = mel_scale(fft_bin_width * j) + if left_mel < mel < right_mel: + if mel <= center_mel: + mel_banks[i, j] = (mel - left_mel) / (center_mel - left_mel) + else: + mel_banks[i, j] = (right_mel - mel) / (right_mel - center_mel) + return mel_banks + + +# ---------- mel-computations ---------- + + +# ---------- compute-fbank-feats ---------- + +def compute_fbank_feats( + waveform, + blackman_coeff=0.42, + dither=1.0, + energy_floor=0.0, + frame_length=25, + frame_shift=10, + high_freq=0, + low_freq=20, + num_mel_bins=23, + preemphasis_coefficient=0.97, + raw_energy=True, + remove_dc_offset=True, + round_to_power_of_two=True, + sample_frequency=16000, + snip_edges=True, + use_energy=False, + use_log_fbank=True, + use_power=True, + window_type='povey', + dtype=np.float32): + """ Compute (log) Mel filter bank energies + + :param waveform: Input waveform. + :param blackman_coeff: Constant coefficient for generalized Blackman window. (float, default = 0.42) + :param dither: Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1) + :param energy_floor: Floor on energy (absolute, not relative) in FBANK computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0) + :param frame_length: Frame length in milliseconds (float, default = 25) + :param frame_shift: Frame shift in milliseconds (float, default = 10) + :param high_freq: High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0) + :param low_freq: Low cutoff frequency for mel bins (float, default = 20) + :param num_mel_bins: Number of triangular mel-frequency bins (int, default = 23) + :param preemphasis_coefficient: Coefficient for use in signal preemphasis (float, default = 0.97) + :param raw_energy: If true, compute energy before preemphasis and windowing (bool, default = true) + :param remove_dc_offset: Subtract mean from waveform on each frame (bool, default = true) + :param round_to_power_of_two: If true, round window size to power of two by zero-padding input to FFT. (bool, default = true) + :param sample_frequency: Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000) + :param snip_edges: If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true) + :param use_energy: Add an extra energy output. (bool, default = false) + :param use_log_fbank: If true, produce log-filterbank, else produce linear. (bool, default = true) + :param use_power: If true, use power, else use magnitude. (bool, default = true) + :param window_type: Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"sine"|"blackmann") (string, default = "povey") + :param dtype: Type of array (np.float32|np.float64) (dtype or string, default=np.float32) + :return: (Log) Mel filter bank energies. + """ + window_size = int(frame_length * sample_frequency * 0.001) + window_shift = int(frame_shift * sample_frequency * 0.001) + frames, log_energy = extract_window( + waveform=waveform, + blackman_coeff=blackman_coeff, + dither=dither, + window_size=window_size, + window_shift=window_shift, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + snip_edges=snip_edges, + window_type=window_type, + dtype=dtype + ) + if round_to_power_of_two: + n = 1 + while n < window_size: + n *= 2 + else: + n = window_size + if use_power: + spectrum = compute_power_spectrum(frames, n) + else: + spectrum = compute_spectrum(frames, n) + mel_banks = compute_mel_banks( + num_bins=num_mel_bins, + sample_frequency=sample_frequency, + low_freq=low_freq, + high_freq=high_freq, + n=n + ).astype(dtype) + feat = np.dot(spectrum, mel_banks.T) + if use_log_fbank: + feat = np.log(feat.clip(min=np.finfo(dtype).eps)) + if use_energy: + if energy_floor > 0.0: + log_energy.clip(min=np.math.log(energy_floor)) + return feat, log_energy + return feat + +# ---------- compute-fbank-feats ---------- diff --git a/ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py b/ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a98e574c27c840572a64dbbbcb222df2b9e520 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py @@ -0,0 +1,59 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Any, Callable, Tuple, Union + +from pyannote_audio_utils.audio.core.task import Specifications + + +def map_with_specifications( + specifications: Union[Specifications, Tuple[Specifications]], + func: Callable, + *iterables, +) -> Union[Any, Tuple[Any]]: + """Compute the function using arguments from each of the iterables + + Returns a tuple if provided `specifications` is a tuple, + otherwise returns the function return value. + + Parameters + ---------- + specifications : (tuple of) Specifications + Specifications or tuple of specifications + func : callable + Function called for each specification with + `func(*iterables[i], specifications=specifications[i])` + *iterables : + List of iterables with same length as `specifications`. + + Returns + ------- + output : (tuple of) `func` return value(s) + """ + + if isinstance(specifications, Specifications): + return func(*iterables, specifications=specifications) + + return tuple( + func(*i, specifications=s) for s, *i in zip(specifications, *iterables) + ) diff --git a/ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py b/ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd2ef5ee04f9240b23057033687b6c0e85e058a --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py @@ -0,0 +1,125 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - https://herve.niderb.fr +# Alexis PLAQUET + +from functools import cached_property +from itertools import combinations + +import scipy.special +import numpy as np + +class Powerset(): + """Powerset to multilabel conversion, and back. + + Parameters + ---------- + num_classes : int + Number of regular classes. + max_set_size : int + Maximum number of classes in each set. + """ + + def __init__(self, num_classes: int, max_set_size: int): + super().__init__() + self.num_classes = num_classes + self.max_set_size = max_set_size + self.mapping = self.build_mapping() + self.cardinality = self.build_cardinality() + + + + @cached_property + def num_powerset_classes(self) -> int: + # compute number of subsets of size at most "max_set_size" + # e.g. with num_classes = 3 and max_set_size = 2: + # {}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2} + return int( + sum( + scipy.special.binom(self.num_classes, i) + for i in range(0, self.max_set_size + 1) + ) + ) + + def build_mapping(self) -> np.ndarray: + """Compute powerset to regular mapping + + Returns + ------- + mapping : (num_powerset_classes, num_classes) torch.Tensor + mapping[i, j] == 1 if jth regular class is a member of ith powerset class + mapping[i, j] == 0 otherwise + + Example + ------- + With num_classes == 3 and max_set_size == 2, returns + + [0, 0, 0] # none + [1, 0, 0] # class #1 + [0, 1, 0] # class #2 + [0, 0, 1] # class #3 + [1, 1, 0] # classes #1 and #2 + [1, 0, 1] # classes #1 and #3 + [0, 1, 1] # classes #2 and #3 + + """ + mapping = np.zeros((self.num_powerset_classes, self.num_classes)) + + powerset_k = 0 + for set_size in range(0, self.max_set_size + 1): + for current_set in combinations(range(self.num_classes), set_size): + mapping[powerset_k, current_set] = 1 + powerset_k += 1 + + return mapping + + def build_cardinality(self) -> np.ndarray: + """Compute size of each powerset class""" + return np.sum(self.mapping, axis=1) + + def to_multilabel(self, powerset: np.ndarray, soft: bool = False) -> np.ndarray: + """Convert predictions from powerset to multi-label + + Parameter + --------- + powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor + Soft predictions in "powerset" space. + soft : bool, optional + Return soft multi-label predictions. Defaults to False (i.e. hard predictions) + Assumes that `powerset` are "logits" (not "probabilities"). + + Returns + ------- + multi_label : (batch_size, num_frames, num_classes) torch.Tensor + Predictions in "multi-label" space. + """ + + powerset_probs = np.identity(self.num_powerset_classes)[np.argmax(powerset, axis=-1)] + return np.matmul(powerset_probs, self.mapping) + + + def __call__(self, powerset: np.ndarray, soft: bool = False) -> np.ndarray: + """Alias for `to_multilabel`""" + + return self.to_multilabel(powerset, soft=soft) diff --git a/ailia-models/code/pyannote_audio_utils/audio/utils/signal.py b/ailia-models/code/pyannote_audio_utils/audio/utils/signal.py new file mode 100644 index 0000000000000000000000000000000000000000..94328140379a1efe2ca3ffee22ef2371c5b518dc --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/utils/signal.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python +# encoding: utf-8 +# +# The MIT License (MIT) +# +# Copyright (c) 2016-2021 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +""" +# Signal processing +""" + +from functools import singledispatch +from itertools import zip_longest +from typing import Optional, Union + +import numpy as np +import scipy.signal +from pyannote_audio_utils.core import Annotation, Segment, SlidingWindowFeature, Timeline +from pyannote_audio_utils.core.utils.generators import pairwise + + +@singledispatch +def binarize( + scores, + onset: float = 0.5, + offset: Optional[float] = None, + initial_state: Optional[Union[bool, np.ndarray]] = None, +): + """(Batch) hysteresis thresholding + + Parameters + ---------- + scores : numpy.ndarray or SlidingWindowFeature + (num_chunks, num_frames, num_classes)- or (num_frames, num_classes)-shaped scores. + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + initial_state : np.ndarray or bool, optional + Initial state. + + Returns + ------- + binarized : same as scores + Binarized scores with same shape and type as scores. + + Reference + --------- + https://stackoverflow.com/questions/23289976/how-to-find-zero-crossings-with-hysteresis + """ + raise NotImplementedError( + "scores must be of type numpy.ndarray or SlidingWindowFeatures" + ) + + +@binarize.register +def binarize_ndarray( + scores: np.ndarray, + onset: float = 0.5, + offset: Optional[float] = None, + initial_state: Optional[Union[bool, np.ndarray]] = None, +): + """(Batch) hysteresis thresholding + + Parameters + ---------- + scores : numpy.ndarray + (num_frames, num_classes)-shaped scores. + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + initial_state : np.ndarray or bool, optional + Initial state. + + Returns + ------- + binarized : same as scores + Binarized scores with same shape and type as scores. + """ + + offset = offset or onset + + batch_size, num_frames = scores.shape + + scores = np.nan_to_num(scores) + + if initial_state is None: + initial_state = scores[:, 0] >= 0.5 * (onset + offset) + + elif isinstance(initial_state, bool): + initial_state = initial_state * np.ones((batch_size,), dtype=bool) + + elif isinstance(initial_state, np.ndarray): + assert initial_state.shape == (batch_size,) + assert initial_state.dtype == bool + + initial_state = np.tile(initial_state, (num_frames, 1)).T + + on = scores > onset + off_or_on = (scores < offset) | on + + # indices of frames for which the on/off state is well-defined + well_defined_idx = np.array( + list(zip_longest(*[np.nonzero(oon)[0] for oon in off_or_on], fillvalue=-1)) + ).T + + # corner case where well_defined_idx is empty + if not well_defined_idx.size: + return np.zeros_like(scores, dtype=bool) | initial_state + + # points to the index of the previous well-defined frame + same_as = np.cumsum(off_or_on, axis=1) + + samples = np.tile(np.arange(batch_size), (num_frames, 1)).T + + return np.where( + same_as, on[samples, well_defined_idx[samples, same_as - 1]], initial_state + ) + + +@binarize.register +def binarize_swf( + scores: SlidingWindowFeature, + onset: float = 0.5, + offset: Optional[float] = None, + initial_state: Optional[bool] = None, +): + """(Batch) hysteresis thresholding + + Parameters + ---------- + scores : SlidingWindowFeature + (num_chunks, num_frames, num_classes)- or (num_frames, num_classes)-shaped scores. + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + initial_state : np.ndarray or bool, optional + Initial state. + + Returns + ------- + binarized : same as scores + Binarized scores with same shape and type as scores. + + """ + + offset = offset or onset + + if scores.data.ndim == 2: + num_frames, num_classes = scores.data.shape + data = scores.data.transpose() + binarized = binarize( + data, onset=onset, offset=offset, initial_state=initial_state + ) + return SlidingWindowFeature( + 1.0 + * binarized.transpose(), + scores.sliding_window, + ) + + elif scores.data.ndim == 3: + num_chunks, num_frames, num_classes = scores.data.shape + data = scores.data.reshape([-1, num_classes]) + binarized = binarize( + data, onset=onset, offset=offset, initial_state=initial_state + ) + return SlidingWindowFeature( + 1.0 + * binarized.reshape([num_chunks, num_frames, num_classes]), + scores.sliding_window, + ) + + else: + raise ValueError( + "Shape of scores must be (num_chunks, num_frames, num_classes) or (num_frames, num_classes)." + ) + + +class Binarize: + """Binarize detection scores using hysteresis thresholding + + Parameters + ---------- + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + min_duration_on : float, optional + Remove active regions shorter than that many seconds. Defaults to 0s. + min_duration_off : float, optional + Fill inactive regions shorter than that many seconds. Defaults to 0s. + pad_onset : float, optional + Extend active regions by moving their start time by that many seconds. + Defaults to 0s. + pad_offset : float, optional + Extend active regions by moving their end time by that many seconds. + Defaults to 0s. + + Reference + --------- + Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of + RNN-based Voice Activity Detection", InterSpeech 2015. + """ + + def __init__( + self, + onset: float = 0.5, + offset: Optional[float] = None, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + pad_onset: float = 0.0, + pad_offset: float = 0.0, + ): + + super().__init__() + + self.onset = onset + self.offset = offset or onset + + self.pad_onset = pad_onset + self.pad_offset = pad_offset + + self.min_duration_on = min_duration_on + self.min_duration_off = min_duration_off + + def __call__(self, scores: SlidingWindowFeature) -> Annotation: + """Binarize detection scores + + Parameters + ---------- + scores : SlidingWindowFeature + Detection scores. + + Returns + ------- + active : Annotation + Binarized scores. + """ + + num_frames, num_classes = scores.data.shape + frames = scores.sliding_window + timestamps = [frames[i].middle for i in range(num_frames)] + + # annotation meant to store 'active' regions + active = Annotation() + + for k, k_scores in enumerate(scores.data.T): + + label = k if scores.labels is None else scores.labels[k] + + # initial state + start = timestamps[0] + is_active = k_scores[0] > self.onset + + for t, y in zip(timestamps[1:], k_scores[1:]): + + # currently active + if is_active: + # switching from active to inactive + if y < self.offset: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + start = t + is_active = False + + # currently inactive + else: + # switching from inactive to active + if y > self.onset: + start = t + is_active = True + + # if active at the end, add final region + if is_active: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + + # because of padding, some active regions might be overlapping: merge them. + # also: fill same speaker gaps shorter than min_duration_off + if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: + active = active.support(collar=self.min_duration_off) + + # remove tracks shorter than min_duration_on + if self.min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < self.min_duration_on: + del active[segment, track] + + return active + + +class Peak: + """Peak detection + + Parameters + ---------- + alpha : float, optional + Peak threshold. Defaults to 0.5 + min_duration : float, optional + Minimum elapsed time between two consecutive peaks. Defaults to 1 second. + """ + + def __init__( + self, + alpha: float = 0.5, + min_duration: float = 1.0, + ): + super(Peak, self).__init__() + self.alpha = alpha + self.min_duration = min_duration + + def __call__(self, scores: SlidingWindowFeature): + """Peak detection + + Parameter + --------- + scores : SlidingWindowFeature + Detection scores. + + Returns + ------- + segmentation : Timeline + Partition. + """ + + if scores.dimension != 1: + raise ValueError("Peak expects one-dimensional scores.") + + num_frames = len(scores) + frames = scores.sliding_window + + precision = frames.step + order = max(1, int(np.rint(self.min_duration / precision))) + indices = scipy.signal.argrelmax(scores[:], order=order)[0] + + peak_time = np.array( + [frames[i].middle for i in indices if scores[i] > self.alpha] + ) + boundaries = np.hstack([[frames[0].start], peak_time, [frames[num_frames].end]]) + + segmentation = Timeline() + for i, (start, end) in enumerate(pairwise(boundaries)): + segment = Segment(start, end) + segmentation.add(segment) + + return segmentation diff --git a/ailia-models/code/pyannote_audio_utils/audio/version.py b/ailia-models/code/pyannote_audio_utils/audio/version.py new file mode 100644 index 0000000000000000000000000000000000000000..726691bc2df6a60e2ee617793c20a0539b981b72 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/audio/version.py @@ -0,0 +1 @@ +__version__ = '3.1.1' diff --git a/ailia-models/code/pyannote_audio_utils/core/__init__.py b/ailia-models/code/pyannote_audio_utils/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27058b68137c4692650526c190ce3a63ca671259 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/__init__.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + + +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions + +PYANNOTE_URI = 'uri' +PYANNOTE_MODALITY = 'modality' +PYANNOTE_SEGMENT = 'segment' +PYANNOTE_TRACK = 'track' +PYANNOTE_LABEL = 'label' +PYANNOTE_SCORE = 'score' +PYANNOTE_IDENTITY = 'identity' + +from .segment import Segment, SlidingWindow +from .timeline import Timeline +from .annotation import Annotation +from .feature import SlidingWindowFeature + +Segment.set_precision() + diff --git a/ailia-models/code/pyannote_audio_utils/core/_version.py b/ailia-models/code/pyannote_audio_utils/core/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8aa08d7adfef532d3ce569748904b5bab43806 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/_version.py @@ -0,0 +1,20 @@ + +# This file was generated by 'versioneer.py' (0.15) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +{ + "dirty": false, + "error": null, + "full-revisionid": "4b0fd5302d8fa3ba249b42d3ab7b4cb51ee61ba2", + "version": "5.0.0" +} +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) diff --git a/ailia-models/code/pyannote_audio_utils/core/annotation.py b/ailia-models/code/pyannote_audio_utils/core/annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..3390a99e9bc04d7f797c0a4118683f6edab9ebc3 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/annotation.py @@ -0,0 +1,1551 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2021 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +# Paul LERNER + +""" +########## +Annotation +########## + +.. plot:: pyplots/annotation.py + +:class:`pyannote.core.Annotation` instances are ordered sets of non-empty +tracks: + + - ordered, because segments are sorted by start time (and end time in case of tie) + - set, because one cannot add twice the same track + - non-empty, because one cannot add empty track + +A track is a (support, name) pair where `support` is a Segment instance, +and `name` is an additional identifier so that it is possible to add multiple +tracks with the same support. + +To define the annotation depicted above: + +.. code-block:: ipython + + In [1]: from pyannote.core import Annotation, Segment + + In [6]: annotation = Annotation() + ...: annotation[Segment(1, 5)] = 'Carol' + ...: annotation[Segment(6, 8)] = 'Bob' + ...: annotation[Segment(12, 18)] = 'Carol' + ...: annotation[Segment(7, 20)] = 'Alice' + ...: + +which is actually a shortcut for + +.. code-block:: ipython + + In [6]: annotation = Annotation() + ...: annotation[Segment(1, 5), '_'] = 'Carol' + ...: annotation[Segment(6, 8), '_'] = 'Bob' + ...: annotation[Segment(12, 18), '_'] = 'Carol' + ...: annotation[Segment(7, 20), '_'] = 'Alice' + ...: + +where all tracks share the same (default) name ``'_'``. + +In case two tracks share the same support, use a different track name: + +.. code-block:: ipython + + In [6]: annotation = Annotation(uri='my_video_file', modality='speaker') + ...: annotation[Segment(1, 5), 1] = 'Carol' # track name = 1 + ...: annotation[Segment(1, 5), 2] = 'Bob' # track name = 2 + ...: annotation[Segment(12, 18)] = 'Carol' + ...: + +The track name does not have to be unique over the whole set of tracks. + +.. note:: + + The optional *uri* and *modality* keywords argument can be used to remember + which document and modality (e.g. speaker or face) it describes. + +Several convenient methods are available. Here are a few examples: + +.. code-block:: ipython + + In [9]: annotation.labels() # sorted list of labels + Out[9]: ['Bob', 'Carol'] + + In [10]: annotation.chart() # label duration chart + Out[10]: [('Carol', 10), ('Bob', 4)] + + In [11]: list(annotation.itertracks()) + Out[11]: [(, 1), (, 2), (, u'_')] + + In [12]: annotation.label_timeline('Carol') + Out[12]: , ])> + +See :class:`pyannote.core.Annotation` for the complete reference. +""" +import itertools +import warnings +from collections import defaultdict +from typing import ( + Hashable, + Optional, + Dict, + Union, + Iterable, + List, + Set, + TextIO, + Tuple, + Iterator, + Text, + TYPE_CHECKING, +) + +import numpy as np +from sortedcontainers import SortedDict + +from . import ( + PYANNOTE_SEGMENT, + PYANNOTE_TRACK, + PYANNOTE_LABEL, +) +from .segment import Segment, SlidingWindow +from .timeline import Timeline +from .feature import SlidingWindowFeature +from .utils.generators import string_generator, int_generator +from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode + +if TYPE_CHECKING: + import pandas as pd + + +class Annotation: + """Annotation + + Parameters + ---------- + uri : string, optional + name of annotated resource (e.g. audio or video file) + modality : string, optional + name of annotated modality + + Returns + ------- + annotation : Annotation + New annotation + + """ + + @classmethod + def from_df( + cls, + df: "pd.DataFrame", + uri: Optional[str] = None, + modality: Optional[str] = None, + ) -> "Annotation": + + df = df[[PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL]] + return Annotation.from_records(df.itertuples(index=False), uri, modality) + + def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None): + + self._uri: Optional[str] = uri + self.modality: Optional[str] = modality + + # sorted dictionary + # keys: annotated segments + # values: {track: label} dictionary + self._tracks: Dict[Segment, Dict[TrackName, Label]] = SortedDict() + + # dictionary + # key: label + # value: timeline + self._labels: Dict[Label, Timeline] = {} + self._labelNeedsUpdate: Dict[Label, bool] = {} + + # timeline meant to store all annotated segments + self._timeline: Timeline = None + self._timelineNeedsUpdate: bool = True + + @property + def uri(self): + return self._uri + + @uri.setter + def uri(self, uri: str): + # update uri for all internal timelines + for label in self.labels(): + timeline = self.label_timeline(label, copy=False) + timeline.uri = uri + timeline = self.get_timeline(copy=False) + timeline.uri = uri + self._uri = uri + + def _updateLabels(self): + + # list of labels that needs to be updated + update = set( + label for label, update in self._labelNeedsUpdate.items() if update + ) + + # accumulate segments for updated labels + _segments = {label: [] for label in update} + for segment, track, label in self.itertracks(yield_label=True): + if label in update: + _segments[label].append(segment) + + # create timeline with accumulated segments for updated labels + for label in update: + if _segments[label]: + self._labels[label] = Timeline(segments=_segments[label], uri=self.uri) + self._labelNeedsUpdate[label] = False + else: + self._labels.pop(label, None) + self._labelNeedsUpdate.pop(label, None) + + def __len__(self): + """Number of segments + + >>> len(annotation) # annotation contains three segments + 3 + """ + return len(self._tracks) + + def __nonzero__(self): + return self.__bool__() + + def __bool__(self): + """Emptiness + + >>> if annotation: + ... # annotation is not empty + ... else: + ... # annotation is empty + """ + return len(self._tracks) > 0 + + def itersegments(self): + """Iterate over segments (in chronological order) + + >>> for segment in annotation.itersegments(): + ... # do something with the segment + + See also + -------- + :class:`pyannote.core.Segment` describes how segments are sorted. + """ + return iter(self._tracks) + + def itertracks( + self, yield_label: bool = False + ) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]: + """Iterate over tracks (in chronological order) + + Parameters + ---------- + yield_label : bool, optional + When True, yield (segment, track, label) tuples, such that + annotation[segment, track] == label. Defaults to yielding + (segment, track) tuple. + + Examples + -------- + + >>> for segment, track in annotation.itertracks(): + ... # do something with the track + + >>> for segment, track, label in annotation.itertracks(yield_label=True): + ... # do something with the track and its label + """ + + for segment, tracks in self._tracks.items(): + for track, lbl in sorted( + tracks.items(), key=lambda tl: (str(tl[0]), str(tl[1])) + ): + if yield_label: + yield segment, track, lbl + else: + yield segment, track + + def _updateTimeline(self): + self._timeline = Timeline(segments=self._tracks, uri=self.uri) + self._timelineNeedsUpdate = False + + def get_timeline(self, copy: bool = True) -> Timeline: + """Get timeline made of all annotated segments + + Parameters + ---------- + copy : bool, optional + Defaults (True) to returning a copy of the internal timeline. + Set to False to return the actual internal timeline (faster). + + Returns + ------- + timeline : Timeline + Timeline made of all annotated segments. + + Note + ---- + In case copy is set to False, be careful **not** to modify the returned + timeline, as it may lead to weird subsequent behavior of the annotation + instance. + + """ + if self._timelineNeedsUpdate: + self._updateTimeline() + if copy: + return self._timeline.copy() + return self._timeline + + def __eq__(self, other: "Annotation"): + """Equality + + >>> annotation == other + + Two annotations are equal if and only if their tracks and associated + labels are equal. + """ + pairOfTracks = itertools.zip_longest( + self.itertracks(yield_label=True), other.itertracks(yield_label=True) + ) + return all(t1 == t2 for t1, t2 in pairOfTracks) + + def __ne__(self, other: "Annotation"): + """Inequality""" + pairOfTracks = itertools.zip_longest( + self.itertracks(yield_label=True), other.itertracks(yield_label=True) + ) + + return any(t1 != t2 for t1, t2 in pairOfTracks) + + def __contains__(self, included: Union[Segment, Timeline]): + """Inclusion + + Check whether every segment of `included` does exist in annotation. + + Parameters + ---------- + included : Segment or Timeline + Segment or timeline being checked for inclusion + + Returns + ------- + contains : bool + True if every segment in `included` exists in timeline, + False otherwise + + """ + return included in self.get_timeline(copy=False) + + def _iter_rttm(self) -> Iterator[Text]: + """Generate lines for an RTTM file for this annotation + + Returns + ------- + iterator: Iterator[str] + An iterator over RTTM text lines + """ + uri = self.uri if self.uri else "" + if isinstance(uri, Text) and " " in uri: + msg = ( + f"Space-separated RTTM file format does not allow file URIs " + f'containing spaces (got: "{uri}").' + ) + raise ValueError(msg) + for segment, _, label in self.itertracks(yield_label=True): + if isinstance(label, Text) and " " in label: + msg = ( + f"Space-separated RTTM file format does not allow labels " + f'containing spaces (got: "{label}").' + ) + raise ValueError(msg) + yield ( + f"SPEAKER {uri} 1 {segment.start:.3f} {segment.duration:.3f} " + f" {label} \n" + ) + + def to_rttm(self) -> Text: + """Serialize annotation as a string using RTTM format + + Returns + ------- + serialized: str + RTTM string + """ + return "".join([line for line in self._iter_rttm()]) + + def write_rttm(self, file: TextIO): + """Dump annotation to file using RTTM format + + Parameters + ---------- + file : file object + + Usage + ----- + >>> with open('file.rttm', 'w') as file: + ... annotation.write_rttm(file) + """ + for line in self._iter_rttm(): + file.write(line) + + def _iter_lab(self) -> Iterator[Text]: + """Generate lines for a LAB file for this annotation + + Returns + ------- + iterator: Iterator[str] + An iterator over LAB text lines + """ + for segment, _, label in self.itertracks(yield_label=True): + if isinstance(label, Text) and " " in label: + msg = ( + f"Space-separated LAB file format does not allow labels " + f'containing spaces (got: "{label}").' + ) + raise ValueError(msg) + yield f"{segment.start:.3f} {segment.start + segment.duration:.3f} {label}\n" + + def to_lab(self) -> Text: + """Serialize annotation as a string using LAB format + + Returns + ------- + serialized: str + LAB string + """ + return "".join([line for line in self._iter_lab()]) + + def write_lab(self, file: TextIO): + """Dump annotation to file using LAB format + + Parameters + ---------- + file : file object + + Usage + ----- + >>> with open('file.lab', 'w') as file: + ... annotation.write_lab(file) + """ + for line in self._iter_lab(): + file.write(line) + + def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation": + """Crop annotation to new support + + Parameters + ---------- + support : Segment or Timeline + If `support` is a `Timeline`, its support is used. + mode : {'strict', 'loose', 'intersection'}, optional + Controls how segments that are not fully included in `support` are + handled. 'strict' mode only keeps fully included segments. 'loose' + mode keeps any intersecting segment. 'intersection' mode keeps any + intersecting segment but replace them by their actual intersection. + + Returns + ------- + cropped : Annotation + Cropped annotation + + Note + ---- + In 'intersection' mode, the best is done to keep the track names + unchanged. However, in some cases where two original segments are + cropped into the same resulting segments, conflicting track names are + modified to make sure no track is lost. + + """ + + # TODO speed things up by working directly with annotation internals + + if isinstance(support, Segment): + support = Timeline(segments=[support], uri=self.uri) + return self.crop(support, mode=mode) + + elif isinstance(support, Timeline): + + # if 'support' is a `Timeline`, we use its support + support = support.support() + cropped = self.__class__(uri=self.uri, modality=self.modality) + + if mode == "loose": + + _tracks = {} + _labels = set([]) + + for segment, _ in self.get_timeline(copy=False).co_iter(support): + tracks = dict(self._tracks[segment]) + _tracks[segment] = tracks + _labels.update(tracks.values()) + + cropped._tracks = SortedDict(_tracks) + + cropped._labelNeedsUpdate = {label: True for label in _labels} + cropped._labels = {label: None for label in _labels} + + cropped._timelineNeedsUpdate = True + cropped._timeline = None + + return cropped + + elif mode == "strict": + + _tracks = {} + _labels = set([]) + + for segment, other_segment in self.get_timeline(copy=False).co_iter( + support + ): + + if segment not in other_segment: + continue + + tracks = dict(self._tracks[segment]) + _tracks[segment] = tracks + _labels.update(tracks.values()) + + cropped._tracks = SortedDict(_tracks) + + cropped._labelNeedsUpdate = {label: True for label in _labels} + cropped._labels = {label: None for label in _labels} + + cropped._timelineNeedsUpdate = True + cropped._timeline = None + + return cropped + + elif mode == "intersection": + + for segment, other_segment in self.get_timeline(copy=False).co_iter( + support + ): + + intersection = segment & other_segment + for track, label in self._tracks[segment].items(): + track = cropped.new_track(intersection, candidate=track) + cropped[intersection, track] = label + + return cropped + + else: + raise NotImplementedError("unsupported mode: '%s'" % mode) + + def extrude( + self, removed: Support, mode: CropMode = "intersection" + ) -> "Annotation": + """Remove segments that overlap `removed` support. + + A simple illustration: + + annotation + A |------| |------| + B |----------| + C |--------------| |------| + + removed `Timeline` + |-------| |-----------| + + extruded Annotation with mode="intersection" + B |---| + C |--| |------| + + extruded Annotation with mode="loose" + C |------| + + extruded Annotation with mode="strict" + A |------| + B |----------| + C |--------------| |------| + + Parameters + ---------- + removed : Segment or Timeline + If `support` is a `Timeline`, its support is used. + mode : {'strict', 'loose', 'intersection'}, optional + Controls how segments that are not fully included in `removed` are + handled. 'strict' mode only removes fully included segments. 'loose' + mode removes any intersecting segment. 'intersection' mode removes + the overlapping part of any intersecting segment. + + Returns + ------- + extruded : Annotation + Extruded annotation + + Note + ---- + In 'intersection' mode, the best is done to keep the track names + unchanged. However, in some cases where two original segments are + cropped into the same resulting segments, conflicting track names are + modified to make sure no track is lost. + + """ + if isinstance(removed, Segment): + removed = Timeline([removed]) + + extent_tl = Timeline([self.get_timeline().extent()], uri=self.uri) + truncating_support = removed.gaps(support=extent_tl) + # loose for truncate means strict for crop and vice-versa + if mode == "loose": + mode = "strict" + elif mode == "strict": + mode = "loose" + return self.crop(truncating_support, mode=mode) + + def get_overlap(self, labels: Optional[Iterable[Label]] = None) -> "Timeline": + """Get overlapping parts of the annotation. + + A simple illustration: + + annotation + A |------| |------| |----| + B |--| |-----| |----------| + C |--------------| |------| + + annotation.get_overlap() + |------| |-----| |--------| + + annotation.get_overlap(for_labels=["A", "B"]) + |--| |--| |----| + + Parameters + ---------- + labels : optional list of labels + Labels for which to consider the overlap + + Returns + ------- + overlap : `pyannote.core.Timeline` + Timeline of the overlaps. + """ + if labels: + annotation = self.subset(labels) + else: + annotation = self + + overlaps_tl = Timeline(uri=annotation.uri) + for (s1, t1), (s2, t2) in annotation.co_iter(annotation): + # if labels are the same for the two segments, skipping + if self[s1, t1] == self[s2, t2]: + continue + overlaps_tl.add(s1 & s2) + return overlaps_tl.support() + + def get_tracks(self, segment: Segment) -> Set[TrackName]: + """Query tracks by segment + + Parameters + ---------- + segment : Segment + Query + + Returns + ------- + tracks : set + Set of tracks + + Note + ---- + This will return an empty set if segment does not exist. + """ + return set(self._tracks.get(segment, {}).keys()) + + def has_track(self, segment: Segment, track: TrackName) -> bool: + """Check whether a given track exists + + Parameters + ---------- + segment : Segment + Query segment + track : + Query track + + Returns + ------- + exists : bool + True if track exists for segment + """ + return track in self._tracks.get(segment, {}) + + def copy(self) -> "Annotation": + """Get a copy of the annotation + + Returns + ------- + annotation : Annotation + Copy of the annotation + """ + + # create new empty annotation + copied = self.__class__(uri=self.uri, modality=self.modality) + + # deep copy internal track dictionary + _tracks, _labels = [], set([]) + for key, value in self._tracks.items(): + _labels.update(value.values()) + _tracks.append((key, dict(value))) + + copied._tracks = SortedDict(_tracks) + + copied._labels = {label: None for label in _labels} + copied._labelNeedsUpdate = {label: True for label in _labels} + + copied._timeline = None + copied._timelineNeedsUpdate = True + + return copied + + def new_track( + self, + segment: Segment, + candidate: Optional[TrackName] = None, + prefix: Optional[str] = None, + ) -> TrackName: + """Generate a new track name for given segment + + Ensures that the returned track name does not already + exist for the given segment. + + Parameters + ---------- + segment : Segment + Segment for which a new track name is generated. + candidate : any valid track name, optional + When provided, try this candidate name first. + prefix : str, optional + Track name prefix. Defaults to the empty string ''. + + Returns + ------- + name : str + New track name + """ + + # obtain list of existing tracks for segment + existing_tracks = set(self._tracks.get(segment, {})) + + # if candidate is provided, check whether it already exists + # in case it does not, use it + if (candidate is not None) and (candidate not in existing_tracks): + return candidate + + # no candidate was provided or the provided candidate already exists + # we need to create a brand new one + + # by default (if prefix is not provided), use '' + if prefix is None: + prefix = "" + + # find first non-existing track name for segment + # eg. if '0' exists, try '1', then '2', ... + count = 0 + while ("%s%d" % (prefix, count)) in existing_tracks: + count += 1 + + # return first non-existing track name + return "%s%d" % (prefix, count) + + def __str__(self): + """Human-friendly representation""" + # TODO: use pandas.DataFrame + return "\n".join( + ["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)] + ) + + def __delitem__(self, key: Key): + """Delete one track + + >>> del annotation[segment, track] + + Delete all tracks of a segment + + >>> del annotation[segment] + """ + + # del annotation[segment] + if isinstance(key, Segment): + + # Pop segment out of dictionary + # and get corresponding tracks + # Raises KeyError if segment does not exist + tracks = self._tracks.pop(key) + + # mark timeline as modified + self._timelineNeedsUpdate = True + + # mark every label in tracks as modified + for track, label in tracks.items(): + self._labelNeedsUpdate[label] = True + + # del annotation[segment, track] + elif isinstance(key, tuple) and len(key) == 2: + + # get segment tracks as dictionary + # if segment does not exist, get empty dictionary + # Raises KeyError if segment does not exist + tracks = self._tracks[key[0]] + + # pop track out of tracks dictionary + # and get corresponding label + # Raises KeyError if track does not exist + label = tracks.pop(key[1]) + + # mark label as modified + self._labelNeedsUpdate[label] = True + + # if tracks dictionary is now empty, + # remove segment as well + if not tracks: + self._tracks.pop(key[0]) + self._timelineNeedsUpdate = True + + else: + raise NotImplementedError( + "Deletion only works with Segment or (Segment, track) keys." + ) + + # label = annotation[segment, track] + def __getitem__(self, key: Key) -> Label: + """Get track label + + >>> label = annotation[segment, track] + + Note + ---- + ``annotation[segment]`` is equivalent to ``annotation[segment, '_']`` + + """ + + if isinstance(key, Segment): + key = (key, "_") + + return self._tracks[key[0]][key[1]] + + # annotation[segment, track] = label + def __setitem__(self, key: Key, label: Label): + """Add new or update existing track + + >>> annotation[segment, track] = label + + If (segment, track) does not exist, it is added. + If (segment, track) already exists, it is updated. + + Note + ---- + ``annotation[segment] = label`` is equivalent to ``annotation[segment, '_'] = label`` + + Note + ---- + If `segment` is empty, it does nothing. + """ + + if isinstance(key, Segment): + key = (key, "_") + + segment, track = key + + # do not add empty track + if not segment: + return + + # in case we create a new segment + # mark timeline as modified + if segment not in self._tracks: + self._tracks[segment] = {} + self._timelineNeedsUpdate = True + + # in case we modify an existing track + # mark old label as modified + if track in self._tracks[segment]: + old_label = self._tracks[segment][track] + self._labelNeedsUpdate[old_label] = True + + # mark new label as modified + self._tracks[segment][track] = label + self._labelNeedsUpdate[label] = True + + def empty(self) -> "Annotation": + """Return an empty copy + + Returns + ------- + empty : Annotation + Empty annotation using the same 'uri' and 'modality' attributes. + + """ + return self.__class__(uri=self.uri, modality=self.modality) + + def labels(self) -> List[Label]: + """Get sorted list of labels + + Returns + ------- + labels : list + Sorted list of labels + """ + if any([lnu for lnu in self._labelNeedsUpdate.values()]): + self._updateLabels() + return sorted(self._labels, key=str) + + def get_labels( + self, segment: Segment, unique: bool = True + ) -> Union[Set[Label], List[Label]]: + """Query labels by segment + + Parameters + ---------- + segment : Segment + Query + unique : bool, optional + When False, return the list of (possibly repeated) labels. + Defaults to returning the set of labels. + + Returns + ------- + labels : set or list + Set (resp. list) of labels for `segment` if it exists, empty set (resp. list) otherwise + if unique (resp. if not unique). + + Examples + -------- + >>> annotation = Annotation() + >>> segment = Segment(0, 2) + >>> annotation[segment, 'speaker1'] = 'Bernard' + >>> annotation[segment, 'speaker2'] = 'John' + >>> print sorted(annotation.get_labels(segment)) + set(['Bernard', 'John']) + >>> print annotation.get_labels(Segment(1, 2)) + set([]) + + """ + + labels = self._tracks.get(segment, {}).values() + + if unique: + return set(labels) + + return list(labels) + + def subset(self, labels: Iterable[Label], invert: bool = False) -> "Annotation": + """Filter annotation by labels + + Parameters + ---------- + labels : iterable + List of filtered labels + invert : bool, optional + If invert is True, extract all but requested labels + + Returns + ------- + filtered : Annotation + Filtered annotation + """ + + labels = set(labels) + + if invert: + labels = set(self.labels()) - labels + else: + labels = labels & set(self.labels()) + + sub = self.__class__(uri=self.uri, modality=self.modality) + + _tracks, _labels = {}, set([]) + for segment, tracks in self._tracks.items(): + sub_tracks = { + track: label for track, label in tracks.items() if label in labels + } + if sub_tracks: + _tracks[segment] = sub_tracks + _labels.update(sub_tracks.values()) + + sub._tracks = SortedDict(_tracks) + + sub._labelNeedsUpdate = {label: True for label in _labels} + sub._labels = {label: None for label in _labels} + + sub._timelineNeedsUpdate = True + sub._timeline = None + + return sub + + def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation": + """Add every track of an existing annotation (in place) + + Parameters + ---------- + annotation : Annotation + Annotation whose tracks are being added + copy : bool, optional + Return a copy of the annotation. Defaults to updating the + annotation in-place. + + Returns + ------- + self : Annotation + Updated annotation + + Note + ---- + Existing tracks are updated with the new label. + """ + + result = self.copy() if copy else self + + # TODO speed things up by working directly with annotation internals + for segment, track, label in annotation.itertracks(yield_label=True): + result[segment, track] = label + + return result + + def label_timeline(self, label: Label, copy: bool = True) -> Timeline: + """Query segments by label + + Parameters + ---------- + label : object + Query + copy : bool, optional + Defaults (True) to returning a copy of the internal timeline. + Set to False to return the actual internal timeline (faster). + + Returns + ------- + timeline : Timeline + Timeline made of all segments for which at least one track is + annotated as label + + Note + ---- + If label does not exist, this will return an empty timeline. + + Note + ---- + In case copy is set to False, be careful **not** to modify the returned + timeline, as it may lead to weird subsequent behavior of the annotation + instance. + + """ + if label not in self.labels(): + return Timeline(uri=self.uri) + + if self._labelNeedsUpdate[label]: + self._updateLabels() + + if copy: + return self._labels[label].copy() + + return self._labels[label] + + def label_support(self, label: Label) -> Timeline: + """Label support + + Equivalent to ``Annotation.label_timeline(label).support()`` + + Parameters + ---------- + label : object + Query + + Returns + ------- + support : Timeline + Label support + + See also + -------- + :func:`~pyannote.core.Annotation.label_timeline` + :func:`~pyannote.core.Timeline.support` + + """ + return self.label_timeline(label, copy=False).support() + + def label_duration(self, label: Label) -> float: + """Label duration + + Equivalent to ``Annotation.label_timeline(label).duration()`` + + Parameters + ---------- + label : object + Query + + Returns + ------- + duration : float + Duration, in seconds. + + See also + -------- + :func:`~pyannote.core.Annotation.label_timeline` + :func:`~pyannote.core.Timeline.duration` + + """ + + return self.label_timeline(label, copy=False).duration() + + def chart(self, percent: bool = False) -> List[Tuple[Label, float]]: + """Get labels chart (from longest to shortest duration) + + Parameters + ---------- + percent : bool, optional + Return list of (label, percentage) tuples. + Defaults to returning list of (label, duration) tuples. + + Returns + ------- + chart : list + List of (label, duration), sorted by duration in decreasing order. + """ + + chart = sorted( + ((L, self.label_duration(L)) for L in self.labels()), + key=lambda x: x[1], + reverse=True, + ) + + if percent: + total = np.sum([duration for _, duration in chart]) + chart = [(label, duration / total) for (label, duration) in chart] + + return chart + + def argmax(self, support: Optional[Support] = None) -> Optional[Label]: + """Get label with longest duration + + Parameters + ---------- + support : Segment or Timeline, optional + Find label with longest duration within provided support. + Defaults to whole extent. + + Returns + ------- + label : any existing label or None + Label with longest intersection + + Examples + -------- + >>> annotation = Annotation(modality='speaker') + >>> annotation[Segment(0, 10), 'speaker1'] = 'Alice' + >>> annotation[Segment(8, 20), 'speaker1'] = 'Bob' + >>> print "%s is such a talker!" % annotation.argmax() + Bob is such a talker! + >>> segment = Segment(22, 23) + >>> if not annotation.argmax(support): + ... print "No label intersecting %s" % segment + No label intersection [22 --> 23] + + """ + + cropped = self + if support is not None: + cropped = cropped.crop(support, mode="intersection") + + if not cropped: + return None + + return max( + ((_, cropped.label_duration(_)) for _ in cropped.labels()), + key=lambda x: x[1], + )[0] + + def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation": + """Rename all tracks + + Parameters + ---------- + generator : 'string', 'int', or iterable, optional + If 'string' (default) rename tracks to 'A', 'B', 'C', etc. + If 'int', rename tracks to 0, 1, 2, etc. + If iterable, use it to generate track names. + + Returns + ------- + renamed : Annotation + Copy of the original annotation where tracks are renamed. + + Example + ------- + >>> annotation = Annotation() + >>> annotation[Segment(0, 1), 'a'] = 'a' + >>> annotation[Segment(0, 1), 'b'] = 'b' + >>> annotation[Segment(1, 2), 'a'] = 'a' + >>> annotation[Segment(1, 3), 'c'] = 'c' + >>> print(annotation) + [ 00:00:00.000 --> 00:00:01.000] a a + [ 00:00:00.000 --> 00:00:01.000] b b + [ 00:00:01.000 --> 00:00:02.000] a a + [ 00:00:01.000 --> 00:00:03.000] c c + >>> print(annotation.rename_tracks(generator='int')) + [ 00:00:00.000 --> 00:00:01.000] 0 a + [ 00:00:00.000 --> 00:00:01.000] 1 b + [ 00:00:01.000 --> 00:00:02.000] 2 a + [ 00:00:01.000 --> 00:00:03.000] 3 c + """ + + renamed = self.__class__(uri=self.uri, modality=self.modality) + + if generator == "string": + generator = string_generator() + elif generator == "int": + generator = int_generator() + + # TODO speed things up by working directly with annotation internals + for s, _, label in self.itertracks(yield_label=True): + renamed[s, next(generator)] = label + return renamed + + def rename_labels( + self, + mapping: Optional[Dict] = None, + generator: LabelGenerator = "string", + copy: bool = True, + ) -> "Annotation": + """Rename labels + + Parameters + ---------- + mapping : dict, optional + {old_name: new_name} mapping dictionary. + generator : 'string', 'int' or iterable, optional + If 'string' (default) rename label to 'A', 'B', 'C', ... If 'int', + rename to 0, 1, 2, etc. If iterable, use it to generate labels. + copy : bool, optional + Set to True to return a copy of the annotation. Set to False to + update the annotation in-place. Defaults to True. + + Returns + ------- + renamed : Annotation + Annotation where labels have been renamed + + Note + ---- + Unmapped labels are kept unchanged. + + Note + ---- + Parameter `generator` has no effect when `mapping` is provided. + + """ + + if mapping is None: + if generator == "string": + generator = string_generator() + elif generator == "int": + generator = int_generator() + # generate mapping + mapping = {label: next(generator) for label in self.labels()} + + renamed = self.copy() if copy else self + + for old_label, new_label in mapping.items(): + renamed._labelNeedsUpdate[old_label] = True + renamed._labelNeedsUpdate[new_label] = True + + for segment, tracks in self._tracks.items(): + new_tracks = { + track: mapping.get(label, label) for track, label in tracks.items() + } + renamed._tracks[segment] = new_tracks + + return renamed + + def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation": + """Relabel tracks + + Create a new annotation where each track has a unique label. + + Parameters + ---------- + generator : 'string', 'int' or iterable, optional + If 'string' (default) relabel tracks to 'A', 'B', 'C', ... If 'int' + relabel to 0, 1, 2, ... If iterable, use it to generate labels. + + Returns + ------- + renamed : Annotation + New annotation with relabeled tracks. + """ + + if generator == "string": + generator = string_generator() + elif generator == "int": + generator = int_generator() + + relabeled = self.empty() + for s, t, _ in self.itertracks(yield_label=True): + relabeled[s, t] = next(generator) + + return relabeled + + def support(self, collar: float = 0.0) -> "Annotation": + """Annotation support + + The support of an annotation is an annotation where contiguous tracks + with same label are merged into one unique covering track. + + A picture is worth a thousand words:: + + collar + |---| + + annotation + |--A--| |--A--| |-B-| + |-B-| |--C--| |----B-----| + + annotation.support(collar) + |------A------| |------B------| + |-B-| |--C--| + + Parameters + ---------- + collar : float, optional + Merge tracks with same label and separated by less than `collar` + seconds. This is why 'A' tracks are merged in above figure. + Defaults to 0. + + Returns + ------- + support : Annotation + Annotation support + + Note + ---- + Track names are lost in the process. + """ + + generator = string_generator() + + # initialize an empty annotation + # with same uri and modality as original + support = self.empty() + for label in self.labels(): + + # get timeline for current label + timeline = self.label_timeline(label, copy=True) + + # fill the gaps shorter than collar + timeline = timeline.support(collar) + + # reconstruct annotation with merged tracks + for segment in timeline.support(): + support[segment, next(generator)] = label + + return support + + def co_iter( + self, other: "Annotation" + ) -> Iterator[Tuple[Tuple[Segment, TrackName], Tuple[Segment, TrackName]]]: + """Iterate over pairs of intersecting tracks + + Parameters + ---------- + other : Annotation + Second annotation + + Returns + ------- + iterable : (Segment, object), (Segment, object) iterable + Yields pairs of intersecting tracks, in chronological (then + alphabetical) order. + + See also + -------- + :func:`~pyannote.core.Timeline.co_iter` + + """ + timeline = self.get_timeline(copy=False) + other_timeline = other.get_timeline(copy=False) + for s, S in timeline.co_iter(other_timeline): + tracks = sorted(self.get_tracks(s), key=str) + other_tracks = sorted(other.get_tracks(S), key=str) + for t, T in itertools.product(tracks, other_tracks): + yield (s, t), (S, T) + + def __mul__(self, other: "Annotation") -> np.ndarray: + """Cooccurrence (or confusion) matrix + + >>> matrix = annotation * other + + Parameters + ---------- + other : Annotation + Second annotation + + Returns + ------- + cooccurrence : (n_self, n_other) np.ndarray + Cooccurrence matrix where `n_self` (resp. `n_other`) is the number + of labels in `self` (resp. `other`). + """ + + if not isinstance(other, Annotation): + raise TypeError( + "computing cooccurrence matrix only works with Annotation " "instances." + ) + + i_labels = self.labels() + j_labels = other.labels() + + I = {label: i for i, label in enumerate(i_labels)} + J = {label: j for j, label in enumerate(j_labels)} + + matrix = np.zeros((len(I), len(J))) + + # iterate over intersecting tracks and accumulate durations + for (segment, track), (other_segment, other_track) in self.co_iter(other): + i = I[self[segment, track]] + j = J[other[other_segment, other_track]] + duration = (segment & other_segment).duration + matrix[i, j] += duration + + return matrix + + def discretize( + self, + support: Optional[Segment] = None, + resolution: Union[float, SlidingWindow] = 0.01, + labels: Optional[List[Hashable]] = None, + duration: Optional[float] = None, + ): + """Discretize + + Parameters + ---------- + support : Segment, optional + Part of annotation to discretize. + Defaults to annotation full extent. + resolution : float or SlidingWindow, optional + Defaults to 10ms frames. + labels : list of labels, optional + Defaults to self.labels() + duration : float, optional + Overrides support duration and ensures that the number of + returned frames is fixed (which might otherwise not be the case + because of rounding errors). + + Returns + ------- + discretized : SlidingWindowFeature + (num_frames, num_labels)-shaped binary features. + """ + + if support is None: + support = self.get_timeline().extent() + start_time, end_time = support + + cropped = self.crop(support, mode="intersection") + + if labels is None: + labels = cropped.labels() + + if isinstance(resolution, SlidingWindow): + resolution = SlidingWindow( + start=start_time, step=resolution.step, duration=resolution.duration + ) + else: + resolution = SlidingWindow( + start=start_time, step=resolution, duration=resolution + ) + + start_frame = resolution.closest_frame(start_time) + if duration is None: + end_frame = resolution.closest_frame(end_time) + num_frames = end_frame - start_frame + else: + num_frames = int(round(duration / resolution.step)) + + data = np.zeros((num_frames, len(labels)), dtype=np.uint8) + for k, label in enumerate(labels): + segments = cropped.label_timeline(label) + for start, stop in resolution.crop( + segments, mode="center", return_ranges=True + ): + data[max(0, start) : min(stop, num_frames), k] += 1 + data = np.minimum(data, 1, out=data) + + return SlidingWindowFeature(data, resolution, labels=labels) + + @classmethod + def from_records( + cls, + records: Iterator[Tuple[Segment, TrackName, Label]], + uri: Optional[str] = None, + modality: Optional[str] = None, + ) -> "Annotation": + """Annotation + + Parameters + ---------- + records : iterator of tuples + (segment, track, label) tuples + uri : string, optional + name of annotated resource (e.g. audio or video file) + modality : string, optional + name of annotated modality + + Returns + ------- + annotation : Annotation + New annotation + + """ + annotation = cls(uri=uri, modality=modality) + tracks = defaultdict(dict) + labels = set() + for segment, track, label in records: + tracks[segment][track] = label + labels.add(label) + annotation._tracks = SortedDict(tracks) + annotation._labels = {label: None for label in labels} + annotation._labelNeedsUpdate = {label: True for label in annotation._labels} + annotation._timeline = None + annotation._timelineNeedsUpdate = True + + return annotation + + def _repr_png_(self): + """IPython notebook support + + See also + -------- + :mod:`pyannote.core.notebook` + """ + from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING + + if not MATPLOTLIB_IS_AVAILABLE: + warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__)) + return None + + from .notebook import repr_annotation + return repr_annotation(self) diff --git a/ailia-models/code/pyannote_audio_utils/core/feature.py b/ailia-models/code/pyannote_audio_utils/core/feature.py new file mode 100644 index 0000000000000000000000000000000000000000..253a87ea7b394131a9dd0891d3bc8bf459e27de0 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/feature.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + + +""" +######## +Features +######## + +See :class:`pyannote_audio_utils.core.SlidingWindowFeature` for the complete reference. +""" +import numbers +import warnings +from typing import Tuple, Optional, Union, Iterator, List, Text + +import numpy as np + +from pyannote_audio_utils.core.utils.types import Alignment +from .segment import Segment +from .segment import SlidingWindow +from .timeline import Timeline + + +class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin): + """Periodic feature vectors + + Parameters + ---------- + data : (n_frames, n_features) numpy array + sliding_window : SlidingWindow + labels : list, optional + Textual description of each dimension. + """ + + def __init__( + self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None + ): + self.sliding_window: SlidingWindow = sliding_window + self.data = data + self.labels = labels + self.__i: int = -1 + + def __len__(self): + """Number of feature vectors""" + return self.data.shape[0] + + @property + def extent(self): + return self.sliding_window.range_to_segment(0, len(self)) + + @property + def dimension(self): + """Dimension of feature vectors""" + return self.data.shape[1] + + def getNumber(self): + warnings.warn("This is deprecated in favor of `__len__`", DeprecationWarning) + return self.data.shape[0] + + def getDimension(self): + warnings.warn( + "This is deprecated in favor of `dimension` property", DeprecationWarning + ) + return self.dimension + + def getExtent(self): + warnings.warn( + "This is deprecated in favor of `extent` property", DeprecationWarning + ) + return self.extent + + def __getitem__(self, i: int) -> np.ndarray: + """Get ith feature vector""" + return self.data[i] + + def __iter__(self): + self.__i = -1 + return self + + def __next__(self) -> Tuple[Segment, np.ndarray]: + self.__i += 1 + try: + return self.sliding_window[self.__i], self.data[self.__i] + except IndexError as e: + raise StopIteration() + + def next(self): + return self.__next__() + + def iterfeatures( + self, window: Optional[bool] = False + ) -> Iterator[Union[Tuple[np.ndarray, Segment], np.ndarray]]: + """Feature vector iterator + + Parameters + ---------- + window : bool, optional + When True, yield both feature vector and corresponding window. + Default is to only yield feature vector + + """ + n_samples = self.data.shape[0] + for i in range(n_samples): + if window: + yield self.data[i], self.sliding_window[i] + else: + yield self.data[i] + + def crop( + self, + focus: Union[Segment, Timeline], + mode: Alignment = "loose", + fixed: Optional[float] = None, + return_data: bool = True, + ) -> Union[np.ndarray, "SlidingWindowFeature"]: + """Extract frames + + Parameters + ---------- + focus : Segment or Timeline + mode : {'loose', 'strict', 'center'}, optional + In 'strict' mode, only frames fully included in 'focus' support are + returned. In 'loose' mode, any intersecting frames are returned. In + 'center' mode, first and last frames are chosen to be the ones + whose centers are the closest to 'focus' start and end times. + Defaults to 'loose'. + fixed : float, optional + Overrides `Segment` 'focus' duration and ensures that the number of + returned frames is fixed (which might otherwise not be the case + because of rounding errors). + return_data : bool, optional + Return a numpy array (default). For `Segment` 'focus', setting it + to False will return a `SlidingWindowFeature` instance. + + Returns + ------- + data : `numpy.ndarray` or `SlidingWindowFeature` + Frame features. + + See also + -------- + SlidingWindow.crop + + """ + + if (not return_data) and (not isinstance(focus, Segment)): + msg = ( + '"focus" must be a "Segment" instance when "return_data"' + "is set to False." + ) + raise ValueError(msg) + + if (not return_data) and (fixed is not None): + msg = '"fixed" cannot be set when "return_data" is set to False.' + raise ValueError(msg) + + ranges = self.sliding_window.crop( + focus, mode=mode, fixed=fixed, return_ranges=True + ) + + # total number of samples in features + n_samples = self.data.shape[0] + + # 1 for vector features (e.g. MFCC in pyannote_audio_utils.audio) + # 2 for matrix features (e.g. grey-level frames in pyannote_audio_utils.video) + # 3 for 3rd order tensor (e.g. RBG frames in pyannote_audio_utils.video) + n_dimensions = len(self.data.shape) - 1 + + # clip ranges + clipped_ranges, repeat_first, repeat_last = [], 0, 0 + for start, end in ranges: + # count number of requested samples before first sample + repeat_first += min(end, 0) - min(start, 0) + # count number of requested samples after last sample + repeat_last += max(end, n_samples) - max(start, n_samples) + # if all requested samples are out of bounds, skip + if end < 0 or start >= n_samples: + continue + else: + # keep track of non-empty clipped ranges + clipped_ranges += [[max(start, 0), min(end, n_samples)]] + + if clipped_ranges: + data = np.vstack([self.data[start:end, :] for start, end in clipped_ranges]) + else: + # if all ranges are out of bounds, just return empty data + shape = (0,) + self.data.shape[1:] + data = np.empty(shape) + + # corner case when "fixed" duration cropping is requested: + # correct number of samples even with out-of-bounds indices + if fixed is not None: + data = np.vstack( + [ + # repeat first sample as many times as needed + np.tile(self.data[0], (repeat_first,) + (1,) * n_dimensions), + data, + # repeat last sample as many times as needed + np.tile( + self.data[n_samples - 1], (repeat_last,) + (1,) * n_dimensions + ), + ] + ) + + # return data + if return_data: + return data + + # wrap data in a SlidingWindowFeature and return + sliding_window = SlidingWindow( + start=self.sliding_window[clipped_ranges[0][0]].start, + duration=self.sliding_window.duration, + step=self.sliding_window.step, + ) + + return SlidingWindowFeature(data, sliding_window, labels=self.labels) + + def _repr_png_(self): + from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING + + if not MATPLOTLIB_IS_AVAILABLE: + warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__)) + return None + + from .notebook import repr_feature + + return repr_feature(self) + + _HANDLED_TYPES = (np.ndarray, numbers.Number) + + def __array__(self) -> np.ndarray: + return self.data + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + out = kwargs.get("out", ()) + for x in inputs + out: + # Only support operations with instances of _HANDLED_TYPES. + # Use SlidingWindowFeature instead of type(self) for isinstance to + # allow subclasses that don't override __array_ufunc__ to + # handle SlidingWindowFeature objects. + if not isinstance(x, self._HANDLED_TYPES + (SlidingWindowFeature,)): + return NotImplemented + + # Defer to the implementation of the ufunc on unwrapped values. + inputs = tuple( + x.data if isinstance(x, SlidingWindowFeature) else x for x in inputs + ) + if out: + kwargs["out"] = tuple( + x.data if isinstance(x, SlidingWindowFeature) else x for x in out + ) + data = getattr(ufunc, method)(*inputs, **kwargs) + + if type(data) is tuple: + # multiple return values + return tuple( + type(self)(x, self.sliding_window, labels=self.labels) for x in data + ) + elif method == "at": + # no return value + return None + else: + # one return value + return type(self)(data, self.sliding_window, labels=self.labels) + + def align(self, to: "SlidingWindowFeature") -> "SlidingWindowFeature": + """Align features by linear temporal interpolation + + Parameters + ---------- + to : SlidingWindowFeature + Features to align with. + + Returns + ------- + aligned : SlidingWindowFeature + Aligned features + """ + + old_start = self.sliding_window.start + old_step = self.sliding_window.step + old_duration = self.sliding_window.duration + old_samples = len(self) + old_t = old_start + 0.5 * old_duration + np.arange(old_samples) * old_step + + new_start = to.sliding_window.start + new_step = to.sliding_window.step + new_duration = to.sliding_window.duration + new_samples = len(to) + new_t = new_start + 0.5 * new_duration + np.arange(new_samples) * new_step + + new_data = np.hstack( + [ + np.interp(new_t, old_t, old_data)[:, np.newaxis] + for old_data in self.data.T + ] + ) + return SlidingWindowFeature(new_data, to.sliding_window, labels=self.labels) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/ailia-models/code/pyannote_audio_utils/core/notebook.py b/ailia-models/code/pyannote_audio_utils/core/notebook.py new file mode 100644 index 0000000000000000000000000000000000000000..96a0a55bdba2e3a3b4f361b8618534daa6038f18 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/notebook.py @@ -0,0 +1,468 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +""" +############# +Visualization +############# + +:class:`pyannote.core.Segment`, :class:`pyannote.core.Timeline`, +:class:`pyannote.core.Annotation` and :class:`pyannote.core.SlidingWindowFeature` +instances can be directly visualized in notebooks. + +You will however need to install ``pytannote.core``'s additional dependencies +for notebook representations (namely, matplotlib): + + +.. code-block:: bash + + pip install pyannote.core[notebook] + + +Segments +-------- + +.. code-block:: ipython + + In [1]: from pyannote.core import Segment + + In [2]: segment = Segment(start=5, end=15) + ....: segment + +.. plot:: pyplots/segment.py + + +Timelines +--------- + +.. code-block:: ipython + + In [25]: from pyannote.core import Timeline, Segment + + In [26]: timeline = Timeline() + ....: timeline.add(Segment(1, 5)) + ....: timeline.add(Segment(6, 8)) + ....: timeline.add(Segment(12, 18)) + ....: timeline.add(Segment(7, 20)) + ....: timeline + +.. plot:: pyplots/timeline.py + + +Annotations +----------- + + +.. code-block:: ipython + + In [1]: from pyannote.core import Annotation, Segment + + In [6]: annotation = Annotation() + ...: annotation[Segment(1, 5)] = 'Carol' + ...: annotation[Segment(6, 8)] = 'Bob' + ...: annotation[Segment(12, 18)] = 'Carol' + ...: annotation[Segment(7, 20)] = 'Alice' + ...: annotation + +.. plot:: pyplots/annotation.py + +""" +from typing import Iterable, Dict, Optional + +from .utils.types import Label, LabelStyle, Resource + +# try: + # from IPython.core.pylabtools import print_figure +# except Exception as e: + # pass +import numpy as np +from itertools import cycle, product, groupby +from .segment import Segment, SlidingWindow +from .timeline import Timeline +from .annotation import Annotation +from .feature import SlidingWindowFeature + +try: + import matplotlib +except ImportError: + MATPLOTLIB_IS_AVAILABLE = False +else: + MATPLOTLIB_IS_AVAILABLE = True + +MATPLOTLIB_WARNING = ( + "Couldn't import matplotlib to render the vizualization " + "for object {klass}. To enable, install the required dependencies " + "with 'pip install pyannore.core[notebook]'" +) + + +class Notebook: + def __init__(self): + self.reset() + + def reset(self): + from matplotlib.cm import get_cmap + + linewidth = [3, 1] + linestyle = ["solid", "dashed", "dotted"] + + cm = get_cmap("Set1") + colors = [cm(1.0 * i / 8) for i in range(9)] + + self._style_generator = cycle(product(linestyle, linewidth, colors)) + self._style: Dict[Optional[Label], LabelStyle] = { + None: ("solid", 1, (0.0, 0.0, 0.0)) + } + del self.crop + del self.width + + @property + def crop(self): + """The crop property.""" + return self._crop + + @crop.setter + def crop(self, segment: Segment): + self._crop = segment + + @crop.deleter + def crop(self): + self._crop = None + + @property + def width(self): + """The width property""" + return self._width + + @width.setter + def width(self, value: int): + self._width = value + + @width.deleter + def width(self): + self._width = 20 + + def __getitem__(self, label: Label) -> LabelStyle: + """Get line style for a given label""" + if label not in self._style: + self._style[label] = next(self._style_generator) + return self._style[label] + + def setup(self, ax=None, ylim=(0, 1), yaxis=False, time=True): + import matplotlib.pyplot as plt + + if ax is None: + ax = plt.gca() + ax.set_xlim(self.crop) + if time: + ax.set_xlabel("Time") + else: + ax.set_xticklabels([]) + ax.set_ylim(ylim) + ax.axes.get_yaxis().set_visible(yaxis) + return ax + + def draw_segment(self, ax, segment: Segment, y, label=None, boundaries=True): + + # do nothing if segment is empty + if not segment: + return + + linestyle, linewidth, color = self[label] + + # draw segment + ax.hlines( + y, + segment.start, + segment.end, + color, + linewidth=linewidth, + linestyle=linestyle, + label=label, + ) + if boundaries: + ax.vlines( + segment.start, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid" + ) + ax.vlines( + segment.end, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid" + ) + + if label is None: + return + + def get_y(self, segments: Iterable[Segment]) -> np.ndarray: + """ + + Parameters + ---------- + segments : Iterable + `Segment` iterable (sorted) + + Returns + ------- + y : np.array + y coordinates of each segment + + """ + + # up_to stores the largest end time + # displayed in each line (at the current iteration) + # (at the beginning, there is only one empty line) + up_to = [-np.inf] + + # y[k] indicates on which line to display kth segment + y = [] + + for segment in segments: + # so far, we do not know which line to use + found = False + # try each line until we find one that is ok + for i, u in enumerate(up_to): + # if segment starts after the previous one + # on the same line, then we add it to the line + if segment.start >= u: + found = True + y.append(i) + up_to[i] = segment.end + break + # in case we went out of lines, create a new one + if not found: + y.append(len(up_to)) + up_to.append(segment.end) + + # from line numbers to actual y coordinates + y = 1.0 - 1.0 / (len(up_to) + 1) * (1 + np.array(y)) + + return y + + def __call__(self, resource: Resource, time: bool = True, legend: bool = True): + + if isinstance(resource, Segment): + self.plot_segment(resource, time=time) + + elif isinstance(resource, Timeline): + self.plot_timeline(resource, time=time) + + elif isinstance(resource, Annotation): + self.plot_annotation(resource, time=time, legend=legend) + + elif isinstance(resource, SlidingWindowFeature): + self.plot_feature(resource, time=time) + + def plot_segment(self, segment, ax=None, time=True): + + if not self.crop: + self.crop = segment + + ax = self.setup(ax=ax, time=time) + self.draw_segment(ax, segment, 0.5) + + def plot_timeline(self, timeline: Timeline, ax=None, time=True): + + if not self.crop and timeline: + self.crop = timeline.extent() + + cropped = timeline.crop(self.crop, mode="loose") + + ax = self.setup(ax=ax, time=time) + + for segment, y in zip(cropped, self.get_y(cropped)): + self.draw_segment(ax, segment, y) + + # ax.set_aspect(3. / self.crop.duration) + + def plot_annotation(self, annotation: Annotation, ax=None, time=True, legend=True): + + if not self.crop: + self.crop = annotation.get_timeline(copy=False).extent() + + cropped = annotation.crop(self.crop, mode="intersection") + labels = cropped.labels() + segments = [s for s, _ in cropped.itertracks()] + + ax = self.setup(ax=ax, time=time) + + for (segment, track, label), y in zip( + cropped.itertracks(yield_label=True), self.get_y(segments) + ): + self.draw_segment(ax, segment, y, label=label) + + if legend: + H, L = ax.get_legend_handles_labels() + + # corner case when no segment is visible + if not H: + return + + # this gets exactly one legend handle and one legend label per label + # (avoids repeated legends for repeated tracks with same label) + HL = groupby( + sorted(zip(H, L), key=lambda h_l: h_l[1]), key=lambda h_l: h_l[1] + ) + H, L = zip(*list((next(h_l)[0], l) for l, h_l in HL)) + ax.legend( + H, + L, + bbox_to_anchor=(0, 1), + loc=3, + ncol=5, + borderaxespad=0.0, + frameon=False, + ) + + def plot_feature( + self, feature: SlidingWindowFeature, ax=None, time=True, ylim=None + ): + + if not self.crop: + self.crop = feature.getExtent() + + window = feature.sliding_window + n, dimension = feature.data.shape + ((start, stop),) = window.crop(self.crop, mode="loose", return_ranges=True) + xlim = (window[start].middle, window[stop].middle) + + start = max(0, start) + stop = min(stop, n) + t = window[0].middle + window.step * np.arange(start, stop) + data = feature[start:stop] + + if ylim is None: + m = np.nanmin(data) + M = np.nanmax(data) + ylim = (m - 0.1 * (M - m), M + 0.1 * (M - m)) + + ax = self.setup(ax=ax, yaxis=False, ylim=ylim, time=time) + ax.plot(t, data) + ax.set_xlim(xlim) + + +notebook = Notebook() + +def repr_segment(segment: Segment): + """Get `png` data for `segment`""" + import matplotlib.pyplot as plt + + figsize = plt.rcParams["figure.figsize"] + plt.rcParams["figure.figsize"] = (notebook.width, 1) + fig, ax = plt.subplots() + notebook.plot_segment(segment, ax=ax) + # data = print_figure(fig, "png") + plt.savefig('./output') + plt.close(fig) + plt.rcParams["figure.figsize"] = figsize + return + + +def repr_timeline(timeline: Timeline): + """Get `png` data for `timeline`""" + import matplotlib.pyplot as plt + breakpoint() + figsize = plt.rcParams["figure.figsize"] + plt.rcParams["figure.figsize"] = (notebook.width, 1) + fig, ax = plt.subplots() + notebook.plot_timeline(timeline, ax=ax) + # data = print_figure(fig, "png") + plt.savefig('./output') + plt.cla(fig) + plt.rcParams["figure.figsize"] = figsize + return + + +def repr_annotation(annotation: Annotation): + """Get `png` data for `annotation`""" + import matplotlib.pyplot as plt + + figsize = plt.rcParams["figure.figsize"] + plt.rcParams["figure.figsize"] = (notebook.width, 2) + fig, ax = plt.subplots() + notebook.plot_annotation(annotation, ax=ax) + # data = print_figure(fig, "png") + plt.savefig('./output') + plt.close(fig) + plt.rcParams["figure.figsize"] = figsize + return + + +def repr_feature(feature: SlidingWindowFeature): + """Get `png` data for `feature`""" + import matplotlib.pyplot as plt + + figsize = plt.rcParams["figure.figsize"] + + if feature.data.ndim == 2: + + plt.rcParams["figure.figsize"] = (notebook.width, 2) + fig, ax = plt.subplots() + notebook.plot_feature(feature, ax=ax) + # data = print_figure(fig, "png") + plt.savefig('./output') + plt.close(fig) + + elif feature.data.ndim == 3: + + num_chunks = len(feature) + + if notebook.crop is None: + notebook.crop = Segment( + start=feature.sliding_window.start, + end=feature.sliding_window[num_chunks - 1].end, + ) + else: + feature = feature.crop(notebook.crop, mode="loose", return_data=False) + + num_overlap = ( + round(feature.sliding_window.duration // feature.sliding_window.step) + 1 + ) + + num_overlap = min(num_chunks, num_overlap) + + plt.rcParams["figure.figsize"] = (notebook.width, 1.5 * num_overlap) + + fig, axes = plt.subplots(nrows=num_overlap, ncols=1,) + mini, maxi = np.nanmin(feature.data), np.nanmax(feature.data) + ylim = (mini - 0.2 * (maxi - mini), maxi + 0.2 * (maxi - mini)) + for c, (window, data) in enumerate(feature): + ax = axes[c % num_overlap] + step = duration = window.duration / len(data) + frames = SlidingWindow(start=window.start, step=step, duration=duration) + window_feature = SlidingWindowFeature(data, frames, labels=feature.labels) + notebook.plot_feature( + window_feature, + ax=ax, + time=c % num_overlap == (num_overlap - 1), + ylim=ylim, + ) + ax.set_prop_cycle(None) + # data = print_figure(fig, "png") + plt.savefig('./output') + plt.close(fig) + + plt.rcParams["figure.figsize"] = figsize + return diff --git a/ailia-models/code/pyannote_audio_utils/core/segment.py b/ailia-models/code/pyannote_audio_utils/core/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb5fdf98b1fa1fb9d08e137a5b59202212059e5 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/segment.py @@ -0,0 +1,910 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2021 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +""" +####### +Segment +####### + +.. plot:: pyplots/segment.py + +:class:`pyannote.core.Segment` instances describe temporal fragments (*e.g.* of an audio file). The segment depicted above can be defined like that: + +.. code-block:: ipython + + In [1]: from pyannote.core import Segment + + In [2]: segment = Segment(start=5, end=15) + + In [3]: print(segment) + +It is nothing more than 2-tuples augmented with several useful methods and properties: + +.. code-block:: ipython + + In [4]: start, end = segment + + In [5]: start + + In [6]: segment.end + + In [7]: segment.duration # duration (read-only) + + In [8]: segment.middle # middle (read-only) + + In [9]: segment & Segment(3, 12) # intersection + + In [10]: segment | Segment(3, 12) # union + + In [11]: segment.overlaps(3) # does segment overlap time t=3? + + +Use `Segment.set_precision(ndigits)` to automatically round start and end timestamps to `ndigits` precision after the decimal point. +To ensure consistency between `Segment` instances, it is recommended to call this method only once, right after importing `pyannote.core.Segment`. + +.. code-block:: ipython + + In [12]: Segment(1/1000, 330/1000) == Segment(1/1000, 90/1000+240/1000) + Out[12]: False + + In [13]: Segment.set_precision(ndigits=4) + + In [14]: Segment(1/1000, 330/1000) == Segment(1/1000, 90/1000+240/1000) + Out[14]: True + +See :class:`pyannote.core.Segment` for the complete reference. +""" + +import warnings +from typing import Union, Optional, Tuple, List, Iterator, Iterable + +from .utils.types import Alignment + +import numpy as np +from dataclasses import dataclass + + +# setting 'frozen' to True makes it hashable and immutable +@dataclass(frozen=True, order=True) +class Segment: + """ + Time interval + + Parameters + ---------- + start : float + interval start time, in seconds. + end : float + interval end time, in seconds. + + + Segments can be compared and sorted using the standard operators: + + >>> Segment(0, 1) == Segment(0, 1.) + True + >>> Segment(0, 1) != Segment(3, 4) + True + >>> Segment(0, 1) < Segment(2, 3) + True + >>> Segment(0, 1) < Segment(0, 2) + True + >>> Segment(1, 2) < Segment(0, 3) + False + + Note + ---- + A segment is smaller than another segment if one of these two conditions is verified: + + - `segment.start < other_segment.start` + - `segment.start == other_segment.start` and `segment.end < other_segment.end` + + """ + start: float = 0.0 + end: float = 0.0 + + @staticmethod + def set_precision(ndigits: Optional[int] = None): + """Automatically round start and end timestamps to `ndigits` precision after the decimal point + + To ensure consistency between `Segment` instances, it is recommended to call this method only + once, right after importing `pyannote.core.Segment`. + + Usage + ----- + >>> from pyannote.core import Segment + >>> Segment.set_precision(2) + >>> Segment(1/3, 2/3) + + """ + global AUTO_ROUND_TIME + global SEGMENT_PRECISION + + if ndigits is None: + # backward compatibility + AUTO_ROUND_TIME = False + # 1 μs (one microsecond) + SEGMENT_PRECISION = 1e-6 + else: + AUTO_ROUND_TIME = True + SEGMENT_PRECISION = 10 ** (-ndigits) + + def __bool__(self): + """Emptiness + + >>> if segment: + ... # segment is not empty. + ... else: + ... # segment is empty. + + Note + ---- + A segment is considered empty if its end time is smaller than its + start time, or its duration is smaller than 1μs. + """ + return bool((self.end - self.start) > SEGMENT_PRECISION) + + def __post_init__(self): + """Round start and end up to SEGMENT_PRECISION precision (when required)""" + if AUTO_ROUND_TIME: + object.__setattr__(self, 'start', int(self.start / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION) + object.__setattr__(self, 'end', int(self.end / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION) + + @property + def duration(self) -> float: + """Segment duration (read-only)""" + return self.end - self.start if self else 0. + + @property + def middle(self) -> float: + """Segment mid-time (read-only)""" + return .5 * (self.start + self.end) + + def __iter__(self) -> Iterator[float]: + """Unpack segment boundaries + >>> segment = Segment(start, end) + >>> start, end = segment + """ + yield self.start + yield self.end + + def copy(self) -> 'Segment': + """Get a copy of the segment + + Returns + ------- + copy : Segment + Copy of the segment. + """ + return Segment(start=self.start, end=self.end) + + # ------------------------------------------------------- # + # Inclusion (in), intersection (&), union (|) and gap (^) # + # ------------------------------------------------------- # + + def __contains__(self, other: 'Segment'): + """Inclusion + + >>> segment = Segment(start=0, end=10) + >>> Segment(start=3, end=10) in segment: + True + >>> Segment(start=5, end=15) in segment: + False + """ + return (self.start <= other.start) and (self.end >= other.end) + + def __and__(self, other): + """Intersection + + >>> segment = Segment(0, 10) + >>> other_segment = Segment(5, 15) + >>> segment & other_segment + + + Note + ---- + When the intersection is empty, an empty segment is returned: + + >>> segment = Segment(0, 10) + >>> other_segment = Segment(15, 20) + >>> intersection = segment & other_segment + >>> if not intersection: + ... # intersection is empty. + """ + start = max(self.start, other.start) + end = min(self.end, other.end) + return Segment(start=start, end=end) + + def intersects(self, other: 'Segment') -> bool: + """Check whether two segments intersect each other + + Parameters + ---------- + other : Segment + Other segment + + Returns + ------- + intersect : bool + True if segments intersect, False otherwise + """ + + return (self.start < other.start and + other.start < self.end - SEGMENT_PRECISION) or \ + (self.start > other.start and + self.start < other.end - SEGMENT_PRECISION) or \ + (self.start == other.start) + + def overlaps(self, t: float) -> bool: + """Check if segment overlaps a given time + + Parameters + ---------- + t : float + Time, in seconds. + + Returns + ------- + overlap: bool + True if segment overlaps time t, False otherwise. + """ + return self.start <= t and self.end >= t + + def __or__(self, other: 'Segment') -> 'Segment': + """Union + + >>> segment = Segment(0, 10) + >>> other_segment = Segment(5, 15) + >>> segment | other_segment + + + Note + ---- + When a gap exists between the segment, their union covers the gap as well: + + >>> segment = Segment(0, 10) + >>> other_segment = Segment(15, 20) + >>> segment | other_segment + 'Segment': + """Gap + + >>> segment = Segment(0, 10) + >>> other_segment = Segment(15, 20) + >>> segment ^ other_segment + >> segment = Segment(0, 10) + >>> empty_segment = Segment(11, 11) + >>> segment ^ empty_segment + ValueError: The gap between a segment and an empty segment is not defined. + """ + + # if segment is empty, xor is not defined + if (not self) or (not other): + raise ValueError( + 'The gap between a segment and an empty segment ' + 'is not defined.') + + start = min(self.end, other.end) + end = max(self.start, other.start) + return Segment(start=start, end=end) + + def _str_helper(self, seconds: float) -> str: + from datetime import timedelta + negative = seconds < 0 + seconds = abs(seconds) + td = timedelta(seconds=seconds) + seconds = td.seconds + 86400 * td.days + microseconds = td.microseconds + hours, remainder = divmod(seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return '%s%02d:%02d:%02d.%03d' % ( + '-' if negative else ' ', hours, minutes, + seconds, microseconds / 1000) + + def __str__(self): + """Human-readable representation + + >>> print(Segment(1337, 1337 + 0.42)) + [ 00:22:17.000 --> 00:22:17.420] + + Note + ---- + Empty segments are printed as "[]" + """ + if self: + return '[%s --> %s]' % (self._str_helper(self.start), + self._str_helper(self.end)) + return '[]' + + def __repr__(self): + """Computer-readable representation + + >>> Segment(1337, 1337 + 0.42) + + """ + return '' % (self.start, self.end) + + def _repr_png_(self): + """IPython notebook support + + See also + -------- + :mod:`pyannote.core.notebook` + """ + from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING + if not MATPLOTLIB_IS_AVAILABLE: + warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__)) + return None + + from .notebook import repr_segment + try: + return repr_segment(self) + except ImportError: + warnings.warn( + f"Couldn't import matplotlib to render the vizualization for object {self}. To enable, install the required dependencies with 'pip install pyannore.core[notebook]'") + return None + + +class SlidingWindow: + """Sliding window + + Parameters + ---------- + duration : float > 0, optional + Window duration, in seconds. Default is 30 ms. + step : float > 0, optional + Step between two consecutive position, in seconds. Default is 10 ms. + start : float, optional + First start position of window, in seconds. Default is 0. + end : float > `start`, optional + Default is infinity (ie. window keeps sliding forever) + + Examples + -------- + + >>> sw = SlidingWindow(duration, step, start) + >>> frame_range = (a, b) + >>> frame_range == sw.toFrameRange(sw.toSegment(*frame_range)) + ... True + + >>> segment = Segment(A, B) + >>> new_segment = sw.toSegment(*sw.toFrameRange(segment)) + >>> abs(segment) - abs(segment & new_segment) < .5 * sw.step + + >>> sw = SlidingWindow(end=0.1) + >>> print(next(sw)) + [ 00:00:00.000 --> 00:00:00.030] + >>> print(next(sw)) + [ 00:00:00.010 --> 00:00:00.040] + """ + + def __init__(self, duration=0.030, step=0.010, start=0.000, end=None): + + # duration must be a float > 0 + if duration <= 0: + raise ValueError("'duration' must be a float > 0.") + self.__duration = duration + + # step must be a float > 0 + if step <= 0: + raise ValueError("'step' must be a float > 0.") + self.__step: float = step + + # start must be a float. + self.__start: float = start + + # if end is not provided, set it to infinity + if end is None: + self.__end: float = np.inf + else: + # end must be greater than start + if end <= start: + raise ValueError("'end' must be greater than 'start'.") + self.__end: float = end + + # current index of iterator + self.__i: int = -1 + + @property + def start(self) -> float: + """Sliding window start time in seconds.""" + return self.__start + + @property + def end(self) -> float: + """Sliding window end time in seconds.""" + return self.__end + + @property + def step(self) -> float: + """Sliding window step in seconds.""" + return self.__step + + @property + def duration(self) -> float: + """Sliding window duration in seconds.""" + return self.__duration + + def closest_frame(self, t: float) -> int: + """Closest frame to timestamp. + + Parameters + ---------- + t : float + Timestamp, in seconds. + + Returns + ------- + index : int + Index of frame whose middle is the closest to `timestamp` + + """ + return int(np.rint( + (t - self.__start - .5 * self.__duration) / self.__step + )) + + def samples(self, from_duration: float, mode: Alignment = 'strict') -> int: + """Number of frames + + Parameters + ---------- + from_duration : float + Duration in seconds. + mode : {'strict', 'loose', 'center'} + In 'strict' mode, computes the maximum number of consecutive frames + that can be fitted into a segment with duration `from_duration`. + In 'loose' mode, computes the maximum number of consecutive frames + intersecting a segment with duration `from_duration`. + In 'center' mode, computes the average number of consecutive frames + where the first one is centered on the start time and the last one + is centered on the end time of a segment with duration + `from_duration`. + + """ + if mode == 'strict': + return int(np.floor((from_duration - self.duration) / self.step)) + 1 + + elif mode == 'loose': + return int(np.floor((from_duration + self.duration) / self.step)) + + elif mode == 'center': + return int(np.rint((from_duration / self.step))) + + def crop(self, focus: Union[Segment, 'Timeline'], + mode: Alignment = 'loose', + fixed: Optional[float] = None, + return_ranges: Optional[bool] = False) -> \ + Union[np.ndarray, List[List[int]]]: + """Crop sliding window + + Parameters + ---------- + focus : `Segment` or `Timeline` + mode : {'strict', 'loose', 'center'}, optional + In 'strict' mode, only indices of segments fully included in + 'focus' support are returned. In 'loose' mode, indices of any + intersecting segments are returned. In 'center' mode, first and + last positions are chosen to be the positions whose centers are the + closest to 'focus' start and end times. Defaults to 'loose'. + fixed : float, optional + Overrides `Segment` 'focus' duration and ensures that the number of + returned frames is fixed (which might otherwise not be the case + because of rounding erros). + return_ranges : bool, optional + Return as list of ranges. Defaults to indices numpy array. + + Returns + ------- + indices : np.array (or list of ranges) + Array of unique indices of matching segments + """ + + from .timeline import Timeline + + if not isinstance(focus, (Segment, Timeline)): + msg = '"focus" must be a `Segment` or `Timeline` instance.' + raise TypeError(msg) + + if isinstance(focus, Timeline): + + if fixed is not None: + msg = "'fixed' is not supported with `Timeline` 'focus'." + raise ValueError(msg) + + if return_ranges: + ranges = [] + + for i, s in enumerate(focus.support()): + rng = self.crop(s, mode=mode, fixed=fixed, + return_ranges=True) + + # if first or disjoint segment, add it + if i == 0 or rng[0][0] > ranges[-1][1]: + ranges += rng + + # if overlapping segment, update last range + else: + ranges[-1][1] = rng[0][1] + + return ranges + + # concatenate all indices + indices = np.hstack([ + self.crop(s, mode=mode, fixed=fixed, return_ranges=False) + for s in focus.support()]) + + # remove duplicate indices + return np.unique(indices) + + # 'focus' is a `Segment` instance + + if mode == 'loose': + + # find smallest integer i such that + # self.start + i x self.step + self.duration >= focus.start + i_ = (focus.start - self.duration - self.start) / self.step + i = int(np.ceil(i_)) + + if fixed is None: + # find largest integer j such that + # self.start + j x self.step <= focus.end + j_ = (focus.end - self.start) / self.step + j = int(np.floor(j_)) + rng = (i, j + 1) + + else: + n = self.samples(fixed, mode='loose') + rng = (i, i + n) + + elif mode == 'strict': + + # find smallest integer i such that + # self.start + i x self.step >= focus.start + i_ = (focus.start - self.start) / self.step + i = int(np.ceil(i_)) + + if fixed is None: + + # find largest integer j such that + # self.start + j x self.step + self.duration <= focus.end + j_ = (focus.end - self.duration - self.start) / self.step + j = int(np.floor(j_)) + rng = (i, j + 1) + + else: + n = self.samples(fixed, mode='strict') + rng = (i, i + n) + + elif mode == 'center': + + # find window position whose center is the closest to focus.start + i = self.closest_frame(focus.start) + + if fixed is None: + # find window position whose center is the closest to focus.end + j = self.closest_frame(focus.end) + rng = (i, j + 1) + else: + n = self.samples(fixed, mode='center') + rng = (i, i + n) + + else: + msg = "'mode' must be one of {'loose', 'strict', 'center'}." + raise ValueError(msg) + + if return_ranges: + return [list(rng)] + + return np.array(range(*rng), dtype=np.int64) + + def segmentToRange(self, segment: Segment) -> Tuple[int, int]: + warnings.warn("Deprecated in favor of `segment_to_range`", + DeprecationWarning) + return self.segment_to_range(segment) + + def segment_to_range(self, segment: Segment) -> Tuple[int, int]: + """Convert segment to 0-indexed frame range + + Parameters + ---------- + segment : Segment + + Returns + ------- + i0 : int + Index of first frame + n : int + Number of frames + + Examples + -------- + + >>> window = SlidingWindow() + >>> print window.segment_to_range(Segment(10, 15)) + i0, n + + """ + # find closest frame to segment start + i0 = self.closest_frame(segment.start) + + # number of steps to cover segment duration + n = int(segment.duration / self.step) + 1 + + return i0, n + + def rangeToSegment(self, i0: int, n: int) -> Segment: + warnings.warn("This is deprecated in favor of `range_to_segment`", + DeprecationWarning) + return self.range_to_segment(i0, n) + + def range_to_segment(self, i0: int, n: int) -> Segment: + """Convert 0-indexed frame range to segment + + Each frame represents a unique segment of duration 'step', centered on + the middle of the frame. + + The very first frame (i0 = 0) is the exception. It is extended to the + sliding window start time. + + Parameters + ---------- + i0 : int + Index of first frame + n : int + Number of frames + + Returns + ------- + segment : Segment + + Examples + -------- + + >>> window = SlidingWindow() + >>> print window.range_to_segment(3, 2) + [ --> ] + + """ + + # frame start time + # start = self.start + i0 * self.step + # frame middle time + # start += .5 * self.duration + # subframe start time + # start -= .5 * self.step + start = self.__start + (i0 - .5) * self.__step + .5 * self.__duration + duration = n * self.__step + end = start + duration + + # extend segment to the beginning of the timeline + if i0 == 0: + start = self.start + + return Segment(start, end) + + def samplesToDuration(self, nSamples: int) -> float: + warnings.warn("This is deprecated in favor of `samples_to_duration`", + DeprecationWarning) + return self.samples_to_duration(nSamples) + + def samples_to_duration(self, n_samples: int) -> float: + """Returns duration of samples""" + return self.range_to_segment(0, n_samples).duration + + def durationToSamples(self, duration: float) -> int: + warnings.warn("This is deprecated in favor of `duration_to_samples`", + DeprecationWarning) + return self.duration_to_samples(duration) + + def duration_to_samples(self, duration: float) -> int: + """Returns samples in duration""" + return self.segment_to_range(Segment(0, duration))[1] + + def __getitem__(self, i: int) -> Segment: + """ + Parameters + ---------- + i : int + Index of sliding window position + + Returns + ------- + segment : :class:`Segment` + Sliding window at ith position + + """ + + # window start time at ith position + start = self.__start + i * self.__step + + # in case segment starts after the end, + # return an empty segment + if start >= self.__end: + return None + + return Segment(start=start, end=start + self.__duration) + + def next(self) -> Segment: + return self.__next__() + + def __next__(self) -> Segment: + self.__i += 1 + window = self[self.__i] + + if window: + return window + else: + raise StopIteration() + + def __iter__(self) -> 'SlidingWindow': + """Sliding window iterator + + Use expression 'for segment in sliding_window' + + Examples + -------- + + >>> window = SlidingWindow(end=0.1) + >>> for segment in window: + ... print(segment) + [ 00:00:00.000 --> 00:00:00.030] + [ 00:00:00.010 --> 00:00:00.040] + [ 00:00:00.020 --> 00:00:00.050] + [ 00:00:00.030 --> 00:00:00.060] + [ 00:00:00.040 --> 00:00:00.070] + [ 00:00:00.050 --> 00:00:00.080] + [ 00:00:00.060 --> 00:00:00.090] + [ 00:00:00.070 --> 00:00:00.100] + [ 00:00:00.080 --> 00:00:00.110] + [ 00:00:00.090 --> 00:00:00.120] + """ + + # reset iterator index + self.__i = -1 + return self + + def __len__(self) -> int: + """Number of positions + + Equivalent to len([segment for segment in window]) + + Returns + ------- + length : int + Number of positions taken by the sliding window + (from start times to end times) + + """ + if np.isinf(self.__end): + raise ValueError('infinite sliding window.') + + # start looking for last position + # based on frame closest to the end + i = self.closest_frame(self.__end) + + while (self[i]): + i += 1 + length = i + + return length + + def copy(self) -> 'SlidingWindow': + """Duplicate sliding window""" + duration = self.duration + step = self.step + start = self.start + end = self.end + sliding_window = self.__class__( + duration=duration, step=step, start=start, end=end + ) + return sliding_window + + def __call__(self, + support: Union[Segment, 'Timeline'], + align_last: bool = False) -> Iterable[Segment]: + """Slide window over support + + Parameter + --------- + support : Segment or Timeline + Support on which to slide the window. + align_last : bool, optional + Yield a final segment so that it aligns exactly with end of support. + + Yields + ------ + chunk : Segment + + Example + ------- + >>> window = SlidingWindow(duration=2., step=1.) + >>> for chunk in window(Segment(3, 7.5)): + ... print(tuple(chunk)) + (3.0, 5.0) + (4.0, 6.0) + (5.0, 7.0) + >>> for chunk in window(Segment(3, 7.5), align_last=True): + ... print(tuple(chunk)) + (3.0, 5.0) + (4.0, 6.0) + (5.0, 7.0) + (5.5, 7.5) + """ + + from pyannote.core import Timeline + if isinstance(support, Timeline): + segments = support + + elif isinstance(support, Segment): + segments = Timeline(segments=[support]) + + else: + msg = ( + f'"support" must be either a Segment or a Timeline ' + f'instance (is {type(support)})' + ) + raise TypeError(msg) + + for segment in segments: + + if segment.duration < self.duration: + continue + + window = SlidingWindow(duration=self.duration, + step=self.step, + start=segment.start, + end=segment.end) + + for s in window: + # ugly hack to account for floating point imprecision + if s in segment: + yield s + last = s + + if align_last and last.end < segment.end: + yield Segment(start=segment.end - self.duration, + end=segment.end) diff --git a/ailia-models/code/pyannote_audio_utils/core/timeline.py b/ailia-models/code/pyannote_audio_utils/core/timeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ed566c6d090c43368a8a91381674532e340f36bf --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/timeline.py @@ -0,0 +1,1126 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +# Grant JENKS - http://www.grantjenks.com/ +# Paul LERNER + +""" +######## +Timeline +######## + +.. plot:: pyplots/timeline.py + +:class:`pyannote.core.Timeline` instances are ordered sets of non-empty +segments: + + - ordered, because segments are sorted by start time (and end time in case of tie) + - set, because one cannot add twice the same segment + - non-empty, because one cannot add empty segments (*i.e.* start >= end) + +There are two ways to define the timeline depicted above: + +.. code-block:: ipython + + In [25]: from pyannote.core import Timeline, Segment + + In [26]: timeline = Timeline() + ....: timeline.add(Segment(1, 5)) + ....: timeline.add(Segment(6, 8)) + ....: timeline.add(Segment(12, 18)) + ....: timeline.add(Segment(7, 20)) + ....: + + In [27]: segments = [Segment(1, 5), Segment(6, 8), Segment(12, 18), Segment(7, 20)] + ....: timeline = Timeline(segments=segments, uri='my_audio_file') # faster + ....: + + In [9]: for segment in timeline: + ...: print(segment) + ...: + [ 00:00:01.000 --> 00:00:05.000] + [ 00:00:06.000 --> 00:00:08.000] + [ 00:00:07.000 --> 00:00:20.000] + [ 00:00:12.000 --> 00:00:18.000] + + +.. note:: + + The optional *uri* keyword argument can be used to remember which document it describes. + +Several convenient methods are available. Here are a few examples: + +.. code-block:: ipython + + In [3]: timeline.extent() # extent + Out[3]: + + In [5]: timeline.support() # support + Out[5]: , ])> + + In [6]: timeline.duration() # support duration + Out[6]: 18 + + +See :class:`pyannote.core.Timeline` for the complete reference. +""" +import warnings +from typing import (Optional, Iterable, List, Union, Callable, + TextIO, Tuple, TYPE_CHECKING, Iterator, Dict, Text) + +from sortedcontainers import SortedList + +from . import PYANNOTE_SEGMENT +from .segment import Segment +from .utils.types import Support, Label, CropMode + + +# this is a moderately ugly way to import `Annotation` to the namespace +# without causing some circular imports : +# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports +if TYPE_CHECKING: + from .annotation import Annotation + import pandas as pd + + +# ===================================================================== +# Timeline class +# ===================================================================== + + +class Timeline: + """ + Ordered set of segments. + + A timeline can be seen as an ordered set of non-empty segments (Segment). + Segments can overlap -- though adding an already exisiting segment to a + timeline does nothing. + + Parameters + ---------- + segments : Segment iterator, optional + initial set of (non-empty) segments + uri : string, optional + name of segmented resource + + Returns + ------- + timeline : Timeline + New timeline + """ + + @classmethod + def from_df(cls, df: 'pd.DataFrame', uri: Optional[str] = None) -> 'Timeline': + segments = list(df[PYANNOTE_SEGMENT]) + timeline = cls(segments=segments, uri=uri) + return timeline + + def __init__(self, + segments: Optional[Iterable[Segment]] = None, + uri: str = None): + if segments is None: + segments = () + + # set of segments (used for checking inclusion) + # Store only non-empty Segments. + segments_set = set([segment for segment in segments if segment]) + + self.segments_set_ = segments_set + + # sorted list of segments (used for sorted iteration) + self.segments_list_ = SortedList(segments_set) + + # sorted list of (possibly redundant) segment boundaries + boundaries = (boundary for segment in segments_set for boundary in segment) + self.segments_boundaries_ = SortedList(boundaries) + + # path to (or any identifier of) segmented resource + self.uri: str = uri + + def __len__(self): + """Number of segments + + >>> len(timeline) # timeline contains three segments + 3 + """ + return len(self.segments_set_) + + def __nonzero__(self): + return self.__bool__() + + def __bool__(self): + """Emptiness + + >>> if timeline: + ... # timeline is not empty + ... else: + ... # timeline is empty + """ + return len(self.segments_set_) > 0 + + def __iter__(self) -> Iterable[Segment]: + """Iterate over segments (in chronological order) + + >>> for segment in timeline: + ... # do something with the segment + + See also + -------- + :class:`pyannote.core.Segment` describes how segments are sorted. + """ + return iter(self.segments_list_) + + def __getitem__(self, k: int) -> Segment: + """Get segment by index (in chronological order) + + >>> first_segment = timeline[0] + >>> penultimate_segment = timeline[-2] + """ + return self.segments_list_[k] + + def __eq__(self, other: 'Timeline'): + """Equality + + Two timelines are equal if and only if their segments are equal. + + >>> timeline1 = Timeline([Segment(0, 1), Segment(2, 3)]) + >>> timeline2 = Timeline([Segment(2, 3), Segment(0, 1)]) + >>> timeline3 = Timeline([Segment(2, 3)]) + >>> timeline1 == timeline2 + True + >>> timeline1 == timeline3 + False + """ + return self.segments_set_ == other.segments_set_ + + def __ne__(self, other: 'Timeline'): + """Inequality""" + return self.segments_set_ != other.segments_set_ + + def index(self, segment: Segment) -> int: + """Get index of (existing) segment + + Parameters + ---------- + segment : Segment + Segment that is being looked for. + + Returns + ------- + position : int + Index of `segment` in timeline + + Raises + ------ + ValueError if `segment` is not present. + """ + return self.segments_list_.index(segment) + + def add(self, segment: Segment) -> 'Timeline': + """Add a segment (in place) + + Parameters + ---------- + segment : Segment + Segment that is being added + + Returns + ------- + self : Timeline + Updated timeline. + + Note + ---- + If the timeline already contains this segment, it will not be added + again, as a timeline is meant to be a **set** of segments (not a list). + + If the segment is empty, it will not be added either, as a timeline + only contains non-empty segments. + """ + + segments_set_ = self.segments_set_ + if segment in segments_set_ or not segment: + return self + + segments_set_.add(segment) + + self.segments_list_.add(segment) + + segments_boundaries_ = self.segments_boundaries_ + segments_boundaries_.add(segment.start) + segments_boundaries_.add(segment.end) + + return self + + def remove(self, segment: Segment) -> 'Timeline': + """Remove a segment (in place) + + Parameters + ---------- + segment : Segment + Segment that is being removed + + Returns + ------- + self : Timeline + Updated timeline. + + Note + ---- + If the timeline does not contain this segment, this does nothing + """ + + segments_set_ = self.segments_set_ + if segment not in segments_set_: + return self + + segments_set_.remove(segment) + + self.segments_list_.remove(segment) + + segments_boundaries_ = self.segments_boundaries_ + segments_boundaries_.remove(segment.start) + segments_boundaries_.remove(segment.end) + + return self + + def discard(self, segment: Segment) -> 'Timeline': + """Same as `remove` + + See also + -------- + :func:`pyannote.core.Timeline.remove` + """ + return self.remove(segment) + + def __ior__(self, timeline: 'Timeline') -> 'Timeline': + return self.update(timeline) + + def update(self, timeline: Segment) -> 'Timeline': + """Add every segments of an existing timeline (in place) + + Parameters + ---------- + timeline : Timeline + Timeline whose segments are being added + + Returns + ------- + self : Timeline + Updated timeline + + Note + ---- + Only segments that do not already exist will be added, as a timeline is + meant to be a **set** of segments (not a list). + + """ + + segments_set = self.segments_set_ + + segments_set |= timeline.segments_set_ + + # sorted list of segments (used for sorted iteration) + self.segments_list_ = SortedList(segments_set) + + # sorted list of (possibly redundant) segment boundaries + boundaries = (boundary for segment in segments_set for boundary in segment) + self.segments_boundaries_ = SortedList(boundaries) + + return self + + def __or__(self, timeline: 'Timeline') -> 'Timeline': + return self.union(timeline) + + def union(self, timeline: 'Timeline') -> 'Timeline': + """Create new timeline made of union of segments + + Parameters + ---------- + timeline : Timeline + Timeline whose segments are being added + + Returns + ------- + union : Timeline + New timeline containing the union of both timelines. + + Note + ---- + This does the same as timeline.update(...) except it returns a new + timeline, and the original one is not modified. + """ + segments = self.segments_set_ | timeline.segments_set_ + return Timeline(segments=segments, uri=self.uri) + + def co_iter(self, other: 'Timeline') -> Iterator[Tuple[Segment, Segment]]: + """Iterate over pairs of intersecting segments + + >>> timeline1 = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)]) + >>> timeline2 = Timeline([Segment(1, 3), Segment(3, 5)]) + >>> for segment1, segment2 in timeline1.co_iter(timeline2): + ... print(segment1, segment2) + (, ) + (, ) + (, ) + + Parameters + ---------- + other : Timeline + Second timeline + + Returns + ------- + iterable : (Segment, Segment) iterable + Yields pairs of intersecting segments in chronological order. + """ + + for segment in self.segments_list_: + + # iterate over segments that starts before 'segment' ends + temp = Segment(start=segment.end, end=segment.end) + for other_segment in other.segments_list_.irange(maximum=temp): + if segment.intersects(other_segment): + yield segment, other_segment + + def crop_iter(self, + support: Support, + mode: CropMode = 'intersection', + returns_mapping: bool = False) \ + -> Iterator[Union[Tuple[Segment, Segment], Segment]]: + """Like `crop` but returns a segment iterator instead + + See also + -------- + :func:`pyannote.core.Timeline.crop` + """ + + if mode not in {'loose', 'strict', 'intersection'}: + raise ValueError("Mode must be one of 'loose', 'strict', or " + "'intersection'.") + + if not isinstance(support, (Segment, Timeline)): + raise TypeError("Support must be a Segment or a Timeline.") + + if isinstance(support, Segment): + # corner case where "support" is empty + if support: + segments = [support] + else: + segments = [] + + support = Timeline(segments=segments, uri=self.uri) + for yielded in self.crop_iter(support, mode=mode, + returns_mapping=returns_mapping): + yield yielded + return + + # if 'support' is a `Timeline`, we use its support + support = support.support() + + # loose mode + if mode == 'loose': + for segment, _ in self.co_iter(support): + yield segment + return + + # strict mode + if mode == 'strict': + for segment, other_segment in self.co_iter(support): + if segment in other_segment: + yield segment + return + + # intersection mode + for segment, other_segment in self.co_iter(support): + mapped_to = segment & other_segment + if not mapped_to: + continue + if returns_mapping: + yield segment, mapped_to + else: + yield mapped_to + + def crop(self, + support: Support, + mode: CropMode = 'intersection', + returns_mapping: bool = False) \ + -> Union['Timeline', Tuple['Timeline', Dict[Segment, Segment]]]: + """Crop timeline to new support + + Parameters + ---------- + support : Segment or Timeline + If `support` is a `Timeline`, its support is used. + mode : {'strict', 'loose', 'intersection'}, optional + Controls how segments that are not fully included in `support` are + handled. 'strict' mode only keeps fully included segments. 'loose' + mode keeps any intersecting segment. 'intersection' mode keeps any + intersecting segment but replace them by their actual intersection. + returns_mapping : bool, optional + In 'intersection' mode, return a dictionary whose keys are segments + of the cropped timeline, and values are list of the original + segments that were cropped. Defaults to False. + + Returns + ------- + cropped : Timeline + Cropped timeline + mapping : dict + When 'returns_mapping' is True, dictionary whose keys are segments + of 'cropped', and values are lists of corresponding original + segments. + + Examples + -------- + + >>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)]) + >>> timeline.crop(Segment(1, 3)) + ])> + + >>> timeline.crop(Segment(1, 3), mode='loose') + , ])> + + >>> timeline.crop(Segment(1, 3), mode='strict') + ])> + + >>> cropped, mapping = timeline.crop(Segment(1, 3), returns_mapping=True) + >>> print(mapping) + {: [, ]} + + """ + + if mode == 'intersection' and returns_mapping: + segments, mapping = [], {} + for segment, mapped_to in self.crop_iter(support, + mode='intersection', + returns_mapping=True): + segments.append(mapped_to) + mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment] + return Timeline(segments=segments, uri=self.uri), mapping + + return Timeline(segments=self.crop_iter(support, mode=mode), + uri=self.uri) + + def overlapping(self, t: float) -> List[Segment]: + """Get list of segments overlapping `t` + + Parameters + ---------- + t : float + Timestamp, in seconds. + + Returns + ------- + segments : list + List of all segments of timeline containing time t + """ + return list(self.overlapping_iter(t)) + + def overlapping_iter(self, t: float) -> Iterator[Segment]: + """Like `overlapping` but returns a segment iterator instead + + See also + -------- + :func:`pyannote.core.Timeline.overlapping` + """ + segment = Segment(start=t, end=t) + for segment in self.segments_list_.irange(maximum=segment): + if segment.overlaps(t): + yield segment + + def get_overlap(self) -> 'Timeline': + """Get overlapping parts of the timeline. + + A simple illustration: + + timeline + |------| |------| |----| + |--| |-----| |----------| + + timeline.get_overlap() + |--| |---| |----| + + + Returns + ------- + overlap : `pyannote.core.Timeline` + Timeline of the overlaps. + """ + overlaps_tl = Timeline(uri=self.uri) + for s1, s2 in self.co_iter(self): + if s1 == s2: + continue + overlaps_tl.add(s1 & s2) + return overlaps_tl.support() + + def extrude(self, + removed: Support, + mode: CropMode = 'intersection') -> 'Timeline': + """Remove segments that overlap `removed` support. + + Parameters + ---------- + removed : Segment or Timeline + If `support` is a `Timeline`, its support is used. + mode : {'strict', 'loose', 'intersection'}, optional + Controls how segments that are not fully included in `removed` are + handled. 'strict' mode only removes fully included segments. 'loose' + mode removes any intersecting segment. 'intersection' mode removes + the overlapping part of any intersecting segment. + + Returns + ------- + extruded : Timeline + Extruded timeline + + Examples + -------- + + >>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 5)]) + >>> timeline.extrude(Segment(1, 2)) + , ])> + + >>> timeline.extrude(Segment(1, 3), mode='loose') + ])> + + >>> timeline.extrude(Segment(1, 3), mode='strict') + , ])> + + """ + if isinstance(removed, Segment): + removed = Timeline([removed]) + + extent_tl = Timeline([self.extent()], uri=self.uri) + truncating_support = removed.gaps(support=extent_tl) + # loose for truncate means strict for crop and vice-versa + if mode == "loose": + mode = "strict" + elif mode == "strict": + mode = "loose" + return self.crop(truncating_support, mode=mode) + + def __str__(self): + """Human-readable representation + + >>> timeline = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)]) + >>> print(timeline) + [[ 00:00:00.000 --> 00:00:10.000] + [ 00:00:01.000 --> 00:00:13.370]] + + """ + + n = len(self.segments_list_) + string = "[" + for i, segment in enumerate(self.segments_list_): + string += str(segment) + string += "\n " if i + 1 < n else "" + string += "]" + return string + + def __repr__(self): + """Computer-readable representation + + >>> Timeline(segments=[Segment(0, 10), Segment(1, 13.37)]) + , ])> + + """ + + return "" % (self.uri, + list(self.segments_list_)) + + def __contains__(self, included: Union[Segment, 'Timeline']): + """Inclusion + + Check whether every segment of `included` does exist in timeline. + + Parameters + ---------- + included : Segment or Timeline + Segment or timeline being checked for inclusion + + Returns + ------- + contains : bool + True if every segment in `included` exists in timeline, + False otherwise + + Examples + -------- + >>> timeline1 = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)]) + >>> timeline2 = Timeline(segments=[Segment(0, 10)]) + >>> timeline1 in timeline2 + False + >>> timeline2 in timeline1 + >>> Segment(1, 13.37) in timeline1 + True + + """ + + if isinstance(included, Segment): + return included in self.segments_set_ + + elif isinstance(included, Timeline): + return self.segments_set_.issuperset(included.segments_set_) + + else: + raise TypeError( + 'Checking for inclusion only supports Segment and ' + 'Timeline instances') + + def empty(self) -> 'Timeline': + """Return an empty copy + + Returns + ------- + empty : Timeline + Empty timeline using the same 'uri' attribute. + + """ + return Timeline(uri=self.uri) + + def covers(self, other: 'Timeline') -> bool: + """Check whether other timeline is fully covered by the timeline + + Parameter + --------- + other : Timeline + Second timeline + + Returns + ------- + covers : bool + True if timeline covers "other" timeline entirely. False if at least + one segment of "other" is not fully covered by timeline + """ + + # compute gaps within "other" extent + # this is where we should look for possible faulty segments + gaps = self.gaps(support=other.extent()) + + # if at least one gap intersects with a segment from "other", + # "self" does not cover "other" entirely --> return False + for _ in gaps.co_iter(other): + return False + + # if no gap intersects with a segment from "other", + # "self" covers "other" entirely --> return True + return True + + def copy(self, segment_func: Optional[Callable[[Segment], Segment]] = None) \ + -> 'Timeline': + """Get a copy of the timeline + + If `segment_func` is provided, it is applied to each segment first. + + Parameters + ---------- + segment_func : callable, optional + Callable that takes a segment as input, and returns a segment. + Defaults to identity function (segment_func(segment) = segment) + + Returns + ------- + timeline : Timeline + Copy of the timeline + + """ + + # if segment_func is not provided + # just add every segment + if segment_func is None: + return Timeline(segments=self.segments_list_, uri=self.uri) + + # if is provided + # apply it to each segment before adding them + return Timeline(segments=[segment_func(s) for s in self.segments_list_], + uri=self.uri) + + def extent(self) -> Segment: + """Extent + + The extent of a timeline is the segment of minimum duration that + contains every segments of the timeline. It is unique, by definition. + The extent of an empty timeline is an empty segment. + + A picture is worth a thousand words:: + + timeline + |------| |------| |----| + |--| |-----| |----------| + + timeline.extent() + |--------------------------------| + + Returns + ------- + extent : Segment + Timeline extent + + Examples + -------- + >>> timeline = Timeline(segments=[Segment(0, 1), Segment(9, 10)]) + >>> timeline.extent() + + + """ + if self.segments_set_: + segments_boundaries_ = self.segments_boundaries_ + start = segments_boundaries_[0] + end = segments_boundaries_[-1] + return Segment(start=start, end=end) + + return Segment(start=0.0, end=0.0) + + def support_iter(self, collar: float = 0.0) -> Iterator[Segment]: + """Like `support` but returns a segment generator instead + + See also + -------- + :func:`pyannote.core.Timeline.support` + """ + + # The support of an empty timeline is an empty timeline. + if not self: + return + + # Principle: + # * gather all segments with no gap between them + # * add one segment per resulting group (their union |) + # Note: + # Since segments are kept sorted internally, + # there is no need to perform an exhaustive segment clustering. + # We just have to consider them in their natural order. + + # Initialize new support segment + # as very first segment of the timeline + new_segment = self.segments_list_[0] + + for segment in self: + + # If there is no gap between new support segment and next segment + # OR there is a gap with duration < collar seconds, + possible_gap = segment ^ new_segment + if not possible_gap or possible_gap.duration < collar: + # Extend new support segment using next segment + new_segment |= segment + + # If there actually is a gap and the gap duration >= collar + # seconds, + else: + yield new_segment + + # Initialize new support segment as next segment + # (right after the gap) + new_segment = segment + + # Add new segment to the timeline support + yield new_segment + + def support(self, collar: float = 0.) -> 'Timeline': + """Timeline support + + The support of a timeline is the timeline with the minimum number of + segments with exactly the same time span as the original timeline. It + is (by definition) unique and does not contain any overlapping + segments. + + A picture is worth a thousand words:: + + collar + |---| + + timeline + |------| |------| |----| + |--| |-----| |----------| + + timeline.support() + |------| |--------| |----------| + + timeline.support(collar) + |------------------| |----------| + + Parameters + ---------- + collar : float, optional + Merge separated by less than `collar` seconds. This is why there + are only two segments in the final timeline in the above figure. + Defaults to 0. + + Returns + ------- + support : Timeline + Timeline support + """ + return Timeline(segments=self.support_iter(collar), uri=self.uri) + + def duration(self) -> float: + """Timeline duration + + The timeline duration is the sum of the durations of the segments + in the timeline support. + + Returns + ------- + duration : float + Duration of timeline support, in seconds. + """ + + # The timeline duration is the sum of the durations + # of the segments in the timeline support. + return sum(s.duration for s in self.support_iter()) + + def gaps_iter(self, support: Optional[Support] = None) -> Iterator[Segment]: + """Like `gaps` but returns a segment generator instead + + See also + -------- + :func:`pyannote.core.Timeline.gaps` + + """ + + if support is None: + support = self.extent() + + if not isinstance(support, (Segment, Timeline)): + raise TypeError("unsupported operand type(s) for -':" + "%s and Timeline." % type(support).__name__) + + # segment support + if isinstance(support, Segment): + + # `end` is meant to store the end time of former segment + # initialize it with beginning of provided segment `support` + end = support.start + + # support on the intersection of timeline and provided segment + for segment in self.crop(support, mode='intersection').support(): + + # add gap between each pair of consecutive segments + # if there is no gap, segment is empty, therefore not added + gap = Segment(start=end, end=segment.start) + if gap: + yield gap + + # keep track of the end of former segment + end = segment.end + + # add final gap (if not empty) + gap = Segment(start=end, end=support.end) + if gap: + yield gap + + # timeline support + elif isinstance(support, Timeline): + + # yield gaps for every segment in support of provided timeline + for segment in support.support(): + for gap in self.gaps_iter(support=segment): + yield gap + + def gaps(self, support: Optional[Support] = None) \ + -> 'Timeline': + """Gaps + + A picture is worth a thousand words:: + + timeline + |------| |------| |----| + |--| |-----| |----------| + + timeline.gaps() + |--| |--| + + Parameters + ---------- + support : None, Segment or Timeline + Support in which gaps are looked for. Defaults to timeline extent + + Returns + ------- + gaps : Timeline + Timeline made of all gaps from original timeline, and delimited + by provided support + + See also + -------- + :func:`pyannote.core.Timeline.extent` + + """ + return Timeline(segments=self.gaps_iter(support=support), + uri=self.uri) + + def segmentation(self) -> 'Timeline': + """Segmentation + + Create the unique timeline with same support and same set of segment + boundaries as original timeline, but with no overlapping segments. + + A picture is worth a thousand words:: + + timeline + |------| |------| |----| + |--| |-----| |----------| + + timeline.segmentation() + |-|--|-| |-|---|--| |--|----|--| + + Returns + ------- + timeline : Timeline + (unique) timeline with same support and same set of segment + boundaries as original timeline, but with no overlapping segments. + """ + # COMPLEXITY: O(n) + support = self.support() + + # COMPLEXITY: O(n.log n) + # get all boundaries (sorted) + # |------| |------| |----| + # |--| |-----| |----------| + # becomes + # | | | | | | | | | | | | + timestamps = set([]) + for (start, end) in self: + timestamps.add(start) + timestamps.add(end) + timestamps = sorted(timestamps) + + # create new partition timeline + # | | | | | | | | | | | | + # becomes + # |-|--|-| |-|---|--| |--|----|--| + + # start with an empty copy + timeline = Timeline(uri=self.uri) + + if len(timestamps) == 0: + return Timeline(uri=self.uri) + + segments = [] + start = timestamps[0] + for end in timestamps[1:]: + # only add segments that are covered by original timeline + segment = Segment(start=start, end=end) + if segment and support.overlapping(segment.middle): + segments.append(segment) + # next segment... + start = end + + return Timeline(segments=segments, uri=self.uri) + + def to_annotation(self, + generator: Union[str, Iterable[Label], None, None] = 'string', + modality: Optional[str] = None) \ + -> 'Annotation': + """Turn timeline into an annotation + + Each segment is labeled by a unique label. + + Parameters + ---------- + generator : 'string', 'int', or iterable, optional + If 'string' (default) generate string labels. If 'int', generate + integer labels. If iterable, use it to generate labels. + modality : str, optional + + Returns + ------- + annotation : Annotation + Annotation + """ + + from .annotation import Annotation + annotation = Annotation(uri=self.uri, modality=modality) + if generator == 'string': + from .utils.generators import string_generator + generator = string_generator() + elif generator == 'int': + from .utils.generators import int_generator + generator = int_generator() + + for segment in self: + annotation[segment] = next(generator) + + return annotation + + def _iter_uem(self) -> Iterator[Text]: + """Generate lines for a UEM file for this timeline + + Returns + ------- + iterator: Iterator[str] + An iterator over UEM text lines + """ + uri = self.uri if self.uri else "" + if isinstance(uri, Text) and ' ' in uri: + msg = (f'Space-separated UEM file format does not allow file URIs ' + f'containing spaces (got: "{uri}").') + raise ValueError(msg) + for segment in self: + yield f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n" + + def to_uem(self) -> Text: + """Serialize timeline as a string using UEM format + + Returns + ------- + serialized: str + UEM string + """ + return "".join([line for line in self._iter_uem()]) + + def write_uem(self, file: TextIO): + """Dump timeline to file using UEM format + + Parameters + ---------- + file : file object + + Usage + ----- + >>> with open('file.uem', 'w') as file: + ... timeline.write_uem(file) + """ + for line in self._iter_uem(): + file.write(line) + + def _repr_png_(self): + """IPython notebook support + + See also + -------- + :mod:`pyannote.core.notebook` + """ + + from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING + if not MATPLOTLIB_IS_AVAILABLE: + warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__)) + return None + + from .notebook import repr_timeline + return repr_timeline(self) diff --git a/ailia-models/code/pyannote_audio_utils/core/utils/generators.py b/ailia-models/code/pyannote_audio_utils/core/utils/generators.py new file mode 100644 index 0000000000000000000000000000000000000000..00a9600fb77a8d2fa8f915cc9289acf3d4814b90 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/utils/generators.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2014-2018 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +import itertools +from string import ascii_uppercase +from typing import Iterable, Union, List, Set, Optional, Iterator + + +def pairwise(iterable: Iterable): + """s -> (s0,s1), (s1,s2), (s2, s3), ...""" + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +def string_generator(skip: Optional[Union[List, Set]] = None) \ + -> Iterator[str]: + """Label generator + + Parameters + ---------- + skip : list or set + List of labels that must be skipped. + This option is useful in case you want to make sure generated labels + are different from a pre-existing set of labels. + + Usage + ----- + t = string_generator() + next(t) -> 'A' # start with 1-letter labels + ... # from A to Z + next(t) -> 'Z' + next(t) -> 'AA' # then 2-letters labels + next(t) -> 'AB' # from AA to ZZ + ... + next(t) -> 'ZY' + next(t) -> 'ZZ' + next(t) -> 'AAA' # then 3-letters labels + ... # (you get the idea) + """ + if skip is None: + skip = list() + + # label length + r = 1 + + # infinite loop + while True: + + # generate labels with current length + for c in itertools.product(ascii_uppercase, repeat=r): + if c in skip: + continue + yield ''.join(c) + + # increment label length when all possibilities are exhausted + r = r + 1 + + +def int_generator() -> Iterator[int]: + i = 0 + while True: + yield i + i = i + 1 diff --git a/ailia-models/code/pyannote_audio_utils/core/utils/types.py b/ailia-models/code/pyannote_audio_utils/core/utils/types.py new file mode 100644 index 0000000000000000000000000000000000000000..5fe2616572c6479d6aada052f66d159e542a2c4f --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/core/utils/types.py @@ -0,0 +1,13 @@ +from typing import Hashable, Union, Tuple, Iterator, Literal + +Label = Hashable +Support = Union['Segment', 'Timeline'] +LabelGeneratorMode = Literal['int', 'string'] +LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]] +TrackName = Union[str, int] +Key = Union['Segment', Tuple['Segment', TrackName]] +Resource = Union['Segment', 'Timeline', 'SlidingWindowFeature', + 'Annotation'] +CropMode = Literal['intersection', 'loose', 'strict'] +Alignment = Literal['center', 'loose', 'strict'] +LabelStyle = Tuple[str, int, Tuple[float, float, float]] diff --git a/ailia-models/code/pyannote_audio_utils/database/__init__.py b/ailia-models/code/pyannote_audio_utils/database/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2ebc608dfea23411b8de109c90d3aefebb60e7 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/database/__init__.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2016- CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +# Alexis PLAQUET + +"""pyannote.database""" + + +from typing import Optional +import warnings + +# from .registry import registry, LoadingMode + +# from .database import Database + +from .protocol.protocol import Protocol +from .protocol.protocol import ProtocolFile +from .protocol.protocol import Subset +from .protocol.protocol import Preprocessors + +# from .file_finder import FileFinder +# from .util import get_annotated +# from .util import get_unique_identifier +# from .util import get_label_identifier + +# from ._version import get_versions +# + +# __version__ = get_versions()["version"] +# del get_versions + + +def get_protocol(name, preprocessors: Optional[Preprocessors] = None) -> Protocol: + """Get protocol by full name + + name : str + Protocol full name (e.g. "Etape.SpeakerDiarization.TV") + preprocessors : dict or (key, preprocessor) iterable + When provided, each protocol item (dictionary) are preprocessed, such + that item[key] = preprocessor(item). In case 'preprocessor' is not + callable, it should be a string containing placeholder for item keys + (e.g. {'audio': '/path/to/{uri}.wav'}) + + Returns + ------- + protocol : Protocol + Protocol instance + """ + warnings.warn( + "`get_protocol` has been deprecated in favor of `pyannote.database.registry.get_protocol`.", + DeprecationWarning) + return registry.get_protocol(name, preprocessors=preprocessors) + + +__all__ = [ + "registry", + "get_protocol", + "LoadingMode", + "Database", + "Protocol", + "ProtocolFile", + "Subset", + "FileFinder", + "get_annotated", + "get_unique_identifier", + "get_label_identifier", +] diff --git a/ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py b/ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f92e690647b5a7454abcb2315c7b3e8ad58e1fda --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2016- CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +from .protocol import Protocol + +__all__ = [ + "Protocol", +] + diff --git a/ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py b/ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..717846a18746d342c67b03db761a21e1799874de --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py @@ -0,0 +1,434 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2016-2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +""" +######### +Protocols +######### + +""" + +import warnings +import collections +import threading +import itertools +from typing import Union, Dict, Iterator, Callable, Any, Text, Optional + +# try: +from typing import Literal +# except ImportError: + # from typing_extensions import Literal + +Subset = Literal["train", "development", "test"] +LEGACY_SUBSET_MAPPING = {"train": "trn", "development": "dev", "test": "tst"} +Scope = Literal["file", "database", "global"] + +Preprocessor = Callable[["ProtocolFile"], Any] +Preprocessors = Dict[Text, Preprocessor] + + +class ProtocolFile(collections.abc.MutableMapping): + """Protocol file with lazy preprocessors + + This is a dict-like data structure where some values may depend on other + values, and are only computed if/when requested. Once computed, they are + cached and never recomputed again. + + Parameters + ---------- + precomputed : dict + Regular dictionary with precomputed values + lazy : dict, optional + Dictionary describing how lazy value needs to be computed. + Values are callable expecting a dictionary as input and returning the + computed value. + + """ + + def __init__(self, precomputed: Union[Dict, "ProtocolFile"], lazy: Dict = None): + + if lazy is None: + lazy = dict() + + if isinstance(precomputed, ProtocolFile): + # when 'precomputed' is a ProtocolFile, it may already contain lazy keys. + + # we use 'precomputed' precomputed keys as precomputed keys + self._store: Dict = abs(precomputed) + + # we handle the corner case where the intersection of 'precomputed' lazy keys + # and 'lazy' keys is not empty. this is currently achieved by "unlazying" the + # 'precomputed' one (which is probably not the most efficient solution). + for key in set(precomputed.lazy) & set(lazy): + self._store[key] = precomputed[key] + + # we use the union of 'precomputed' lazy keys and provided 'lazy' keys as lazy keys + compound_lazy = dict(precomputed.lazy) + compound_lazy.update(lazy) + self.lazy: Dict = compound_lazy + + else: + # when 'precomputed' is a Dict, we use it directly as precomputed keys + # and 'lazy' as lazy keys. + self._store = dict(precomputed) + self.lazy = dict(lazy) + + # re-entrant lock used below to make ProtocolFile thread-safe + self.lock_ = threading.RLock() + + # this is needed to avoid infinite recursion + # when a key is both in precomputed and lazy. + # keys with evaluating_ > 0 are currently being evaluated + # and therefore should be taken from precomputed + self.evaluating_ = collections.Counter() + + # since RLock is not pickable, remove it before pickling... + def __getstate__(self): + d = dict(self.__dict__) + del d["lock_"] + return d + + # ... and add it back when unpickling + def __setstate__(self, d): + self.__dict__.update(d) + self.lock_ = threading.RLock() + + def __abs__(self): + with self.lock_: + return dict(self._store) + + def __getitem__(self, key): + with self.lock_: + + if key in self.lazy and self.evaluating_[key] == 0: + + # mark lazy key as being evaluated + self.evaluating_.update([key]) + + # apply preprocessor once and remove it + value = self.lazy[key](self) + del self.lazy[key] + + # warn the user when a precomputed key is modified + if key in self._store and value != self._store[key]: + msg = 'Existing precomputed key "{key}" has been modified by a preprocessor.' + warnings.warn(msg.format(key=key)) + + # store the output of the lazy computation + # so that it is available for future access + self._store[key] = value + + # lazy evaluation is finished for key + self.evaluating_.subtract([key]) + + return self._store[key] + + def __setitem__(self, key, value): + with self.lock_: + + if key in self.lazy: + del self.lazy[key] + + self._store[key] = value + + def __delitem__(self, key): + with self.lock_: + + if key in self.lazy: + del self.lazy[key] + + del self._store[key] + + def __iter__(self): + with self.lock_: + + store_keys = list(self._store) + for key in store_keys: + yield key + + lazy_keys = list(self.lazy) + for key in lazy_keys: + if key in self._store: + continue + yield key + + def __len__(self): + with self.lock_: + return len(set(self._store) | set(self.lazy)) + + def files(self) -> Iterator["ProtocolFile"]: + """Iterate over all files + + When `current_file` refers to only one file, + yield it and return. + When `current_file` refers to a list of file (i.e. 'uri' is a list), + yield each file separately. + + Examples + -------- + >>> current_file = ProtocolFile({ + ... 'uri': 'my_uri', + ... 'database': 'my_database'}) + >>> for file in current_file.files(): + ... print(file['uri'], file['database']) + my_uri my_database + + >>> current_file = { + ... 'uri': ['my_uri1', 'my_uri2', 'my_uri3'], + ... 'database': 'my_database'} + >>> for file in current_file.files(): + ... print(file['uri'], file['database']) + my_uri1 my_database + my_uri2 my_database + my_uri3 my_database + + """ + + uris = self["uri"] + if not isinstance(uris, list): + yield self + return + + n_uris = len(uris) + + # iterate over precomputed keys and make sure + + precomputed = {"uri": uris} + for key, value in abs(self).items(): + + if key == "uri": + continue + + if not isinstance(value, list): + precomputed[key] = itertools.repeat(value) + + else: + if len(value) != n_uris: + msg = ( + f'Mismatch between number of "uris" ({n_uris}) ' + f'and number of "{key}" ({len(value)}).' + ) + raise ValueError(msg) + precomputed[key] = value + + keys = list(precomputed.keys()) + for values in zip(*precomputed.values()): + precomputed_one = dict(zip(keys, values)) + yield ProtocolFile(precomputed_one, self.lazy) + + +class Protocol: + """Experimental protocol + + An experimental protocol usually defines three subsets: a training subset, + a development subset, and a test subset. + + An experimental protocol can be defined programmatically by creating a + class that inherits from Protocol and implements at least + one of `train_iter`, `development_iter` and `test_iter` methods: + + >>> class MyProtocol(Protocol): + ... def train_iter(self) -> Iterator[Dict]: + ... yield {"uri": "filename1", "any_other_key": "..."} + ... yield {"uri": "filename2", "any_other_key": "..."} + + `{subset}_iter` should return an iterator of dictionnaries with + - "uri" key (mandatory) that provides a unique file identifier (usually + the filename), + - any other key that the protocol may provide. + + It can then be used in Python like this: + + >>> protocol = MyProtocol() + >>> for file in protocol.train(): + ... print(file["uri"]) + filename1 + filename2 + + An experimental protocol can also be defined using `pyannote_audio_utils.database` + configuration file, whose (configurable) path defaults to "~/database.yml". + + ~~~ Content of ~/database.yml ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Protocols: + MyDatabase: + Protocol: + MyProtocol: + train: + uri: /path/to/collection.lst + any_other_key: ... # see custom loader documentation + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + where "/path/to/collection.lst" contains the list of identifiers of the + files in the collection: + + ~~~ Content of "/path/to/collection.lst ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + filename1 + filename2 + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + It can then be used in Python like this: + + >>> from pyannote_audio_utils.database import registry + >>> protocol = registry.get_protocol('MyDatabase.Protocol.MyProtocol') + >>> for file in protocol.train(): + ... print(file["uri"]) + filename1 + filename2 + + This class is usually inherited from, but can be used directly. + + Parameters + ---------- + preprocessors : dict + Preprocess protocol files so that `file[key] = preprocessors[key](file)` + for each key in `preprocessors`. In case `preprocessors[key]` is not + callable, it should be a string containing placeholders for `file` keys + (e.g. {'audio': '/path/to/{uri}.wav'}) + """ + + def __init__(self, preprocessors: Optional[Preprocessors] = None): + super().__init__() + + if preprocessors is None: + preprocessors = dict() + + self.preprocessors = dict() + for key, preprocessor in preprocessors.items(): + + if callable(preprocessor): + self.preprocessors[key] = preprocessor + + # when `preprocessor` is not callable, it should be a string + # containing placeholder for item key (e.g. '/path/to/{uri}.wav') + elif isinstance(preprocessor, str): + preprocessor_copy = str(preprocessor) + + def func(current_file): + return preprocessor_copy.format(**current_file) + + self.preprocessors[key] = func + + else: + msg = f'"{key}" preprocessor is neither a callable nor a string.' + raise ValueError(msg) + + def preprocess(self, current_file: Union[Dict, ProtocolFile]) -> ProtocolFile: + return ProtocolFile(current_file, lazy=self.preprocessors) + + def __str__(self): + return self.__doc__ + + def train_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: + """Iterate over files in the training subset""" + raise NotImplementedError() + + def development_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: + """Iterate over files in the development subset""" + raise NotImplementedError() + + def test_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: + """Iterate over files in the test subset""" + raise NotImplementedError() + + def subset_helper(self, subset: Subset) -> Iterator[ProtocolFile]: + + try: + files = getattr(self, f"{subset}_iter")() + except (AttributeError, NotImplementedError): + # previous pyannote_audio_utils.database versions used `trn_iter` instead of + # `train_iter`, `dev_iter` instead of `development_iter`, and + # `tst_iter` instead of `test_iter`. therefore, we use the legacy + # version when it is available (and the new one is not). + subset_legacy = LEGACY_SUBSET_MAPPING[subset] + try: + files = getattr(self, f"{subset_legacy}_iter")() + except AttributeError: + msg = f"Protocol does not implement a {subset} subset." + raise NotImplementedError(msg) + + for file in files: + yield self.preprocess(file) + + def train(self) -> Iterator[ProtocolFile]: + return self.subset_helper("train") + + def development(self) -> Iterator[ProtocolFile]: + return self.subset_helper("development") + + def test(self) -> Iterator[ProtocolFile]: + return self.subset_helper("test") + + def files(self) -> Iterator[ProtocolFile]: + """Iterate over all files in `protocol`""" + + # imported here to avoid circular imports + from pyannote_audio_utils.database.util import get_unique_identifier + + yielded_uris = set() + + for method in [ + "development", + "development_enrolment", + "development_trial", + "test", + "test_enrolment", + "test_trial", + "train", + "train_enrolment", + "train_trial", + ]: + + if not hasattr(self, method): + continue + + def iterate(): + try: + for file in getattr(self, method)(): + yield file + except (AttributeError, NotImplementedError): + return + + for current_file in iterate(): + + # skip "files" that do not contain a "uri" entry. + # this happens for speaker verification trials that contain + # two nested files "file1" and "file2" + # see https://github.com/pyannote_audio_utils/pyannote_audio_utils-db-voxceleb/issues/4 + if "uri" not in current_file: + continue + + for current_file_ in current_file.files(): + + # corner case when the same file is yielded several times + uri = get_unique_identifier(current_file_) + if uri in yielded_uris: + continue + + yield current_file_ + + yielded_uris.add(uri) diff --git a/ailia-models/code/pyannote_audio_utils/database/util.py b/ailia-models/code/pyannote_audio_utils/database/util.py new file mode 100644 index 0000000000000000000000000000000000000000..452961feb9bfb7645d0ad98f52b4bd9f854ddb77 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/database/util.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2016-2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + + +import warnings +import pandas as pd +from pyannote_audio_utils.core import Segment, Timeline, Annotation + +from typing import Text + +DatabaseName = Text +PathTemplate = Text + + +def get_unique_identifier(item): + """Return unique item identifier + + The complete format is {database}/{uri}_{channel}: + * prefixed by "{database}/" only when `item` has a 'database' key. + * suffixed by "_{channel}" only when `item` has a 'channel' key. + + Parameters + ---------- + item : dict + Item as yielded by pyannote_audio_utils.database protocols + + Returns + ------- + identifier : str + Unique item identifier + """ + + IDENTIFIER = "" + + # {database}/{uri}_{channel} + database = item.get("database", None) + if database is not None: + IDENTIFIER += f"{database}/" + IDENTIFIER += item["uri"] + channel = item.get("channel", None) + if channel is not None: + IDENTIFIER += f"_{channel:d}" + + return IDENTIFIER + + +# This function is used in custom.py +def get_annotated(current_file): + """Get part of the file that is annotated. + + Parameters + ---------- + current_file : `dict` + File generated by a `pyannote_audio_utils.database` protocol. + + Returns + ------- + annotated : `pyannote_audio_utils.core.Timeline` + Part of the file that is annotated. Defaults to + `current_file["annotated"]`. When it does not exist, try to use the + full audio extent. When that fails, use "annotation" extent. + """ + + # if protocol provides 'annotated' key, use it + if "annotated" in current_file: + annotated = current_file["annotated"] + return annotated + + # if it does not, but does provide 'audio' key + # try and use wav duration + + if "duration" in current_file: + try: + duration = current_file["duration"] + except ImportError: + pass + else: + annotated = Timeline([Segment(0, duration)]) + msg = '"annotated" was approximated by [0, audio duration].' + warnings.warn(msg) + return annotated + + extent = current_file["annotation"].get_timeline().extent() + annotated = Timeline([extent]) + + msg = ( + '"annotated" was approximated by "annotation" extent. ' + 'Please provide "annotated" directly, or at the very ' + 'least, use a "duration" preprocessor.' + ) + warnings.warn(msg) + + return annotated + + +def get_label_identifier(label, current_file): + """Return unique label identifier + + Parameters + ---------- + label : str + Database-internal label + current_file + Yielded by pyannote_audio_utils.database protocols + + Returns + ------- + unique_label : str + Global label + """ + + # TODO. when the "true" name of a person is used, + # do not preprend database name. + database = current_file["database"] + return database + "|" + label + + +def load_rttm(file_rttm, keep_type="SPEAKER"): + """Load RTTM file + + Parameter + --------- + file_rttm : `str` + Path to RTTM file. + keep_type : str, optional + Only keep lines with this type (field #1 in RTTM specs). + Defaults to "SPEAKER". + + Returns + ------- + annotations : `dict` + Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary. + """ + + names = [ + "type", + "uri", + "NA2", + "start", + "duration", + "NA3", + "NA4", + "speaker", + "NA5", + "NA6", + ] + dtype = {"uri": str, "start": float, "duration": float, "speaker": str} + data = pd.read_csv( + file_rttm, + names=names, + dtype=dtype, + # delim_whitespace=True, + sep='\s+', + keep_default_na=True, + ) + + annotations = dict() + for uri, turns in data.groupby("uri"): + annotation = Annotation(uri=uri) + for i, turn in turns.iterrows(): + if turn.type != keep_type: + continue + segment = Segment(turn.start, turn.start + turn.duration) + annotation[segment, i] = turn.speaker + annotations[uri] = annotation + + return annotations + + +def load_stm(file_stm): + """Load STM file (speaker-info only) + + Parameter + --------- + file_stm : str + Path to STM file + + Returns + ------- + annotations : `dict` + Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary. + """ + + dtype = {"uri": str, "speaker": str, "start": float, "end": float} + data = pd.read_csv( + file_stm, + # delim_whitespace=True, + sep='\s+', + usecols=[0, 2, 3, 4], + dtype=dtype, + names=list(dtype), + ) + + annotations = dict() + for uri, turns in data.groupby("uri"): + annotation = Annotation(uri=uri) + for i, turn in turns.iterrows(): + segment = Segment(turn.start, turn.end) + annotation[segment, i] = turn.speaker + annotations[uri] = annotation + + return annotations + + +def load_mdtm(file_mdtm): + """Load MDTM file + + Parameter + --------- + file_mdtm : `str` + Path to MDTM file. + + Returns + ------- + annotations : `dict` + Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary. + """ + + names = ["uri", "NA1", "start", "duration", "NA2", "NA3", "NA4", "speaker"] + dtype = {"uri": str, "start": float, "duration": float, "speaker": str} + data = pd.read_csv( + file_mdtm, + names=names, + dtype=dtype, + # delim_whitespace=True, + sep='\s+', + keep_default_na=False, + ) + + annotations = dict() + for uri, turns in data.groupby("uri"): + annotation = Annotation(uri=uri) + for i, turn in turns.iterrows(): + segment = Segment(turn.start, turn.start + turn.duration) + annotation[segment, i] = turn.speaker + annotations[uri] = annotation + + return annotations + + +def load_uem(file_uem): + """Load UEM file + + Parameter + --------- + file_uem : `str` + Path to UEM file. + + Returns + ------- + timelines : `dict` + Evaluation map as a {uri: pyannote_audio_utils.core.Timeline} dictionary. + """ + + names = ["uri", "NA1", "start", "end"] + dtype = {"uri": str, "start": float, "end": float} + data = pd.read_csv(file_uem, names=names, dtype=dtype, sep='\s+',) + + timelines = dict() + for uri, parts in data.groupby("uri"): + segments = [Segment(part.start, part.end) for i, part in parts.iterrows()] + timelines[uri] = Timeline(segments=segments, uri=uri) + + return timelines + + +def load_lab(path, uri: str = None) -> Annotation: + """Load LAB file + + Parameter + --------- + file_lab : `str` + Path to LAB file + + Returns + ------- + data : `pyannote_audio_utils.core.Annotation` + """ + + names = ["start", "end", "label"] + dtype = {"start": float, "end": float, "label": str} + data = pd.read_csv(path, names=names, dtype=dtype, sep='\s+',) + + annotation = Annotation(uri=uri) + for i, turn in data.iterrows(): + segment = Segment(turn.start, turn.end) + annotation[segment, i] = turn.label + + return annotation + + +def load_lst(file_lst): + """Load LST file + + LST files provide a list of URIs (one line per URI) + + Parameter + --------- + file_lst : `str` + Path to LST file. + + Returns + ------- + uris : `list` + List or uris + """ + + with open(file_lst, mode="r") as fp: + lines = fp.readlines() + return [line.strip() for line in lines] + + +def load_mapping(mapping_txt): + """Load mapping file + + Parameter + --------- + mapping_txt : `str` + Path to mapping file + + Returns + ------- + mapping : `dict` + {1st field: 2nd field} dictionary + """ + + with open(mapping_txt, mode="r") as fp: + lines = fp.readlines() + + mapping = dict() + for line in lines: + key, value, *left = line.strip().split() + mapping[key] = value + + return mapping + + +class LabelMapper(object): + """Label mapper for use as pyannote_audio_utils.database preprocessor + + Parameters + ---------- + mapping : `dict` + Mapping dictionary as used in `Annotation.rename_labels()`. + keep_missing : `bool`, optional + In case a label has no mapping, a `ValueError` will be raised. + Set "keep_missing" to True to keep those labels unchanged instead. + + Usage + ----- + >>> mapping = {'Hadrien': 'MAL', 'Marvin': 'MAL', + ... 'Wassim': 'CHI', 'Herve': 'GOD'} + >>> preprocessors = {'annotation': LabelMapper(mapping=mapping)} + >>> protocol = registry.get_protocol('AMI.SpeakerDiarization.MixHeadset', + preprocessors=preprocessors) + + """ + + def __init__(self, mapping, keep_missing=False): + self.mapping = mapping + self.keep_missing = keep_missing + + def __call__(self, current_file): + + if not self.keep_missing: + missing = set(current_file["annotation"].labels()) - set(self.mapping) + if missing and not self.keep_missing: + label = missing.pop() + msg = ( + f'No mapping found for label "{label}". Set "keep_missing" ' + f"to True to keep labels with no mapping." + ) + raise ValueError(msg) + + return current_file["annotation"].rename_labels(mapping=self.mapping) diff --git a/ailia-models/code/pyannote_audio_utils/metrics/__init__.py b/ailia-models/code/pyannote_audio_utils/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaf1860454350724769eb42c199d27ff8f6872 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/__init__.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2012-2021 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +from ._version import get_versions +from .base import f_measure + +__version__ = get_versions()["version"] +del get_versions + + +__all__ = ["f_measure"] diff --git a/ailia-models/code/pyannote_audio_utils/metrics/_version.py b/ailia-models/code/pyannote_audio_utils/metrics/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..11087d384aa7380f3f57fe689e83a1a00bab23df --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/_version.py @@ -0,0 +1,21 @@ + +# This file was generated by 'versioneer.py' (0.15) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json +import sys + +version_json = ''' +{ + "dirty": false, + "error": null, + "full-revisionid": "babbd1c68adc50c0e2199676c7ae741194c520da", + "version": "3.2.1" +} +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) diff --git a/ailia-models/code/pyannote_audio_utils/metrics/base.py b/ailia-models/code/pyannote_audio_utils/metrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bc5fa4fcae948dece8116ec79daaa83c09ef19 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/base.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2012- CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +from typing import List, Union, Optional, Set, Tuple + +import warnings +import numpy as np +import pandas as pd +import scipy.stats +from pyannote_audio_utils.core import Annotation, Timeline + +from pyannote_audio_utils.metrics.types import Details, MetricComponents + + +class BaseMetric: + """ + :class:`BaseMetric` is the base class for most pyannote_audio_utils evaluation metrics. + + Attributes + ---------- + name : str + Human-readable name of the metric (eg. 'diarization error rate') + """ + + @classmethod + def metric_name(cls) -> str: + raise NotImplementedError( + cls.__name__ + " is missing a 'metric_name' class method. " + "It should return the name of the metric as string." + ) + + @classmethod + def metric_components(cls) -> MetricComponents: + raise NotImplementedError( + cls.__name__ + " is missing a 'metric_components' class method. " + "It should return the list of names of metric components." + ) + + def __init__(self, **kwargs): + super(BaseMetric, self).__init__() + self.metric_name_ = self.__class__.metric_name() + self.components_: Set[str] = set(self.__class__.metric_components()) + self.reset() + + def init_components(self): + return {value: 0.0 for value in self.components_} + + def reset(self): + """Reset accumulated components and metric values""" + self.accumulated_: Details = dict() + self.results_: List = list() + for value in self.components_: + self.accumulated_[value] = 0.0 + + @property + def name(self): + """Metric name.""" + return self.metric_name() + + # TODO: use joblib/locky to allow parallel processing? + # TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...) + + def __call__(self, reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + detailed: bool = False, uri: Optional[str] = None, **kwargs): + """Compute metric value and accumulate components + + Parameters + ---------- + reference : type depends on the metric + Manual `reference` + hypothesis : type depends on the metric + Evaluated `hypothesis` + uri : optional + Override uri. + detailed : bool, optional + By default (False), return metric value only. + Set `detailed` to True to return dictionary where keys are + components names and values are component values + + Returns + ------- + value : float (if `detailed` is False) + Metric value + components : dict (if `detailed` is True) + `components` updated with metric value + """ + + # compute metric components + components = self.compute_components(reference, hypothesis, **kwargs) + + # compute rate based on components + components[self.metric_name_] = self.compute_metric(components) + + # keep track of this computation + uri = uri or getattr(reference, "uri", "NA") + self.results_.append((uri, components)) + + # accumulate components + for name in self.components_: + self.accumulated_[name] += components[name] + + if detailed: + return components + + return components[self.metric_name_] + + def report(self, display: bool = False) -> pd.DataFrame: + """Evaluation report + + Parameters + ---------- + display : bool, optional + Set to True to print the report to stdout. + + Returns + ------- + report : pandas.DataFrame + Dataframe with one column per metric component, one row per + evaluated item, and one final row for accumulated results. + """ + + report = [] + uris = [] + + percent = "total" in self.metric_components() + + for uri, components in self.results_: + row = {} + if percent: + total = components["total"] + for key, value in components.items(): + if key == self.name: + row[key, "%"] = 100 * value + elif key == "total": + row[key, ""] = value + else: + row[key, ""] = value + if percent: + if total > 0: + row[key, "%"] = 100 * value / total + else: + row[key, "%"] = np.NaN + + report.append(row) + uris.append(uri) + + row = {} + components = self.accumulated_ + + if percent: + total = components["total"] + + for key, value in components.items(): + if key == self.name: + row[key, "%"] = 100 * value + elif key == "total": + row[key, ""] = value + else: + row[key, ""] = value + if percent: + if total > 0: + row[key, "%"] = 100 * value / total + else: + row[key, "%"] = np.NaN + + row[self.name, "%"] = 100 * abs(self) + report.append(row) + uris.append("TOTAL") + + df = pd.DataFrame(report) + + df["item"] = uris + df = df.set_index("item") + + df.columns = pd.MultiIndex.from_tuples(df.columns) + + df = df[[self.name] + self.metric_components()] + + if display: + print( + df.to_string( + index=True, + sparsify=False, + justify="right", + float_format=lambda f: "{0:.2f}".format(f), + ) + ) + + return df + + def __str__(self): + report = self.report(display=False) + return report.to_string( + sparsify=False, float_format=lambda f: "{0:.2f}".format(f) + ) + + def __abs__(self): + """Compute metric value from accumulated components""" + return self.compute_metric(self.accumulated_) + + def __getitem__(self, component: str) -> Union[float, Details]: + """Get value of accumulated `component`. + + Parameters + ---------- + component : str + Name of `component` + + Returns + ------- + value : type depends on the metric + Value of accumulated `component` + + """ + if component == slice(None, None, None): + return dict(self.accumulated_) + else: + return self.accumulated_[component] + + def __iter__(self): + """Iterator over the accumulated (uri, value)""" + for uri, component in self.results_: + yield uri, component + + def compute_components(self, + reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + **kwargs) -> Details: + """Compute metric components + + Parameters + ---------- + reference : type depends on the metric + Manual `reference` + hypothesis : same as `reference` + Evaluated `hypothesis` + + Returns + ------- + components : dict + Dictionary where keys are component names and values are component + values + + """ + raise NotImplementedError( + self.__class__.__name__ + " is missing a 'compute_components' method." + "It should return a dictionary where keys are component names " + "and values are component values." + ) + + def compute_metric(self, components: Details): + """Compute metric value from computed `components` + + Parameters + ---------- + components : dict + Dictionary where keys are components names and values are component + values + + Returns + ------- + value : type depends on the metric + Metric value + """ + raise NotImplementedError( + self.__class__.__name__ + " is missing a 'compute_metric' method. " + "It should return the actual value of the metric based " + "on the precomputed component dictionary given as input." + ) + + def confidence_interval(self, alpha: float = 0.9) \ + -> Tuple[float, Tuple[float, float]]: + """Compute confidence interval on accumulated metric values + + Parameters + ---------- + alpha : float, optional + Probability that the returned confidence interval contains + the true metric value. + + Returns + ------- + (center, (lower, upper)) + with center the mean of the conditional pdf of the metric value + and (lower, upper) is a confidence interval centered on the median, + containing the estimate to a probability alpha. + + See Also: + --------- + scipy.stats.bayes_mvs + + """ + + values = [r[self.metric_name_] for _, r in self.results_] + + if len(values) == 0: + raise ValueError("Please evaluate a bunch of files before computing confidence interval.") + + elif len(values) == 1: + warnings.warn("Cannot compute a reliable confidence interval out of just one file.") + center = lower = upper = values[0] + return center, (lower, upper) + + else: + return scipy.stats.bayes_mvs(values, alpha=alpha)[0] + + +PRECISION_NAME = "precision" +PRECISION_RETRIEVED = "# retrieved" +PRECISION_RELEVANT_RETRIEVED = "# relevant retrieved" + + +class Precision(BaseMetric): + """ + :class:`Precision` is a base class for precision-like evaluation metrics. + + It defines two components '# retrieved' and '# relevant retrieved' and the + compute_metric() method to compute the actual precision: + + Precision = # retrieved / # relevant retrieved + + Inheriting classes must implement compute_components(). + """ + + @classmethod + def metric_name(cls): + return PRECISION_NAME + + @classmethod + def metric_components(cls) -> MetricComponents: + return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED] + + def compute_metric(self, components: Details) -> float: + """Compute precision from `components`""" + numerator = components[PRECISION_RELEVANT_RETRIEVED] + denominator = components[PRECISION_RETRIEVED] + if denominator == 0.0: + if numerator == 0: + return 1.0 + else: + raise ValueError("") + else: + return numerator / denominator + + +RECALL_NAME = "recall" +RECALL_RELEVANT = "# relevant" +RECALL_RELEVANT_RETRIEVED = "# relevant retrieved" + + +class Recall(BaseMetric): + """ + :class:`Recall` is a base class for recall-like evaluation metrics. + + It defines two components '# relevant' and '# relevant retrieved' and the + compute_metric() method to compute the actual recall: + + Recall = # relevant retrieved / # relevant + + Inheriting classes must implement compute_components(). + """ + + @classmethod + def metric_name(cls): + return RECALL_NAME + + @classmethod + def metric_components(cls) -> MetricComponents: + return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED] + + def compute_metric(self, components: Details) -> float: + """Compute recall from `components`""" + numerator = components[RECALL_RELEVANT_RETRIEVED] + denominator = components[RECALL_RELEVANT] + if denominator == 0.0: + if numerator == 0: + return 1.0 + else: + raise ValueError("") + else: + return numerator / denominator + + +def f_measure(precision: float, recall: float, beta=1.0) -> float: + """Compute f-measure + + f-measure is defined as follows: + F(P, R, b) = (1+b²).P.R / (b².P + R) + + where P is `precision`, R is `recall` and b is `beta` + """ + if precision + recall == 0.0: + return 0 + return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall) diff --git a/ailia-models/code/pyannote_audio_utils/metrics/diarization.py b/ailia-models/code/pyannote_audio_utils/metrics/diarization.py new file mode 100644 index 0000000000000000000000000000000000000000..6935a412f0279cc5c48769dfb4393fa53c2100ba --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/diarization.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2012-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +"""Metrics for diarization""" +from typing import Optional, Dict, TYPE_CHECKING + +from pyannote_audio_utils.core import Annotation, Timeline +from pyannote_audio_utils.core.utils.types import Label + +from .identification import IdentificationErrorRate +from .matcher import HungarianMapper +from .types import Details, MetricComponents + +if TYPE_CHECKING: + pass + +# TODO: can't we put these as class attributes? +DER_NAME = 'diarization error rate' + + +class DiarizationErrorRate(IdentificationErrorRate): + """Diarization error rate + + First, the optimal mapping between reference and hypothesis labels + is obtained using the Hungarian algorithm. Then, the actual diarization + error rate is computed as the identification error rate with each hypothesis + label translated into the corresponding reference label. + + Parameters + ---------- + collar : float, optional + Duration (in seconds) of collars removed from evaluation around + boundaries of reference segments. + skip_overlap : bool, optional + Set to True to not evaluate overlap regions. + Defaults to False (i.e. keep overlap regions). + + Usage + ----- + + * Diarization error rate between `reference` and `hypothesis` annotations + + >>> metric = DiarizationErrorRate() + >>> reference = Annotation(...) # doctest: +SKIP + >>> hypothesis = Annotation(...) # doctest: +SKIP + >>> value = metric(reference, hypothesis) # doctest: +SKIP + + * Compute global diarization error rate and confidence interval + over multiple documents + + >>> for reference, hypothesis in ... # doctest: +SKIP + ... metric(reference, hypothesis) # doctest: +SKIP + >>> global_value = abs(metric) # doctest: +SKIP + >>> mean, (lower, upper) = metric.confidence_interval() # doctest: +SKIP + + * Get diarization error rate detailed components + + >>> components = metric(reference, hypothesis, detailed=True) #doctest +SKIP + + * Get accumulated components + + >>> components = metric[:] # doctest: +SKIP + >>> metric['confusion'] # doctest: +SKIP + + See Also + -------- + :class:`pyannote_audio_utils.metric.base.BaseMetric`: details on accumulation + :class:`pyannote_audio_utils.metric.identification.IdentificationErrorRate`: identification error rate + + """ + + @classmethod + def metric_name(cls) -> str: + return DER_NAME + + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + **kwargs): + super().__init__(collar=collar, skip_overlap=skip_overlap, **kwargs) + self.mapper_ = HungarianMapper() + + def optimal_mapping(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None) -> Dict[Label, Label]: + """Optimal label mapping + + Parameters + ---------- + reference : Annotation + hypothesis : Annotation + Reference and hypothesis diarization + uem : Timeline + Evaluation map + + Returns + ------- + mapping : dict + Mapping between hypothesis (key) and reference (value) labels + """ + + # NOTE that this 'uemification' will not be called when + # 'optimal_mapping' is called from 'compute_components' as it + # has already been done in 'compute_components' + if uem: + reference, hypothesis = self.uemify(reference, hypothesis, uem=uem) + + # call hungarian mapper + return self.mapper_(hypothesis, reference) + + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: + # crop reference and hypothesis to evaluated regions (uem) + # remove collars around reference segment boundaries + # remove overlap regions (if requested) + reference, hypothesis, uem = self.uemify( + reference, hypothesis, uem=uem, + collar=self.collar, skip_overlap=self.skip_overlap, + returns_uem=True) + # NOTE that this 'uemification' must be done here because it + # might have an impact on the search for the optimal mapping. + + # make sure reference only contains string labels ('A', 'B', ...) + reference = reference.rename_labels(generator='string') + + # make sure hypothesis only contains integer labels (1, 2, ...) + hypothesis = hypothesis.rename_labels(generator='int') + + # optimal (int --> str) mapping + mapping = self.optimal_mapping(reference, hypothesis) + + # compute identification error rate based on mapped hypothesis + # NOTE that collar is set to 0.0 because 'uemify' has already + # been applied (same reason for setting skip_overlap to False) + mapped = hypothesis.rename_labels(mapping=mapping) + return super(DiarizationErrorRate, self) \ + .compute_components(reference, mapped, uem=uem, + collar=0.0, skip_overlap=False, + **kwargs) + diff --git a/ailia-models/code/pyannote_audio_utils/metrics/identification.py b/ailia-models/code/pyannote_audio_utils/metrics/identification.py new file mode 100644 index 0000000000000000000000000000000000000000..7db8e34ffc730f7662d11223ea1177a728db8bf2 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/identification.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2012-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +from typing import Optional + +from pyannote_audio_utils.core import Annotation, Timeline + +from .base import BaseMetric +from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED +from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED +from .matcher import LabelMatcher, \ + MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \ + MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM +from .types import MetricComponents, Details +from .utils import UEMSupportMixin + +# TODO: can't we put these as class attributes? +IER_TOTAL = MATCH_TOTAL +IER_CORRECT = MATCH_CORRECT +IER_CONFUSION = MATCH_CONFUSION +IER_FALSE_ALARM = MATCH_FALSE_ALARM +IER_MISS = MATCH_MISSED_DETECTION +IER_NAME = 'identification error rate' + + +class IdentificationErrorRate(UEMSupportMixin, BaseMetric): + """Identification error rate + + ``ier = (wc x confusion + wf x false_alarm + wm x miss) / total`` + + where + - `confusion` is the total confusion duration in seconds + - `false_alarm` is the total hypothesis duration where there are + - `miss` is + - `total` is the total duration of all tracks + - wc, wf and wm are optional weights (default to 1) + + Parameters + ---------- + collar : float, optional + Duration (in seconds) of collars removed from evaluation around + boundaries of reference segments. + skip_overlap : bool, optional + Set to True to not evaluate overlap regions. + Defaults to False (i.e. keep overlap regions). + confusion, miss, false_alarm: float, optional + Optional weights for confusion, miss and false alarm respectively. + Default to 1. (no weight) + """ + + @classmethod + def metric_name(cls) -> str: + return IER_NAME + + @classmethod + def metric_components(cls) -> MetricComponents: + return [ + IER_TOTAL, + IER_CORRECT, + IER_FALSE_ALARM, IER_MISS, + IER_CONFUSION] + + def __init__(self, + confusion: float = 1., + miss: float = 1., + false_alarm: float = 1., + collar: float = 0., + skip_overlap: bool = False, + **kwargs): + + super().__init__(**kwargs) + self.matcher_ = LabelMatcher() + self.confusion = confusion + self.miss = miss + self.false_alarm = false_alarm + self.collar = collar + self.skip_overlap = skip_overlap + + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + collar: Optional[float] = None, + skip_overlap: Optional[float] = None, + **kwargs) -> Details: + """ + + Parameters + ---------- + collar : float, optional + Override self.collar + skip_overlap : bool, optional + Override self.skip_overlap + + See also + -------- + :class:`pyannote_audio_utils.metric.diarization.DiarizationErrorRate` uses these + two options in its `compute_components` method. + + """ + + detail = self.init_components() + + if collar is None: + collar = self.collar + if skip_overlap is None: + skip_overlap = self.skip_overlap + + R, H, common_timeline = self.uemify( + reference, hypothesis, uem=uem, + collar=collar, skip_overlap=skip_overlap, + returns_timeline=True) + + # loop on all segments + for segment in common_timeline: + # segment duration + duration = segment.duration + + # list of IDs in reference segment + r = R.get_labels(segment, unique=False) + + # list of IDs in hypothesis segment + h = H.get_labels(segment, unique=False) + + counts, _ = self.matcher_(r, h) + + detail[IER_TOTAL] += duration * counts[IER_TOTAL] + detail[IER_CORRECT] += duration * counts[IER_CORRECT] + detail[IER_CONFUSION] += duration * counts[IER_CONFUSION] + detail[IER_MISS] += duration * counts[IER_MISS] + detail[IER_FALSE_ALARM] += duration * counts[IER_FALSE_ALARM] + + return detail + + def compute_metric(self, detail: Details) -> float: + + numerator = 1. * ( + self.confusion * detail[IER_CONFUSION] + + self.false_alarm * detail[IER_FALSE_ALARM] + + self.miss * detail[IER_MISS] + ) + denominator = 1. * detail[IER_TOTAL] + if denominator == 0.: + if numerator == 0: + return 0. + else: + return 1. + else: + return numerator / denominator + + +class IdentificationPrecision(UEMSupportMixin, Precision): + """Identification Precision + + Parameters + ---------- + collar : float, optional + Duration (in seconds) of collars removed from evaluation around + boundaries of reference segments. + skip_overlap : bool, optional + Set to True to not evaluate overlap regions. + Defaults to False (i.e. keep overlap regions). + """ + + def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): + super().__init__(**kwargs) + self.collar = collar + self.skip_overlap = skip_overlap + self.matcher_ = LabelMatcher() + + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: + detail = self.init_components() + + R, H, common_timeline = self.uemify( + reference, hypothesis, uem=uem, + collar=self.collar, skip_overlap=self.skip_overlap, + returns_timeline=True) + + # loop on all segments + for segment in common_timeline: + # segment duration + duration = segment.duration + + # list of IDs in reference segment + r = R.get_labels(segment, unique=False) + + # list of IDs in hypothesis segment + h = H.get_labels(segment, unique=False) + + counts, _ = self.matcher_(r, h) + + detail[PRECISION_RETRIEVED] += duration * len(h) + detail[PRECISION_RELEVANT_RETRIEVED] += \ + duration * counts[IER_CORRECT] + + return detail + + +class IdentificationRecall(UEMSupportMixin, Recall): + """Identification Recall + + Parameters + ---------- + collar : float, optional + Duration (in seconds) of collars removed from evaluation around + boundaries of reference segments. + skip_overlap : bool, optional + Set to True to not evaluate overlap regions. + Defaults to False (i.e. keep overlap regions). + """ + + def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): + super().__init__(**kwargs) + self.collar = collar + self.skip_overlap = skip_overlap + self.matcher_ = LabelMatcher() + + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: + detail = self.init_components() + + R, H, common_timeline = self.uemify( + reference, hypothesis, uem=uem, + collar=self.collar, skip_overlap=self.skip_overlap, + returns_timeline=True) + + # loop on all segments + for segment in common_timeline: + # segment duration + duration = segment.duration + + # list of IDs in reference segment + r = R.get_labels(segment, unique=False) + + # list of IDs in hypothesis segment + h = H.get_labels(segment, unique=False) + + counts, _ = self.matcher_(r, h) + + detail[RECALL_RELEVANT] += duration * counts[IER_TOTAL] + detail[RECALL_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT] + + return detail diff --git a/ailia-models/code/pyannote_audio_utils/metrics/matcher.py b/ailia-models/code/pyannote_audio_utils/metrics/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..1b71039016025872777d212adc369de87878348e --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/matcher.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2012-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +from typing import Dict, Tuple, Iterable, List, TYPE_CHECKING + +import numpy as np +from pyannote_audio_utils.core import Annotation +from scipy.optimize import linear_sum_assignment + +if TYPE_CHECKING: + from pyannote_audio_utils.core.utils.types import Label + +MATCH_CORRECT = 'correct' +MATCH_CONFUSION = 'confusion' +MATCH_MISSED_DETECTION = 'missed detection' +MATCH_FALSE_ALARM = 'false alarm' +MATCH_TOTAL = 'total' + + +class LabelMatcher: + """ + ID matcher base class mixin. + + All ID matcher classes must inherit from this class and implement + .match() -- ie return True if two IDs match and False + otherwise. + """ + + def match(self, rlabel: 'Label', hlabel: 'Label') -> bool: + """ + Parameters + ---------- + rlabel : + Reference label + hlabel : + Hypothesis label + + Returns + ------- + match : bool + True if labels match, False otherwise. + + """ + # Two IDs match if they are equal to each other + return rlabel == hlabel + + def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \ + -> Tuple[Dict[str, int], + Dict[str, List['Label']]]: + """ + + Parameters + ---------- + rlabels, hlabels : iterable + Reference and hypothesis labels + + Returns + ------- + counts : dict + details : dict + + """ + + # counts and details + counts = { + MATCH_CORRECT: 0, + MATCH_CONFUSION: 0, + MATCH_MISSED_DETECTION: 0, + MATCH_FALSE_ALARM: 0, + MATCH_TOTAL: 0 + } + + details = { + MATCH_CORRECT: [], + MATCH_CONFUSION: [], + MATCH_MISSED_DETECTION: [], + MATCH_FALSE_ALARM: [] + } + # this is to make sure rlabels and hlabels are lists + # as we will access them later by index + rlabels = list(rlabels) + hlabels = list(hlabels) + + NR = len(rlabels) + NH = len(hlabels) + N = max(NR, NH) + + # corner case + if N == 0: + return counts, details + + # initialize match matrix + # with True if labels match and False otherwise + match = np.zeros((N, N), dtype=bool) + for r, rlabel in enumerate(rlabels): + for h, hlabel in enumerate(hlabels): + match[r, h] = self.match(rlabel, hlabel) + + # find one-to-one mapping that maximize total number of matches + # using the Hungarian algorithm and computes error accordingly + for r, h in zip(*linear_sum_assignment(~match)): + + # hypothesis label is matched with unexisting reference label + # ==> this is a false alarm + if r >= NR: + counts[MATCH_FALSE_ALARM] += 1 + details[MATCH_FALSE_ALARM].append(hlabels[h]) + + # reference label is matched with unexisting hypothesis label + # ==> this is a missed detection + elif h >= NH: + counts[MATCH_MISSED_DETECTION] += 1 + details[MATCH_MISSED_DETECTION].append(rlabels[r]) + + # reference and hypothesis labels match + # ==> this is a correct detection + elif match[r, h]: + counts[MATCH_CORRECT] += 1 + details[MATCH_CORRECT].append((rlabels[r], hlabels[h])) + + # reference and hypothesis do not match + # ==> this is a confusion + else: + counts[MATCH_CONFUSION] += 1 + details[MATCH_CONFUSION].append((rlabels[r], hlabels[h])) + + counts[MATCH_TOTAL] += NR + + # returns counts and details + return counts, details + + +class HungarianMapper: + + def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: + mapping = {} + + cooccurrence = A * B + a_labels, b_labels = A.labels(), B.labels() + + for a, b in zip(*linear_sum_assignment(-cooccurrence)): + if cooccurrence[a, b] > 0: + mapping[a_labels[a]] = b_labels[b] + + return mapping + + +class GreedyMapper: + + def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: + mapping = {} + + cooccurrence = A * B + Na, Nb = cooccurrence.shape + a_labels, b_labels = A.labels(), B.labels() + + for i in range(min(Na, Nb)): + a, b = np.unravel_index(np.argmax(cooccurrence), (Na, Nb)) + + if cooccurrence[a, b] > 0: + mapping[a_labels[a]] = b_labels[b] + cooccurrence[a, :] = 0. + cooccurrence[:, b] = 0. + continue + + break + + return mapping diff --git a/ailia-models/code/pyannote_audio_utils/metrics/types.py b/ailia-models/code/pyannote_audio_utils/metrics/types.py new file mode 100644 index 0000000000000000000000000000000000000000..f5498a412cc776ba512603b387ab9c6836ded623 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/types.py @@ -0,0 +1,7 @@ +from typing import Dict, List, Literal + + +MetricComponent = str +CalibrationMethod = Literal["isotonic", "sigmoid"] +MetricComponents = List[MetricComponent] +Details = Dict[MetricComponent, float] \ No newline at end of file diff --git a/ailia-models/code/pyannote_audio_utils/metrics/utils.py b/ailia-models/code/pyannote_audio_utils/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e902d388aa9c7424f700a3eda55ca86336514d1c --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/metrics/utils.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2012-2019 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +import warnings +from typing import Optional, Tuple, Union + +from pyannote_audio_utils.core import Timeline, Segment, Annotation + + +class UEMSupportMixin: + """Provides 'uemify' method with optional (à la NIST) collar""" + + def extrude(self, + uem: Timeline, + reference: Annotation, + collar: float = 0.0, + skip_overlap: bool = False) -> Timeline: + """Extrude reference boundary collars from uem + + reference |----| |--------------| |-------------| + uem |---------------------| |-------------------------------| + extruded |--| |--| |---| |-----| |-| |-----| |-----------| |-----| + + Parameters + ---------- + uem : Timeline + Evaluation map. + reference : Annotation + Reference annotation. + collar : float, optional + When provided, set the duration of collars centered around + reference segment boundaries that are extruded from both reference + and hypothesis. Defaults to 0. (i.e. no collar). + skip_overlap : bool, optional + Set to True to not evaluate overlap regions. + Defaults to False (i.e. keep overlap regions). + + Returns + ------- + extruded_uem : Timeline + """ + + if collar == 0. and not skip_overlap: + return uem + + collars, overlap_regions = [], [] + + # build list of collars if needed + if collar > 0.: + # iterate over all segments in reference + for segment in reference.itersegments(): + # add collar centered on start time + t = segment.start + collars.append(Segment(t - .5 * collar, t + .5 * collar)) + + # add collar centered on end time + t = segment.end + collars.append(Segment(t - .5 * collar, t + .5 * collar)) + + # build list of overlap regions if needed + if skip_overlap: + # iterate over pair of intersecting segments + for (segment1, track1), (segment2, track2) in reference.co_iter(reference): + if segment1 == segment2 and track1 == track2: + continue + # add their intersection + overlap_regions.append(segment1 & segment2) + + segments = collars + overlap_regions + + return Timeline(segments=segments).support().gaps(support=uem) + + def common_timeline(self, reference: Annotation, hypothesis: Annotation) \ + -> Timeline: + """Return timeline common to both reference and hypothesis + + reference |--------| |------------| |---------| |----| + hypothesis |--------------| |------| |----------------| + timeline |--|-----|----|---|-|------| |-|---------|----| |----| + + Parameters + ---------- + reference : Annotation + hypothesis : Annotation + + Returns + ------- + timeline : Timeline + """ + timeline = reference.get_timeline(copy=True) + timeline.update(hypothesis.get_timeline(copy=False)) + return timeline.segmentation() + + def project(self, annotation: Annotation, timeline: Timeline) -> Annotation: + """Project annotation onto timeline segments + + reference |__A__| |__B__| + |____C____| + + timeline |---|---|---| |---| + + projection |_A_|_A_|_C_| |_B_| + |_C_| + + Parameters + ---------- + annotation : Annotation + timeline : Timeline + + Returns + ------- + projection : Annotation + """ + projection = annotation.empty() + timeline_ = annotation.get_timeline(copy=False) + for segment_, segment in timeline_.co_iter(timeline): + for track_ in annotation.get_tracks(segment_): + track = projection.new_track(segment, candidate=track_) + projection[segment, track] = annotation[segment_, track_] + return projection + + def uemify(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + collar: float = 0., + skip_overlap: bool = False, + returns_uem: bool = False, + returns_timeline: bool = False) \ + -> Union[ + Tuple[Annotation, Annotation], + Tuple[Annotation, Annotation, Timeline], + Tuple[Annotation, Annotation, Timeline, Timeline], + ]: + """Crop 'reference' and 'hypothesis' to 'uem' support + + Parameters + ---------- + reference, hypothesis : Annotation + Reference and hypothesis annotations. + uem : Timeline, optional + Evaluation map. + collar : float, optional + When provided, set the duration of collars centered around + reference segment boundaries that are extruded from both reference + and hypothesis. Defaults to 0. (i.e. no collar). + skip_overlap : bool, optional + Set to True to not evaluate overlap regions. + Defaults to False (i.e. keep overlap regions). + returns_uem : bool, optional + Set to True to return extruded uem as well. + Defaults to False (i.e. only return reference and hypothesis) + returns_timeline : bool, optional + Set to True to oversegment reference and hypothesis so that they + share the same internal timeline. + + Returns + ------- + reference, hypothesis : Annotation + Extruded reference and hypothesis annotations + uem : Timeline + Extruded uem (returned only when 'returns_uem' is True) + timeline : Timeline: + Common timeline (returned only when 'returns_timeline' is True) + """ + + # when uem is not provided, use the union of reference and hypothesis + # extents -- and warn the user about that. + if uem is None: + r_extent = reference.get_timeline().extent() + h_extent = hypothesis.get_timeline().extent() + extent = r_extent | h_extent + uem = Timeline(segments=[extent] if extent else [], + uri=reference.uri) + warnings.warn( + "'uem' was approximated by the union of 'reference' " + "and 'hypothesis' extents.") + + # extrude collars (and overlap regions) from uem + uem = self.extrude(uem, reference, collar=collar, + skip_overlap=skip_overlap) + + # extrude regions outside of uem + reference = reference.crop(uem, mode='intersection') + hypothesis = hypothesis.crop(uem, mode='intersection') + + # project reference and hypothesis on common timeline + if returns_timeline: + timeline = self.common_timeline(reference, hypothesis) + reference = self.project(reference, timeline) + hypothesis = self.project(hypothesis, timeline) + + result = (reference, hypothesis) + if returns_uem: + result += (uem,) + + if returns_timeline: + result += (timeline,) + + return result diff --git a/ailia-models/code/pyannote_audio_utils/pipeline/__init__.py b/ailia-models/code/pyannote_audio_utils/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03bd8c139b9e1b648022933761fe6159f827166d --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/pipeline/__init__.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2018-2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + + +# from ._version import get_versions + +# __version__ = get_versions()["version"] +# del get_versions + + +from .pipeline import Pipeline +# from .optimizer import Optimizer diff --git a/ailia-models/code/pyannote_audio_utils/pipeline/parameter.py b/ailia-models/code/pyannote_audio_utils/pipeline/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..f3534d8d434313368412c33e3eae0b2657c23e79 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/pipeline/parameter.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2018-2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr +# Hadrien TITEUX - https://github.com/hadware + + +from typing import Iterable, Any + +from .pipeline import Pipeline +from collections.abc import Mapping + + +class Parameter: + """Base hyper-parameter""" + + pass + + +class Categorical(Parameter): + """Categorical hyper-parameter + + The value is sampled from `choices`. + + Parameters + ---------- + choices : iterable + Candidates of hyper-parameter value. + """ + + def __init__(self, choices: Iterable): + super().__init__() + self.choices = list(choices) + + # def __call__(self, name: str, trial: Trial): + # return trial.suggest_categorical(name, self.choices) + + +class DiscreteUniform(Parameter): + """Discrete uniform hyper-parameter + + The value is sampled from the range [low, high], + and the step of discretization is `q`. + + Parameters + ---------- + low : `float` + Lower endpoint of the range of suggested values. + `low` is included in the range. + high : `float` + Upper endpoint of the range of suggested values. + `high` is included in the range. + q : `float` + A step of discretization. + """ + + def __init__(self, low: float, high: float, q: float): + super().__init__() + self.low = float(low) + self.high = float(high) + self.q = float(q) + + # def __call__(self, name: str, trial: Trial): + # return trial.suggest_discrete_uniform(name, self.low, self.high, self.q) + + +class Integer(Parameter): + """Integer hyper-parameter + + The value is sampled from the integers in [low, high]. + + Parameters + ---------- + low : `int` + Lower endpoint of the range of suggested values. + `low` is included in the range. + high : `int` + Upper endpoint of the range of suggested values. + `high` is included in the range. + """ + + def __init__(self, low: int, high: int): + super().__init__() + self.low = int(low) + self.high = int(high) + + # def __call__(self, name: str, trial: Trial): + # return trial.suggest_int(name, self.low, self.high) + + +class LogUniform(Parameter): + """Log-uniform hyper-parameter + + The value is sampled from the range [low, high) in the log domain. + + Parameters + ---------- + low : `float` + Lower endpoint of the range of suggested values. + `low` is included in the range. + high : `float` + Upper endpoint of the range of suggested values. + `high` is excluded from the range. + """ + + def __init__(self, low: float, high: float): + super().__init__() + self.low = float(low) + self.high = float(high) + + # def __call__(self, name: str, trial: Trial): + # return trial.suggest_loguniform(name, self.low, self.high) + + +class Uniform(Parameter): + """Uniform hyper-parameter + + The value is sampled from the range [low, high) in the linear domain. + + Parameters + ---------- + low : `float` + Lower endpoint of the range of suggested values. + `low` is included in the range. + high : `float` + Upper endpoint of the range of suggested values. + `high` is excluded from the range. + """ + + def __init__(self, low: float, high: float): + super().__init__() + self.low = float(low) + self.high = float(high) + + # def __call__(self, name: str, trial: Trial): + # return trial.suggest_uniform(name, self.low, self.high) + + +class Frozen(Parameter): + """Frozen hyper-parameter + + The value is fixed a priori + + Parameters + ---------- + value : + Fixed value. + """ + + def __init__(self, value: Any): + super().__init__() + self.value = value + + # def __call__(self, name: str, trial: Trial): + # return self.value + + +class ParamDict(Pipeline, Mapping): + """Dict-like structured hyper-parameter + + Usage + ----- + >>> params = ParamDict(param1=Uniform(0.0, 1.0), param2=Uniform(-1.0, 1.0)) + >>> params = ParamDict(**{"param1": Uniform(0.0, 1.0), "param2": Uniform(-1.0, 1.0)}) + """ + + def __init__(self, **params): + super().__init__() + self.__params = params + for param_name, param_value in params.items(): + setattr(self, param_name, param_value) + + def __len__(self): + return len(self.__params) + + def __iter__(self): + return iter(self.__params) + + def __getitem__(self, param_name): + return getattr(self, param_name) diff --git a/ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py b/ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e189fa1d7a8e34e35feb7fb1bf212dfd205089e1 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py @@ -0,0 +1,614 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2018-2022 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +from typing import Optional, TextIO, Union, Dict, Any + +from collections import OrderedDict +from .typing import PipelineInput +from .typing import PipelineOutput +import warnings + + +class Pipeline: + """Base tunable pipeline""" + + def __init__(self): + + # un-instantiated parameters (= `Parameter` instances) + self._parameters: Dict[str, Parameter] = OrderedDict() + + # instantiated parameters + self._instantiated: Dict[str, Any] = OrderedDict() + + # sub-pipelines + self._pipelines: Dict[str, Pipeline] = OrderedDict() + + # whether pipeline is currently being optimized + self.training = False + + @property + def training(self): + return self._training + + @training.setter + def training(self, training): + self._training = training + # recursively set sub-pipeline training attribute + for _, pipeline in self._pipelines.items(): + pipeline.training = training + + def __hash__(self): + # FIXME -- also keep track of (sub)pipeline attributes + frozen = self.parameters(frozen=True) + return hash(tuple(sorted(self._flatten(frozen).items()))) + + def __getattr__(self, name): + """(Advanced) attribute getter""" + + # in case `name` corresponds to an instantiated parameter value, returns it + if "_instantiated" in self.__dict__: + _instantiated = self.__dict__["_instantiated"] + if name in _instantiated: + return _instantiated[name] + + # in case `name` corresponds to a parameter, returns it + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + return _parameters[name] + + # in case `name` corresponds to a sub-pipeline, returns it + if "_pipelines" in self.__dict__: + _pipelines = self.__dict__["_pipelines"] + if name in _pipelines: + return _pipelines[name] + + msg = "'{}' object has no attribute '{}'".format(type(self).__name__, name) + raise AttributeError(msg) + + def __setattr__(self, name, value): + """(Advanced) attribute setter + + If `value` is an instance of `Parameter`, store it in `_parameters`. + elif `value` is an instance of `Pipeline`, store it in `_pipelines`. + elif `value` isn't an instance of `Parameter` and `name` is in `_parameters`, + store `value` in `_instantiated`. + """ + + # imported here to avoid circular import + from .parameter import Parameter + + def remove_from(*dicts): + for d in dicts: + if name in d: + del d[name] + + _parameters = self.__dict__.get("_parameters") + _instantiated = self.__dict__.get("_instantiated") + _pipelines = self.__dict__.get("_pipelines") + + # if `value` is an instance of `Parameter`, store it in `_parameters` + + if isinstance(value, Parameter): + if _parameters is None: + msg = ( + "cannot assign hyper-parameters " "before Pipeline.__init__() call" + ) + raise AttributeError(msg) + remove_from(self.__dict__, _instantiated, _pipelines) + _parameters[name] = value + return + + # add/update one sub-pipeline + if isinstance(value, Pipeline): + if _pipelines is None: + msg = "cannot assign sub-pipelines " "before Pipeline.__init__() call" + raise AttributeError(msg) + remove_from(self.__dict__, _parameters, _instantiated) + _pipelines[name] = value + return + + # store instantiated parameter value + if _parameters is not None and name in _parameters: + _instantiated[name] = value + return + + object.__setattr__(self, name, value) + + def __delattr__(self, name): + + if name in self._parameters: + del self._parameters[name] + + elif name in self._instantiated: + del self._instantiated[name] + + elif name in self._pipelines: + del self._pipelines[name] + + else: + object.__delattr__(self, name) + + def _flattened_parameters( + self, frozen: Optional[bool] = False, instantiated: Optional[bool] = False + ) -> dict: + """Get flattened dictionary of parameters + + Parameters + ---------- + frozen : `bool`, optional + Only return value of frozen parameters. + instantiated : `bool`, optional + Only return value of instantiated parameters. + + Returns + ------- + params : `dict` + Flattened dictionary of parameters. + """ + + # imported here to avoid circular imports + from .parameter import Frozen + + if frozen and instantiated: + msg = "one must choose between `frozen` and `instantiated`." + raise ValueError(msg) + + # initialize dictionary with root parameters + if instantiated: + params = dict(self._instantiated) + + elif frozen: + params = { + n: p.value for n, p in self._parameters.items() if isinstance(p, Frozen) + } + + else: + params = dict(self._parameters) + + # recursively add sub-pipeline parameters + for pipeline_name, pipeline in self._pipelines.items(): + pipeline_params = pipeline._flattened_parameters( + frozen=frozen, instantiated=instantiated + ) + for name, value in pipeline_params.items(): + params[f"{pipeline_name}>{name}"] = value + + return params + + def _flatten(self, nested_params: dict) -> dict: + """Convert nested dictionary to flattened dictionary + + For instance, a nested dictionary like this one: + + ~~~~~~~~~~~~~~~~~~~~~ + param: value1 + pipeline: + param: value2 + subpipeline: + param: value3 + ~~~~~~~~~~~~~~~~~~~~~ + + becomes the following flattened dictionary: + + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + param : value1 + pipeline>param : value2 + pipeline>subpipeline>param : value3 + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Parameter + --------- + nested_params : `dict` + + Returns + ------- + flattened_params : `dict` + """ + flattened_params = dict() + for name, value in nested_params.items(): + if isinstance(value, dict): + for subname, subvalue in self._flatten(value).items(): + flattened_params[f"{name}>{subname}"] = subvalue + else: + flattened_params[name] = value + return flattened_params + + def _unflatten(self, flattened_params: dict) -> dict: + """Convert flattened dictionary to nested dictionary + + For instance, a flattened dictionary like this one: + + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + param : value1 + pipeline>param : value2 + pipeline>subpipeline>param : value3 + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + becomes the following nested dictionary: + + ~~~~~~~~~~~~~~~~~~~~~ + param: value1 + pipeline: + param: value2 + subpipeline: + param: value3 + ~~~~~~~~~~~~~~~~~~~~~ + + Parameter + --------- + flattened_params : `dict` + + Returns + ------- + nested_params : `dict` + """ + + nested_params = {} + + pipeline_params = {name: {} for name in self._pipelines} + for name, value in flattened_params.items(): + # if name contains has multipe ">"-separated tokens + # it means that it is a sub-pipeline parameter + tokens = name.split(">") + if len(tokens) > 1: + # read sub-pipeline name + pipeline_name = tokens[0] + # read parameter name + param_name = ">".join(tokens[1:]) + # update sub-pipeline flattened dictionary + pipeline_params[pipeline_name][param_name] = value + + # otherwise, it is an actual parameter of this pipeline + else: + # store it as such + nested_params[name] = value + + # recursively unflatten sub-pipeline flattened dictionary + for name, pipeline in self._pipelines.items(): + nested_params[name] = pipeline._unflatten(pipeline_params[name]) + + return nested_params + + def parameters( + self, + trial = None, + frozen: Optional[bool] = False, + instantiated: Optional[bool] = False, + ) -> dict: + """Returns nested dictionary of (optionnaly instantiated) parameters. + + For a pipeline with one `param`, one sub-pipeline with its own param + and its own sub-pipeline, it will returns something like: + + ~~~~~~~~~~~~~~~~~~~~~ + param: value1 + pipeline: + param: value2 + subpipeline: + param: value3 + ~~~~~~~~~~~~~~~~~~~~~ + + Parameter + --------- + trial : `Trial`, optional + When provided, use trial to suggest new parameter values + and return them. + frozen : `bool`, optional + Return frozen parameter value + instantiated : `bool`, optional + Return instantiated parameter values. + + Returns + ------- + params : `dict` + Nested dictionary of parameters. See above for the actual format. + """ + + if (instantiated or frozen) and trial is not None: + msg = "One must choose between `trial`, `instantiated`, or `frozen`" + raise ValueError(msg) + + # get flattened dictionary of uninstantiated parameters + params = self._flattened_parameters(frozen=frozen, instantiated=instantiated) + + if trial is not None: + # use provided `trial` to suggest values for parameters + params = {name: param(name, trial) for name, param in params.items()} + + # un-flatten flattened dictionary + return self._unflatten(params) + + def initialize(self): + """Instantiate root pipeline with current set of parameters""" + pass + + # def freeze(self, params: dict) -> "Pipeline": + # """Recursively freeze pipeline parameters + + # Parameters + # ---------- + # params : `dict` + # Nested dictionary of parameters. + + # Returns + # ------- + # self : `Pipeline` + # Pipeline. + # """ + + # # imported here to avoid circular imports + # from .parameter import Frozen + + # for name, value in params.items(): + + # # recursively freeze sub-pipelines parameters + # if name in self._pipelines: + # if not isinstance(value, dict): + # msg = ( + # f"only parameters of '{name}' pipeline can " + # f"be frozen (not the whole pipeline)" + # ) + # raise ValueError(msg) + # self._pipelines[name].freeze(value) + # continue + + # # instantiate parameter value + # if name in self._parameters: + # setattr(self, name, Frozen(value)) + # continue + + # msg = f"parameter '{name}' does not exist" + # raise ValueError(msg) + + # return self + + def instantiate(self, params: dict) -> "Pipeline": + """Recursively instantiate all pipelines + + Parameters + ---------- + params : `dict` + Nested dictionary of parameters. + + Returns + ------- + self : `Pipeline` + Instantiated pipeline. + """ + + # imported here to avoid circular imports + from .parameter import Frozen + + for name, value in params.items(): + + # recursively call `instantiate` with sub-pipelines + if name in self._pipelines: + if not isinstance(value, dict): + msg = ( + f"only parameters of '{name}' pipeline can " + f"be instantiated (not the whole pipeline)" + ) + raise ValueError(msg) + self._pipelines[name].instantiate(value) + continue + + # instantiate parameter value + if name in self._parameters: + param = getattr(self, name) + # overwrite provided value of frozen parameters + if isinstance(param, Frozen) and param.value != value: + msg = ( + f"Parameter '{name}' is frozen: using its frozen value " + f"({param.value}) instead of the one provided ({value})." + ) + warnings.warn(msg) + value = param.value + setattr(self, name, value) + continue + + msg = f"parameter '{name}' does not exist" + raise ValueError(msg) + + self.initialize() + + return self + + @property + def instantiated(self): + """Whether pipeline has been instantiated (and therefore can be applied)""" + parameters = set(self._flatten(self.parameters())) + instantiated = set(self._flatten(self.parameters(instantiated=True))) + return parameters == instantiated + + # def dump_params( + # self, + # params_yml: Path, + # params: Optional[dict] = None, + # loss: Optional[float] = None, + # ) -> str: + # """Dump parameters to disk + + # Parameters + # ---------- + # params_yml : `Path` + # Path to YAML file. + # params : `dict`, optional + # Nested Parameters. Defaults to pipeline current parameters. + # loss : `float`, optional + # Loss value. Defaults to not write loss to file. + + # Returns + # ------- + # content : `str` + # Content written in `param_yml`. + # """ + # # use instantiated parameters when `params` is not provided + # if params is None: + # params = self.parameters(instantiated=True) + + # content = {"params": params} + # if loss is not None: + # content["loss"] = loss + + # # format as valid YAML + # content_yml = yaml.dump(content, default_flow_style=False) + + # # (safely) dump YAML content + # with FileLock(params_yml.with_suffix(".lock")): + # with open(params_yml, mode="w") as fp: + # fp.write(content_yml) + + # return content_yml + + # def load_params(self, params_yml: Path) -> "Pipeline": + # """Instantiate pipeline using parameters from disk + + # Parameters + # ---------- + # param_yml : `Path` + # Path to YAML file. + + # Returns + # ------- + # self : `Pipeline` + # Instantiated pipeline + + # """ + + # with open(params_yml, mode="r") as fp: + # params = yaml.load(fp, Loader=yaml.SafeLoader) + # return self.instantiate(params["params"]) + + def __call__(self, input: PipelineInput) -> PipelineOutput: + """Apply pipeline on input and return its output""" + raise NotImplementedError + + # def get_metric(self) -> "pyannote.metrics.base.BaseMetric": + # """Return new metric (from pyannote.metrics) + + # When this method is implemented, the returned metric is used as a + # replacement for the loss method below. + + # Returns + # ------- + # metric : `pyannote.metrics.base.BaseMetric` + # """ + # raise NotImplementedError() + + # def get_direction(self) -> Direction: + # return "minimize" + + # def loss(self, input: PipelineInput, output: PipelineOutput) -> float: + # """Compute loss for given input/output pair + + # Parameters + # ---------- + # input : object + # Pipeline input. + # output : object + # Pipeline output + + # Returns + # ------- + # loss : `float` + # Loss value + # """ + # raise NotImplementedError() + + # @property + # def write_format(self): + # return "rttm" + + # def write(self, file: TextIO, output: PipelineOutput): + # """Write pipeline output to file + + # Parameters + # ---------- + # file : file object + # output : object + # Pipeline output + # """ + + # return getattr(self, f"write_{self.write_format}")(file, output) + + # def write_rttm(self, file: TextIO, output: Union[Timeline, Annotation]): + # """Write pipeline output to "rttm" file + + # Parameters + # ---------- + # file : file object + # output : `pyannote.core.Timeline` or `pyannote.core.Annotation` + # Pipeline output + # """ + + # if isinstance(output, Timeline): + # output = output.to_annotation(generator="string") + + # if isinstance(output, Annotation): + # for s, t, l in output.itertracks(yield_label=True): + # line = ( + # f"SPEAKER {output.uri} 1 {s.start:.3f} {s.duration:.3f} " + # f" {l} \n" + # ) + # file.write(line) + # return + + # msg = ( + # f'Dumping {output.__class__.__name__} instances to "rttm" files ' + # f"is not supported." + # ) + # raise NotImplementedError(msg) + + # def write_txt(self, file: TextIO, output: Union[Timeline, Annotation]): + # """Write pipeline output to "txt" file + + # Parameters + # ---------- + # file : file object + # output : `pyannote.core.Timeline` or `pyannote.core.Annotation` + # Pipeline output + # """ + + # if isinstance(output, Timeline): + # for s in output: + # line = f"{output.uri} {s.start:.3f} {s.end:.3f}\n" + # file.write(line) + # return + + # if isinstance(output, Annotation): + # for s, t, l in output.itertracks(yield_label=True): + # line = f"{output.uri} {s.start:.3f} {s.end:.3f} {t} {l}\n" + # file.write(line) + # return + + # msg = ( + # f'Dumping {output.__class__.__name__} instances to "txt" files ' + # f"is not supported." + # ) + # raise NotImplementedError(msg) diff --git a/ailia-models/code/pyannote_audio_utils/pipeline/typing.py b/ailia-models/code/pyannote_audio_utils/pipeline/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..99411c30819740ef0c8f848090ab9ca2ac5fc965 --- /dev/null +++ b/ailia-models/code/pyannote_audio_utils/pipeline/typing.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2018-2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + +from typing import TypeVar + +PipelineInput = TypeVar("PipelineInput") +PipelineOutput = TypeVar("PipelineOutput") diff --git a/ailia-models/code/requirements.txt b/ailia-models/code/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a17300fb8ffd7f10eb4e259fb43210e8316c26d9 --- /dev/null +++ b/ailia-models/code/requirements.txt @@ -0,0 +1,6 @@ +pyyaml +sortedcontainers +pandas +soundfile + + diff --git a/ailia-models/segmentation.onnx b/ailia-models/segmentation.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5f445989c256526dd9f1d88bfd49c731cacf8ee1 --- /dev/null +++ b/ailia-models/segmentation.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b78fc48113bb46fd247ae6a9aea737079550c647638db961df7e0e1e9f4ba62e +size 5983836 diff --git a/ailia-models/segmentation.onnx.prototxt b/ailia-models/segmentation.onnx.prototxt new file mode 100644 index 0000000000000000000000000000000000000000..29055ad05871cd840f47366d69b0e3944417959d --- /dev/null +++ b/ailia-models/segmentation.onnx.prototxt @@ -0,0 +1,1414 @@ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.2.1" +model_version: 0 +graph { + name: "main_graph" + node { + input: "input" + input: "ortshared_1_1_1_1_token_110" + input: "ortshared_1_1_1_0_token_107" + output: "/sincnet/wav_norm1d/InstanceNormalization_output_0" + name: "/sincnet/wav_norm1d/InstanceNormalization" + op_type: "InstanceNormalization" + attribute { + name: "epsilon" + f: 9.999999747378752e-06 + type: FLOAT + } + } + node { + input: "/sincnet/wav_norm1d/InstanceNormalization_output_0" + output: "onnx::Gather_115" + name: "Shape_111" + op_type: "Shape" + attribute { + name: "start" + i: 0 + type: INT + } + } + node { + input: "onnx::Gather_115" + input: "ortshared_7_0_1_1_token_113" + output: "onnx::Equal_117" + name: "Gather_113" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "onnx::Equal_117" + input: "ortshared_7_0_1_1_token_113" + output: "onnx::Cast_119" + name: "Equal_115" + op_type: "Equal" + } + node { + input: "onnx::Cast_119" + output: "onnx::Abs_121" + name: "If_117" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "/sincnet/wav_norm1d/InstanceNormalization_output_0" + input: "/sincnet/conv1d.0/Concat_2_output_0" + output: "122" + name: "Conv_118" + op_type: "Conv" + attribute { + name: "strides" + ints: 10 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 251 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + } + name: "sub_graph" + doc_string: "" + output { + name: "122" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "Conv122_dim_0" + } + dim { + dim_value: 80 + } + dim { + dim_param: "Conv122_dim_2" + } + } + } + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "onnx::Gather_115" + input: "126" + input: "127" + input: "125" + output: "128" + name: "Slice_124" + op_type: "Slice" + } + node { + input: "128" + input: "129" + output: "130" + name: "Squeeze_126" + op_type: "Squeeze" + } + node { + input: "130" + input: "135" + output: "136" + name: "Unsqueeze_132" + op_type: "Unsqueeze" + } + node { + input: "132" + input: "134" + input: "136" + output: "137" + name: "Concat_133" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "/sincnet/wav_norm1d/InstanceNormalization_output_0" + input: "137" + output: "138" + name: "Reshape_134" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + } + node { + input: "138" + input: "/sincnet/conv1d.0/Concat_2_output_0" + output: "batched_conv" + name: "Conv_135" + op_type: "Conv" + attribute { + name: "strides" + ints: 10 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 251 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + } + node { + input: "batched_conv" + output: "147" + name: "Shape_143" + op_type: "Shape" + attribute { + name: "start" + i: 0 + type: INT + } + } + node { + input: "147" + input: "150" + input: "151" + input: "148" + input: "153" + output: "154" + name: "Slice_150" + op_type: "Slice" + } + node { + input: "onnx::Gather_115" + input: "141" + input: "143" + input: "140" + input: "145" + output: "146" + name: "Slice_142" + op_type: "Slice" + } + node { + input: "146" + input: "154" + output: "155" + name: "Concat_151" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "batched_conv" + input: "155" + output: "156" + name: "Reshape_152" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + } + name: "sub_graph1" + initializer { + dims: 1 + data_type: 7 + name: "151" + raw_data: "\377\377\377\377\377\377\377\177" + } + initializer { + dims: 1 + data_type: 7 + name: "134" + raw_data: "\001\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "125" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "126" + raw_data: "\377\377\377\377\377\377\377\377" + } + initializer { + dims: 1 + data_type: 7 + name: "127" + raw_data: "\377\377\377\377\377\377\377\177" + } + initializer { + dims: 1 + data_type: 7 + name: "129" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "150" + raw_data: "\376\377\377\377\377\377\377\377" + } + initializer { + dims: 1 + data_type: 7 + name: "153" + raw_data: "\001\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "135" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "140" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "141" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "132" + raw_data: "\377\377\377\377\377\377\377\377" + } + initializer { + dims: 1 + data_type: 7 + name: "145" + raw_data: "\001\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "148" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { + dims: 1 + data_type: 7 + name: "143" + raw_data: "\377\377\377\377\377\377\377\377" + } + doc_string: "" + output { + name: "156" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "Reshape156_dim_0" + } + dim { + dim_param: "Reshape156_dim_1" + } + dim { + dim_param: "Reshape156_dim_2" + } + dim { + dim_param: "Reshape156_dim_3" + } + } + } + } + } + } + type: GRAPH + } + } + node { + input: "onnx::Abs_121" + output: "/sincnet/Abs_output_0" + name: "/sincnet/Abs" + op_type: "Abs" + } + node { + input: "/sincnet/Abs_output_0" + output: "/sincnet/pool1d.0/MaxPool_output_0" + name: "/sincnet/pool1d.0/MaxPool" + op_type: "MaxPool" + attribute { + name: "storage_order" + i: 0 + type: INT + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "ceil_mode" + i: 0 + type: INT + } + attribute { + name: "strides" + ints: 3 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 3 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + } + node { + input: "/sincnet/pool1d.0/MaxPool_output_0" + input: "sincnet.norm1d.0.weight" + input: "sincnet.norm1d.0.bias" + output: "/sincnet/norm1d.0/InstanceNormalization_output_0" + name: "/sincnet/norm1d.0/InstanceNormalization" + op_type: "InstanceNormalization" + attribute { + name: "epsilon" + f: 9.999999747378752e-06 + type: FLOAT + } + } + node { + input: "/sincnet/norm1d.0/InstanceNormalization_output_0" + output: "/sincnet/LeakyRelu_output_0" + name: "/sincnet/LeakyRelu" + op_type: "LeakyRelu" + attribute { + name: "alpha" + f: 0.009999999776482582 + type: FLOAT + } + } + node { + input: "/sincnet/LeakyRelu_output_0" + input: "sincnet.conv1d.1.weight" + input: "sincnet.conv1d.1.bias" + output: "/sincnet/conv1d.1/Conv_output_0" + name: "/sincnet/conv1d.1/Conv" + op_type: "Conv" + attribute { + name: "strides" + ints: 1 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 5 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + } + node { + input: "/sincnet/conv1d.1/Conv_output_0" + output: "/sincnet/pool1d.1/MaxPool_output_0" + name: "/sincnet/pool1d.1/MaxPool" + op_type: "MaxPool" + attribute { + name: "storage_order" + i: 0 + type: INT + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "ceil_mode" + i: 0 + type: INT + } + attribute { + name: "strides" + ints: 3 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 3 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + } + node { + input: "/sincnet/pool1d.1/MaxPool_output_0" + input: "sincnet.norm1d.1.weight" + input: "sincnet.norm1d.1.bias" + output: "/sincnet/norm1d.1/InstanceNormalization_output_0" + name: "/sincnet/norm1d.1/InstanceNormalization" + op_type: "InstanceNormalization" + attribute { + name: "epsilon" + f: 9.999999747378752e-06 + type: FLOAT + } + } + node { + input: "/sincnet/norm1d.1/InstanceNormalization_output_0" + output: "/sincnet/LeakyRelu_1_output_0" + name: "/sincnet/LeakyRelu_1" + op_type: "LeakyRelu" + attribute { + name: "alpha" + f: 0.009999999776482582 + type: FLOAT + } + } + node { + input: "/sincnet/LeakyRelu_1_output_0" + input: "sincnet.conv1d.2.weight" + input: "sincnet.conv1d.2.bias" + output: "/sincnet/conv1d.2/Conv_output_0" + name: "/sincnet/conv1d.2/Conv" + op_type: "Conv" + attribute { + name: "strides" + ints: 1 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 5 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + } + node { + input: "/sincnet/conv1d.2/Conv_output_0" + output: "/sincnet/pool1d.2/MaxPool_output_0" + name: "/sincnet/pool1d.2/MaxPool" + op_type: "MaxPool" + attribute { + name: "storage_order" + i: 0 + type: INT + } + attribute { + name: "pads" + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "ceil_mode" + i: 0 + type: INT + } + attribute { + name: "strides" + ints: 3 + type: INTS + } + attribute { + name: "kernel_shape" + ints: 3 + type: INTS + } + attribute { + name: "auto_pad" + s: "NOTSET" + type: STRING + } + attribute { + name: "dilations" + ints: 1 + type: INTS + } + } + node { + input: "/sincnet/pool1d.2/MaxPool_output_0" + input: "sincnet.norm1d.2.weight" + input: "sincnet.norm1d.2.bias" + output: "/sincnet/norm1d.2/InstanceNormalization_output_0" + name: "/sincnet/norm1d.2/InstanceNormalization" + op_type: "InstanceNormalization" + attribute { + name: "epsilon" + f: 9.999999747378752e-06 + type: FLOAT + } + } + node { + input: "/sincnet/norm1d.2/InstanceNormalization_output_0" + output: "/sincnet/LeakyRelu_2_output_0" + name: "/sincnet/LeakyRelu_2" + op_type: "LeakyRelu" + attribute { + name: "alpha" + f: 0.009999999776482582 + type: FLOAT + } + } + node { + input: "/sincnet/LeakyRelu_2_output_0" + output: "/lstm/Shape" + name: "/lstm/Shape" + op_type: "Shape" + attribute { + name: "start" + i: 0 + type: INT + } + } + node { + input: "/lstm/Shape" + input: "ortshared_7_1_3_0_token_104" + output: "/lstm/Shape_output_0" + name: "Gather" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + doc_string: "Added in transpose optimizer" + } + node { + input: "/lstm/Shape_output_0" + input: "ortshared_7_0_1_0_token_105" + output: "/lstm/Gather_output_0" + name: "/lstm/Gather" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "/lstm/Gather_output_0" + input: "ortshared_7_1_1_2_token_111" + output: "/lstm/Unsqueeze_output_0" + name: "/lstm/Unsqueeze" + op_type: "Unsqueeze" + } + node { + input: "ortshared_7_1_1_0_token_106" + input: "/lstm/Unsqueeze_output_0" + input: "ortshared_7_1_1_3_token_114" + output: "/lstm/Concat_output_0" + name: "/lstm/Concat" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "/lstm/Concat_output_0" + output: "/lstm/ConstantOfShape_output_0" + name: "/lstm/ConstantOfShape" + op_type: "ConstantOfShape" + attribute { + name: "value" + t { + dims: 1 + data_type: 1 + raw_data: "\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "/lstm/ConstantOfShape_output_0" + input: "ortshared_7_1_1_5_token_116" + input: "ortshared_7_1_1_0_token_106" + input: "ortshared_7_1_1_2_token_111" + output: "/lstm/Slice_7_output_0" + name: "/lstm/Slice_7" + op_type: "Slice" + } + node { + input: "/lstm/ConstantOfShape_output_0" + input: "ortshared_7_1_1_1_token_108" + input: "ortshared_7_1_1_5_token_116" + input: "ortshared_7_1_1_2_token_111" + output: "/lstm/Slice_5_output_0" + name: "/lstm/Slice_5" + op_type: "Slice" + } + node { + input: "/lstm/ConstantOfShape_output_0" + input: "ortshared_7_1_1_4_token_115" + input: "ortshared_7_1_1_1_token_108" + input: "ortshared_7_1_1_2_token_111" + output: "/lstm/Slice_3_output_0" + name: "/lstm/Slice_3" + op_type: "Slice" + } + node { + input: "/lstm/ConstantOfShape_output_0" + input: "ortshared_7_1_1_2_token_111" + input: "ortshared_7_1_1_4_token_115" + input: "ortshared_7_1_1_2_token_111" + output: "/lstm/Slice_1_output_0" + name: "/lstm/Slice_1" + op_type: "Slice" + } + node { + input: "/sincnet/LeakyRelu_2_output_0" + output: "/lstm/Transpose_output_0" + name: "/lstm/Transpose" + op_type: "Transpose" + attribute { + name: "perm" + ints: 2 + ints: 0 + ints: 1 + type: INTS + } + } + node { + input: "/lstm/Transpose_output_0" + input: "onnx::LSTM_784" + input: "onnx::LSTM_785" + input: "onnx::LSTM_783" + input: "" + input: "/lstm/Slice_1_output_0" + input: "/lstm/Slice_1_output_0" + output: "/lstm/LSTM_output_0" + output: "/lstm/LSTM_output_1" + output: "/lstm/LSTM_output_2" + name: "/lstm/LSTM" + op_type: "LSTM" + attribute { + name: "layout" + i: 0 + type: INT + } + attribute { + name: "input_forget" + i: 0 + type: INT + } + attribute { + name: "direction" + s: "bidirectional" + type: STRING + } + attribute { + name: "hidden_size" + i: 128 + type: INT + } + } + node { + input: "/lstm/LSTM_output_0" + output: "/lstm/Transpose_1_output_0" + name: "/lstm/Transpose_1" + op_type: "Transpose" + attribute { + name: "perm" + ints: 0 + ints: 2 + ints: 1 + ints: 3 + type: INTS + } + } + node { + input: "/lstm/Transpose_1_output_0" + input: "ortshared_7_1_3_1_token_112" + output: "/lstm/Reshape_output_0" + name: "/lstm/Reshape" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + } + node { + input: "/lstm/Reshape_output_0" + input: "onnx::LSTM_827" + input: "onnx::LSTM_828" + input: "onnx::LSTM_826" + input: "" + input: "/lstm/Slice_3_output_0" + input: "/lstm/Slice_3_output_0" + output: "/lstm/LSTM_1_output_0" + output: "/lstm/LSTM_1_output_1" + output: "/lstm/LSTM_1_output_2" + name: "/lstm/LSTM_1" + op_type: "LSTM" + attribute { + name: "layout" + i: 0 + type: INT + } + attribute { + name: "input_forget" + i: 0 + type: INT + } + attribute { + name: "direction" + s: "bidirectional" + type: STRING + } + attribute { + name: "hidden_size" + i: 128 + type: INT + } + } + node { + input: "/lstm/LSTM_1_output_0" + output: "/lstm/Transpose_2_output_0" + name: "/lstm/Transpose_2" + op_type: "Transpose" + attribute { + name: "perm" + ints: 0 + ints: 2 + ints: 1 + ints: 3 + type: INTS + } + } + node { + input: "/lstm/Transpose_2_output_0" + input: "ortshared_7_1_3_1_token_112" + output: "/lstm/Reshape_1_output_0" + name: "/lstm/Reshape_1" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + } + node { + input: "/lstm/Reshape_1_output_0" + input: "onnx::LSTM_870" + input: "onnx::LSTM_871" + input: "onnx::LSTM_869" + input: "" + input: "/lstm/Slice_5_output_0" + input: "/lstm/Slice_5_output_0" + output: "/lstm/LSTM_2_output_0" + output: "/lstm/LSTM_2_output_1" + output: "/lstm/LSTM_2_output_2" + name: "/lstm/LSTM_2" + op_type: "LSTM" + attribute { + name: "layout" + i: 0 + type: INT + } + attribute { + name: "input_forget" + i: 0 + type: INT + } + attribute { + name: "direction" + s: "bidirectional" + type: STRING + } + attribute { + name: "hidden_size" + i: 128 + type: INT + } + } + node { + input: "/lstm/LSTM_2_output_0" + output: "/lstm/Transpose_3_output_0" + name: "/lstm/Transpose_3" + op_type: "Transpose" + attribute { + name: "perm" + ints: 0 + ints: 2 + ints: 1 + ints: 3 + type: INTS + } + } + node { + input: "/lstm/Transpose_3_output_0" + input: "ortshared_7_1_3_1_token_112" + output: "/lstm/Reshape_2_output_0" + name: "/lstm/Reshape_2" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + } + node { + input: "/lstm/Reshape_2_output_0" + input: "onnx::LSTM_913" + input: "onnx::LSTM_914" + input: "onnx::LSTM_912" + input: "" + input: "/lstm/Slice_7_output_0" + input: "/lstm/Slice_7_output_0" + output: "/lstm/LSTM_3_output_0" + output: "/lstm/LSTM_3_output_1" + output: "/lstm/LSTM_3_output_2" + name: "/lstm/LSTM_3" + op_type: "LSTM" + attribute { + name: "layout" + i: 0 + type: INT + } + attribute { + name: "input_forget" + i: 0 + type: INT + } + attribute { + name: "direction" + s: "bidirectional" + type: STRING + } + attribute { + name: "hidden_size" + i: 128 + type: INT + } + } + node { + input: "/lstm/LSTM_3_output_0" + output: "/lstm/Transpose_4_output_0" + name: "/lstm/Transpose_4" + op_type: "Transpose" + attribute { + name: "perm" + ints: 0 + ints: 2 + ints: 1 + ints: 3 + type: INTS + } + } + node { + input: "/lstm/Transpose_4_output_0" + input: "ortshared_7_1_3_1_token_112" + output: "/lstm/Reshape_3_output_0" + name: "/lstm/Reshape_3" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + } + node { + input: "/lstm/Reshape_3_output_0" + output: "/lstm/Transpose_5_output_0" + name: "/lstm/Transpose_5" + op_type: "Transpose" + attribute { + name: "perm" + ints: 1 + ints: 0 + ints: 2 + type: INTS + } + } + node { + input: "/lstm/Transpose_5_output_0" + input: "onnx::MatMul_915" + output: "/linear.0/MatMul_output_0" + name: "/linear.0/MatMul" + op_type: "MatMul" + } + node { + input: "linear.0.bias" + input: "/linear.0/MatMul_output_0" + output: "/linear.0/Add_output_0" + name: "/linear.0/Add" + op_type: "Add" + } + node { + input: "/linear.0/Add_output_0" + output: "/LeakyRelu_output_0" + name: "/LeakyRelu" + op_type: "LeakyRelu" + attribute { + name: "alpha" + f: 0.009999999776482582 + type: FLOAT + } + } + node { + input: "/LeakyRelu_output_0" + input: "onnx::MatMul_916" + output: "/linear.1/MatMul_output_0" + name: "/linear.1/MatMul" + op_type: "MatMul" + } + node { + input: "linear.1.bias" + input: "/linear.1/MatMul_output_0" + output: "/linear.1/Add_output_0" + name: "/linear.1/Add" + op_type: "Add" + } + node { + input: "/linear.1/Add_output_0" + output: "/LeakyRelu_1_output_0" + name: "/LeakyRelu_1" + op_type: "LeakyRelu" + attribute { + name: "alpha" + f: 0.009999999776482582 + type: FLOAT + } + } + node { + input: "/LeakyRelu_1_output_0" + input: "onnx::MatMul_917" + output: "/classifier/MatMul_output_0" + name: "/classifier/MatMul" + op_type: "MatMul" + } + node { + input: "ortshared_1_1_7_0_token_109" + input: "/classifier/MatMul_output_0" + output: "/classifier/Add_output_0" + name: "/classifier/Add" + op_type: "Add" + } + node { + input: "/classifier/Add_output_0" + output: "output" + name: "/activation/LogSoftmax" + op_type: "LogSoftmax" + attribute { + name: "axis" + i: -1 + type: INT + } + } + initializer { + dims: 1 + data_type: 1 + name: "ortshared_1_1_1_1_token_110" + } + initializer { + dims: 1 + data_type: 1 + name: "ortshared_1_1_1_0_token_107" + } + initializer { + dims: 80 + dims: 1 + dims: 251 + data_type: 1 + name: "/sincnet/conv1d.0/Concat_2_output_0" + } + initializer { + data_type: 7 + name: "ortshared_7_0_1_0_token_105" + } + initializer { + dims: 1 + data_type: 7 + name: "ortshared_7_1_1_2_token_111" + } + initializer { + data_type: 7 + name: "ortshared_7_0_1_1_token_113" + } + initializer { + dims: 60 + dims: 80 + dims: 5 + data_type: 1 + name: "sincnet.conv1d.1.weight" + } + initializer { + dims: 60 + data_type: 1 + name: "sincnet.conv1d.1.bias" + } + initializer { + dims: 60 + dims: 60 + dims: 5 + data_type: 1 + name: "sincnet.conv1d.2.weight" + } + initializer { + dims: 60 + data_type: 1 + name: "sincnet.conv1d.2.bias" + } + initializer { + dims: 80 + data_type: 1 + name: "sincnet.norm1d.0.weight" + } + initializer { + dims: 80 + data_type: 1 + name: "sincnet.norm1d.0.bias" + } + initializer { + dims: 60 + data_type: 1 + name: "sincnet.norm1d.1.weight" + } + initializer { + dims: 60 + data_type: 1 + name: "sincnet.norm1d.1.bias" + } + initializer { + dims: 60 + data_type: 1 + name: "sincnet.norm1d.2.weight" + } + initializer { + dims: 60 + data_type: 1 + name: "sincnet.norm1d.2.bias" + } + initializer { + dims: 128 + data_type: 1 + name: "linear.0.bias" + } + initializer { + dims: 128 + data_type: 1 + name: "linear.1.bias" + } + initializer { + dims: 7 + data_type: 1 + name: "ortshared_1_1_7_0_token_109" + } + initializer { + dims: 1 + data_type: 7 + name: "ortshared_7_1_1_4_token_115" + } + initializer { + dims: 2 + dims: 1024 + data_type: 1 + name: "onnx::LSTM_783" + } + initializer { + dims: 2 + dims: 512 + dims: 60 + data_type: 1 + name: "onnx::LSTM_784" + } + initializer { + dims: 2 + dims: 512 + dims: 128 + data_type: 1 + name: "onnx::LSTM_785" + } + initializer { + dims: 2 + dims: 1024 + data_type: 1 + name: "onnx::LSTM_826" + } + initializer { + dims: 2 + dims: 512 + dims: 256 + data_type: 1 + name: "onnx::LSTM_827" + } + initializer { + dims: 2 + dims: 512 + dims: 128 + data_type: 1 + name: "onnx::LSTM_828" + } + initializer { + dims: 2 + dims: 1024 + data_type: 1 + name: "onnx::LSTM_869" + } + initializer { + dims: 2 + dims: 512 + dims: 256 + data_type: 1 + name: "onnx::LSTM_870" + } + initializer { + dims: 2 + dims: 512 + dims: 128 + data_type: 1 + name: "onnx::LSTM_871" + } + initializer { + dims: 2 + dims: 1024 + data_type: 1 + name: "onnx::LSTM_912" + } + initializer { + dims: 2 + dims: 512 + dims: 256 + data_type: 1 + name: "onnx::LSTM_913" + } + initializer { + dims: 2 + dims: 512 + dims: 128 + data_type: 1 + name: "onnx::LSTM_914" + } + initializer { + dims: 256 + dims: 128 + data_type: 1 + name: "onnx::MatMul_915" + } + initializer { + dims: 128 + dims: 128 + data_type: 1 + name: "onnx::MatMul_916" + } + initializer { + dims: 128 + dims: 7 + data_type: 1 + name: "onnx::MatMul_917" + } + initializer { + dims: 1 + data_type: 7 + name: "ortshared_7_1_1_1_token_108" + } + initializer { + dims: 1 + data_type: 7 + name: "ortshared_7_1_1_5_token_116" + } + initializer { + dims: 3 + data_type: 7 + name: "ortshared_7_1_3_1_token_112" + } + initializer { + dims: 1 + data_type: 7 + name: "ortshared_7_1_1_3_token_114" + } + initializer { + dims: 1 + data_type: 7 + name: "ortshared_7_1_1_0_token_106" + } + initializer { + dims: 3 + data_type: 7 + name: "ortshared_7_1_3_0_token_104" + } + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "B" + } + dim { + dim_param: "C" + } + dim { + dim_param: "T" + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "LogSoftmaxoutput_dim_0" + } + dim { + dim_param: "LogSoftmaxoutput_dim_1" + } + dim { + dim_value: 7 + } + } + } + } + } +} +opset_import { + domain: "" + version: 17 +} +opset_import { + domain: "com.microsoft.experimental" + version: 1 +} +opset_import { + domain: "ai.onnx.ml" + version: 4 +} +opset_import { + domain: "ai.onnx.training" + version: 1 +} +opset_import { + domain: "com.microsoft" + version: 1 +} +opset_import { + domain: "ai.onnx.preview.training" + version: 1 +} +opset_import { + domain: "com.microsoft.nchwc" + version: 1 +} +opset_import { + domain: "org.pytorch.aten" + version: 1 +} diff --git a/ailia-models/source.txt b/ailia-models/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..17b17bdc6b68e161cd25c472ece0c9216860df75 --- /dev/null +++ b/ailia-models/source.txt @@ -0,0 +1,7 @@ +https://github.com/axinc-ai/ailia-models/tree/master/audio_processing/pyannote-audio + +https://storage.googleapis.com/ailia-models/pyannote-audio/segmentation.onnx +https://storage.googleapis.com/ailia-models/pyannote-audio/segmentation.onnx.prototxt + +https://storage.googleapis.com/ailia-models/pyannote-audio/speaker-embedding.onnx +https://storage.googleapis.com/ailia-models/pyannote-audio/speaker-embedding.onnx.prototxt \ No newline at end of file diff --git a/ailia-models/speaker-embedding.onnx b/ailia-models/speaker-embedding.onnx new file mode 100644 index 0000000000000000000000000000000000000000..81d3cfc3bfc4c32e4f2c3586a2042b010da3d415 --- /dev/null +++ b/ailia-models/speaker-embedding.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bb2f06e9df17cdf1ef14ee8a15ab08ed28e8d0ef5054ee135741560df2ec068 +size 26530309 diff --git a/ailia-models/speaker-embedding.onnx.prototxt b/ailia-models/speaker-embedding.onnx.prototxt new file mode 100644 index 0000000000000000000000000000000000000000..8b4bb88d85002a88d8188168849b98593fe3bf12 --- /dev/null +++ b/ailia-models/speaker-embedding.onnx.prototxt @@ -0,0 +1,2518 @@ +ir_version: 7 +producer_name: "pytorch" +producer_version: "1.10" +model_version: 0 +graph { + name: "torch-jit-export" + node { + input: "feats" + output: "220" + name: "Transpose_0" + op_type: "Transpose" + attribute { + name: "perm" + ints: 0 + ints: 2 + ints: 1 + type: INTS + } + } + node { + output: "221" + name: "Constant_1" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + data_location: 0 + } + type: TENSOR + } + } + node { + input: "220" + input: "221" + output: "222" + name: "Unsqueeze_2" + op_type: "Unsqueeze" + } + node { + input: "222" + input: "367" + input: "368" + output: "366" + name: "Conv_3" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "366" + output: "225" + name: "Relu_4" + op_type: "Relu" + } + node { + input: "225" + input: "370" + input: "371" + output: "369" + name: "Conv_5" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "369" + output: "228" + name: "Relu_6" + op_type: "Relu" + } + node { + input: "228" + input: "373" + input: "374" + output: "372" + name: "Conv_7" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "372" + input: "225" + output: "231" + name: "Add_8" + op_type: "Add" + } + node { + input: "231" + output: "232" + name: "Relu_9" + op_type: "Relu" + } + node { + input: "232" + input: "376" + input: "377" + output: "375" + name: "Conv_10" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "375" + output: "235" + name: "Relu_11" + op_type: "Relu" + } + node { + input: "235" + input: "379" + input: "380" + output: "378" + name: "Conv_12" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "378" + input: "232" + output: "238" + name: "Add_13" + op_type: "Add" + } + node { + input: "238" + output: "239" + name: "Relu_14" + op_type: "Relu" + } + node { + input: "239" + input: "382" + input: "383" + output: "381" + name: "Conv_15" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "381" + output: "242" + name: "Relu_16" + op_type: "Relu" + } + node { + input: "242" + input: "385" + input: "386" + output: "384" + name: "Conv_17" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "384" + input: "239" + output: "245" + name: "Add_18" + op_type: "Add" + } + node { + input: "245" + output: "246" + name: "Relu_19" + op_type: "Relu" + } + node { + input: "246" + input: "388" + input: "389" + output: "387" + name: "Conv_20" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + node { + input: "387" + output: "249" + name: "Relu_21" + op_type: "Relu" + } + node { + input: "249" + input: "391" + input: "392" + output: "390" + name: "Conv_22" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "246" + input: "394" + input: "395" + output: "393" + name: "Conv_23" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + node { + input: "390" + input: "393" + output: "254" + name: "Add_24" + op_type: "Add" + } + node { + input: "254" + output: "255" + name: "Relu_25" + op_type: "Relu" + } + node { + input: "255" + input: "397" + input: "398" + output: "396" + name: "Conv_26" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "396" + output: "258" + name: "Relu_27" + op_type: "Relu" + } + node { + input: "258" + input: "400" + input: "401" + output: "399" + name: "Conv_28" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "399" + input: "255" + output: "261" + name: "Add_29" + op_type: "Add" + } + node { + input: "261" + output: "262" + name: "Relu_30" + op_type: "Relu" + } + node { + input: "262" + input: "403" + input: "404" + output: "402" + name: "Conv_31" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "402" + output: "265" + name: "Relu_32" + op_type: "Relu" + } + node { + input: "265" + input: "406" + input: "407" + output: "405" + name: "Conv_33" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "405" + input: "262" + output: "268" + name: "Add_34" + op_type: "Add" + } + node { + input: "268" + output: "269" + name: "Relu_35" + op_type: "Relu" + } + node { + input: "269" + input: "409" + input: "410" + output: "408" + name: "Conv_36" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "408" + output: "272" + name: "Relu_37" + op_type: "Relu" + } + node { + input: "272" + input: "412" + input: "413" + output: "411" + name: "Conv_38" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "411" + input: "269" + output: "275" + name: "Add_39" + op_type: "Add" + } + node { + input: "275" + output: "276" + name: "Relu_40" + op_type: "Relu" + } + node { + input: "276" + input: "415" + input: "416" + output: "414" + name: "Conv_41" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + node { + input: "414" + output: "279" + name: "Relu_42" + op_type: "Relu" + } + node { + input: "279" + input: "418" + input: "419" + output: "417" + name: "Conv_43" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "276" + input: "421" + input: "422" + output: "420" + name: "Conv_44" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + node { + input: "417" + input: "420" + output: "284" + name: "Add_45" + op_type: "Add" + } + node { + input: "284" + output: "285" + name: "Relu_46" + op_type: "Relu" + } + node { + input: "285" + input: "424" + input: "425" + output: "423" + name: "Conv_47" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "423" + output: "288" + name: "Relu_48" + op_type: "Relu" + } + node { + input: "288" + input: "427" + input: "428" + output: "426" + name: "Conv_49" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "426" + input: "285" + output: "291" + name: "Add_50" + op_type: "Add" + } + node { + input: "291" + output: "292" + name: "Relu_51" + op_type: "Relu" + } + node { + input: "292" + input: "430" + input: "431" + output: "429" + name: "Conv_52" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "429" + output: "295" + name: "Relu_53" + op_type: "Relu" + } + node { + input: "295" + input: "433" + input: "434" + output: "432" + name: "Conv_54" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "432" + input: "292" + output: "298" + name: "Add_55" + op_type: "Add" + } + node { + input: "298" + output: "299" + name: "Relu_56" + op_type: "Relu" + } + node { + input: "299" + input: "436" + input: "437" + output: "435" + name: "Conv_57" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "435" + output: "302" + name: "Relu_58" + op_type: "Relu" + } + node { + input: "302" + input: "439" + input: "440" + output: "438" + name: "Conv_59" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "438" + input: "299" + output: "305" + name: "Add_60" + op_type: "Add" + } + node { + input: "305" + output: "306" + name: "Relu_61" + op_type: "Relu" + } + node { + input: "306" + input: "442" + input: "443" + output: "441" + name: "Conv_62" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "441" + output: "309" + name: "Relu_63" + op_type: "Relu" + } + node { + input: "309" + input: "445" + input: "446" + output: "444" + name: "Conv_64" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "444" + input: "306" + output: "312" + name: "Add_65" + op_type: "Add" + } + node { + input: "312" + output: "313" + name: "Relu_66" + op_type: "Relu" + } + node { + input: "313" + input: "448" + input: "449" + output: "447" + name: "Conv_67" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "447" + output: "316" + name: "Relu_68" + op_type: "Relu" + } + node { + input: "316" + input: "451" + input: "452" + output: "450" + name: "Conv_69" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "450" + input: "313" + output: "319" + name: "Add_70" + op_type: "Add" + } + node { + input: "319" + output: "320" + name: "Relu_71" + op_type: "Relu" + } + node { + input: "320" + input: "454" + input: "455" + output: "453" + name: "Conv_72" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + node { + input: "453" + output: "323" + name: "Relu_73" + op_type: "Relu" + } + node { + input: "323" + input: "457" + input: "458" + output: "456" + name: "Conv_74" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "320" + input: "460" + input: "461" + output: "459" + name: "Conv_75" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "pads" + ints: 0 + ints: 0 + ints: 0 + ints: 0 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + node { + input: "456" + input: "459" + output: "328" + name: "Add_76" + op_type: "Add" + } + node { + input: "328" + output: "329" + name: "Relu_77" + op_type: "Relu" + } + node { + input: "329" + input: "463" + input: "464" + output: "462" + name: "Conv_78" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "462" + output: "332" + name: "Relu_79" + op_type: "Relu" + } + node { + input: "332" + input: "466" + input: "467" + output: "465" + name: "Conv_80" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "465" + input: "329" + output: "335" + name: "Add_81" + op_type: "Add" + } + node { + input: "335" + output: "336" + name: "Relu_82" + op_type: "Relu" + } + node { + input: "336" + input: "469" + input: "470" + output: "468" + name: "Conv_83" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "468" + output: "339" + name: "Relu_84" + op_type: "Relu" + } + node { + input: "339" + input: "472" + input: "473" + output: "471" + name: "Conv_85" + op_type: "Conv" + attribute { + name: "dilations" + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "group" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 1 + ints: 1 + ints: 1 + ints: 1 + type: INTS + } + attribute { + name: "strides" + ints: 1 + ints: 1 + type: INTS + } + } + node { + input: "471" + input: "336" + output: "342" + name: "Add_86" + op_type: "Add" + } + node { + input: "342" + output: "343" + name: "Relu_87" + op_type: "Relu" + } + node { + input: "343" + output: "344" + name: "ReduceMean_88" + op_type: "ReduceMean" + attribute { + name: "axes" + ints: -1 + type: INTS + } + attribute { + name: "keepdims" + i: 0 + type: INT + } + } + node { + input: "343" + output: "345" + name: "ReduceMean_89" + op_type: "ReduceMean" + attribute { + name: "axes" + ints: -1 + type: INTS + } + attribute { + name: "keepdims" + i: 1 + type: INT + } + } + node { + input: "343" + output: "346" + name: "Shape_90" + op_type: "Shape" + } + node { + output: "347" + name: "Constant_91" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + data_location: 0 + } + type: TENSOR + } + } + node { + input: "346" + input: "347" + output: "348" + name: "Gather_92" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "348" + output: "349" + name: "ReduceProd_93" + op_type: "ReduceProd" + attribute { + name: "keepdims" + i: 0 + type: INT + } + } + node { + input: "343" + input: "345" + output: "350" + name: "Sub_94" + op_type: "Sub" + } + node { + input: "350" + input: "350" + output: "351" + name: "Mul_95" + op_type: "Mul" + } + node { + input: "351" + output: "352" + name: "ReduceMean_96" + op_type: "ReduceMean" + attribute { + name: "axes" + ints: -1 + type: INTS + } + attribute { + name: "keepdims" + i: 0 + type: INT + } + } + node { + input: "349" + output: "353" + name: "Cast_97" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + } + node { + input: "352" + input: "353" + output: "354" + name: "Mul_98" + op_type: "Mul" + } + node { + output: "355" + name: "Constant_99" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + data_location: 0 + } + type: TENSOR + } + } + node { + input: "353" + input: "355" + output: "356" + name: "Sub_100" + op_type: "Sub" + } + node { + input: "354" + input: "356" + output: "357" + name: "Div_101" + op_type: "Div" + } + node { + output: "358" + name: "Constant_102" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + data_location: 0 + } + type: TENSOR + } + } + node { + input: "357" + input: "358" + output: "359" + name: "Add_103" + op_type: "Add" + } + node { + input: "359" + output: "360" + name: "Sqrt_104" + op_type: "Sqrt" + } + node { + input: "344" + output: "361" + name: "Flatten_105" + op_type: "Flatten" + attribute { + name: "axis" + i: 1 + type: INT + } + } + node { + input: "360" + output: "362" + name: "Flatten_106" + op_type: "Flatten" + attribute { + name: "axis" + i: 1 + type: INT + } + } + node { + input: "361" + input: "362" + output: "363" + name: "Concat_107" + op_type: "Concat" + attribute { + name: "axis" + i: 1 + type: INT + } + } + node { + input: "363" + input: "model.seg_1.weight" + input: "model.seg_1.bias" + output: "364" + name: "Gemm_108" + op_type: "Gemm" + attribute { + name: "alpha" + f: 1.0 + type: FLOAT + } + attribute { + name: "beta" + f: 1.0 + type: FLOAT + } + attribute { + name: "transB" + i: 1 + type: INT + } + } + node { + input: "364" + input: "mean_vec" + output: "embs" + name: "Sub_109" + op_type: "Sub" + } + initializer { + dims: 256 + data_type: 1 + name: "mean_vec" + } + initializer { + dims: 256 + dims: 5120 + data_type: 1 + name: "model.seg_1.weight" + } + initializer { + dims: 256 + data_type: 1 + name: "model.seg_1.bias" + } + initializer { + dims: 32 + dims: 1 + dims: 3 + dims: 3 + data_type: 1 + name: "367" + } + initializer { + dims: 32 + data_type: 1 + name: "368" + } + initializer { + dims: 32 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "370" + } + initializer { + dims: 32 + data_type: 1 + name: "371" + } + initializer { + dims: 32 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "373" + } + initializer { + dims: 32 + data_type: 1 + name: "374" + } + initializer { + dims: 32 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "376" + } + initializer { + dims: 32 + data_type: 1 + name: "377" + } + initializer { + dims: 32 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "379" + } + initializer { + dims: 32 + data_type: 1 + name: "380" + } + initializer { + dims: 32 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "382" + } + initializer { + dims: 32 + data_type: 1 + name: "383" + } + initializer { + dims: 32 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "385" + } + initializer { + dims: 32 + data_type: 1 + name: "386" + } + initializer { + dims: 64 + dims: 32 + dims: 3 + dims: 3 + data_type: 1 + name: "388" + } + initializer { + dims: 64 + data_type: 1 + name: "389" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "391" + } + initializer { + dims: 64 + data_type: 1 + name: "392" + } + initializer { + dims: 64 + dims: 32 + dims: 1 + dims: 1 + data_type: 1 + name: "394" + } + initializer { + dims: 64 + data_type: 1 + name: "395" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "397" + } + initializer { + dims: 64 + data_type: 1 + name: "398" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "400" + } + initializer { + dims: 64 + data_type: 1 + name: "401" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "403" + } + initializer { + dims: 64 + data_type: 1 + name: "404" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "406" + } + initializer { + dims: 64 + data_type: 1 + name: "407" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "409" + } + initializer { + dims: 64 + data_type: 1 + name: "410" + } + initializer { + dims: 64 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "412" + } + initializer { + dims: 64 + data_type: 1 + name: "413" + } + initializer { + dims: 128 + dims: 64 + dims: 3 + dims: 3 + data_type: 1 + name: "415" + } + initializer { + dims: 128 + data_type: 1 + name: "416" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "418" + } + initializer { + dims: 128 + data_type: 1 + name: "419" + } + initializer { + dims: 128 + dims: 64 + dims: 1 + dims: 1 + data_type: 1 + name: "421" + } + initializer { + dims: 128 + data_type: 1 + name: "422" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "424" + } + initializer { + dims: 128 + data_type: 1 + name: "425" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "427" + } + initializer { + dims: 128 + data_type: 1 + name: "428" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "430" + } + initializer { + dims: 128 + data_type: 1 + name: "431" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "433" + } + initializer { + dims: 128 + data_type: 1 + name: "434" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "436" + } + initializer { + dims: 128 + data_type: 1 + name: "437" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "439" + } + initializer { + dims: 128 + data_type: 1 + name: "440" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "442" + } + initializer { + dims: 128 + data_type: 1 + name: "443" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "445" + } + initializer { + dims: 128 + data_type: 1 + name: "446" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "448" + } + initializer { + dims: 128 + data_type: 1 + name: "449" + } + initializer { + dims: 128 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "451" + } + initializer { + dims: 128 + data_type: 1 + name: "452" + } + initializer { + dims: 256 + dims: 128 + dims: 3 + dims: 3 + data_type: 1 + name: "454" + } + initializer { + dims: 256 + data_type: 1 + name: "455" + } + initializer { + dims: 256 + dims: 256 + dims: 3 + dims: 3 + data_type: 1 + name: "457" + } + initializer { + dims: 256 + data_type: 1 + name: "458" + } + initializer { + dims: 256 + dims: 128 + dims: 1 + dims: 1 + data_type: 1 + name: "460" + } + initializer { + dims: 256 + data_type: 1 + name: "461" + } + initializer { + dims: 256 + dims: 256 + dims: 3 + dims: 3 + data_type: 1 + name: "463" + } + initializer { + dims: 256 + data_type: 1 + name: "464" + } + initializer { + dims: 256 + dims: 256 + dims: 3 + dims: 3 + data_type: 1 + name: "466" + } + initializer { + dims: 256 + data_type: 1 + name: "467" + } + initializer { + dims: 256 + dims: 256 + dims: 3 + dims: 3 + data_type: 1 + name: "469" + } + initializer { + dims: 256 + data_type: 1 + name: "470" + } + initializer { + dims: 256 + dims: 256 + dims: 3 + dims: 3 + data_type: 1 + name: "472" + } + initializer { + dims: 256 + data_type: 1 + name: "473" + } + input { + name: "feats" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "B" + } + dim { + dim_param: "T" + } + dim { + dim_value: 80 + } + } + } + } + } + output { + name: "embs" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "B" + } + dim { + dim_value: 256 + } + } + } + } + } +} +opset_import { + domain: "" + version: 14 +} diff --git a/embedding/.gitattributes b/embedding/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..07f0db3339ad9053dc95b284c4ae14e014efff89 --- /dev/null +++ b/embedding/.gitattributes @@ -0,0 +1,16 @@ +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tar.gz filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/embedding/LICENSE b/embedding/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e5e0c2daded4524693e062d3e4fd016bbfb9a308 --- /dev/null +++ b/embedding/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 CNRS + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/embedding/README.md b/embedding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f4fe9d8c7ed6049c51f5d107d627865f8c049508 --- /dev/null +++ b/embedding/README.md @@ -0,0 +1,121 @@ +--- +tags: +- pyannote +- pyannote-audio +- pyannote-audio-model +- audio +- voice +- speech +- speaker +- speaker-recognition +- speaker-verification +- speaker-identification +- speaker-embedding +datasets: +- voxceleb +license: mit +inference: false +extra_gated_prompt: "The collected information will help acquire a better knowledge of pyannote.audio userbase and help its maintainers apply for grants to improve it further. If you are an academic researcher, please cite the relevant papers in your own publications using the model. If you work for a company, please consider contributing back to pyannote.audio development (e.g. through unrestricted gifts). We also provide scientific consulting services around speaker diarization and machine listening." +extra_gated_fields: + Company/university: text + Website: text + I plan to use this model for (task, type of audio data, etc): text +--- + +Using this open-source model in production? +Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options. + +# 🎹 Speaker embedding + +Relies on pyannote.audio 2.1: see [installation instructions](https://github.com/pyannote/pyannote-audio/). + +This model is based on the [canonical x-vector TDNN-based architecture](https://ieeexplore.ieee.org/abstract/document/8461375), but with filter banks replaced with [trainable SincNet features](https://ieeexplore.ieee.org/document/8639585). See [`XVectorSincNet`](https://github.com/pyannote/pyannote-audio/blob/3c988c028dc505c64fe776720372f6fe816b585a/pyannote/audio/models/embedding/xvector.py#L104-L169) architecture for implementation details. + + +## Basic usage + +```python +# 1. visit hf.co/pyannote/embedding and accept user conditions +# 2. visit hf.co/settings/tokens to create an access token +# 3. instantiate pretrained model +from pyannote.audio import Model +model = Model.from_pretrained("pyannote/embedding", + use_auth_token="ACCESS_TOKEN_GOES_HERE") +``` + +```python +from pyannote.audio import Inference +inference = Inference(model, window="whole") +embedding1 = inference("speaker1.wav") +embedding2 = inference("speaker2.wav") +# `embeddingX` is (1 x D) numpy array extracted from the file as a whole. + +from scipy.spatial.distance import cdist +distance = cdist(embedding1, embedding2, metric="cosine")[0,0] +# `distance` is a `float` describing how dissimilar speakers 1 and 2 are. +``` + +Using cosine distance directly, this model reaches 2.8% equal error rate (EER) on VoxCeleb 1 test set. +This is without voice activity detection (VAD) nor probabilistic linear discriminant analysis (PLDA). +Expect even better results when adding one of those. + +## Advanced usage + +### Running on GPU + +```python +import torch +inference.to(torch.device("cuda")) +embedding = inference("audio.wav") +``` + +### Extract embedding from an excerpt + +```python +from pyannote.audio import Inference +from pyannote.core import Segment +inference = Inference(model, window="whole") +excerpt = Segment(13.37, 19.81) +embedding = inference.crop("audio.wav", excerpt) +# `embedding` is (1 x D) numpy array extracted from the file excerpt. +``` + +### Extract embeddings using a sliding window + +```python +from pyannote.audio import Inference +inference = Inference(model, window="sliding", + duration=3.0, step=1.0) +embeddings = inference("audio.wav") +# `embeddings` is a (N x D) pyannote.core.SlidingWindowFeature +# `embeddings[i]` is the embedding of the ith position of the +# sliding window, i.e. from [i * step, i * step + duration]. +``` + + +## Citation + +```bibtex +@inproceedings{Bredin2020, + Title = {{pyannote.audio: neural building blocks for speaker diarization}}, + Author = {{Bredin}, Herv{\'e} and {Yin}, Ruiqing and {Coria}, Juan Manuel and {Gelly}, Gregory and {Korshunov}, Pavel and {Lavechin}, Marvin and {Fustes}, Diego and {Titeux}, Hadrien and {Bouaziz}, Wassim and {Gill}, Marie-Philippe}, + Booktitle = {ICASSP 2020, IEEE International Conference on Acoustics, Speech, and Signal Processing}, + Address = {Barcelona, Spain}, + Month = {May}, + Year = {2020}, +} +``` + +```bibtex +@inproceedings{Coria2020, + author="Coria, Juan M. and Bredin, Herv{\'e} and Ghannay, Sahar and Rosset, Sophie", + editor="Espinosa-Anke, Luis and Mart{\'i}n-Vide, Carlos and Spasi{\'{c}}, Irena", + title="{A Comparison of Metric Learning Loss Functions for End-To-End Speaker Verification}", + booktitle="Statistical Language and Speech Processing", + year="2020", + publisher="Springer International Publishing", + pages="137--148", + isbn="978-3-030-59430-5" +} +``` + diff --git a/embedding/config.yaml b/embedding/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..703dcb4bd78d9156d978192c7693aab407ee05f4 --- /dev/null +++ b/embedding/config.yaml @@ -0,0 +1,85 @@ +protocol: VoxCeleb.SpeakerVerification.VoxCeleb_X +patience: 5 +task: + _target_: pyannote.audio.tasks.SupervisedRepresentationLearningWithArcFace + min_duration: 2 + duration: 5.0 + num_classes_per_batch: 64 + num_chunks_per_class: 4 + margin: 10.0 + scale: 50.0 + num_workers: 20 + pin_memory: false +model: + _target_: pyannote.audio.models.embedding.XVectorSincNet +optimizer: + _target_: torch.optim.Adam + lr: 0.001 + betas: + - 0.9 + - 0.999 + eps: 1.0e-08 + weight_decay: 0 + amsgrad: false +trainer: + _target_: pytorch_lightning.Trainer + accelerator: null + accumulate_grad_batches: 1 + amp_backend: native + amp_level: O2 + auto_lr_find: false + auto_scale_batch_size: false + auto_select_gpus: true + benchmark: false + check_val_every_n_epoch: 1 + checkpoint_callback: true + deterministic: false + fast_dev_run: false + flush_logs_every_n_steps: 100 + gpus: 1 + gradient_clip_val: 0 + limit_test_batches: 1.0 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + log_every_n_steps: 50 + log_gpu_memory: null + max_epochs: 1000 + max_steps: null + min_epochs: 1 + min_steps: null + num_nodes: 1 + num_processes: 1 + num_sanity_val_steps: 2 + overfit_batches: 0.0 + precision: 32 + prepare_data_per_node: true + process_position: 0 + profiler: null + progress_bar_refresh_rate: 1 + reload_dataloaders_every_epoch: false + replace_sampler_ddp: true + sync_batchnorm: false + terminate_on_nan: false + tpu_cores: null + track_grad_norm: -1 + truncated_bptt_steps: null + val_check_interval: 1.0 + weights_save_path: null + weights_summary: top +augmentation: + transform: Compose + params: + shuffle: false + transforms: + - transform: AddBackgroundNoise + params: + background_paths: /gpfswork/rech/eie/commun/data/background/musan + min_snr_in_db: 5.0 + max_snr_in_db: 15.0 + mode: per_example + p: 0.9 + - transform: ApplyImpulseResponse + params: + ir_paths: /gpfswork/rech/eie/commun/data/rir + mode: per_example + p: 0.5 diff --git a/embedding/hparams.yaml b/embedding/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b8ac333f350d51f485ea36180143a63e16f23de --- /dev/null +++ b/embedding/hparams.yaml @@ -0,0 +1,6 @@ +sample_rate: 16000 +num_channels: 1 +sincnet: + stride: 10 + sample_rate: 16000 +dimension: 512 diff --git a/embedding/hydra.yaml b/embedding/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d1abcac5abe316e5f084f45b1aea5cd27bf7e14 --- /dev/null +++ b/embedding/hydra.yaml @@ -0,0 +1,139 @@ +hydra: + run: + dir: ${protocol}/${task._target_}/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}/${protocol}/${task._target_} + subdir: ${hydra.job.num} + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + help: + app_name: pyannote-audio-train + header: == ${hydra.help.app_name} == + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help' + template: "${hydra.help.header}\n\npyannote-audio-train protocol={protocol_name}\ + \ task={task} model={model}\n\n{task} can be any of the following:\n* vad (default)\ + \ = voice activity detection\n* scd = speaker change detection\n* osd = overlapped\ + \ speech detection\n* xseg = multi-task segmentation\n\n{model} can be any of\ + \ the following:\n* debug (default) = simple segmentation model for debugging\ + \ purposes\n\n{optimizer} can be any of the following\n* adam (default) = Adam\ + \ optimizer\n\n{trainer} can be any of the following\n* fast_dev_run for debugging\n\ + * default (default) for training the model\n\nOptions\n=======\n\nHere, we describe\ + \ the most common options: use \"--cfg job\" option to get a complete list.\n\ + \n* task.duration: audio chunk duration (in seconds)\n* task.batch_size: number\ + \ of audio chunks per batch\n* task.num_workers: number of workers used for\ + \ generating training chunks\n\n* optimizer.lr: learning rate\n* trainer.auto_lr_find:\ + \ use pytorch-lightning AutoLR\n\nHyper-parameter optimization\n============================\n\ + \nBecause it is powered by Hydra (https://hydra.cc), one can run grid search\ + \ using the --multirun option.\n\nFor instance, the following command will run\ + \ the same job three times, with three different learning rates:\n pyannote-audio-train\ + \ --multirun protocol={protocol_name} task={task} optimizer.lr=1e-3,1e-2,1e-1\n\ + \nEven better, one can use Ax (https://ax.dev) sweeper to optimize learning\ + \ rate directly:\n pyannote-audio-train --multirun hydra/sweeper=ax protocol={protocol_name}\ + \ task={task} optimizer.lr=\"interval(1e-3, 1e-1)\"\n\nSee https://hydra.cc/docs/plugins/ax_sweeper\ + \ for more details.\n\nUser-defined task or model\n==========================\n\ + \n1. define your_package.YourTask (or your_package.YourModel) class\n2. create\ + \ file /path/to/your_config/task/your_task.yaml (or /path/to/your_config/model/your_model.yaml)\n\ + \ # @package _group_\n _target_: your_package.YourTask # or YourModel\n\ + \ param1: value1\n param2: value2\n3. call pyannote-audio-train --config-dir\ + \ /path/to/your_config task=your_task task.param1=modified_value1 model=your_model\ + \ ...\n\n${hydra.help.footer}" + hydra_help: + hydra_help: ??? + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + output_subdir: '' + overrides: + hydra: [] + task: + - protocol=VoxCeleb.SpeakerVerification.VoxCeleb_X + - task=SpeakerEmbedding + - task.num_workers=20 + - task.min_duration=2 + - task.duration=5. + - task.num_classes_per_batch=64 + - task.num_chunks_per_class=4 + - task.margin=10.0 + - task.scale=50. + - model=XVectorSincNet + - trainer.gpus=1 + - +augmentation=background_then_reverb + job: + name: train + override_dirname: +augmentation=background_then_reverb,model=XVectorSincNet,protocol=VoxCeleb.SpeakerVerification.VoxCeleb_X,task.duration=5.,task.margin=10.0,task.min_duration=2,task.num_chunks_per_class=4,task.num_classes_per_batch=64,task.num_workers=20,task.scale=50.,task=SpeakerEmbedding,trainer.gpus=1 + id: ??? + num: ??? + config_name: config + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.0.4 + cwd: /gpfsdswork/projects/rech/eie/uno46kl/xvectors/debug + verbose: false diff --git a/embedding/overrides.yaml b/embedding/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60bbe62178f276467bb20e95d95e8c92ae929d54 --- /dev/null +++ b/embedding/overrides.yaml @@ -0,0 +1,12 @@ +- protocol=VoxCeleb.SpeakerVerification.VoxCeleb_X +- task=SpeakerEmbedding +- task.num_workers=20 +- task.min_duration=2 +- task.duration=5. +- task.num_classes_per_batch=64 +- task.num_chunks_per_class=4 +- task.margin=10.0 +- task.scale=50. +- model=XVectorSincNet +- trainer.gpus=1 +- +augmentation=background_then_reverb diff --git a/embedding/pytorch_model.bin b/embedding/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..779b5abd2bb3378cdd480d258f972fa61d5f4fd8 --- /dev/null +++ b/embedding/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bcec986de13da7af7ac88736572692359950df63669989c4f78b294934c9089 +size 96383626 diff --git a/embedding/source.txt b/embedding/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..1956b5ab545beafec3cb201b0d740393de0a8cc5 --- /dev/null +++ b/embedding/source.txt @@ -0,0 +1 @@ +https://huggingface.co/pyannote/embedding \ No newline at end of file diff --git a/embedding/tfevents.bin b/embedding/tfevents.bin new file mode 100644 index 0000000000000000000000000000000000000000..8aece709c7aebfe84c0118129b0b76dc14a7f2cd --- /dev/null +++ b/embedding/tfevents.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3319218e36d416c5400ffbc592acc2e1ab520a187d586be86db7eef30fb65616 +size 5669685 diff --git a/embedding/train.log b/embedding/train.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/overlapped-speech-detection/.gitattributes b/overlapped-speech-detection/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..6d34772f5ca361021038b404fb913ec8dc0b1a5a --- /dev/null +++ b/overlapped-speech-detection/.gitattributes @@ -0,0 +1,27 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/overlapped-speech-detection/README.md b/overlapped-speech-detection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5383a38b6a70ed415a67325ee830678f75b0bf7a --- /dev/null +++ b/overlapped-speech-detection/README.md @@ -0,0 +1,73 @@ +--- +tags: +- pyannote +- pyannote-audio +- pyannote-audio-pipeline +- audio +- voice +- speech +- speaker +- overlapped-speech-detection +- automatic-speech-recognition +datasets: +- ami +- dihard +- voxconverse +license: mit +extra_gated_prompt: "The collected information will help acquire a better knowledge of pyannote.audio userbase and help its maintainers apply for grants to improve it further. If you are an academic researcher, please cite the relevant papers in your own publications using the model. If you work for a company, please consider contributing back to pyannote.audio development (e.g. through unrestricted gifts). We also provide scientific consulting services around speaker diarization and machine listening." +extra_gated_fields: + Company/university: text + Website: text + I plan to use this model for (task, type of audio data, etc): text +--- + +Using this open-source model in production? +Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options. + +# 🎹 Overlapped speech detection + +Relies on pyannote.audio 2.1: see [installation instructions](https://github.com/pyannote/pyannote-audio#installation). + +```python +# 1. visit hf.co/pyannote/segmentation and accept user conditions +# 2. visit hf.co/settings/tokens to create an access token +# 3. instantiate pretrained overlapped speech detection pipeline +from pyannote.audio import Pipeline +pipeline = Pipeline.from_pretrained("pyannote/overlapped-speech-detection", + use_auth_token="ACCESS_TOKEN_GOES_HERE") +output = pipeline("audio.wav") + +for speech in output.get_timeline().support(): + # two or more speakers are active between speech.start and speech.end + ... +``` + +## Support + +For commercial enquiries and scientific consulting, please contact [me](mailto:herve@niderb.fr). +For [technical questions](https://github.com/pyannote/pyannote-audio/discussions) and [bug reports](https://github.com/pyannote/pyannote-audio/issues), please check [pyannote.audio](https://github.com/pyannote/pyannote-audio) Github repository. + + +## Citation + +```bibtex +@inproceedings{Bredin2021, + Title = {{End-to-end speaker segmentation for overlap-aware resegmentation}}, + Author = {{Bredin}, Herv{\'e} and {Laurent}, Antoine}, + Booktitle = {Proc. Interspeech 2021}, + Address = {Brno, Czech Republic}, + Month = {August}, + Year = {2021}, +} +``` + +```bibtex +@inproceedings{Bredin2020, + Title = {{pyannote.audio: neural building blocks for speaker diarization}}, + Author = {{Bredin}, Herv{\'e} and {Yin}, Ruiqing and {Coria}, Juan Manuel and {Gelly}, Gregory and {Korshunov}, Pavel and {Lavechin}, Marvin and {Fustes}, Diego and {Titeux}, Hadrien and {Bouaziz}, Wassim and {Gill}, Marie-Philippe}, + Booktitle = {ICASSP 2020, IEEE International Conference on Acoustics, Speech, and Signal Processing}, + Address = {Barcelona, Spain}, + Month = {May}, + Year = {2020}, +} +``` diff --git a/overlapped-speech-detection/config.yaml b/overlapped-speech-detection/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0da926d1ac8120182c597663b967bf307fa2d8ab --- /dev/null +++ b/overlapped-speech-detection/config.yaml @@ -0,0 +1,10 @@ +pipeline: + name: pyannote.audio.pipelines.OverlappedSpeechDetection + params: + segmentation: pyannote/segmentation@Interspeech2021 + +params: + min_duration_off: 0.09791355693027545 + min_duration_on: 0.05537587440407595 + offset: 0.4806866463041527 + onset: 0.8104268538848918 diff --git a/overlapped-speech-detection/source.txt b/overlapped-speech-detection/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..98f720fcb96d8513f1928ad412771ff670085bdc --- /dev/null +++ b/overlapped-speech-detection/source.txt @@ -0,0 +1 @@ +https://huggingface.co/pyannote/overlapped-speech-detection \ No newline at end of file diff --git a/speaker-diarization-3.1/.github/workflows/sync_to_hub.yaml b/speaker-diarization-3.1/.github/workflows/sync_to_hub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b0192054281691ad01d08d6976a03793e73262fa --- /dev/null +++ b/speaker-diarization-3.1/.github/workflows/sync_to_hub.yaml @@ -0,0 +1,20 @@ +name: Sync to Hugging Face hub + +on: + push: + branches: [main] + + # to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + sync-to-hub: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Push to hub + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: git push https://hbredin:$HF_TOKEN@huggingface.co/pyannote/speaker-diarization-3.1 main --force diff --git a/speaker-diarization-3.1/source.txt b/speaker-diarization-3.1/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f20782b78ab04d7bacc5b3ee5fac19ce73933f2 --- /dev/null +++ b/speaker-diarization-3.1/source.txt @@ -0,0 +1 @@ +https://huggingface.co/pyannote/speaker-diarization-3.1 \ No newline at end of file diff --git a/speaker-diarization-community-1-cloud/.gitattributes b/speaker-diarization-community-1-cloud/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/speaker-diarization-community-1-cloud/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/speaker-diarization-community-1-cloud/README.md b/speaker-diarization-community-1-cloud/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f752d00b67c77cc2eab71c523077c1bfb5d0a50b --- /dev/null +++ b/speaker-diarization-community-1-cloud/README.md @@ -0,0 +1,42 @@ +--- +tags: + - pyannote + - pyannote-audio + - pyannote-audio-pipeline + - audio + - voice + - speech + - speaker + - speaker-diarization + - speaker-change-detection + - voice-activity-detection + - overlapped-speech-detection +--- + +# Hosted `Community-1` speaker diarization + +This pipeline runs [`Community-1`](https://hf.co/pyannote/speaker-diarization-community-1) speaker diarization on [pyannoteAI](https://www.pyannote.ai) cloud. +Read the announcement [blog post](https://www.pyannote.ai/blog/community-1). + +## Setup + +1. `pip install pyannote.audio` +2. Create an API key on [`pyannoteAI` dashboard](https://dashboard.pyannote.ai) (free credits included) + +## Usage + +```python +# initialize speaker diarization pipeline +from pyannote.audio import Pipeline +pipeline = Pipeline.from_pretrained( + 'pyannote/speaker-diarization-community-1-cloud', + token="{pyannoteAI-api-key}") + +# run speaker diarization on pyannoteAI cloud +output = pipeline("/path/to/audio.wav") + +# print speaker diarization +for turn, speaker in output.speaker_diarization: + print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}") +``` + diff --git a/speaker-diarization-community-1-cloud/config.yaml b/speaker-diarization-community-1-cloud/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..115e65965542dbadaf4de8d850cdb629c57f4025 --- /dev/null +++ b/speaker-diarization-community-1-cloud/config.yaml @@ -0,0 +1,7 @@ +dependencies: + pyannote.audio: 4.0.0 + +pipeline: + name: pyannote.audio.pipelines.pyannoteai.sdk.SDK + params: + model: community-1 \ No newline at end of file diff --git a/speaker-diarization-community-1-cloud/source.txt b/speaker-diarization-community-1-cloud/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca309a83b36b83ba9a993df4906a4079a5d923d9 --- /dev/null +++ b/speaker-diarization-community-1-cloud/source.txt @@ -0,0 +1 @@ +https://huggingface.co/pyannote/speaker-diarization-community-1-cloud \ No newline at end of file diff --git a/speaker-diarization-community-1/.gitattributes b/speaker-diarization-community-1/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..7c39301dc01fe65e09b2432a27a18b4ee3e74a37 --- /dev/null +++ b/speaker-diarization-community-1/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +diarization.gif filter=lfs diff=lfs merge=lfs -text diff --git a/speaker-diarization-community-1/README.md b/speaker-diarization-community-1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8356d6634d7b1074581dd36e2225887ec809326e --- /dev/null +++ b/speaker-diarization-community-1/README.md @@ -0,0 +1,227 @@ +--- +tags: + - pyannote + - pyannote-audio + - pyannote-audio-pipeline + - audio + - voice + - speech + - speaker + - speaker-diarization + - speaker-change-detection + - voice-activity-detection + - overlapped-speech-detection + - automatic-speech-recognition +license: cc-by-4.0 +extra_gated_prompt: "Your input helps us strengthen the pyannote community and improve our open-source offerings. This pipeline is released under the CC-BY-4.0 license and will always remain freely accessible. By providing your details, you agree that we may email you occasionally with important news about pyannote models, invitations to try premium pipelines, and information about specific services designed for researchers and professionals like you." +extra_gated_fields: + Company/university: text + Use case: + type: select + options: + - label: Meeting note taker (automated meeting transcription, action item extraction, and speaker identification in recordings) + value: meeting + - label: Conversation AI (chatbots, voice assistants, multi-turn dialogue systems with speaker awareness) + value: conversation + - label: CCaaS and customer experience (call center analytics, customer service optimization, and interaction quality monitoring) + value: ccaas + - label: Voice agents (AI-powered phone systems, automated customer service, voice-based interactions) + value: agent + - label: Media and automated dubbing (content creation, podcast processing, video production, and multilingual media) + value: dubbing + - label: Training and development (educational content analysis, corporate training evaluation, and learning assessment tools) + value: training + - label: Other + value: other +--- + +# `community-1` speaker diarization + +This pipeline ingests mono audio sampled at 16kHz and outputs speaker diarization. + +- stereo or multi-channel audio files are automatically downmixed to mono by averaging the channels. +- audio files sampled at a different rate are resampled to 16kHz automatically upon loading. + +The [main improvements brought by `Community-1`](https://www.pyannote.ai/blog/community-1) are: + +- [improved](#benchmark) speaker assignment and counting +- simpler reconciliation with transcription timestamps with [*exclusive*](#exclusive-speaker-diarization) speaker diarization +- easy [offline use](#offline-use) (i.e. without internet connection) +- (optionally) [hosted](https://hf.co/pyannote/speaker-diarization-community-1-cloud) on pyannoteAI cloud + + +## Setup + +1. `pip install pyannote.audio` +2. Accept user conditions +3. Create access token at [`hf.co/settings/tokens`](https://hf.co/settings/tokens). + +## Quick start + +```python +# download the pipeline from Huggingface +from pyannote.audio import Pipeline +pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-community-1", + token="{huggingface-token}") + +# run the pipeline locally on your computer +output = pipeline("audio.wav") + +# print the predicted speaker diarization +for turn, speaker in output.speaker_diarization: + print(f"{speaker} speaks between t={turn.start:.3f}s and t={turn.end:.3f}s") +``` + +## Benchmark + +Out of the box, `Community-1` is much better than `speaker-diarization-3.1`. + +We report [diarization error rates](http://pyannote.github.io/pyannote-metrics/reference.html#diarization) (in %) on large collection of academic benchmarks (fully automatic processing, no forgiveness collar, nor skipping overlapping speech). + +| Benchmark (last updated in 2025-09) | `legacy` (3.1)| `community-1` | `precision-2` | +| --------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ | -------------------------------------------------| ------------------------------------------------ | +| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 12.2 | 11.7 | 11.4 | +| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 24.5 | 20.3 | 15.2 | +| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.8 | 17.0 | 12.9 | +| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 22.7 | 19.9 | 15.6 | +| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 49.7 | 44.6 | 37.1 | +| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 28.5 | 26.7 | 16.6 | +| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 21.4 | 20.2 | 14.7 | +| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 51.2 | 46.8 | 39.0 | +| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 25.4 | 22.8 | 17.3 | +| [RAMC](https://www.openslr.org/123/) | 22.2 | 20.8 | 10.5 | +| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 7.9 | 8.9 | 7.4 | +| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.2 | 8.5 | + +`Precision-2` model is even better and can be tested like this: + +1. Create an API key on [pyannoteAI dashboard]((https://dashboard.pyannote.ai)) (free credits included) +2. Change one line of code + +```diff +from pyannote.audio import Pipeline +pipeline = Pipeline.from_pretrained( +- 'pyannote/speaker-diarization-community-1', token="{huggingface-token}") ++ 'pyannote/speaker-diarization-precision-2', token="{pyannoteAI-api-key}") +diarization = pipeline("audio.wav") # runs on pyannoteAI servers +``` + +## Processing on GPU + +`pyannote.audio` pipelines run on CPU by default. +You can send them to GPU with the following lines: + +```python +import torch +pipeline.to(torch.device("cuda")) +``` + +## Processing from memory + +Pre-loading audio files in memory may result in faster processing: + +```python +waveform, sample_rate = torchaudio.load("audio.wav") +output = pipeline({"waveform": waveform, "sample_rate": sample_rate}) +``` + +## Monitoring progress + +Hooks are available to monitor the progress of the pipeline: + +```python +from pyannote.audio.pipelines.utils.hook import ProgressHook +with ProgressHook() as hook: + output = pipeline("audio.wav", hook=hook) +``` + +## Controlling the number of speakers + +In case the number of speakers is known in advance, one can use the `num_speakers` option: + +```python +output = pipeline("audio.wav", num_speakers=2) +``` + +One can also provide lower and/or upper bounds on the number of speakers using `min_speakers` and `max_speakers` options: + +```python +output = pipeline("audio.wav", min_speakers=2, max_speakers=5) +``` + +## Exclusive speaker diarization + +`Community-1` pretrained pipeline returns a new *exclusive* speaker diarization, on top of the regular speaker diarization, available as `output.exclusive_speaker_diarization`. + +This is a feature which is [backported from our latest commercial model](https://www.pyannote.ai/blog/precision-2) that simplifies the reconciliation between fine-grained speaker diarization timestamps and (sometimes not so precise) transcription timestamps. + +## Offline use + +1. In the terminal, copy the pipeline on disk: + +```bash +# make sure git-lfs is installed (https://git-lfs.com) +git lfs install + +# create a directory on disk +mkdir /path/to/directory + +# when prompted for a password, use an access token with write permissions. +# generate one from your settings: https://huggingface.co/settings/tokens +git clone https://hf.co/pyannote/speaker-diarization-community-1 /path/to/directory/pyannote-speaker-diarization-community-1 +``` + +2. In Python, use the pipeline without internet connection: + +```python +# load pipeline from disk (works without internet connection) +from pyannote.audio import Pipeline +pipeline = Pipeline.from_pretrained('/path/to/directory/pyannote-speaker-diarization-community-1') + +# run the pipeline locally on your computer +output = pipeline("audio.wav") +``` + +## Citations + +1. Speaker segmentation model + +```bibtex +@inproceedings{Plaquet23, + author={Alexis Plaquet and Hervé Bredin}, + title={{Powerset multi-class cross entropy loss for neural speaker diarization}}, + year=2023, + booktitle={Proc. INTERSPEECH 2023}, +} +``` + +2. Speaker embedding model + +```bibtex +@inproceedings{Wang2023, + title={Wespeaker: A research and production oriented speaker embedding learning toolkit}, + author={Wang, Hongji and Liang, Chengdong and Wang, Shuai and Chen, Zhengyang and Zhang, Binbin and Xiang, Xu and Deng, Yanlei and Qian, Yanmin}, + booktitle={ICASSP 2023, IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={1--5}, + year={2023}, + organization={IEEE} +} +``` + + +3. Speaker clustering + +```bibtex +@article{Landini2022, + author={Landini, Federico and Profant, J{\'a}n and Diez, Mireia and Burget, Luk{\'a}{\v{s}}}, + title={{Bayesian HMM clustering of x-vector sequences (VBx) in speaker diarization: theory, implementation and analysis on standard tasks}}, + year={2022}, + journal={Computer Speech \& Language}, +} +``` + +## Acknowledgment + +Training and tuning made possible thanks to [GENCI](https://www.genci.fr/) on the [**Jean Zay**](http://www.idris.fr/eng/jean-zay/) supercomputer. + diff --git a/speaker-diarization-community-1/config.yaml b/speaker-diarization-community-1/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4022db43960736338378fdb6b5a85cfdae198910 --- /dev/null +++ b/speaker-diarization-community-1/config.yaml @@ -0,0 +1,21 @@ +dependencies: + pyannote.audio: 4.0.0 + +pipeline: + name: pyannote.audio.pipelines.SpeakerDiarization + params: + clustering: VBxClustering + segmentation: $model/segmentation + segmentation_batch_size: 32 + embedding: $model/embedding + embedding_batch_size: 32 + embedding_exclude_overlap: true + plda: $model/plda + +params: + clustering: + threshold: 0.6 + Fa: 0.07 + Fb: 0.8 + segmentation: + min_duration_off: 0.0 diff --git a/speaker-diarization-community-1/diarization.gif b/speaker-diarization-community-1/diarization.gif new file mode 100644 index 0000000000000000000000000000000000000000..114f825e4f8e854ed9ebbf06fd75feaa68d16b88 --- /dev/null +++ b/speaker-diarization-community-1/diarization.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d925ad38995d89009260e493b0ae2e684c3e1397f495265ed841c45c4f73a35 +size 861445 diff --git a/speaker-diarization-community-1/embedding/README.md b/speaker-diarization-community-1/embedding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d71f943e00016652680bef1e6c49fbc842e0d4c4 --- /dev/null +++ b/speaker-diarization-community-1/embedding/README.md @@ -0,0 +1,20 @@ +Copied from https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM + +## License + +According to [this page](https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md): + +> The pretrained model in WeNet follows the license of it's corresponding dataset. For example, the pretrained model on VoxCeleb follows Creative Commons Attribution 4.0 International License., since it is used as license of the VoxCeleb dataset, see https://mm.kaist.ac.kr/datasets/voxceleb/. + +## Citation + +```bibtex +@inproceedings{Wang2023, + title={Wespeaker: A research and production oriented speaker embedding learning toolkit}, + author={Wang, Hongji and Liang, Chengdong and Wang, Shuai and Chen, Zhengyang and Zhang, Binbin and Xiang, Xu and Deng, Yanlei and Qian, Yanmin}, + booktitle={ICASSP 2023, IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={1--5}, + year={2023}, + organization={IEEE} +} +``` diff --git a/speaker-diarization-community-1/embedding/pytorch_model.bin b/speaker-diarization-community-1/embedding/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..6347697826a1958253dd195fa296d8d97b8f6280 --- /dev/null +++ b/speaker-diarization-community-1/embedding/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f10ff60898a1d185fa22e1d11e0bfa8a92efec811f11bca48cb8cafebefd929 +size 26646242 diff --git a/speaker-diarization-community-1/plda/README.md b/speaker-diarization-community-1/plda/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5c01ad2e673203fe32a82d7a140eb817db9930b7 --- /dev/null +++ b/speaker-diarization-community-1/plda/README.md @@ -0,0 +1,3 @@ +PLDA model trained by [BUT Speech@FIT](https://speech.fit.vut.cz/) group. + +Thanks to [Jiangyu Han](https://github.com/jyhan03) and [Petr Pálka](https://github.com/Selesnyan) for the integration of VBx in pyannote.audio. \ No newline at end of file diff --git a/speaker-diarization-community-1/plda/plda.npz b/speaker-diarization-community-1/plda/plda.npz new file mode 100644 index 0000000000000000000000000000000000000000..e61936ef2108eccdff801c440d5e3f6c6995aba4 --- /dev/null +++ b/speaker-diarization-community-1/plda/plda.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b77bcd840692710dd3496f62ecfeed8d8e5f002fd991b785079b244eab7d255 +size 133852 diff --git a/speaker-diarization-community-1/plda/xvec_transform.npz b/speaker-diarization-community-1/plda/xvec_transform.npz new file mode 100644 index 0000000000000000000000000000000000000000..b8079f30a35a68fe7d996bd7abc2170338ba6767 --- /dev/null +++ b/speaker-diarization-community-1/plda/xvec_transform.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:325f1ce8e48f7e55e9c8aa47e05d2766b7c48c4b25b8de8dd751e7a4cc5fbe8f +size 134376 diff --git a/speaker-diarization-community-1/segmentation/pytorch_model.bin b/speaker-diarization-community-1/segmentation/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..e3b91de5a5374fc40b556f9cc51317d280b4ea79 --- /dev/null +++ b/speaker-diarization-community-1/segmentation/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ad24338d844fb95985486eb1a464e32d229f6d7a03c9abe60f978bacf3f816e +size 5906507 diff --git a/speaker-diarization-community-1/source.txt b/speaker-diarization-community-1/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..28f99c9ddb06d411fca2855f5f35894e02e8eacc --- /dev/null +++ b/speaker-diarization-community-1/source.txt @@ -0,0 +1 @@ +https://huggingface.co/pyannote/speaker-diarization-community-1 \ No newline at end of file diff --git a/speaker-diarization-precision-2/.gitattributes b/speaker-diarization-precision-2/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/speaker-diarization-precision-2/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/speaker-diarization-precision-2/README.md b/speaker-diarization-precision-2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7f471286db4d2cb63e761773112df1d780eb69cf --- /dev/null +++ b/speaker-diarization-precision-2/README.md @@ -0,0 +1,53 @@ +--- +tags: + - pyannote + - pyannote-audio + - pyannote-audio-pipeline + - audio + - voice + - speech + - speaker + - speaker-diarization + - speaker-change-detection + - voice-activity-detection + - overlapped-speech-detection +--- + +# `Precision-2` speaker diarization + +This pipeline runs `Precision-2` speaker diarization on [pyannoteAI](https://www.pyannote.ai) cloud. +Read the announcement [blog post](https://www.pyannote.ai/blog/precision-2). + +This pipeline is a stripped down version of pyannoteAI SDK that provides [much more features](https://docs.pyannote.ai): +* speaker diarization optimized for speech-to-text +* speaker voiceprinting and identification +* confidence scores +* and more... + +A self-hosted version of `Precision-2` is also available for enterprise customers. + +## Setup + +1. `pip install pyannote.audio` +2. Create an API key on [`pyannoteAI` dashboard](https://dashboard.pyannote.ai) (free credits included) + +## Usage + +```python +# initialize speaker diarization pipeline +from pyannote.audio import Pipeline +pipeline = Pipeline.from_pretrained( + 'pyannote/speaker-diarization-precision-2', + token="{pyannoteAI-api-key}") + +# run speaker diarization on pyannoteAI cloud +output = pipeline("/path/to/audio.wav") + +# enjoy state-of-the-art speaker diarization +for turn, speaker in output.speaker_diarization: + print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}") +``` + + + + diff --git a/speaker-diarization-precision-2/config.yaml b/speaker-diarization-precision-2/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48a160c19d66903d9d4f6849a93d4bc805c5a585 --- /dev/null +++ b/speaker-diarization-precision-2/config.yaml @@ -0,0 +1,7 @@ +dependencies: + pyannote.audio: 4.0.0 + +pipeline: + name: pyannote.audio.pipelines.pyannoteai.sdk.SDK + params: + model: precision-2 \ No newline at end of file diff --git a/speaker-diarization-precision-2/source.txt b/speaker-diarization-precision-2/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..120f9aa656ab45e0fad00450c7ae378cbc4c0b23 --- /dev/null +++ b/speaker-diarization-precision-2/source.txt @@ -0,0 +1 @@ +https://huggingface.co/pyannote/speaker-diarization-precision-2 \ No newline at end of file diff --git a/wespeaker-voxceleb-resnet34-LM/LICENCE.md b/wespeaker-voxceleb-resnet34-LM/LICENCE.md new file mode 100644 index 0000000000000000000000000000000000000000..1b116fbd9e410a379801292073d336707b27c97c --- /dev/null +++ b/wespeaker-voxceleb-resnet34-LM/LICENCE.md @@ -0,0 +1,13 @@ +Copyright 2022 ChengDong Liang + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/wespeaker-voxceleb-resnet34-LM/README.md b/wespeaker-voxceleb-resnet34-LM/README.md index 6b8a7afcd16051587687a13e556a8b0ee587f5ad..d151b72237c3df707fd6cb31e783903171804fa2 100644 --- a/wespeaker-voxceleb-resnet34-LM/README.md +++ b/wespeaker-voxceleb-resnet34-LM/README.md @@ -1,111 +1,37 @@ --- -tags: - - pyannote - - pyannote-audio - - pyannote-audio-model - - wespeaker - - audio - - voice - - speech - - speaker - - speaker-recognition - - speaker-verification - - speaker-identification - - speaker-embedding -datasets: - - voxceleb -license: cc-by-4.0 -inference: false +license: apache-2.0 --- +# WeSpeaker ResNet34 speaker embedding -Using this open-source model in production? -Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options. +This is a copy of `voxceleb_resnet34_LM.onnx` speaker embedding model taken from [ChengDong Liang's repository](https://huggingface.co/chengdongliang/wespeaker). -# 🎹 Wrapper around wespeaker-voxceleb-resnet34-LM - -This model requires `pyannote.audio` version 3.1 or higher. - -This is a wrapper around [WeSpeaker](https://github.com/wenet-e2e/wespeaker) `wespeaker-voxceleb-resnet34-LM` pretrained speaker embedding model, for use in `pyannote.audio`. - -## Basic usage - -```python -# instantiate pretrained model -from pyannote.audio import Model -model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") ``` +Copyright 2022 ChengDong Liang -```python -from pyannote.audio import Inference -inference = Inference(model, window="whole") -embedding1 = inference("speaker1.wav") -embedding2 = inference("speaker2.wav") -# `embeddingX` is (1 x D) numpy array extracted from the file as a whole. - -from scipy.spatial.distance import cdist -distance = cdist(embedding1, embedding2, metric="cosine")[0,0] -# `distance` is a `float` describing how dissimilar speakers 1 and 2 are. -``` - -## Advanced usage - -### Running on GPU +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at -```python -import torch -inference.to(torch.device("cuda")) -embedding = inference("audio.wav") -``` - -### Extract embedding from an excerpt + http://www.apache.org/licenses/LICENSE-2.0 -```python -from pyannote.audio import Inference -from pyannote.core import Segment -inference = Inference(model, window="whole") -excerpt = Segment(13.37, 19.81) -embedding = inference.crop("audio.wav", excerpt) -# `embedding` is (1 x D) numpy array extracted from the file excerpt. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. ``` -### Extract embeddings using a sliding window +## How to use with [pyannote.audio](https://github.com/pyannote/pyannote-audio) ```python -from pyannote.audio import Inference -inference = Inference(model, window="sliding", - duration=3.0, step=1.0) -embeddings = inference("audio.wav") -# `embeddings` is a (N x D) pyannote.core.SlidingWindowFeature -# `embeddings[i]` is the embedding of the ith position of the -# sliding window, i.e. from [i * step, i * step + duration]. -``` +from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding +get_embedding = PretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM") -## License - -According to [this page](https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md): - -> The pretrained model in WeNet follows the license of it's corresponding dataset. For example, the pretrained model on VoxCeleb follows Creative Commons Attribution 4.0 International License., since it is used as license of the VoxCeleb dataset, see https://mm.kaist.ac.kr/datasets/voxceleb/. - -## Citation - -```bibtex -@inproceedings{Wang2023, - title={Wespeaker: A research and production oriented speaker embedding learning toolkit}, - author={Wang, Hongji and Liang, Chengdong and Wang, Shuai and Chen, Zhengyang and Zhang, Binbin and Xiang, Xu and Deng, Yanlei and Qian, Yanmin}, - booktitle={ICASSP 2023, IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, - pages={1--5}, - year={2023}, - organization={IEEE} -} -``` +assert waveforms.ndim == 3 +batch_size, num_channels, num_samples = waveforms.shape +assert num_channels == 1 -```bibtex -@inproceedings{Bredin23, - author={Hervé Bredin}, - title={{pyannote.audio 2.1 speaker diarization pipeline: principle, benchmark, and recipe}}, - year=2023, - booktitle={Proc. INTERSPEECH 2023}, - pages={1983--1987}, - doi={10.21437/Interspeech.2023-105} -} +embeddings = get_embedding(waveforms) +assert embeddings.ndim == 2 +assert embeddings.shape[0] == batch_size ``` diff --git a/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin b/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin index 8ac248f7e8333ec1d22c55c5e2af4ac8d15596e3..fe1c74dab7fbc9c00dc39920e3cbd4ea6660b6cc 100644 --- a/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin +++ b/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:366edf44f4c80889a3eb7a9d7bdf02c4aede3127f7dd15e274dcdb826b143c56 -size 26645418 +oid sha256:26bba0ee9a3461d0f99ca36ead2626ca3610bfd9896dd479583b3c47e94f7f32 +size 26647953 diff --git a/wespeaker-voxceleb-resnet34-LM/source.txt b/wespeaker-voxceleb-resnet34-LM/source.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c10854d3435e4f02498b3a8091cae30f8b3d470 --- /dev/null +++ b/wespeaker-voxceleb-resnet34-LM/source.txt @@ -0,0 +1 @@ +https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM \ No newline at end of file diff --git a/wespeaker-voxceleb-resnet34-LM/speaker-embedding.onnx b/wespeaker-voxceleb-resnet34-LM/speaker-embedding.onnx new file mode 100644 index 0000000000000000000000000000000000000000..81d3cfc3bfc4c32e4f2c3586a2042b010da3d415 --- /dev/null +++ b/wespeaker-voxceleb-resnet34-LM/speaker-embedding.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bb2f06e9df17cdf1ef14ee8a15ab08ed28e8d0ef5054ee135741560df2ec068 +size 26530309