Pyannote (models, models_onnx)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- ailia-models/code/LICENSE +21 -0
- ailia-models/code/README.md +135 -0
- ailia-models/code/config.yaml +17 -0
- ailia-models/code/data/sample.rttm +10 -0
- ailia-models/code/data/sample.wav +3 -0
- ailia-models/code/output.png +0 -0
- ailia-models/code/output_ground.png +0 -0
- ailia-models/code/pyannote-audio.py +181 -0
- ailia-models/code/pyannote_audio_utils/__init__.py +23 -0
- ailia-models/code/pyannote_audio_utils/audio/__init__.py +33 -0
- ailia-models/code/pyannote_audio_utils/audio/core/inference.py +596 -0
- ailia-models/code/pyannote_audio_utils/audio/core/io.py +352 -0
- ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py +218 -0
- ailia-models/code/pyannote_audio_utils/audio/core/task.py +125 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py +35 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py +468 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py +553 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py +249 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py +37 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py +248 -0
- ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py +291 -0
- ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py +59 -0
- ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py +125 -0
- ailia-models/code/pyannote_audio_utils/audio/utils/signal.py +369 -0
- ailia-models/code/pyannote_audio_utils/audio/version.py +1 -0
- ailia-models/code/pyannote_audio_utils/core/__init__.py +48 -0
- ailia-models/code/pyannote_audio_utils/core/_version.py +20 -0
- ailia-models/code/pyannote_audio_utils/core/annotation.py +1551 -0
- ailia-models/code/pyannote_audio_utils/core/feature.py +329 -0
- ailia-models/code/pyannote_audio_utils/core/notebook.py +468 -0
- ailia-models/code/pyannote_audio_utils/core/segment.py +910 -0
- ailia-models/code/pyannote_audio_utils/core/timeline.py +1126 -0
- ailia-models/code/pyannote_audio_utils/core/utils/generators.py +89 -0
- ailia-models/code/pyannote_audio_utils/core/utils/types.py +13 -0
- ailia-models/code/pyannote_audio_utils/database/__init__.py +91 -0
- ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py +34 -0
- ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py +434 -0
- ailia-models/code/pyannote_audio_utils/database/util.py +400 -0
- ailia-models/code/pyannote_audio_utils/metrics/__init__.py +36 -0
- ailia-models/code/pyannote_audio_utils/metrics/_version.py +21 -0
- ailia-models/code/pyannote_audio_utils/metrics/base.py +419 -0
- ailia-models/code/pyannote_audio_utils/metrics/diarization.py +167 -0
- ailia-models/code/pyannote_audio_utils/metrics/identification.py +274 -0
- ailia-models/code/pyannote_audio_utils/metrics/matcher.py +192 -0
- ailia-models/code/pyannote_audio_utils/metrics/types.py +7 -0
- ailia-models/code/pyannote_audio_utils/metrics/utils.py +225 -0
- ailia-models/code/pyannote_audio_utils/pipeline/__init__.py +37 -0
- ailia-models/code/pyannote_audio_utils/pipeline/parameter.py +203 -0
- ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py +614 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,5 @@ brouhaha/brouhaha.gif filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
separation-ami-1.0/model.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
speaker-diarization/technical_report_2.1.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
speech-separation-ami-1.0/pipeline.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 37 |
separation-ami-1.0/model.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
speaker-diarization/technical_report_2.1.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
speech-separation-ami-1.0/pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
ailia-models/code/data/sample.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
speaker-diarization-community-1/diarization.gif filter=lfs diff=lfs merge=lfs -text
|
ailia-models/code/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 CNRS
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
ailia-models/code/README.md
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pyannote-audio : Speaker Diarization
|
| 2 |
+
|
| 3 |
+
## Input
|
| 4 |
+
|
| 5 |
+
Audio file (.wav format).
|
| 6 |
+
```
|
| 7 |
+
Example
|
| 8 |
+
input: data/demo.wav
|
| 9 |
+
```
|
| 10 |
+
(Wav file from https://github.com/pyannote/pyannote-audio/tree/develop/pyannote/audio/sample)
|
| 11 |
+
|
| 12 |
+
## Output
|
| 13 |
+
|
| 14 |
+
When and who spoke.
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
[ 00:00:06.714 --> 00:00:07.003] A speaker91
|
| 19 |
+
[ 00:00:07.003 --> 00:00:07.173] B speaker90
|
| 20 |
+
[ 00:00:07.580 --> 00:00:08.310] C speaker91
|
| 21 |
+
[ 00:00:08.310 --> 00:00:09.923] D speaker90
|
| 22 |
+
[ 00:00:09.923 --> 00:00:10.976] E speaker91
|
| 23 |
+
[ 00:00:10.466 --> 00:00:14.745] F speaker90
|
| 24 |
+
[ 00:00:14.303 --> 00:00:17.886] G speaker91
|
| 25 |
+
[ 00:00:18.022 --> 00:00:21.502] H speaker90
|
| 26 |
+
[ 00:00:18.157 --> 00:00:18.446] I speaker91
|
| 27 |
+
[ 00:00:21.774 --> 00:00:28.531] J speaker91
|
| 28 |
+
[ 00:00:27.886 --> 00:00:29.991] K speaker90
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Requirements
|
| 32 |
+
|
| 33 |
+
This model recommends additional module.
|
| 34 |
+
```bash
|
| 35 |
+
$ pip3 install -r requirements.txt
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Usage
|
| 39 |
+
|
| 40 |
+
Automatically downloads the onnx and prototxt files on the first run.
|
| 41 |
+
It is necessary to be connected to the Internet while downloading.
|
| 42 |
+
|
| 43 |
+
For the sample
|
| 44 |
+
```bash
|
| 45 |
+
$ python pyannote-audio.py -i ./data/sample.wav
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
For the sample with plot
|
| 49 |
+
```bash
|
| 50 |
+
$ python pyannote-audio.py -i ./data/sample.wav --plt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
For the sample with verification
|
| 54 |
+
```bash
|
| 55 |
+
$ python pyannote-audio.py -i ./data/sample.wav -g ./data/sample.rttm
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
If you want to specify the audio, put the file path after the `--i` or `-input` option.
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
$ python pyannote-audio.py --i FILE_PATH
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
If you want to specify the ground truth, put the file path after the `--ig` or `-input_ground` option.
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
$ python pyannote-audio.py --ig FILE_PATH
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
If you want to specify the output file, put the file path after the `--o` or `-output` option.
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
$ python pyannote-audio.py --o FILE_PATH
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
If you want to specify the output ground truth file, put the file path after the `--og` or `-output_ground` option.
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
$ python pyannote-audio.py --og FILE_PATH
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
If you know the number of speakers, put the numper `--num` or `-num_speaker` option.
|
| 83 |
+
```bash
|
| 84 |
+
$ python pyannote-audio.py --num 2
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
If you know the maxisimum number of speakers, put the numper `--max` or `-max_speaker` option.
|
| 88 |
+
```bash
|
| 89 |
+
$ python pyannote-audio.py --max 4
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
If you know the minimum number of speakers, put the numper `--min` or `-min_speaker` option.
|
| 93 |
+
```bash
|
| 94 |
+
$ python pyannote-audio.py --min 2
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
By giving the `--e` or `-error` option, you can get diarization error rate.
|
| 98 |
+
```bash
|
| 99 |
+
$ python pyannote-audio.py --use_onnx
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
By giving the `--plt` option, you can visualize results.
|
| 103 |
+
```bash
|
| 104 |
+
$ python pyannote-audio.py --use_onnx
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
By giving the `--use_onnx` option, you can use onnx.
|
| 108 |
+
```bash
|
| 109 |
+
$ python pyannote-audio.py --use_onnx
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
By giving the `--embed` option, you can get embedding vector in the input file.
|
| 113 |
+
```bash
|
| 114 |
+
$ python pyannote-audio.py --embed
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Reference
|
| 118 |
+
|
| 119 |
+
- [Pyannote-audio](https://github.com/pyannote/pyannote-audio)
|
| 120 |
+
- [Hugging Face - pyannote in speaker-diariazation](https://huggingface.co/pyannote/speaker-diarization-3.1)
|
| 121 |
+
- [Hugging Face - hdbrain in wespeaker-voxceleb-resnet34-LM](https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main)
|
| 122 |
+
- [KaldiFeat](https://github.com/yuyq96/kaldifeat)
|
| 123 |
+
|
| 124 |
+
## Framework
|
| 125 |
+
|
| 126 |
+
Pytorch
|
| 127 |
+
|
| 128 |
+
## Model Format
|
| 129 |
+
|
| 130 |
+
ONNX opset=14,17
|
| 131 |
+
|
| 132 |
+
## Netron
|
| 133 |
+
|
| 134 |
+
- [segmentation.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pyannote-audio/segmentation.onnx.prototxt)
|
| 135 |
+
- [speaker-embedding.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pyannote-audio/speaker-embedding.onnx.prototxt)
|
ailia-models/code/config.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
params:
|
| 2 |
+
clustering:
|
| 3 |
+
method: centroid
|
| 4 |
+
min_cluster_size: 12
|
| 5 |
+
threshold: 0.7045654963945799
|
| 6 |
+
segmentation:
|
| 7 |
+
min_duration_off: 0.0
|
| 8 |
+
pipeline:
|
| 9 |
+
name: pyannote.audio.pipelines.SpeakerDiarization
|
| 10 |
+
params:
|
| 11 |
+
clustering: AgglomerativeClustering
|
| 12 |
+
embedding: speaker-embedding.onnx
|
| 13 |
+
embedding_batch_size: 32
|
| 14 |
+
embedding_exclude_overlap: true
|
| 15 |
+
segmentation: segmentation.onnx
|
| 16 |
+
segmentation_batch_size: 32
|
| 17 |
+
version: 3.1.0
|
ailia-models/code/data/sample.rttm
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SPEAKER sample 1 6.690 0.430 <NA> <NA> speaker90 <NA> <NA>
|
| 2 |
+
SPEAKER sample 1 7.550 0.800 <NA> <NA> speaker91 <NA> <NA>
|
| 3 |
+
SPEAKER sample 1 8.320 1.700 <NA> <NA> speaker90 <NA> <NA>
|
| 4 |
+
SPEAKER sample 1 9.920 1.110 <NA> <NA> speaker91 <NA> <NA>
|
| 5 |
+
SPEAKER sample 1 10.570 4.130 <NA> <NA> speaker90 <NA> <NA>
|
| 6 |
+
SPEAKER sample 1 14.490 3.430 <NA> <NA> speaker91 <NA> <NA>
|
| 7 |
+
SPEAKER sample 1 18.050 3.440 <NA> <NA> speaker90 <NA> <NA>
|
| 8 |
+
SPEAKER sample 1 18.150 0.440 <NA> <NA> speaker91 <NA> <NA>
|
| 9 |
+
SPEAKER sample 1 21.780 6.720 <NA> <NA> speaker91 <NA> <NA>
|
| 10 |
+
SPEAKER sample 1 27.850 2.150 <NA> <NA> speaker90 <NA> <NA>
|
ailia-models/code/data/sample.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c319b4abca767b124e41432d364fd7df006cb26bb79d09326c487d606a134e6e
|
| 3 |
+
size 960104
|
ailia-models/code/output.png
ADDED
|
ailia-models/code/output_ground.png
ADDED
|
ailia-models/code/pyannote-audio.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import sys
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from pyannote_audio_utils.audio.pipelines.speaker_diarization import SpeakerDiarization
|
| 7 |
+
from pyannote_audio_utils.core import Segment, Annotation
|
| 8 |
+
from pyannote_audio_utils.core.notebook import Notebook
|
| 9 |
+
from pyannote_audio_utils.database.util import load_rttm
|
| 10 |
+
from pyannote_audio_utils.metrics.diarization import DiarizationErrorRate
|
| 11 |
+
|
| 12 |
+
sys.path.append('../../util')
|
| 13 |
+
from arg_utils import get_base_parser, update_parser # noqa: E402
|
| 14 |
+
from model_utils import check_and_download_models # noqa: E402
|
| 15 |
+
from logging import getLogger # noqa: E402
|
| 16 |
+
logger = getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
WEIGHT_SEGMENTATION_PATH = 'segmentation.onnx'
|
| 19 |
+
MODEL_SEGMENTATION_PATH = 'segmentation.onnx.prototxt'
|
| 20 |
+
WEIGHT_EMBEDDING_PATH = 'speaker-embedding.onnx'
|
| 21 |
+
MODEL_EMBEDDING_PATH = 'speaker-embedding.onnx.prototxt'
|
| 22 |
+
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/pyannote-audio/'
|
| 23 |
+
YAML_PATH = 'config.yaml'
|
| 24 |
+
OUT_PATH = 'output.png'
|
| 25 |
+
|
| 26 |
+
parser = get_base_parser(
|
| 27 |
+
'Pyannote-audio', './data/sample.wav', None, input_ftype='audio'
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--num', '-num_speaker', default=0, type=int,
|
| 32 |
+
help='If the number of speakers is fixed',
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
'--max', '-max_speaker', default=0, type=int,
|
| 36 |
+
help='If the maximum number of speakers is fixed',
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
'--min', '-min_speaker', default=0, type=int,
|
| 40 |
+
help='If the minimum number of speakers is fixed',
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
'--ig', '-ground', default=None,
|
| 44 |
+
help='Specify a wav file as ground truth. If you need diarization error rate, you need this file'
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
'--o', '-output', default='output.png',
|
| 48 |
+
help='Specify an output file'
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
'--og', '-output_ground', default='output_ground.png',
|
| 52 |
+
help='Specify an output ground truth file'
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
'--e', '-error',
|
| 56 |
+
action='store_true',
|
| 57 |
+
help='If you need diarization error rate'
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
'--plt',
|
| 61 |
+
action='store_true',
|
| 62 |
+
help='If you want to visualize result'
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
'--embed',
|
| 66 |
+
action='store_true',
|
| 67 |
+
help='If you need embedding vector',
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
'--onnx',
|
| 71 |
+
action='store_true',
|
| 72 |
+
help='execute onnxruntime version'
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
args = update_parser(parser)
|
| 76 |
+
|
| 77 |
+
def repr_annotation(args, annotation: Annotation, notebook:Notebook, ground:bool = False):
|
| 78 |
+
"""Get `png` data for `annotation`"""
|
| 79 |
+
figsize = plt.rcParams["figure.figsize"]
|
| 80 |
+
plt.rcParams["figure.figsize"] = (notebook.width, 2)
|
| 81 |
+
fig, ax = plt.subplots()
|
| 82 |
+
notebook.plot_annotation(annotation, ax=ax)
|
| 83 |
+
if ground:
|
| 84 |
+
plt.savefig(args.og)
|
| 85 |
+
else:
|
| 86 |
+
plt.savefig(args.o)
|
| 87 |
+
plt.close(fig)
|
| 88 |
+
plt.rcParams["figure.figsize"] = figsize
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
def main(args):
|
| 92 |
+
check_and_download_models(WEIGHT_SEGMENTATION_PATH, MODEL_SEGMENTATION_PATH, remote_path=REMOTE_PATH)
|
| 93 |
+
check_and_download_models(WEIGHT_EMBEDDING_PATH, MODEL_EMBEDDING_PATH, remote_path=REMOTE_PATH)
|
| 94 |
+
|
| 95 |
+
if args.benchmark:
|
| 96 |
+
start = int(round(time.time() * 1000))
|
| 97 |
+
|
| 98 |
+
with open(YAML_PATH, 'r') as yml:
|
| 99 |
+
config = yaml.safe_load(yml)
|
| 100 |
+
|
| 101 |
+
config["pipeline"]["params"]["segmentation"] = WEIGHT_SEGMENTATION_PATH
|
| 102 |
+
config["pipeline"]["params"]["embedding"] = WEIGHT_EMBEDDING_PATH
|
| 103 |
+
with open(YAML_PATH, 'w') as f:
|
| 104 |
+
yaml.dump(config, f)
|
| 105 |
+
|
| 106 |
+
audio_file = args.input[0]
|
| 107 |
+
checkpoint_path = YAML_PATH
|
| 108 |
+
config_yml = checkpoint_path
|
| 109 |
+
|
| 110 |
+
with open(config_yml, "r") as fp:
|
| 111 |
+
config = yaml.load(fp, Loader=yaml.SafeLoader)
|
| 112 |
+
|
| 113 |
+
params = config["pipeline"].get("params", {})
|
| 114 |
+
pipeline = SpeakerDiarization(
|
| 115 |
+
**params,
|
| 116 |
+
args=args,
|
| 117 |
+
seg_path=MODEL_SEGMENTATION_PATH,
|
| 118 |
+
emb_path=MODEL_EMBEDDING_PATH,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if "params" in config:
|
| 122 |
+
pipeline.instantiate(config["params"])
|
| 123 |
+
|
| 124 |
+
if args.embed:
|
| 125 |
+
if args.num > 0:
|
| 126 |
+
diarization, embeddings = pipeline(audio_file, return_embeddings=True, num_speakers=args.num)
|
| 127 |
+
for s, speaker in enumerate(diarization.labels()):
|
| 128 |
+
print(speaker, embeddings[s].shape)
|
| 129 |
+
elif args.max > 0 or args.min > 0:
|
| 130 |
+
diarization, embeddings = pipeline(audio_file, return_embeddings=True, min_speakers=args.min, max_speaker=args.max)
|
| 131 |
+
for s, speaker in enumerate(diarization.labels()):
|
| 132 |
+
print(speaker, embeddings[s].shape)
|
| 133 |
+
else:
|
| 134 |
+
diarization, embeddings = pipeline(audio_file, return_embeddings=True)
|
| 135 |
+
for s, speaker in enumerate(diarization.labels()):
|
| 136 |
+
print(speaker, embeddings[s].shape)
|
| 137 |
+
else:
|
| 138 |
+
if args.num > 0:
|
| 139 |
+
diarization = pipeline(audio_file, num_speakers=args.num)
|
| 140 |
+
elif args.max > 0 or args.min > 0:
|
| 141 |
+
diarization = pipeline(audio_file, min_speakers=args.min, max_speaker=args.max)
|
| 142 |
+
else:
|
| 143 |
+
diarization = pipeline(audio_file)
|
| 144 |
+
|
| 145 |
+
if args.benchmark:
|
| 146 |
+
end = int(round(time.time() * 1000))
|
| 147 |
+
print(f'\tailia processing time {end - start} ms')
|
| 148 |
+
|
| 149 |
+
if args.ig:
|
| 150 |
+
_, groundtruth = load_rttm(args.ig).popitem()
|
| 151 |
+
metric = DiarizationErrorRate()
|
| 152 |
+
result = metric(groundtruth, diarization, detailed=False)
|
| 153 |
+
|
| 154 |
+
mapping = metric.optimal_mapping(groundtruth, diarization)
|
| 155 |
+
diarization = diarization.rename_labels(mapping=mapping)
|
| 156 |
+
|
| 157 |
+
print(diarization)
|
| 158 |
+
if args.e:
|
| 159 |
+
print(f'diarization error rate = {100 * result:.1f}%')
|
| 160 |
+
|
| 161 |
+
if args.plt:
|
| 162 |
+
EXCERPT = Segment(0, 30)
|
| 163 |
+
notebook = Notebook()
|
| 164 |
+
notebook.crop = EXCERPT
|
| 165 |
+
repr_annotation(args, diarization, notebook)
|
| 166 |
+
repr_annotation(args, groundtruth, notebook, ground=True)
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
print(diarization)
|
| 171 |
+
|
| 172 |
+
if args.plt:
|
| 173 |
+
EXCERPT = Segment(0, 30)
|
| 174 |
+
notebook = Notebook()
|
| 175 |
+
notebook.crop = EXCERPT
|
| 176 |
+
repr_annotation(args, diarization, notebook)
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
main(args)
|
ailia-models/code/pyannote_audio_utils/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2020 CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
__import__("pkg_resources").declare_namespace(__name__)
|
ailia-models/code/pyannote_audio_utils/audio/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2020-2021 CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from .version import __version__, git_version # noqa: F401
|
| 25 |
+
except ImportError:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
from .core.inference import Inference
|
| 30 |
+
from .core.io import Audio
|
| 31 |
+
from .core.pipeline import Pipeline
|
| 32 |
+
|
| 33 |
+
__all__ = ["Audio", "Inference", "Pipeline"]
|
ailia-models/code/pyannote_audio_utils/audio/core/inference.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2020- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
import ailia
|
| 25 |
+
import warnings
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Callable, List, Optional, Text, Tuple, Union
|
| 28 |
+
from functools import cached_property
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
from pyannote_audio_utils.core import Segment, SlidingWindow, SlidingWindowFeature
|
| 33 |
+
from pyannote_audio_utils.audio.core.io import AudioFile, Audio
|
| 34 |
+
from pyannote_audio_utils.audio.core.task import Resolution, Specifications, Problem
|
| 35 |
+
from pyannote_audio_utils.audio.utils.multi_task import map_with_specifications
|
| 36 |
+
from pyannote_audio_utils.audio.utils.powerset import Powerset
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BaseInference:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class Output:
|
| 44 |
+
num_frames: int
|
| 45 |
+
dimension: int
|
| 46 |
+
frames: SlidingWindow
|
| 47 |
+
|
| 48 |
+
class Inference(BaseInference):
|
| 49 |
+
"""Inference
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
model : Model
|
| 54 |
+
Model. Will be automatically set to eval() mode and moved to `device` when provided.
|
| 55 |
+
window : {"sliding", "whole"}, optional
|
| 56 |
+
Use a "sliding" window and aggregate the corresponding outputs (default)
|
| 57 |
+
or just one (potentially long) window covering the "whole" file or chunk.
|
| 58 |
+
duration : float, optional
|
| 59 |
+
Chunk duration, in seconds. Defaults to duration used for training the model.
|
| 60 |
+
Has no effect when `window` is "whole".
|
| 61 |
+
step : float, optional
|
| 62 |
+
Step between consecutive chunks, in seconds. Defaults to warm-up duration when
|
| 63 |
+
greater than 0s, otherwise 10% of duration. Has no effect when `window` is "whole".
|
| 64 |
+
pre_aggregation_hook : callable, optional
|
| 65 |
+
When a callable is provided, it is applied to the model output, just before aggregation.
|
| 66 |
+
Takes a (num_chunks, num_frames, dimension) numpy array as input and returns a modified
|
| 67 |
+
(num_chunks, num_frames, other_dimension) numpy array passed to overlap-add aggregation.
|
| 68 |
+
skip_aggregation : bool, optional
|
| 69 |
+
Do not aggregate outputs when using "sliding" window. Defaults to False.
|
| 70 |
+
skip_conversion: bool, optional
|
| 71 |
+
In case a task has been trained with `powerset` mode, output is automatically
|
| 72 |
+
converted to `multi-label`, unless `skip_conversion` is set to True.
|
| 73 |
+
batch_size : int, optional
|
| 74 |
+
Batch size. Larger values (should) make inference faster. Defaults to 32.
|
| 75 |
+
device : torch.device, optional
|
| 76 |
+
Device used for inference. Defaults to `model.device`.
|
| 77 |
+
In case `device` and `model.device` are different, model is sent to device.
|
| 78 |
+
use_auth_token : str, optional
|
| 79 |
+
When loading a private huggingface.co model, set `use_auth_token`
|
| 80 |
+
to True or to a string containing your hugginface.co authentication
|
| 81 |
+
token that can be obtained by running `huggingface-cli login`
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
model: Union[Text, Path],
|
| 87 |
+
window: Text = "sliding",
|
| 88 |
+
duration: float = None,
|
| 89 |
+
step: float = None,
|
| 90 |
+
pre_aggregation_hook: Callable[[np.ndarray], np.ndarray] = None,
|
| 91 |
+
skip_aggregation: bool = False,
|
| 92 |
+
skip_conversion: bool = False,
|
| 93 |
+
batch_size: int = 32,
|
| 94 |
+
use_auth_token: Union[Text, None] = None,
|
| 95 |
+
args = None,
|
| 96 |
+
seg_path = None,
|
| 97 |
+
):
|
| 98 |
+
# ~~~~ model ~~~~~
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if args.onnx:
|
| 102 |
+
#print("use onnx runtime")
|
| 103 |
+
import onnxruntime
|
| 104 |
+
model = onnxruntime.InferenceSession(model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 105 |
+
else:
|
| 106 |
+
#print("use ailia")
|
| 107 |
+
model = ailia.Net(seg_path, weight=model, env_id=args.env_id)
|
| 108 |
+
|
| 109 |
+
self.model = model
|
| 110 |
+
self.args = args
|
| 111 |
+
|
| 112 |
+
specifications = Specifications(problem=Problem.MONO_LABEL_CLASSIFICATION, resolution = Resolution.FRAME, classes=['speaker#1', 'speaker#2', 'speaker#3'])
|
| 113 |
+
|
| 114 |
+
self.specifications = specifications
|
| 115 |
+
self.audio = Audio(sample_rate=16000, mono="downmix")
|
| 116 |
+
# ~~~~ sliding window ~~~~~
|
| 117 |
+
|
| 118 |
+
if window not in ["sliding", "whole"]:
|
| 119 |
+
raise ValueError('`window` must be "sliding" or "whole".')
|
| 120 |
+
|
| 121 |
+
if window == "whole" and any(
|
| 122 |
+
s.resolution == Resolution.FRAME for s in specifications
|
| 123 |
+
):
|
| 124 |
+
warnings.warn(
|
| 125 |
+
'Using "whole" `window` inference with a frame-based model might lead to bad results '
|
| 126 |
+
'and huge memory consumption: it is recommended to set `window` to "sliding".'
|
| 127 |
+
)
|
| 128 |
+
self.window = window
|
| 129 |
+
|
| 130 |
+
training_duration = next(iter(specifications)).duration
|
| 131 |
+
duration = duration or training_duration
|
| 132 |
+
if training_duration != duration:
|
| 133 |
+
warnings.warn(
|
| 134 |
+
f"Model was trained with {training_duration:g}s chunks, and you requested "
|
| 135 |
+
f"{duration:g}s chunks for inference: this might lead to suboptimal results."
|
| 136 |
+
)
|
| 137 |
+
self.duration = duration
|
| 138 |
+
|
| 139 |
+
# ~~~~ powerset to multilabel conversion ~~~~
|
| 140 |
+
|
| 141 |
+
self.skip_conversion = skip_conversion
|
| 142 |
+
|
| 143 |
+
conversion = list()
|
| 144 |
+
for s in specifications:
|
| 145 |
+
if s.powerset and not skip_conversion:
|
| 146 |
+
c = Powerset(len(s.classes), s.powerset_max_classes)
|
| 147 |
+
conversion.append(c)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if isinstance(specifications, Specifications):
|
| 151 |
+
self.conversion = conversion[0]
|
| 152 |
+
|
| 153 |
+
# ~~~~ overlap-add aggregation ~~~~~
|
| 154 |
+
|
| 155 |
+
self.skip_aggregation = skip_aggregation
|
| 156 |
+
self.pre_aggregation_hook = pre_aggregation_hook
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
self.warm_up = next(iter(specifications)).warm_up
|
| 160 |
+
# Use that many seconds on the left- and rightmost parts of each chunk
|
| 161 |
+
# to warm up the model. While the model does process those left- and right-most
|
| 162 |
+
# parts, only the remaining central part of each chunk is used for aggregating
|
| 163 |
+
# scores during inference.
|
| 164 |
+
|
| 165 |
+
# step between consecutive chunks
|
| 166 |
+
step = step or (
|
| 167 |
+
0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0]
|
| 168 |
+
)
|
| 169 |
+
if step > self.duration:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"Step between consecutive chunks is set to {step:g}s, while chunks are "
|
| 172 |
+
f"only {self.duration:g}s long, leading to gaps between consecutive chunks. "
|
| 173 |
+
f"Either decrease step or increase duration."
|
| 174 |
+
)
|
| 175 |
+
self.step = step
|
| 176 |
+
|
| 177 |
+
self.batch_size = batch_size
|
| 178 |
+
|
| 179 |
+
def infer(self, chunks: np.ndarray) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
| 180 |
+
"""Forward pass
|
| 181 |
+
|
| 182 |
+
Takes care of sending chunks to right device and outputs back to CPU
|
| 183 |
+
|
| 184 |
+
Parameters
|
| 185 |
+
----------
|
| 186 |
+
chunks : (batch_size, num_channels, num_samples) torch.Tensor
|
| 187 |
+
Batch of audio chunks.
|
| 188 |
+
|
| 189 |
+
Returns
|
| 190 |
+
-------
|
| 191 |
+
outputs : (tuple of) (batch_size, ...) np.ndarray
|
| 192 |
+
Model output.
|
| 193 |
+
"""
|
| 194 |
+
chunks = chunks.astype(np.float32)
|
| 195 |
+
if self.args.onnx:
|
| 196 |
+
outputs = self.model.run(None, {"input": chunks})[0]
|
| 197 |
+
else:
|
| 198 |
+
outputs = self.model.predict([chunks])[0]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def __convert(output: np.ndarray, conversion, **kwargs):
|
| 202 |
+
return conversion(output)
|
| 203 |
+
|
| 204 |
+
return map_with_specifications(self.specifications, __convert, outputs, self.conversion)
|
| 205 |
+
|
| 206 |
+
@cached_property
|
| 207 |
+
def example_output(self) -> Union[Output, Tuple[Output]]:
|
| 208 |
+
"""Example output"""
|
| 209 |
+
example_input_array = np.random.randn(1, 1, self.audio.get_num_samples(self.specifications.duration)).astype(np.float32)
|
| 210 |
+
|
| 211 |
+
example_outputs = self.infer(example_input_array)
|
| 212 |
+
|
| 213 |
+
def __example_output(
|
| 214 |
+
example_output: np.ndarray,
|
| 215 |
+
specifications: Specifications = None,
|
| 216 |
+
) -> Output:
|
| 217 |
+
_, num_frames, dimension = example_output.shape
|
| 218 |
+
|
| 219 |
+
if specifications.resolution == Resolution.FRAME:
|
| 220 |
+
frame_duration = specifications.duration / num_frames
|
| 221 |
+
frames = SlidingWindow(step=frame_duration, duration=frame_duration)
|
| 222 |
+
else:
|
| 223 |
+
frames = None
|
| 224 |
+
|
| 225 |
+
return Output(
|
| 226 |
+
num_frames=num_frames,
|
| 227 |
+
dimension=dimension,
|
| 228 |
+
frames=frames,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return map_with_specifications(
|
| 232 |
+
self.specifications, __example_output, example_outputs
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def slide(
|
| 236 |
+
self,
|
| 237 |
+
waveform: np.ndarray,
|
| 238 |
+
sample_rate: int,
|
| 239 |
+
hook: Optional[Callable],
|
| 240 |
+
) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]:
|
| 241 |
+
"""Slide model on a waveform
|
| 242 |
+
|
| 243 |
+
Parameters
|
| 244 |
+
----------
|
| 245 |
+
waveform: (num_channels, num_samples) torch.Tensor
|
| 246 |
+
Waveform.
|
| 247 |
+
sample_rate : int
|
| 248 |
+
Sample rate.
|
| 249 |
+
hook: Optional[Callable]
|
| 250 |
+
When a callable is provided, it is called everytime a batch is
|
| 251 |
+
processed with two keyword arguments:
|
| 252 |
+
- `completed`: the number of chunks that have been processed so far
|
| 253 |
+
- `total`: the total number of chunks
|
| 254 |
+
|
| 255 |
+
Returns
|
| 256 |
+
-------
|
| 257 |
+
output : (tuple of) SlidingWindowFeature
|
| 258 |
+
Model output. Shape is (num_chunks, dimension) for chunk-level tasks,
|
| 259 |
+
and (num_frames, dimension) for frame-level tasks.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
window_size: int = self.audio.get_num_samples(self.duration)
|
| 263 |
+
step_size: int = round(self.step * sample_rate)
|
| 264 |
+
|
| 265 |
+
_, num_samples = waveform.shape
|
| 266 |
+
|
| 267 |
+
def __frames(
|
| 268 |
+
example_output, specifications: Optional[Specifications] = None
|
| 269 |
+
) -> SlidingWindow:
|
| 270 |
+
if specifications.resolution == Resolution.CHUNK:
|
| 271 |
+
return SlidingWindow(start=0.0, duration=self.duration, step=self.step)
|
| 272 |
+
|
| 273 |
+
return example_output.frames
|
| 274 |
+
|
| 275 |
+
frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications(
|
| 276 |
+
self.specifications,
|
| 277 |
+
__frames,
|
| 278 |
+
self.example_output,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# prepare complete chunks
|
| 283 |
+
def unfold_numpy(waveform, window_size, step_size):
|
| 284 |
+
batch_size, waveform_size = waveform.shape
|
| 285 |
+
num_windows = (waveform_size - window_size) // step_size + 1
|
| 286 |
+
shape = (batch_size, num_windows, window_size)
|
| 287 |
+
strides = (
|
| 288 |
+
waveform.strides[0],
|
| 289 |
+
step_size * waveform.strides[1],
|
| 290 |
+
waveform.strides[1],
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return np.lib.stride_tricks.as_strided(waveform, shape=shape, strides=strides)
|
| 294 |
+
|
| 295 |
+
if num_samples >= window_size:
|
| 296 |
+
chunks: np.ndarray = (unfold_numpy(waveform, window_size, step_size)).transpose(1, 0, 2)
|
| 297 |
+
num_chunks = chunks.shape[0]
|
| 298 |
+
|
| 299 |
+
else:
|
| 300 |
+
num_chunks = 0
|
| 301 |
+
|
| 302 |
+
# prepare last incomplete chunk
|
| 303 |
+
|
| 304 |
+
has_last_chunk = (num_samples < window_size) or (num_samples - window_size) % step_size > 0
|
| 305 |
+
|
| 306 |
+
if has_last_chunk:
|
| 307 |
+
# pad last chunk with zeros
|
| 308 |
+
last_chunk: np.ndarray = waveform[:, num_chunks * step_size :]
|
| 309 |
+
_, last_window_size = last_chunk.shape
|
| 310 |
+
last_pad = window_size - last_window_size
|
| 311 |
+
last_chunk = np.pad(last_chunk, ((0, 0), (0, last_pad)))
|
| 312 |
+
|
| 313 |
+
def __empty_list(**kwargs):
|
| 314 |
+
return list()
|
| 315 |
+
|
| 316 |
+
outputs: Union[
|
| 317 |
+
List[np.ndarray], Tuple[List[np.ndarray]]
|
| 318 |
+
] = map_with_specifications(self.specifications, __empty_list)
|
| 319 |
+
|
| 320 |
+
if hook is not None:
|
| 321 |
+
hook(completed=0, total=num_chunks + has_last_chunk)
|
| 322 |
+
|
| 323 |
+
def __append_batch(output, batch_output, **kwargs) -> None:
|
| 324 |
+
output.append(batch_output)
|
| 325 |
+
return
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# slide over audio chunks in batch
|
| 329 |
+
for c in np.arange(0, num_chunks, self.batch_size):
|
| 330 |
+
batch: np.ndarray = chunks[c : c + self.batch_size]
|
| 331 |
+
batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch)
|
| 332 |
+
|
| 333 |
+
_ = map_with_specifications(
|
| 334 |
+
self.specifications, __append_batch, outputs, batch_outputs
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if hook is not None:
|
| 338 |
+
hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# process orphan last chunk
|
| 342 |
+
if has_last_chunk:
|
| 343 |
+
last_outputs = self.infer(last_chunk[None])
|
| 344 |
+
|
| 345 |
+
_ = map_with_specifications(
|
| 346 |
+
self.specifications, __append_batch, outputs, last_outputs
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if hook is not None:
|
| 350 |
+
hook(
|
| 351 |
+
completed=num_chunks + has_last_chunk,
|
| 352 |
+
total=num_chunks + has_last_chunk,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray:
|
| 356 |
+
return np.vstack(output)
|
| 357 |
+
|
| 358 |
+
outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications(
|
| 359 |
+
self.specifications, __vstack, outputs
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def __aggregate(
|
| 363 |
+
outputs: np.ndarray,
|
| 364 |
+
frames: SlidingWindow,
|
| 365 |
+
specifications: Optional[Specifications] = None,
|
| 366 |
+
) -> SlidingWindowFeature:
|
| 367 |
+
# skip aggregation when requested,
|
| 368 |
+
# or when model outputs just one vector per chunk
|
| 369 |
+
# or when model is permutation-invariant (and not post-processed)
|
| 370 |
+
|
| 371 |
+
if (
|
| 372 |
+
self.skip_aggregation
|
| 373 |
+
or specifications.resolution == Resolution.CHUNK
|
| 374 |
+
or (
|
| 375 |
+
specifications.permutation_invariant
|
| 376 |
+
and self.pre_aggregation_hook is None
|
| 377 |
+
)
|
| 378 |
+
):
|
| 379 |
+
frames = SlidingWindow(
|
| 380 |
+
start=0.0, duration=self.duration, step=self.step
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
return SlidingWindowFeature(outputs, frames)
|
| 384 |
+
|
| 385 |
+
return map_with_specifications(
|
| 386 |
+
self.specifications, __aggregate, outputs, frames
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def __call__(
|
| 390 |
+
self, file: AudioFile, hook: Optional[Callable] = None
|
| 391 |
+
) -> Union[
|
| 392 |
+
Tuple[Union[SlidingWindowFeature, np.ndarray]],
|
| 393 |
+
Union[SlidingWindowFeature, np.ndarray],
|
| 394 |
+
]:
|
| 395 |
+
"""Run inference on a whole file
|
| 396 |
+
|
| 397 |
+
Parameters
|
| 398 |
+
----------
|
| 399 |
+
file : AudioFile
|
| 400 |
+
Audio file.
|
| 401 |
+
hook : callable, optional
|
| 402 |
+
When a callable is provided, it is called everytime a batch is processed
|
| 403 |
+
with two keyword arguments:
|
| 404 |
+
- `completed`: the number of chunks that have been processed so far
|
| 405 |
+
- `total`: the total number of chunks
|
| 406 |
+
|
| 407 |
+
Returns
|
| 408 |
+
-------
|
| 409 |
+
output : (tuple of) SlidingWindowFeature or np.ndarray
|
| 410 |
+
Model output, as `SlidingWindowFeature` if `window` is set to "sliding"
|
| 411 |
+
and `np.ndarray` if is set to "whole".
|
| 412 |
+
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
waveform, sample_rate = self.audio(file)
|
| 416 |
+
|
| 417 |
+
if self.window == "sliding":
|
| 418 |
+
return self.slide(waveform, sample_rate, hook=hook)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@staticmethod
|
| 422 |
+
def aggregate(
|
| 423 |
+
scores: SlidingWindowFeature,
|
| 424 |
+
frames: SlidingWindow = None,
|
| 425 |
+
warm_up: Tuple[float, float] = (0.0, 0.0),
|
| 426 |
+
epsilon: float = 1e-12,
|
| 427 |
+
hamming: bool = False,
|
| 428 |
+
missing: float = np.NaN,
|
| 429 |
+
skip_average: bool = False,
|
| 430 |
+
) -> SlidingWindowFeature:
|
| 431 |
+
"""Aggregation
|
| 432 |
+
|
| 433 |
+
Parameters
|
| 434 |
+
----------
|
| 435 |
+
scores : SlidingWindowFeature
|
| 436 |
+
Raw (unaggregated) scores. Shape is (num_chunks, num_frames_per_chunk, num_classes).
|
| 437 |
+
frames : SlidingWindow, optional
|
| 438 |
+
Frames resolution. Defaults to estimate it automatically based on `scores` shape
|
| 439 |
+
and chunk size. Providing the exact frame resolution (when known) leads to better
|
| 440 |
+
temporal precision.
|
| 441 |
+
warm_up : (float, float) tuple, optional
|
| 442 |
+
Left/right warm up duration (in seconds).
|
| 443 |
+
missing : float, optional
|
| 444 |
+
Value used to replace missing (ie all NaNs) values.
|
| 445 |
+
skip_average : bool, optional
|
| 446 |
+
Skip final averaging step.
|
| 447 |
+
|
| 448 |
+
Returns
|
| 449 |
+
-------
|
| 450 |
+
aggregated_scores : SlidingWindowFeature
|
| 451 |
+
Aggregated scores. Shape is (num_frames, num_classes)
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
num_chunks, num_frames_per_chunk, num_classes = scores.data.shape
|
| 455 |
+
|
| 456 |
+
chunks = scores.sliding_window
|
| 457 |
+
if frames is None:
|
| 458 |
+
duration = step = chunks.duration / num_frames_per_chunk
|
| 459 |
+
frames = SlidingWindow(start=chunks.start, duration=duration, step=step)
|
| 460 |
+
else:
|
| 461 |
+
frames = SlidingWindow(
|
| 462 |
+
start=chunks.start,
|
| 463 |
+
duration=frames.duration,
|
| 464 |
+
step=frames.step,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
masks = 1 - np.isnan(scores)
|
| 468 |
+
scores.data = np.nan_to_num(scores.data, copy=True, nan=0.0)
|
| 469 |
+
|
| 470 |
+
# Hamming window used for overlap-add aggregation
|
| 471 |
+
hamming_window = (
|
| 472 |
+
np.hamming(num_frames_per_chunk).reshape(-1, 1)
|
| 473 |
+
if hamming
|
| 474 |
+
else np.ones((num_frames_per_chunk, 1))
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# anything before warm_up_left (and after num_frames_per_chunk - warm_up_right)
|
| 478 |
+
# will not be used in the final aggregation
|
| 479 |
+
|
| 480 |
+
# warm-up windows used for overlap-add aggregation
|
| 481 |
+
warm_up_window = np.ones((num_frames_per_chunk, 1))
|
| 482 |
+
# anything before warm_up_left will not contribute to aggregation
|
| 483 |
+
warm_up_left = round(
|
| 484 |
+
warm_up[0] / scores.sliding_window.duration * num_frames_per_chunk
|
| 485 |
+
)
|
| 486 |
+
warm_up_window[:warm_up_left] = epsilon
|
| 487 |
+
# anything after num_frames_per_chunk - warm_up_right either
|
| 488 |
+
warm_up_right = round(
|
| 489 |
+
warm_up[1] / scores.sliding_window.duration * num_frames_per_chunk
|
| 490 |
+
)
|
| 491 |
+
warm_up_window[num_frames_per_chunk - warm_up_right :] = epsilon
|
| 492 |
+
|
| 493 |
+
# aggregated_output[i] will be used to store the sum of all predictions
|
| 494 |
+
# for frame #i
|
| 495 |
+
num_frames = (
|
| 496 |
+
frames.closest_frame(
|
| 497 |
+
scores.sliding_window.start
|
| 498 |
+
+ scores.sliding_window.duration
|
| 499 |
+
+ (num_chunks - 1) * scores.sliding_window.step
|
| 500 |
+
)
|
| 501 |
+
+ 1
|
| 502 |
+
)
|
| 503 |
+
aggregated_output: np.ndarray = np.zeros(
|
| 504 |
+
(num_frames, num_classes), dtype=np.float32
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# overlapping_chunk_count[i] will be used to store the number of chunks
|
| 508 |
+
# that contributed to frame #i
|
| 509 |
+
overlapping_chunk_count: np.ndarray = np.zeros(
|
| 510 |
+
(num_frames, num_classes), dtype=np.float32
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# aggregated_mask[i] will be used to indicate whether
|
| 514 |
+
# at least one non-NAN frame contributed to frame #i
|
| 515 |
+
aggregated_mask: np.ndarray = np.zeros(
|
| 516 |
+
(num_frames, num_classes), dtype=np.float32
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# loop on the scores of sliding chunks
|
| 520 |
+
for (chunk, score), (_, mask) in zip(scores, masks):
|
| 521 |
+
# chunk ~ Segment
|
| 522 |
+
# score ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray
|
| 523 |
+
# mask ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray
|
| 524 |
+
|
| 525 |
+
start_frame = frames.closest_frame(chunk.start)
|
| 526 |
+
aggregated_output[start_frame : start_frame + num_frames_per_chunk] += (
|
| 527 |
+
score * mask * hamming_window * warm_up_window
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
overlapping_chunk_count[
|
| 531 |
+
start_frame : start_frame + num_frames_per_chunk
|
| 532 |
+
] += (mask * hamming_window * warm_up_window)
|
| 533 |
+
|
| 534 |
+
aggregated_mask[
|
| 535 |
+
start_frame : start_frame + num_frames_per_chunk
|
| 536 |
+
] = np.maximum(
|
| 537 |
+
aggregated_mask[start_frame : start_frame + num_frames_per_chunk],
|
| 538 |
+
mask,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if skip_average:
|
| 542 |
+
average = aggregated_output
|
| 543 |
+
else:
|
| 544 |
+
average = aggregated_output / np.maximum(overlapping_chunk_count, epsilon)
|
| 545 |
+
|
| 546 |
+
average[aggregated_mask == 0.0] = missing
|
| 547 |
+
|
| 548 |
+
return SlidingWindowFeature(average, frames)
|
| 549 |
+
|
| 550 |
+
@staticmethod
|
| 551 |
+
def trim(
|
| 552 |
+
scores: SlidingWindowFeature,
|
| 553 |
+
warm_up: Tuple[float, float] = (0.1, 0.1),
|
| 554 |
+
) -> SlidingWindowFeature:
|
| 555 |
+
"""Trim left and right warm-up regions
|
| 556 |
+
|
| 557 |
+
Parameters
|
| 558 |
+
----------
|
| 559 |
+
scores : SlidingWindowFeature
|
| 560 |
+
(num_chunks, num_frames, num_classes)-shaped scores.
|
| 561 |
+
warm_up : (float, float) tuple
|
| 562 |
+
Left/right warm up ratio of chunk duration.
|
| 563 |
+
Defaults to (0.1, 0.1), i.e. 10% on both sides.
|
| 564 |
+
|
| 565 |
+
Returns
|
| 566 |
+
-------
|
| 567 |
+
trimmed : SlidingWindowFeature
|
| 568 |
+
(num_chunks, trimmed_num_frames, num_speakers)-shaped scores
|
| 569 |
+
"""
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
assert (
|
| 573 |
+
scores.data.ndim == 3
|
| 574 |
+
), "Inference.trim expects (num_chunks, num_frames, num_classes)-shaped `scores`"
|
| 575 |
+
_, num_frames, _ = scores.data.shape
|
| 576 |
+
|
| 577 |
+
chunks = scores.sliding_window
|
| 578 |
+
|
| 579 |
+
num_frames_left = round(num_frames * warm_up[0])
|
| 580 |
+
num_frames_right = round(num_frames * warm_up[1])
|
| 581 |
+
|
| 582 |
+
num_frames_step = round(num_frames * chunks.step / chunks.duration)
|
| 583 |
+
if num_frames - num_frames_left - num_frames_right < num_frames_step:
|
| 584 |
+
warnings.warn(
|
| 585 |
+
f"Total `warm_up` is so large ({sum(warm_up) * 100:g}% of each chunk) "
|
| 586 |
+
f"that resulting trimmed scores does not cover a whole step ({chunks.step:g}s)"
|
| 587 |
+
)
|
| 588 |
+
new_data = scores.data[:, num_frames_left : num_frames - num_frames_right]
|
| 589 |
+
|
| 590 |
+
new_chunks = SlidingWindow(
|
| 591 |
+
start=chunks.start + warm_up[0] * chunks.duration,
|
| 592 |
+
step=chunks.step,
|
| 593 |
+
duration=(1 - warm_up[0] - warm_up[1]) * chunks.duration,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
return SlidingWindowFeature(new_data, new_chunks)
|
ailia-models/code/pyannote_audio_utils/audio/core/io.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2020- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
# Audio IO
|
| 25 |
+
|
| 26 |
+
pyannote.audio relies on torchaudio for reading and resampling.
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import math
|
| 31 |
+
import random
|
| 32 |
+
import warnings
|
| 33 |
+
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
from typing import Mapping, Optional, Text, Tuple, Union
|
| 36 |
+
import numpy as np
|
| 37 |
+
|
| 38 |
+
from pyannote_audio_utils.core import Segment
|
| 39 |
+
|
| 40 |
+
import ailia.audio
|
| 41 |
+
import soundfile
|
| 42 |
+
|
| 43 |
+
AudioFile = Union[Text, Path, Mapping]
|
| 44 |
+
|
| 45 |
+
AudioFileDocString = """
|
| 46 |
+
Audio files can be provided to the Audio class using different types:
|
| 47 |
+
- a "str" or "Path" instance: "audio.wav" or Path("audio.wav")
|
| 48 |
+
- a "IOBase" instance with "read" and "seek" support: open("audio.wav", "rb")
|
| 49 |
+
- a "Mapping" with any of the above as "audio" key: {"audio": ...}
|
| 50 |
+
- a "Mapping" with both "waveform" and "sample_rate" key:
|
| 51 |
+
{"waveform": (channel, time) numpy.ndarray or torch.Tensor, "sample_rate": 44100}
|
| 52 |
+
|
| 53 |
+
For last two options, an additional "channel" key can be provided as a zero-indexed
|
| 54 |
+
integer to load a specific channel: {"audio": "stereo.wav", "channel": 0}
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Audio:
|
| 59 |
+
"""Audio IO
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
sample_rate: int, optional
|
| 64 |
+
Target sampling rate. Defaults to using native sampling rate.
|
| 65 |
+
mono : {'random', 'downmix'}, optional
|
| 66 |
+
In case of multi-channel audio, convert to single-channel audio
|
| 67 |
+
using one of the following strategies: select one channel at
|
| 68 |
+
'random' or 'downmix' by averaging all channels.
|
| 69 |
+
|
| 70 |
+
Usage
|
| 71 |
+
-----
|
| 72 |
+
>>> audio = Audio(sample_rate=16000, mono='downmix')
|
| 73 |
+
>>> waveform, sample_rate = audio({"audio": "/path/to/audio.wav"})
|
| 74 |
+
>>> assert sample_rate == 16000
|
| 75 |
+
>>> sample_rate = 44100
|
| 76 |
+
>>> two_seconds_stereo = torch.rand(2, 2 * sample_rate)
|
| 77 |
+
>>> waveform, sample_rate = audio({"waveform": two_seconds_stereo, "sample_rate": sample_rate})
|
| 78 |
+
>>> assert sample_rate == 16000
|
| 79 |
+
>>> assert waveform.shape[0] == 1
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
PRECISION = 0.001
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def validate_file(file: AudioFile) -> Mapping:
|
| 87 |
+
"""Validate file for use with the other Audio methods
|
| 88 |
+
|
| 89 |
+
Parameter
|
| 90 |
+
---------
|
| 91 |
+
file: AudioFile
|
| 92 |
+
|
| 93 |
+
Returns
|
| 94 |
+
-------
|
| 95 |
+
validated_file : Mapping
|
| 96 |
+
{"audio": str, "uri": str, ...}
|
| 97 |
+
{"waveform": array or tensor, "sample_rate": int, "uri": str, ...}
|
| 98 |
+
{"audio": file, "uri": "stream"} if `file` is an IOBase instance
|
| 99 |
+
|
| 100 |
+
Raises
|
| 101 |
+
------
|
| 102 |
+
ValueError if file format is not valid or file does not exist.
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if isinstance(file, Mapping):
|
| 108 |
+
pass
|
| 109 |
+
elif isinstance(file, (str, Path)):
|
| 110 |
+
file = {"audio": str(file), "uri": Path(file).stem}
|
| 111 |
+
|
| 112 |
+
# elif isinstance(file, IOBase):
|
| 113 |
+
# return {"audio": file, "uri": "stream"}
|
| 114 |
+
|
| 115 |
+
# else:
|
| 116 |
+
# raise ValueError(AudioFileDocString)
|
| 117 |
+
|
| 118 |
+
if "waveform" in file:
|
| 119 |
+
waveform: np.ndarray = file["waveform"]
|
| 120 |
+
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"'waveform' must be provided as a (channel, time) torch Tensor."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
sample_rate: int = file.get("sample_rate", None)
|
| 126 |
+
if sample_rate is None:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"'waveform' must be provided with their 'sample_rate'."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
file.setdefault("uri", "waveform")
|
| 132 |
+
|
| 133 |
+
elif "audio" in file:
|
| 134 |
+
# if isinstance(file["audio"], IOBase):
|
| 135 |
+
# return file
|
| 136 |
+
|
| 137 |
+
path = Path(file["audio"])
|
| 138 |
+
if not path.is_file():
|
| 139 |
+
raise ValueError(f"File {path} does not exist")
|
| 140 |
+
|
| 141 |
+
file.setdefault("uri", path.stem)
|
| 142 |
+
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Neither 'waveform' nor 'audio' is available for this file."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return file
|
| 149 |
+
|
| 150 |
+
def __init__(self, sample_rate=None, mono=None):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.sample_rate = sample_rate
|
| 153 |
+
self.mono = mono
|
| 154 |
+
|
| 155 |
+
def downmix_and_resample(self, waveform: np.ndarray, sample_rate: int) -> np.ndarray:
|
| 156 |
+
"""Downmix and resample
|
| 157 |
+
|
| 158 |
+
Parameters
|
| 159 |
+
----------
|
| 160 |
+
waveform : (channel, time) Tensor
|
| 161 |
+
Waveform.
|
| 162 |
+
sample_rate : int
|
| 163 |
+
Sample rate.
|
| 164 |
+
|
| 165 |
+
Returns
|
| 166 |
+
-------
|
| 167 |
+
waveform : (channel, time) Tensor
|
| 168 |
+
Remixed and resampled waveform
|
| 169 |
+
sample_rate : int
|
| 170 |
+
New sample rate
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
# downmix to mono
|
| 174 |
+
|
| 175 |
+
num_channels = waveform.shape[0]
|
| 176 |
+
if num_channels > 1:
|
| 177 |
+
if self.mono == "random":
|
| 178 |
+
channel = random.randint(0, num_channels - 1)
|
| 179 |
+
waveform = waveform[channel : channel + 1]
|
| 180 |
+
elif self.mono == "downmix":
|
| 181 |
+
waveform = np.mean(waveform, axis=0, keepdims=True)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
######## ここでずれる ##########
|
| 186 |
+
if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
|
| 187 |
+
waveform = ailia.audio.resample(
|
| 188 |
+
waveform, org_sr=sample_rate, target_sr=self.sample_rate)
|
| 189 |
+
|
| 190 |
+
sample_rate = self.sample_rate
|
| 191 |
+
|
| 192 |
+
return waveform, sample_rate
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_num_samples(
|
| 196 |
+
self, duration: float, sample_rate: Optional[int] = None
|
| 197 |
+
) -> int:
|
| 198 |
+
"""Deterministic number of samples from duration and sample rate"""
|
| 199 |
+
|
| 200 |
+
sample_rate = sample_rate or self.sample_rate
|
| 201 |
+
|
| 202 |
+
if sample_rate is None:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
"`sample_rate` must be provided to compute number of samples."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return math.floor(duration * sample_rate)
|
| 208 |
+
|
| 209 |
+
def __call__(self, file: AudioFile) -> Tuple[np.ndarray, int]:
|
| 210 |
+
"""Obtain waveform
|
| 211 |
+
|
| 212 |
+
Parameters
|
| 213 |
+
----------
|
| 214 |
+
file : AudioFile
|
| 215 |
+
|
| 216 |
+
Returns
|
| 217 |
+
-------
|
| 218 |
+
waveform : (channel, time) torch.Tensor
|
| 219 |
+
Waveform
|
| 220 |
+
sample_rate : int
|
| 221 |
+
Sample rate
|
| 222 |
+
|
| 223 |
+
See also
|
| 224 |
+
--------
|
| 225 |
+
AudioFile
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
file = self.validate_file(file)
|
| 229 |
+
|
| 230 |
+
if "waveform" in file:
|
| 231 |
+
waveform = file["waveform"]
|
| 232 |
+
sample_rate = file["sample_rate"]
|
| 233 |
+
|
| 234 |
+
waveform, sample_rate = soundfile.read(file["audio"])
|
| 235 |
+
|
| 236 |
+
if waveform.ndim == 1:
|
| 237 |
+
waveform = np.expand_dims(waveform,axis=0)
|
| 238 |
+
else:
|
| 239 |
+
waveform = waveform.T
|
| 240 |
+
|
| 241 |
+
channel = file.get("channel", None)
|
| 242 |
+
|
| 243 |
+
if channel is not None:
|
| 244 |
+
waveform = waveform[channel : channel + 1]
|
| 245 |
+
|
| 246 |
+
return self.downmix_and_resample(waveform, sample_rate)
|
| 247 |
+
|
| 248 |
+
def crop(
|
| 249 |
+
self,
|
| 250 |
+
file: AudioFile,
|
| 251 |
+
segment: Segment,
|
| 252 |
+
duration: Optional[float] = None,
|
| 253 |
+
mode="raise",
|
| 254 |
+
) -> Tuple[np.ndarray, int]:
|
| 255 |
+
"""Fast version of self(file).crop(segment, **kwargs)
|
| 256 |
+
|
| 257 |
+
Parameters
|
| 258 |
+
----------
|
| 259 |
+
file : AudioFile
|
| 260 |
+
Audio file.
|
| 261 |
+
segment : `pyannote.core.Segment`
|
| 262 |
+
Temporal segment to load.
|
| 263 |
+
duration : float, optional
|
| 264 |
+
Overrides `Segment` 'focus' duration and ensures that the number of
|
| 265 |
+
returned frames is fixed (which might otherwise not be the case
|
| 266 |
+
because of rounding errors).
|
| 267 |
+
mode : {'raise', 'pad'}, optional
|
| 268 |
+
Specifies how out-of-bounds segments will behave.
|
| 269 |
+
* 'raise' -- raise an error (default)
|
| 270 |
+
* 'pad' -- zero pad
|
| 271 |
+
|
| 272 |
+
Returns
|
| 273 |
+
-------
|
| 274 |
+
waveform : (channel, time) torch.Tensor
|
| 275 |
+
Waveform
|
| 276 |
+
sample_rate : int
|
| 277 |
+
Sample rate
|
| 278 |
+
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
file = self.validate_file(file)
|
| 282 |
+
|
| 283 |
+
if "waveform" in file:
|
| 284 |
+
waveform = file["waveform"]
|
| 285 |
+
frames = waveform.shape[1]
|
| 286 |
+
sample_rate = file["sample_rate"]
|
| 287 |
+
|
| 288 |
+
elif "torchaudio.info" in file:
|
| 289 |
+
info = file["torchaudio.info"]
|
| 290 |
+
frames = info.num_frames
|
| 291 |
+
sample_rate = info.sample_rate
|
| 292 |
+
|
| 293 |
+
else:
|
| 294 |
+
info = soundfile.read(file["audio"])
|
| 295 |
+
frames = info[0].shape[0]
|
| 296 |
+
sample_rate = info[1]
|
| 297 |
+
|
| 298 |
+
channel = file.get("channel", None)
|
| 299 |
+
|
| 300 |
+
# infer which samples to load from sample rate and requested chunk
|
| 301 |
+
start_frame = math.floor(segment.start * sample_rate)
|
| 302 |
+
|
| 303 |
+
if duration:
|
| 304 |
+
num_frames = math.floor(duration * sample_rate)
|
| 305 |
+
end_frame = start_frame + num_frames
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
end_frame = math.floor(segment.end * sample_rate)
|
| 309 |
+
num_frames = end_frame - start_frame
|
| 310 |
+
|
| 311 |
+
if mode == "pad":
|
| 312 |
+
pad_start = -min(0, start_frame)
|
| 313 |
+
pad_end = max(end_frame, frames) - frames
|
| 314 |
+
start_frame = max(0, start_frame)
|
| 315 |
+
end_frame = min(end_frame, frames)
|
| 316 |
+
num_frames = end_frame - start_frame
|
| 317 |
+
|
| 318 |
+
if "waveform" in file:
|
| 319 |
+
data = file["waveform"][:, start_frame:end_frame]
|
| 320 |
+
|
| 321 |
+
else:
|
| 322 |
+
try:
|
| 323 |
+
data, _ = soundfile.read(file["audio"], start=start_frame, frames=num_frames)
|
| 324 |
+
if data.ndim == 1:
|
| 325 |
+
data = np.expand_dims(data, axis=0)
|
| 326 |
+
else:
|
| 327 |
+
data = data.T
|
| 328 |
+
|
| 329 |
+
except RuntimeError:
|
| 330 |
+
msg = (
|
| 331 |
+
f"torchaudio failed to seek-and-read in {file['audio']}: "
|
| 332 |
+
f"loading the whole file instead."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
warnings.warn(msg)
|
| 336 |
+
waveform, sample_rate = self.__call__(file)
|
| 337 |
+
data = waveform[:, start_frame:end_frame]
|
| 338 |
+
|
| 339 |
+
# storing waveform and sample_rate for next time
|
| 340 |
+
# as it is very likely that seek-and-read will
|
| 341 |
+
# fail again for this particular file
|
| 342 |
+
file["waveform"] = waveform
|
| 343 |
+
file["sample_rate"] = sample_rate
|
| 344 |
+
|
| 345 |
+
if channel is not None:
|
| 346 |
+
data = data[channel : channel + 1, :]
|
| 347 |
+
|
| 348 |
+
if mode == "pad":
|
| 349 |
+
data = np.pad(data, ((0, 0), (pad_start, pad_end)))
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
return self.downmix_and_resample(data, sample_rate)
|
ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2021 CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import warnings
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
from collections.abc import Iterator
|
| 27 |
+
from functools import partial
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Callable, Dict, List, Optional, Text, Union
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
import yaml
|
| 33 |
+
from importlib import import_module
|
| 34 |
+
from pyannote_audio_utils.database import ProtocolFile
|
| 35 |
+
from pyannote_audio_utils.pipeline import Pipeline as _Pipeline
|
| 36 |
+
|
| 37 |
+
from pyannote_audio_utils.audio import Audio, __version__
|
| 38 |
+
from pyannote_audio_utils.audio.core.inference import BaseInference
|
| 39 |
+
from pyannote_audio_utils.audio.core.io import AudioFile
|
| 40 |
+
|
| 41 |
+
PIPELINE_PARAMS_NAME = "config.yaml"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Pipeline(_Pipeline):
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(
|
| 47 |
+
cls,
|
| 48 |
+
checkpoint_path: Union[Text, Path],
|
| 49 |
+
hparams_file: Union[Text, Path] = None,
|
| 50 |
+
use_auth_token: Union[Text, None] = None,
|
| 51 |
+
) -> "Pipeline":
|
| 52 |
+
"""Load pretrained pipeline
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
checkpoint_path : Path or str
|
| 57 |
+
Path to pipeline checkpoint, or a remote URL,
|
| 58 |
+
or a pipeline identifier from the huggingface.co model hub.
|
| 59 |
+
hparams_file: Path or str, optional
|
| 60 |
+
use_auth_token : str, optional
|
| 61 |
+
When loading a private huggingface.co pipeline, set `use_auth_token`
|
| 62 |
+
to True or to a string containing your hugginface.co authentication
|
| 63 |
+
token that can be obtained by running `huggingface-cli login`
|
| 64 |
+
cache_dir: Path or str, optional
|
| 65 |
+
Path to model cache directory. Defauorch/pyannote_audio_utils" when unset.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
checkpoint_path = str(checkpoint_path)
|
| 69 |
+
config_yml = checkpoint_path
|
| 70 |
+
|
| 71 |
+
with open(config_yml, "r") as fp:
|
| 72 |
+
config = yaml.load(fp, Loader=yaml.SafeLoader)
|
| 73 |
+
|
| 74 |
+
# initialize pipeline
|
| 75 |
+
pipeline_name = config["pipeline"]["name"]
|
| 76 |
+
tokens = pipeline_name.split('.')
|
| 77 |
+
module_name = '.'.join(tokens[:-1])
|
| 78 |
+
class_name = tokens[-1]
|
| 79 |
+
Klass = getattr(import_module(module_name), class_name)
|
| 80 |
+
params = config["pipeline"].get("params", {})
|
| 81 |
+
pipeline = Klass(**params)
|
| 82 |
+
|
| 83 |
+
# freeze parameters
|
| 84 |
+
if "params" in config:
|
| 85 |
+
pipeline.instantiate(config["params"])
|
| 86 |
+
|
| 87 |
+
return pipeline
|
| 88 |
+
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self._models: Dict[str] = OrderedDict()
|
| 92 |
+
self._inferences: Dict[str, BaseInference] = OrderedDict()
|
| 93 |
+
|
| 94 |
+
def __getattr__(self, name):
|
| 95 |
+
"""(Advanced) attribute getter
|
| 96 |
+
|
| 97 |
+
Adds support for Model and Inference attributes,
|
| 98 |
+
which are iterated over by Pipeline.to() method.
|
| 99 |
+
|
| 100 |
+
See pyannote_audio_utils.pipeline.Pipeline.__getattr__.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
if "_models" in self.__dict__:
|
| 104 |
+
_models = self.__dict__["_models"]
|
| 105 |
+
if name in _models:
|
| 106 |
+
return _models[name]
|
| 107 |
+
|
| 108 |
+
if "_inferences" in self.__dict__:
|
| 109 |
+
_inferences = self.__dict__["_inferences"]
|
| 110 |
+
if name in _inferences:
|
| 111 |
+
return _inferences[name]
|
| 112 |
+
|
| 113 |
+
return super().__getattr__(name)
|
| 114 |
+
|
| 115 |
+
def __setattr__(self, name, value):
|
| 116 |
+
"""(Advanced) attribute setter
|
| 117 |
+
|
| 118 |
+
Adds support for Model and Inference attributes,
|
| 119 |
+
which are iterated over by Pipeline.to() method.
|
| 120 |
+
|
| 121 |
+
See pyannote_audio_utils.pipeline.Pipeline.__setattr__.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def remove_from(*dicts):
|
| 125 |
+
for d in dicts:
|
| 126 |
+
if name in d:
|
| 127 |
+
del d[name]
|
| 128 |
+
|
| 129 |
+
_parameters = self.__dict__.get("_parameters")
|
| 130 |
+
_instantiated = self.__dict__.get("_instantiated")
|
| 131 |
+
_pipelines = self.__dict__.get("_pipelines")
|
| 132 |
+
_models = self.__dict__.get("_models")
|
| 133 |
+
_inferences = self.__dict__.get("_inferences")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if isinstance(value, BaseInference):
|
| 138 |
+
if _inferences is None:
|
| 139 |
+
msg = "cannot assign inferences before Pipeline.__init__() call"
|
| 140 |
+
raise AttributeError(msg)
|
| 141 |
+
remove_from(self.__dict__, _models, _parameters, _instantiated, _pipelines)
|
| 142 |
+
_inferences[name] = value
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
super().__setattr__(name, value)
|
| 146 |
+
|
| 147 |
+
def __delattr__(self, name):
|
| 148 |
+
if name in self._models:
|
| 149 |
+
del self._models[name]
|
| 150 |
+
|
| 151 |
+
elif name in self._inferences:
|
| 152 |
+
del self._inferences[name]
|
| 153 |
+
|
| 154 |
+
else:
|
| 155 |
+
super().__delattr__(name)
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def setup_hook(file: AudioFile, hook: Optional[Callable] = None) -> Callable:
|
| 159 |
+
def noop(*args, **kwargs):
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
return partial(hook or noop, file=file)
|
| 163 |
+
|
| 164 |
+
def default_parameters(self):
|
| 165 |
+
raise NotImplementedError()
|
| 166 |
+
|
| 167 |
+
def classes(self) -> Union[List, Iterator]:
|
| 168 |
+
"""Classes returned by the pipeline
|
| 169 |
+
|
| 170 |
+
Returns
|
| 171 |
+
-------
|
| 172 |
+
classes : list of string or string iterator
|
| 173 |
+
Finite list of strings when classes are known in advance
|
| 174 |
+
(e.g. ["MALE", "FEMALE"] for gender classification), or
|
| 175 |
+
infinite string iterator when they depend on the file
|
| 176 |
+
(e.g. "SPEAKER_00", "SPEAKER_01", ... for speaker diarization)
|
| 177 |
+
|
| 178 |
+
Usage
|
| 179 |
+
-----
|
| 180 |
+
>>> from collections.abc import Iterator
|
| 181 |
+
>>> classes = pipeline.classes()
|
| 182 |
+
>>> if isinstance(classes, Iterator): # classes depend on the input file
|
| 183 |
+
>>> if isinstance(classes, list): # classes are known in advance
|
| 184 |
+
|
| 185 |
+
"""
|
| 186 |
+
raise NotImplementedError()
|
| 187 |
+
|
| 188 |
+
def __call__(self, file: AudioFile, **kwargs):
|
| 189 |
+
# breakpoint()
|
| 190 |
+
# fix_reproducibility(getattr(self, "device", torch.device("cpu")))
|
| 191 |
+
|
| 192 |
+
if not self.instantiated:
|
| 193 |
+
# instantiate with default parameters when available
|
| 194 |
+
try:
|
| 195 |
+
default_parameters = self.default_parameters()
|
| 196 |
+
except NotImplementedError:
|
| 197 |
+
raise RuntimeError(
|
| 198 |
+
"A pipeline must be instantiated with `pipeline.instantiate(parameters)` before it can be applied."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
self.instantiate(default_parameters)
|
| 203 |
+
except ValueError:
|
| 204 |
+
raise RuntimeError(
|
| 205 |
+
"A pipeline must be instantiated with `pipeline.instantiate(paramaters)` before it can be applied. "
|
| 206 |
+
"Tried to use parameters provided by `pipeline.default_parameters()` but those are not compatible. "
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
warnings.warn(
|
| 210 |
+
f"The pipeline has been automatically instantiated with {default_parameters}."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
file = Audio.validate_file(file)
|
| 214 |
+
|
| 215 |
+
if hasattr(self, "preprocessors"):
|
| 216 |
+
file = ProtocolFile(file, lazy=self.preprocessors)
|
| 217 |
+
|
| 218 |
+
return self.apply(file, **kwargs)
|
ailia-models/code/pyannote_audio_utils/audio/core/task.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2020- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
from dataclasses import dataclass
|
| 27 |
+
from enum import Enum
|
| 28 |
+
from functools import cached_property, partial
|
| 29 |
+
from typing import Dict, List, Literal, Optional, Sequence, Text, Tuple, Union
|
| 30 |
+
|
| 31 |
+
import scipy.special
|
| 32 |
+
|
| 33 |
+
from pyannote_audio_utils.database.protocol.protocol import Scope, Subset
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Subsets = list(Subset.__args__)
|
| 37 |
+
Scopes = list(Scope.__args__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Type of machine learning problem
|
| 41 |
+
class Problem(Enum):
|
| 42 |
+
BINARY_CLASSIFICATION = 0
|
| 43 |
+
MONO_LABEL_CLASSIFICATION = 1
|
| 44 |
+
MULTI_LABEL_CLASSIFICATION = 2
|
| 45 |
+
REPRESENTATION = 3
|
| 46 |
+
REGRESSION = 4
|
| 47 |
+
# any other we could think of?
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# A task takes an audio chunk as input and returns
|
| 51 |
+
# either a temporal sequence of predictions
|
| 52 |
+
# or just one prediction for the whole audio chunk
|
| 53 |
+
class Resolution(Enum):
|
| 54 |
+
FRAME = 1 # model outputs a sequence of frames
|
| 55 |
+
CHUNK = 2 # model outputs just one vector for the whole chunk
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class UnknownSpecificationsError(Exception):
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class Specifications:
|
| 64 |
+
problem: Problem
|
| 65 |
+
resolution: Resolution
|
| 66 |
+
|
| 67 |
+
# (maximum) chunk duration in seconds
|
| 68 |
+
# duration: float
|
| 69 |
+
duration: float = 10.0
|
| 70 |
+
|
| 71 |
+
# (for variable-duration tasks only) minimum chunk duration in seconds
|
| 72 |
+
min_duration: Optional[float] = None
|
| 73 |
+
|
| 74 |
+
# use that many seconds on the left- and rightmost parts of each chunk
|
| 75 |
+
# to warm up the model. This is mostly useful for segmentation tasks.
|
| 76 |
+
# While the model does process those left- and right-most parts, only
|
| 77 |
+
# the remaining central part of each chunk is used for computing the
|
| 78 |
+
# loss during training, and for aggregating scores during inference.
|
| 79 |
+
# Defaults to 0. (i.e. no warm-up).
|
| 80 |
+
warm_up: Optional[Tuple[float, float]] = (0.0, 0.0)
|
| 81 |
+
|
| 82 |
+
# (for classification tasks only) list of classes
|
| 83 |
+
classes: Optional[List[Text]] = None
|
| 84 |
+
# classes: Optional[List[Text]] = ['speaker#1', 'speaker#2', 'speaker#3']
|
| 85 |
+
|
| 86 |
+
# (for powerset only) max number of simultaneous classes
|
| 87 |
+
# (n choose k with k <= powerset_max_classes)
|
| 88 |
+
# powerset_max_classes: Optional[int] = None
|
| 89 |
+
powerset_max_classes: Optional[int] = 2
|
| 90 |
+
|
| 91 |
+
# whether classes are permutation-invariant (e.g. diarization)
|
| 92 |
+
# permutation_invariant: bool = False
|
| 93 |
+
permutation_invariant: bool = True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@cached_property
|
| 97 |
+
def powerset(self) -> bool:
|
| 98 |
+
if self.powerset_max_classes is None:
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
if self.problem != Problem.MONO_LABEL_CLASSIFICATION:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"`powerset_max_classes` only makes sense with multi-class classification problems."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return True
|
| 107 |
+
|
| 108 |
+
@cached_property
|
| 109 |
+
def num_powerset_classes(self) -> int:
|
| 110 |
+
# compute number of subsets of size at most "powerset_max_classes"
|
| 111 |
+
# e.g. with len(classes) = 3 and powerset_max_classes = 2:
|
| 112 |
+
# {}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}
|
| 113 |
+
return int(
|
| 114 |
+
sum(
|
| 115 |
+
scipy.special.binom(len(self.classes), i)
|
| 116 |
+
for i in range(0, self.powerset_max_classes + 1)
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def __len__(self):
|
| 121 |
+
return 1
|
| 122 |
+
|
| 123 |
+
def __iter__(self):
|
| 124 |
+
yield self
|
| 125 |
+
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2020-2022 CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# from .multilabel import MultiLabelSegmentation
|
| 24 |
+
# from .overlapped_speech_detection import OverlappedSpeechDetection
|
| 25 |
+
# from .resegmentation import Resegmentation
|
| 26 |
+
from .speaker_diarization import SpeakerDiarization
|
| 27 |
+
# from .voice_activity_detection import VoiceActivityDetection
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
# "VoiceActivityDetection",
|
| 31 |
+
# "OverlappedSpeechDetection",
|
| 32 |
+
"SpeakerDiarization",
|
| 33 |
+
# "Resegmentation",
|
| 34 |
+
# "MultiLabelSegmentation",
|
| 35 |
+
]
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The MIT License (MIT)
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2021- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in
|
| 13 |
+
# all copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
"""Clustering pipelines"""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
import random
|
| 27 |
+
from enum import Enum
|
| 28 |
+
from typing import Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
from pyannote_audio_utils.core import SlidingWindow, SlidingWindowFeature
|
| 32 |
+
from pyannote_audio_utils.pipeline import Pipeline
|
| 33 |
+
from pyannote_audio_utils.pipeline.parameter import Categorical, Integer, Uniform
|
| 34 |
+
from scipy.cluster.hierarchy import fcluster, linkage
|
| 35 |
+
from scipy.optimize import linear_sum_assignment
|
| 36 |
+
from scipy.spatial.distance import cdist
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BaseClustering(Pipeline):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
metric: str = "cosine",
|
| 43 |
+
max_num_embeddings: int = 1000,
|
| 44 |
+
constrained_assignment: bool = False,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.metric = metric
|
| 48 |
+
self.max_num_embeddings = max_num_embeddings
|
| 49 |
+
self.constrained_assignment = constrained_assignment
|
| 50 |
+
|
| 51 |
+
def set_num_clusters(
|
| 52 |
+
self,
|
| 53 |
+
num_embeddings: int,
|
| 54 |
+
num_clusters: Optional[int] = None,
|
| 55 |
+
min_clusters: Optional[int] = None,
|
| 56 |
+
max_clusters: Optional[int] = None,
|
| 57 |
+
):
|
| 58 |
+
min_clusters = num_clusters or min_clusters or 1
|
| 59 |
+
min_clusters = max(1, min(num_embeddings, min_clusters))
|
| 60 |
+
max_clusters = num_clusters or max_clusters or num_embeddings
|
| 61 |
+
max_clusters = max(1, min(num_embeddings, max_clusters))
|
| 62 |
+
|
| 63 |
+
if min_clusters > max_clusters:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
f"min_clusters must be smaller than (or equal to) max_clusters "
|
| 66 |
+
f"(here: min_clusters={min_clusters:g} and max_clusters={max_clusters:g})."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if min_clusters == max_clusters:
|
| 70 |
+
num_clusters = min_clusters
|
| 71 |
+
|
| 72 |
+
return num_clusters, min_clusters, max_clusters
|
| 73 |
+
|
| 74 |
+
def filter_embeddings(
|
| 75 |
+
self,
|
| 76 |
+
embeddings: np.ndarray,
|
| 77 |
+
segmentations: Optional[SlidingWindowFeature] = None,
|
| 78 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 79 |
+
"""Filter NaN embeddings and downsample embeddings
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
embeddings : (num_chunks, num_speakers, dimension) array
|
| 84 |
+
Sequence of embeddings.
|
| 85 |
+
segmentations : (num_chunks, num_frames, num_speakers) array
|
| 86 |
+
Binary segmentations.
|
| 87 |
+
|
| 88 |
+
Returns
|
| 89 |
+
-------
|
| 90 |
+
filtered_embeddings : (num_embeddings, dimension) array
|
| 91 |
+
chunk_idx : (num_embeddings, ) array
|
| 92 |
+
speaker_idx : (num_embeddings, ) array
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
# whether speaker is active
|
| 96 |
+
active = np.sum(segmentations.data, axis=1) > 0
|
| 97 |
+
# whether speaker embedding extraction went fine
|
| 98 |
+
valid = ~np.any(np.isnan(embeddings), axis=2)
|
| 99 |
+
|
| 100 |
+
# indices of embeddings that are both active and valid
|
| 101 |
+
chunk_idx, speaker_idx = np.where(active * valid)
|
| 102 |
+
|
| 103 |
+
# sample max_num_embeddings embeddings
|
| 104 |
+
num_embeddings = len(chunk_idx)
|
| 105 |
+
if num_embeddings > self.max_num_embeddings:
|
| 106 |
+
indices = list(range(num_embeddings))
|
| 107 |
+
random.shuffle(indices)
|
| 108 |
+
indices = sorted(indices[: self.max_num_embeddings])
|
| 109 |
+
chunk_idx = chunk_idx[indices]
|
| 110 |
+
speaker_idx = speaker_idx[indices]
|
| 111 |
+
|
| 112 |
+
return embeddings[chunk_idx, speaker_idx], chunk_idx, speaker_idx
|
| 113 |
+
|
| 114 |
+
def constrained_argmax(self, soft_clusters: np.ndarray) -> np.ndarray:
|
| 115 |
+
soft_clusters = np.nan_to_num(soft_clusters, nan=np.nanmin(soft_clusters))
|
| 116 |
+
num_chunks, num_speakers, num_clusters = soft_clusters.shape
|
| 117 |
+
# num_chunks, num_speakers, num_clusters
|
| 118 |
+
|
| 119 |
+
hard_clusters = -2 * np.ones((num_chunks, num_speakers), dtype=np.int8)
|
| 120 |
+
|
| 121 |
+
for c, cost in enumerate(soft_clusters):
|
| 122 |
+
speakers, clusters = linear_sum_assignment(cost, maximize=True)
|
| 123 |
+
for s, k in zip(speakers, clusters):
|
| 124 |
+
hard_clusters[c, s] = k
|
| 125 |
+
|
| 126 |
+
return hard_clusters
|
| 127 |
+
|
| 128 |
+
def assign_embeddings(
|
| 129 |
+
self,
|
| 130 |
+
embeddings: np.ndarray,
|
| 131 |
+
train_chunk_idx: np.ndarray,
|
| 132 |
+
train_speaker_idx: np.ndarray,
|
| 133 |
+
train_clusters: np.ndarray,
|
| 134 |
+
constrained: bool = False,
|
| 135 |
+
):
|
| 136 |
+
"""Assign embeddings to the closest centroid
|
| 137 |
+
|
| 138 |
+
Cluster centroids are computed as the average of the train embeddings
|
| 139 |
+
previously assigned to them.
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
embeddings : (num_chunks, num_speakers, dimension)-shaped array
|
| 144 |
+
Complete set of embeddings.
|
| 145 |
+
train_chunk_idx : (num_embeddings,)-shaped array
|
| 146 |
+
train_speaker_idx : (num_embeddings,)-shaped array
|
| 147 |
+
Indices of subset of embeddings used for "training".
|
| 148 |
+
train_clusters : (num_embedding,)-shaped array
|
| 149 |
+
Clusters of the above subset
|
| 150 |
+
constrained : bool, optional
|
| 151 |
+
Use constrained_argmax, instead of (default) argmax.
|
| 152 |
+
|
| 153 |
+
Returns
|
| 154 |
+
-------
|
| 155 |
+
soft_clusters : (num_chunks, num_speakers, num_clusters)-shaped array
|
| 156 |
+
hard_clusters : (num_chunks, num_speakers)-shaped array
|
| 157 |
+
centroids : (num_clusters, dimension)-shaped array
|
| 158 |
+
Clusters centroids
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
# TODO: option to add a new (dummy) cluster in case num_clusters < max(frame_speaker_count)
|
| 162 |
+
|
| 163 |
+
num_clusters = np.max(train_clusters) + 1
|
| 164 |
+
num_chunks, num_speakers, dimension = embeddings.shape
|
| 165 |
+
|
| 166 |
+
train_embeddings = embeddings[train_chunk_idx, train_speaker_idx]
|
| 167 |
+
|
| 168 |
+
centroids = np.vstack(
|
| 169 |
+
[
|
| 170 |
+
np.mean(train_embeddings[train_clusters == k], axis=0)
|
| 171 |
+
for k in range(num_clusters)
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
e2k_distance = cdist(
|
| 176 |
+
embeddings.reshape([-1, dimension]),
|
| 177 |
+
centroids,
|
| 178 |
+
metric=self.metric
|
| 179 |
+
).reshape([num_chunks, num_speakers, -1])
|
| 180 |
+
|
| 181 |
+
soft_clusters = 2 - e2k_distance
|
| 182 |
+
|
| 183 |
+
# assign each embedding to the cluster with the most similar centroid
|
| 184 |
+
if constrained:
|
| 185 |
+
hard_clusters = self.constrained_argmax(soft_clusters)
|
| 186 |
+
else:
|
| 187 |
+
hard_clusters = np.argmax(soft_clusters, axis=2)
|
| 188 |
+
|
| 189 |
+
# NOTE: train_embeddings might be reassigned to a different cluster
|
| 190 |
+
# in the process. based on experiments, this seems to lead to better
|
| 191 |
+
# results than sticking to the original assignment.
|
| 192 |
+
|
| 193 |
+
return hard_clusters, soft_clusters, centroids
|
| 194 |
+
|
| 195 |
+
def __call__(
|
| 196 |
+
self,
|
| 197 |
+
embeddings: np.ndarray,
|
| 198 |
+
segmentations: Optional[SlidingWindowFeature] = None,
|
| 199 |
+
num_clusters: Optional[int] = None,
|
| 200 |
+
min_clusters: Optional[int] = None,
|
| 201 |
+
max_clusters: Optional[int] = None,
|
| 202 |
+
**kwargs,
|
| 203 |
+
) -> np.ndarray:
|
| 204 |
+
"""Apply clustering
|
| 205 |
+
|
| 206 |
+
Parameters
|
| 207 |
+
----------
|
| 208 |
+
embeddings : (num_chunks, num_speakers, dimension) array
|
| 209 |
+
Sequence of embeddings.
|
| 210 |
+
segmentations : (num_chunks, num_frames, num_speakers) array
|
| 211 |
+
Binary segmentations.
|
| 212 |
+
num_clusters : int, optional
|
| 213 |
+
Number of clusters, when known. Default behavior is to use
|
| 214 |
+
internal threshold hyper-parameter to decide on the number
|
| 215 |
+
of clusters.
|
| 216 |
+
min_clusters : int, optional
|
| 217 |
+
Minimum number of clusters. Has no effect when `num_clusters` is provided.
|
| 218 |
+
max_clusters : int, optional
|
| 219 |
+
Maximum number of clusters. Has no effect when `num_clusters` is provided.
|
| 220 |
+
|
| 221 |
+
Returns
|
| 222 |
+
-------
|
| 223 |
+
hard_clusters : (num_chunks, num_speakers) array
|
| 224 |
+
Hard cluster assignment (hard_clusters[c, s] = k means that sth speaker
|
| 225 |
+
of cth chunk is assigned to kth cluster)
|
| 226 |
+
soft_clusters : (num_chunks, num_speakers, num_clusters) array
|
| 227 |
+
Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely
|
| 228 |
+
the sth speaker of cth chunk belongs to kth cluster)
|
| 229 |
+
centroids : (num_clusters, dimension) array
|
| 230 |
+
Centroid vectors of each cluster
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
train_embeddings, train_chunk_idx, train_speaker_idx = self.filter_embeddings(
|
| 234 |
+
embeddings,
|
| 235 |
+
segmentations=segmentations,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
num_embeddings, _ = train_embeddings.shape
|
| 239 |
+
|
| 240 |
+
num_clusters, min_clusters, max_clusters = self.set_num_clusters(
|
| 241 |
+
num_embeddings,
|
| 242 |
+
num_clusters=num_clusters,
|
| 243 |
+
min_clusters=min_clusters,
|
| 244 |
+
max_clusters=max_clusters,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if max_clusters < 2:
|
| 248 |
+
# do NOT apply clustering when min_clusters = max_clusters = 1
|
| 249 |
+
num_chunks, num_speakers, _ = embeddings.shape
|
| 250 |
+
hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8)
|
| 251 |
+
soft_clusters = np.ones((num_chunks, num_speakers, 1))
|
| 252 |
+
centroids = np.mean(train_embeddings, axis=0, keepdims=True)
|
| 253 |
+
return hard_clusters, soft_clusters, centroids
|
| 254 |
+
|
| 255 |
+
train_clusters = self.cluster(
|
| 256 |
+
train_embeddings,
|
| 257 |
+
min_clusters,
|
| 258 |
+
max_clusters,
|
| 259 |
+
num_clusters=num_clusters,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
hard_clusters, soft_clusters, centroids = self.assign_embeddings(
|
| 263 |
+
embeddings,
|
| 264 |
+
train_chunk_idx,
|
| 265 |
+
train_speaker_idx,
|
| 266 |
+
train_clusters,
|
| 267 |
+
constrained=self.constrained_assignment,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return hard_clusters, soft_clusters, centroids
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class AgglomerativeClustering(BaseClustering):
|
| 274 |
+
"""Agglomerative clustering
|
| 275 |
+
|
| 276 |
+
Parameters
|
| 277 |
+
----------
|
| 278 |
+
metric : {"cosine", "euclidean", ...}, optional
|
| 279 |
+
Distance metric to use. Defaults to "cosine".
|
| 280 |
+
|
| 281 |
+
Hyper-parameters
|
| 282 |
+
----------------
|
| 283 |
+
method : {"average", "centroid", "complete", "median", "single", "ward"}
|
| 284 |
+
Linkage method.
|
| 285 |
+
threshold : float in range [0.0, 2.0]
|
| 286 |
+
Clustering threshold.
|
| 287 |
+
min_cluster_size : int in range [1, 20]
|
| 288 |
+
Minimum cluster size
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
metric: str = "cosine",
|
| 294 |
+
max_num_embeddings: int = np.inf,
|
| 295 |
+
constrained_assignment: bool = False,
|
| 296 |
+
):
|
| 297 |
+
super().__init__(
|
| 298 |
+
metric=metric,
|
| 299 |
+
max_num_embeddings=max_num_embeddings,
|
| 300 |
+
constrained_assignment=constrained_assignment,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.threshold = Uniform(0.0, 2.0) # assume unit-normalized embeddings
|
| 304 |
+
self.method = Categorical(
|
| 305 |
+
["average", "centroid", "complete", "median", "single", "ward", "weighted"]
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# minimum cluster size
|
| 309 |
+
self.min_cluster_size = Integer(1, 20)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def cluster(
|
| 313 |
+
self,
|
| 314 |
+
embeddings: np.ndarray,
|
| 315 |
+
min_clusters: int,
|
| 316 |
+
max_clusters: int,
|
| 317 |
+
num_clusters: Optional[int] = None,
|
| 318 |
+
):
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
Parameters
|
| 322 |
+
----------
|
| 323 |
+
embeddings : (num_embeddings, dimension) array
|
| 324 |
+
Embeddings
|
| 325 |
+
min_clusters : int
|
| 326 |
+
Minimum number of clusters
|
| 327 |
+
max_clusters : int
|
| 328 |
+
Maximum number of clusters
|
| 329 |
+
num_clusters : int, optional
|
| 330 |
+
Actual number of clusters. Default behavior is to estimate it based
|
| 331 |
+
on values provided for `min_clusters`, `max_clusters`, and `threshold`.
|
| 332 |
+
|
| 333 |
+
Returns
|
| 334 |
+
-------
|
| 335 |
+
clusters : (num_embeddings, ) array
|
| 336 |
+
0-indexed cluster indices.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
num_embeddings, _ = embeddings.shape
|
| 340 |
+
|
| 341 |
+
# heuristic to reduce self.min_cluster_size when num_embeddings is very small
|
| 342 |
+
# (0.1 value is kind of arbitrary, though)
|
| 343 |
+
min_cluster_size = min(
|
| 344 |
+
self.min_cluster_size, max(1, round(0.1 * num_embeddings))
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# linkage function will complain when there is just one embedding to cluster
|
| 349 |
+
if num_embeddings == 1:
|
| 350 |
+
return np.zeros((1,), dtype=np.uint8)
|
| 351 |
+
|
| 352 |
+
# centroid, median, and Ward method only support "euclidean" metric
|
| 353 |
+
# therefore we unit-normalize embeddings to somehow make them "euclidean"
|
| 354 |
+
if self.metric == "cosine" and self.method in ["centroid", "median", "ward"]:
|
| 355 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 356 |
+
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)
|
| 357 |
+
dendrogram: np.ndarray = linkage(
|
| 358 |
+
embeddings, method=self.method, metric="euclidean"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# other methods work just fine with any metric
|
| 362 |
+
else:
|
| 363 |
+
dendrogram: np.ndarray = linkage(
|
| 364 |
+
embeddings, method=self.method, metric=self.metric
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# apply the predefined threshold
|
| 368 |
+
clusters = fcluster(dendrogram, self.threshold, criterion="distance") - 1
|
| 369 |
+
|
| 370 |
+
# split clusters into two categories based on their number of items:
|
| 371 |
+
# large clusters vs. small clusters
|
| 372 |
+
cluster_unique, cluster_counts = np.unique(
|
| 373 |
+
clusters,
|
| 374 |
+
return_counts=True,
|
| 375 |
+
)
|
| 376 |
+
large_clusters = cluster_unique[cluster_counts >= min_cluster_size]
|
| 377 |
+
num_large_clusters = len(large_clusters)
|
| 378 |
+
|
| 379 |
+
# force num_clusters to min_clusters in case the actual number is too small
|
| 380 |
+
if num_large_clusters < min_clusters:
|
| 381 |
+
num_clusters = min_clusters
|
| 382 |
+
|
| 383 |
+
# force num_clusters to max_clusters in case the actual number is too large
|
| 384 |
+
elif num_large_clusters > max_clusters:
|
| 385 |
+
num_clusters = max_clusters
|
| 386 |
+
|
| 387 |
+
# look for perfect candidate if necessary
|
| 388 |
+
if num_clusters is not None and num_large_clusters != num_clusters:
|
| 389 |
+
# switch stopping criterion from "inter-cluster distance" stopping to "iteration index"
|
| 390 |
+
_dendrogram = np.copy(dendrogram)
|
| 391 |
+
_dendrogram[:, 2] = np.arange(num_embeddings - 1)
|
| 392 |
+
|
| 393 |
+
best_iteration = num_embeddings - 1
|
| 394 |
+
best_num_large_clusters = 1
|
| 395 |
+
|
| 396 |
+
# traverse the dendrogram by going further and further away
|
| 397 |
+
# from the "optimal" threshold
|
| 398 |
+
|
| 399 |
+
for iteration in np.argsort(np.abs(dendrogram[:, 2] - self.threshold)):
|
| 400 |
+
# only consider iterations that might have resulted
|
| 401 |
+
# in changing the number of (large) clusters
|
| 402 |
+
new_cluster_size = _dendrogram[iteration, 3]
|
| 403 |
+
if new_cluster_size < min_cluster_size:
|
| 404 |
+
continue
|
| 405 |
+
|
| 406 |
+
# estimate number of large clusters at considered iteration
|
| 407 |
+
clusters = fcluster(_dendrogram, iteration, criterion="distance") - 1
|
| 408 |
+
cluster_unique, cluster_counts = np.unique(clusters, return_counts=True)
|
| 409 |
+
large_clusters = cluster_unique[cluster_counts >= min_cluster_size]
|
| 410 |
+
num_large_clusters = len(large_clusters)
|
| 411 |
+
|
| 412 |
+
# keep track of iteration that leads to the number of large clusters
|
| 413 |
+
# as close as possible to the target number of clusters.
|
| 414 |
+
if abs(num_large_clusters - num_clusters) < abs(
|
| 415 |
+
best_num_large_clusters - num_clusters
|
| 416 |
+
):
|
| 417 |
+
best_iteration = iteration
|
| 418 |
+
best_num_large_clusters = num_large_clusters
|
| 419 |
+
|
| 420 |
+
# stop traversing the dendrogram as soon as we found a good candidate
|
| 421 |
+
if num_large_clusters == num_clusters:
|
| 422 |
+
break
|
| 423 |
+
|
| 424 |
+
# re-apply best iteration in case we did not find a perfect candidate
|
| 425 |
+
if best_num_large_clusters != num_clusters:
|
| 426 |
+
clusters = (
|
| 427 |
+
fcluster(_dendrogram, best_iteration, criterion="distance") - 1
|
| 428 |
+
)
|
| 429 |
+
cluster_unique, cluster_counts = np.unique(clusters, return_counts=True)
|
| 430 |
+
large_clusters = cluster_unique[cluster_counts >= min_cluster_size]
|
| 431 |
+
num_large_clusters = len(large_clusters)
|
| 432 |
+
print(
|
| 433 |
+
f"Found only {num_large_clusters} clusters. Using a smaller value than {min_cluster_size} for `min_cluster_size` might help."
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if num_large_clusters == 0:
|
| 437 |
+
clusters[:] = 0
|
| 438 |
+
return clusters
|
| 439 |
+
|
| 440 |
+
small_clusters = cluster_unique[cluster_counts < min_cluster_size]
|
| 441 |
+
if len(small_clusters) == 0:
|
| 442 |
+
return clusters
|
| 443 |
+
|
| 444 |
+
# re-assign each small cluster to the most similar large cluster based on their respective centroids
|
| 445 |
+
large_centroids = np.vstack(
|
| 446 |
+
[
|
| 447 |
+
np.mean(embeddings[clusters == large_k], axis=0)
|
| 448 |
+
for large_k in large_clusters
|
| 449 |
+
]
|
| 450 |
+
)
|
| 451 |
+
small_centroids = np.vstack(
|
| 452 |
+
[
|
| 453 |
+
np.mean(embeddings[clusters == small_k], axis=0)
|
| 454 |
+
for small_k in small_clusters
|
| 455 |
+
]
|
| 456 |
+
)
|
| 457 |
+
centroids_cdist = cdist(large_centroids, small_centroids, metric=self.metric)
|
| 458 |
+
for small_k, large_k in enumerate(np.argmin(centroids_cdist, axis=0)):
|
| 459 |
+
clusters[clusters == small_clusters[small_k]] = large_clusters[large_k]
|
| 460 |
+
|
| 461 |
+
# re-number clusters from 0 to num_large_clusters
|
| 462 |
+
_, clusters = np.unique(clusters, return_inverse=True)
|
| 463 |
+
return clusters
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class Clustering(Enum):
|
| 467 |
+
AgglomerativeClustering = AgglomerativeClustering
|
| 468 |
+
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The MIT License (MIT)
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2021- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in
|
| 13 |
+
# all copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
"""Speaker diarization pipelines"""
|
| 24 |
+
|
| 25 |
+
import functools
|
| 26 |
+
import itertools
|
| 27 |
+
import math
|
| 28 |
+
import textwrap
|
| 29 |
+
import warnings
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
from typing import Callable, Optional, Text, Union, Mapping
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature
|
| 36 |
+
from pyannote_audio_utils.pipeline.parameter import ParamDict, Uniform
|
| 37 |
+
from pyannote_audio_utils.audio import Audio, Inference, Pipeline
|
| 38 |
+
from pyannote_audio_utils.audio.core.io import AudioFile
|
| 39 |
+
from pyannote_audio_utils.audio.pipelines.clustering import Clustering
|
| 40 |
+
from pyannote_audio_utils.audio.pipelines.speaker_verification import ONNXWeSpeakerPretrainedSpeakerEmbedding
|
| 41 |
+
from pyannote_audio_utils.audio.pipelines.utils import SpeakerDiarizationMixin
|
| 42 |
+
|
| 43 |
+
AudioFile = Union[Text, Path, Mapping]
|
| 44 |
+
PipelineModel = Union[Text, Mapping]
|
| 45 |
+
|
| 46 |
+
def batchify(iterable, batch_size: int = 32, fillvalue=None):
|
| 47 |
+
"""Batchify iterable"""
|
| 48 |
+
# batchify('ABCDEFG', 3) --> ['A', 'B', 'C'] ['D', 'E', 'F'] [G, ]
|
| 49 |
+
args = [iter(iterable)] * batch_size
|
| 50 |
+
return itertools.zip_longest(*args, fillvalue=fillvalue)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline):
|
| 54 |
+
"""Speaker diarization pipeline
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
segmentation : Model, str, or dict, optional
|
| 59 |
+
Pretrained segmentation model. Defaults to "pyannote_audio_utils/segmentation@2022.07".
|
| 60 |
+
See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format.
|
| 61 |
+
segmentation_step: float, optional
|
| 62 |
+
The segmentation model is applied on a window sliding over the whole audio file.
|
| 63 |
+
`segmentation_step` controls the step of this window, provided as a ratio of its
|
| 64 |
+
duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows).
|
| 65 |
+
embedding : Model, str, or dict, optional
|
| 66 |
+
Pretrained embedding model. Defaults to "pyannote_audio_utils/embedding@2022.07".
|
| 67 |
+
See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format.
|
| 68 |
+
embedding_exclude_overlap : bool, optional
|
| 69 |
+
Exclude overlapping speech regions when extracting embeddings.
|
| 70 |
+
Defaults (False) to use the whole speech.
|
| 71 |
+
clustering : str, optional
|
| 72 |
+
Clustering algorithm. See pyannote_audio_utils.audio.pipelines.clustering.Clustering
|
| 73 |
+
for available options. Defaults to "AgglomerativeClustering".
|
| 74 |
+
segmentation_batch_size : int, optional
|
| 75 |
+
Batch size used for speaker segmentation. Defaults to 1.
|
| 76 |
+
embedding_batch_size : int, optional
|
| 77 |
+
Batch size used for speaker embedding. Defaults to 1.
|
| 78 |
+
der_variant : dict, optional
|
| 79 |
+
Optimize for a variant of diarization error rate.
|
| 80 |
+
Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric`
|
| 81 |
+
when instantiating the metric: GreedyDiarizationErrorRate(**der_variant).
|
| 82 |
+
use_auth_token : str, optional
|
| 83 |
+
When loading private huggingface.co models, set `use_auth_token`
|
| 84 |
+
to True or to a string containing your hugginface.co authentication
|
| 85 |
+
token that can be obtained by running `huggingface-cli login`
|
| 86 |
+
|
| 87 |
+
Usage
|
| 88 |
+
-----
|
| 89 |
+
# perform (unconstrained) diarization
|
| 90 |
+
>>> diarization = pipeline("/path/to/audio.wav")
|
| 91 |
+
|
| 92 |
+
# perform diarization, targetting exactly 4 speakers
|
| 93 |
+
>>> diarization = pipeline("/path/to/audio.wav", num_speakers=4)
|
| 94 |
+
|
| 95 |
+
# perform diarization, with at least 2 speakers and at most 10 speakers
|
| 96 |
+
>>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10)
|
| 97 |
+
|
| 98 |
+
# perform diarization and get one representative embedding per speaker
|
| 99 |
+
>>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True)
|
| 100 |
+
>>> for s, speaker in enumerate(diarization.labels()):
|
| 101 |
+
... # embeddings[s] is the embedding of speaker `speaker`
|
| 102 |
+
|
| 103 |
+
Hyper-parameters
|
| 104 |
+
----------------
|
| 105 |
+
segmentation.threshold
|
| 106 |
+
segmentation.min_duration_off
|
| 107 |
+
clustering.???
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
segmentation: PipelineModel = "pyannote_audio_utils/segmentation@2022.07",
|
| 113 |
+
segmentation_step: float = 0.1,
|
| 114 |
+
embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e",
|
| 115 |
+
embedding_exclude_overlap: bool = False,
|
| 116 |
+
clustering: str = "AgglomerativeClustering",
|
| 117 |
+
embedding_batch_size: int = 1,
|
| 118 |
+
segmentation_batch_size: int = 1,
|
| 119 |
+
args = None,
|
| 120 |
+
seg_path = None,
|
| 121 |
+
emb_path = None,
|
| 122 |
+
der_variant: dict = None,
|
| 123 |
+
use_auth_token: Union[Text, None] = None,
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
model = segmentation
|
| 128 |
+
self.segmentation_step = segmentation_step
|
| 129 |
+
self.embedding = embedding
|
| 130 |
+
self.embedding_batch_size = embedding_batch_size
|
| 131 |
+
self.embedding_exclude_overlap = embedding_exclude_overlap
|
| 132 |
+
self.klustering = clustering
|
| 133 |
+
self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False}
|
| 134 |
+
|
| 135 |
+
segmentation_duration = 10.0
|
| 136 |
+
|
| 137 |
+
self._segmentation = Inference(
|
| 138 |
+
model,
|
| 139 |
+
duration=segmentation_duration,
|
| 140 |
+
step=self.segmentation_step * segmentation_duration,
|
| 141 |
+
skip_aggregation=True,
|
| 142 |
+
batch_size=segmentation_batch_size,
|
| 143 |
+
args=args,
|
| 144 |
+
seg_path=seg_path
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self._frames: SlidingWindow = self._segmentation.example_output.frames
|
| 148 |
+
|
| 149 |
+
self.segmentation = ParamDict(
|
| 150 |
+
min_duration_off=Uniform(0.0, 1.0),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self._embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(
|
| 154 |
+
self.embedding,
|
| 155 |
+
args=args,
|
| 156 |
+
emb_path=emb_path
|
| 157 |
+
)
|
| 158 |
+
self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix")
|
| 159 |
+
|
| 160 |
+
metric = self._embedding.metric
|
| 161 |
+
Klustering = Clustering[clustering]
|
| 162 |
+
|
| 163 |
+
self.clustering = Klustering.value(metric=metric)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_segmentations(self, file, hook=None) -> SlidingWindowFeature:
|
| 167 |
+
"""Apply segmentation model
|
| 168 |
+
|
| 169 |
+
Parameter
|
| 170 |
+
---------
|
| 171 |
+
file : AudioFile
|
| 172 |
+
hook : Optional[Callable]
|
| 173 |
+
|
| 174 |
+
Returns
|
| 175 |
+
-------
|
| 176 |
+
segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
if hook is not None:
|
| 180 |
+
hook = functools.partial(hook, "segmentation", None)
|
| 181 |
+
segmentations: SlidingWindowFeature = self._segmentation(file, hook=hook)
|
| 182 |
+
|
| 183 |
+
return segmentations
|
| 184 |
+
|
| 185 |
+
def get_embeddings(
|
| 186 |
+
self,
|
| 187 |
+
file,
|
| 188 |
+
binary_segmentations: SlidingWindowFeature,
|
| 189 |
+
exclude_overlap: bool = False,
|
| 190 |
+
hook: Optional[Callable] = None,
|
| 191 |
+
):
|
| 192 |
+
"""Extract embeddings for each (chunk, speaker) pair
|
| 193 |
+
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
file : AudioFile
|
| 197 |
+
binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
|
| 198 |
+
Binarized segmentation.
|
| 199 |
+
exclude_overlap : bool, optional
|
| 200 |
+
Exclude overlapping speech regions when extracting embeddings.
|
| 201 |
+
In case non-overlapping speech is too short, use the whole speech.
|
| 202 |
+
hook: Optional[Callable]
|
| 203 |
+
Called during embeddings after every batch to report the progress
|
| 204 |
+
|
| 205 |
+
Returns
|
| 206 |
+
-------
|
| 207 |
+
embeddings : (num_chunks, num_speakers, dimension) array
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
# when optimizing the hyper-parameters of this pipeline with frozen
|
| 211 |
+
# "segmentation.threshold", one can reuse the embeddings from the first trial,
|
| 212 |
+
# bringing a massive speed up to the optimization process (and hence allowing to use
|
| 213 |
+
# a larger search space).
|
| 214 |
+
|
| 215 |
+
duration = binary_segmentations.sliding_window.duration
|
| 216 |
+
num_chunks, num_frames, num_speakers = binary_segmentations.data.shape
|
| 217 |
+
|
| 218 |
+
if exclude_overlap:
|
| 219 |
+
|
| 220 |
+
# minimum number of samples needed to extract an embedding
|
| 221 |
+
# (a lower number of samples would result in an error)
|
| 222 |
+
min_num_samples = self._embedding.min_num_samples
|
| 223 |
+
|
| 224 |
+
# corresponding minimum number of frames
|
| 225 |
+
num_samples = duration * self._embedding.sample_rate
|
| 226 |
+
min_num_frames = math.ceil(num_frames * min_num_samples / num_samples)
|
| 227 |
+
|
| 228 |
+
# zero-out frames with overlapping speech
|
| 229 |
+
clean_frames = 1.0 * (
|
| 230 |
+
np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2
|
| 231 |
+
)
|
| 232 |
+
clean_segmentations = SlidingWindowFeature(
|
| 233 |
+
binary_segmentations.data * clean_frames,
|
| 234 |
+
binary_segmentations.sliding_window,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
min_num_frames = -1
|
| 239 |
+
clean_segmentations = SlidingWindowFeature(
|
| 240 |
+
binary_segmentations.data, binary_segmentations.sliding_window
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def iter_waveform_and_mask():
|
| 244 |
+
for (chunk, masks), (_, clean_masks) in zip(binary_segmentations, clean_segmentations):
|
| 245 |
+
# chunk: Segment(t, t + duration)
|
| 246 |
+
# masks: (num_frames, local_num_speakers) np.ndarray
|
| 247 |
+
|
| 248 |
+
waveform, _ = self._audio.crop(
|
| 249 |
+
file,
|
| 250 |
+
chunk,
|
| 251 |
+
duration=duration,
|
| 252 |
+
mode="pad",
|
| 253 |
+
)
|
| 254 |
+
# waveform: (1, num_samples) torch.Tensor
|
| 255 |
+
|
| 256 |
+
# mask may contain NaN (in case of partial stitching)
|
| 257 |
+
masks = np.nan_to_num(masks, nan=0.0).astype(np.float32)
|
| 258 |
+
clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32)
|
| 259 |
+
|
| 260 |
+
for mask, clean_mask in zip(masks.T, clean_masks.T):
|
| 261 |
+
# mask: (num_frames, ) np.ndarray
|
| 262 |
+
|
| 263 |
+
if np.sum(clean_mask) > min_num_frames:
|
| 264 |
+
used_mask = clean_mask
|
| 265 |
+
else:
|
| 266 |
+
used_mask = mask
|
| 267 |
+
|
| 268 |
+
# yield waveform[None], torch.from_numpy(used_mask)[None]
|
| 269 |
+
yield waveform[None], used_mask[None]
|
| 270 |
+
|
| 271 |
+
# w: (1, 1, num_samples) torch.Tensor
|
| 272 |
+
# m: (1, num_frames) torch.Tensor
|
| 273 |
+
|
| 274 |
+
batches = batchify(
|
| 275 |
+
iter_waveform_and_mask(),
|
| 276 |
+
batch_size=self.embedding_batch_size,
|
| 277 |
+
fillvalue=(None, None),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size)
|
| 282 |
+
|
| 283 |
+
embedding_batches = []
|
| 284 |
+
|
| 285 |
+
if hook is not None:
|
| 286 |
+
hook("embeddings", None, total=batch_count, completed=0)
|
| 287 |
+
|
| 288 |
+
for i, batch in enumerate(batches, 1):
|
| 289 |
+
waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch))
|
| 290 |
+
|
| 291 |
+
waveform_batch = np.vstack(waveforms)
|
| 292 |
+
# (batch_size, 1, num_samples) torch.Tensor
|
| 293 |
+
|
| 294 |
+
mask_batch = np.vstack(masks)
|
| 295 |
+
# (batch_size, num_frames) torch.Tensor
|
| 296 |
+
|
| 297 |
+
embedding_batch: np.ndarray = self._embedding(
|
| 298 |
+
waveform_batch, masks=mask_batch
|
| 299 |
+
)
|
| 300 |
+
# (batch_size, dimension) np.ndarray
|
| 301 |
+
|
| 302 |
+
embedding_batches.append(embedding_batch)
|
| 303 |
+
|
| 304 |
+
if hook is not None:
|
| 305 |
+
hook("embeddings", embedding_batch, total=batch_count, completed=i)
|
| 306 |
+
|
| 307 |
+
embedding_batches = np.vstack(embedding_batches)
|
| 308 |
+
embeddings = embedding_batches.reshape([num_chunks, -1 , embedding_batches.shape[-1]])
|
| 309 |
+
|
| 310 |
+
return embeddings
|
| 311 |
+
|
| 312 |
+
def reconstruct(
|
| 313 |
+
self,
|
| 314 |
+
segmentations: SlidingWindowFeature,
|
| 315 |
+
hard_clusters: np.ndarray,
|
| 316 |
+
count: SlidingWindowFeature,
|
| 317 |
+
) -> SlidingWindowFeature:
|
| 318 |
+
"""Build final discrete diarization out of clustered segmentation
|
| 319 |
+
|
| 320 |
+
Parameters
|
| 321 |
+
----------
|
| 322 |
+
segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
|
| 323 |
+
Raw speaker segmentation.
|
| 324 |
+
hard_clusters : (num_chunks, num_speakers) array
|
| 325 |
+
Output of clustering step.
|
| 326 |
+
count : (total_num_frames, 1) SlidingWindowFeature
|
| 327 |
+
Instantaneous number of active speakers.
|
| 328 |
+
|
| 329 |
+
Returns
|
| 330 |
+
-------
|
| 331 |
+
discrete_diarization : SlidingWindowFeature
|
| 332 |
+
Discrete (0s and 1s) diarization.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
num_chunks, num_frames, local_num_speakers = segmentations.data.shape
|
| 336 |
+
|
| 337 |
+
num_clusters = np.max(hard_clusters) + 1
|
| 338 |
+
clustered_segmentations = np.NAN * np.zeros(
|
| 339 |
+
(num_chunks, num_frames, num_clusters)
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
for c, (cluster, (chunk, segmentation)) in enumerate(
|
| 343 |
+
zip(hard_clusters, segmentations)
|
| 344 |
+
):
|
| 345 |
+
# cluster is (local_num_speakers, )-shaped
|
| 346 |
+
# segmentation is (num_frames, local_num_speakers)-shaped
|
| 347 |
+
for k in np.unique(cluster):
|
| 348 |
+
if k == -2:
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# TODO: can we do better than this max here?
|
| 352 |
+
clustered_segmentations[c, :, k] = np.max(
|
| 353 |
+
segmentation[:, cluster == k], axis=1
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
clustered_segmentations = SlidingWindowFeature(
|
| 357 |
+
clustered_segmentations, segmentations.sliding_window
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
return self.to_diarization(clustered_segmentations, count)
|
| 361 |
+
|
| 362 |
+
def apply(
|
| 363 |
+
self,
|
| 364 |
+
file: AudioFile,
|
| 365 |
+
num_speakers: int = None,
|
| 366 |
+
min_speakers: int = None,
|
| 367 |
+
max_speakers: int = None,
|
| 368 |
+
return_embeddings: bool = False,
|
| 369 |
+
hook: Optional[Callable] = None,
|
| 370 |
+
) -> Annotation:
|
| 371 |
+
"""Apply speaker diarization
|
| 372 |
+
|
| 373 |
+
Parameters
|
| 374 |
+
----------
|
| 375 |
+
file : AudioFile
|
| 376 |
+
Processed file.
|
| 377 |
+
num_speakers : int, optional
|
| 378 |
+
Number of speakers, when known.
|
| 379 |
+
min_speakers : int, optional
|
| 380 |
+
Minimum number of speakers. Has no effect when `num_speakers` is provided.
|
| 381 |
+
max_speakers : int, optional
|
| 382 |
+
Maximum number of speakers. Has no effect when `num_speakers` is provided.
|
| 383 |
+
return_embeddings : bool, optional
|
| 384 |
+
Return representative speaker embeddings.
|
| 385 |
+
hook : callable, optional
|
| 386 |
+
Callback called after each major steps of the pipeline as follows:
|
| 387 |
+
hook(step_name, # human-readable name of current step
|
| 388 |
+
step_artefact, # artifact generated by current step
|
| 389 |
+
file=file) # file being processed
|
| 390 |
+
Time-consuming steps call `hook` multiple times with the same `step_name`
|
| 391 |
+
and additional `completed` and `total` keyword arguments usable to track
|
| 392 |
+
progress of current step.
|
| 393 |
+
|
| 394 |
+
Returns
|
| 395 |
+
-------
|
| 396 |
+
diarization : Annotation
|
| 397 |
+
Speaker diarization
|
| 398 |
+
embeddings : np.array, optional
|
| 399 |
+
Representative speaker embeddings such that `embeddings[i]` is the
|
| 400 |
+
speaker embedding for i-th speaker in diarization.labels().
|
| 401 |
+
Only returned when `return_embeddings` is True.
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
# setup hook (e.g. for debugging purposes)
|
| 405 |
+
hook = self.setup_hook(file, hook=hook)
|
| 406 |
+
|
| 407 |
+
num_speakers, min_speakers, max_speakers = self.set_num_speakers(
|
| 408 |
+
num_speakers=num_speakers,
|
| 409 |
+
min_speakers=min_speakers,
|
| 410 |
+
max_speakers=max_speakers,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
segmentations = self.get_segmentations(file, hook=hook)
|
| 414 |
+
hook("segmentation", segmentations)
|
| 415 |
+
# shape: (num_chunks, num_frames, local_num_speakers)
|
| 416 |
+
|
| 417 |
+
# binarize segmentation
|
| 418 |
+
|
| 419 |
+
binarized_segmentations = segmentations
|
| 420 |
+
# estimate frame-level number of instantaneous speakers
|
| 421 |
+
count = self.speaker_count(
|
| 422 |
+
binarized_segmentations,
|
| 423 |
+
frames=self._frames,
|
| 424 |
+
warm_up=(0.0, 0.0),
|
| 425 |
+
)
|
| 426 |
+
hook("speaker_counting", count)
|
| 427 |
+
# shape: (num_frames, 1)
|
| 428 |
+
# dtype: int
|
| 429 |
+
|
| 430 |
+
# exit early when no speaker is ever active
|
| 431 |
+
if np.nanmax(count.data) == 0.0:
|
| 432 |
+
diarization = Annotation(uri=file["uri"])
|
| 433 |
+
if return_embeddings:
|
| 434 |
+
return diarization, np.zeros((0, self._embedding.dimension))
|
| 435 |
+
|
| 436 |
+
return diarization
|
| 437 |
+
|
| 438 |
+
embeddings = self.get_embeddings(
|
| 439 |
+
file,
|
| 440 |
+
binarized_segmentations,
|
| 441 |
+
exclude_overlap=self.embedding_exclude_overlap,
|
| 442 |
+
hook=hook,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
hook("embeddings", embeddings)
|
| 446 |
+
# shape: (num_chunks, local_num_speakers, dimension)
|
| 447 |
+
|
| 448 |
+
hard_clusters, _, centroids = self.clustering(
|
| 449 |
+
embeddings=embeddings,
|
| 450 |
+
segmentations=binarized_segmentations,
|
| 451 |
+
num_clusters=num_speakers,
|
| 452 |
+
min_clusters=min_speakers,
|
| 453 |
+
max_clusters=max_speakers,
|
| 454 |
+
file=file, # <== for oracle clustering
|
| 455 |
+
frames=self._frames, # <== for oracle clustering
|
| 456 |
+
)
|
| 457 |
+
# hard_clusters: (num_chunks, num_speakers)
|
| 458 |
+
# centroids: (num_speakers, dimension)
|
| 459 |
+
|
| 460 |
+
# number of detected clusters is the number of different speakers
|
| 461 |
+
num_different_speakers = np.max(hard_clusters) + 1
|
| 462 |
+
|
| 463 |
+
# detected number of speakers can still be out of bounds
|
| 464 |
+
# (specifically, lower than `min_speakers`), since there could be too few embeddings
|
| 465 |
+
# to make enough clusters with a given minimum cluster size.
|
| 466 |
+
if num_different_speakers < min_speakers or num_different_speakers > max_speakers:
|
| 467 |
+
warnings.warn(textwrap.dedent(
|
| 468 |
+
f"""
|
| 469 |
+
The detected number of speakers ({num_different_speakers}) is outside
|
| 470 |
+
the given bounds [{min_speakers}, {max_speakers}]. This can happen if the
|
| 471 |
+
given audio file is too short to contain {min_speakers} or more speakers.
|
| 472 |
+
Try to lower the desired minimal number of speakers.
|
| 473 |
+
"""
|
| 474 |
+
))
|
| 475 |
+
|
| 476 |
+
# during counting, we could possibly overcount the number of instantaneous
|
| 477 |
+
# speakers due to segmentation errors, so we cap the maximum instantaneous number
|
| 478 |
+
# of speakers by the `max_speakers` value
|
| 479 |
+
count.data = np.minimum(count.data, max_speakers).astype(np.int8)
|
| 480 |
+
|
| 481 |
+
# reconstruct discrete diarization from raw hard clusters
|
| 482 |
+
|
| 483 |
+
# keep track of inactive speakers
|
| 484 |
+
inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0
|
| 485 |
+
# shape: (num_chunks, num_speakers)
|
| 486 |
+
|
| 487 |
+
hard_clusters[inactive_speakers] = -2
|
| 488 |
+
discrete_diarization = self.reconstruct(
|
| 489 |
+
segmentations,
|
| 490 |
+
hard_clusters,
|
| 491 |
+
count,
|
| 492 |
+
)
|
| 493 |
+
hook("discrete_diarization", discrete_diarization)
|
| 494 |
+
|
| 495 |
+
# convert to continuous diarization
|
| 496 |
+
diarization = self.to_annotation(
|
| 497 |
+
discrete_diarization,
|
| 498 |
+
min_duration_on=0.0,
|
| 499 |
+
min_duration_off=self.segmentation.min_duration_off,
|
| 500 |
+
)
|
| 501 |
+
diarization.uri = file["uri"]
|
| 502 |
+
|
| 503 |
+
# at this point, `diarization` speaker labels are integers
|
| 504 |
+
# from 0 to `num_speakers - 1`, aligned with `centroids` rows.
|
| 505 |
+
|
| 506 |
+
if "annotation" in file and file["annotation"]:
|
| 507 |
+
# when reference is available, use it to map hypothesized speakers
|
| 508 |
+
# to reference speakers (this makes later error analysis easier
|
| 509 |
+
# but does not modify the actual output of the diarization pipeline)
|
| 510 |
+
_, mapping = self.optimal_mapping(
|
| 511 |
+
file["annotation"], diarization, return_mapping=True
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# in case there are more speakers in the hypothesis than in
|
| 515 |
+
# the reference, those extra speakers are missing from `mapping`.
|
| 516 |
+
# we add them back here
|
| 517 |
+
mapping = {key: mapping.get(key, key) for key in diarization.labels()}
|
| 518 |
+
|
| 519 |
+
else:
|
| 520 |
+
# when reference is not available, rename hypothesized speakers
|
| 521 |
+
# to human-readable SPEAKER_00, SPEAKER_01, ...
|
| 522 |
+
mapping = {
|
| 523 |
+
label: expected_label
|
| 524 |
+
for label, expected_label in zip(diarization.labels(), self.classes())
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
diarization = diarization.rename_labels(mapping=mapping)
|
| 528 |
+
# at this point, `diarization` speaker labels are strings (or mix of
|
| 529 |
+
# strings and integers when reference is available and some hypothesis
|
| 530 |
+
# speakers are not present in the reference)
|
| 531 |
+
if not return_embeddings:
|
| 532 |
+
return diarization
|
| 533 |
+
|
| 534 |
+
# this can happen when we use OracleClustering
|
| 535 |
+
if centroids is None:
|
| 536 |
+
return diarization, None
|
| 537 |
+
|
| 538 |
+
# The number of centroids may be smaller than the number of speakers
|
| 539 |
+
# in the annotation. This can happen if the number of active speakers
|
| 540 |
+
# obtained from `speaker_count` for some frames is larger than the number
|
| 541 |
+
# of clusters obtained from `clustering`. In this case, we append zero embeddings
|
| 542 |
+
# for extra speakers
|
| 543 |
+
if len(diarization.labels()) > centroids.shape[0]:
|
| 544 |
+
centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)))
|
| 545 |
+
|
| 546 |
+
# re-order centroids so that they match
|
| 547 |
+
# the order given by diarization.labels()
|
| 548 |
+
inverse_mapping = {label: index for index, label in mapping.items()}
|
| 549 |
+
centroids = centroids[
|
| 550 |
+
[inverse_mapping[label] for label in diarization.labels()]
|
| 551 |
+
]
|
| 552 |
+
|
| 553 |
+
return diarization, centroids
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2021 CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from functools import cached_property
|
| 25 |
+
from typing import Optional, Text, Union, Mapping
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import ailia
|
| 29 |
+
|
| 30 |
+
from pyannote_audio_utils.audio.pipelines.utils.kaldifeat import compute_fbank_feats
|
| 31 |
+
from pyannote_audio_utils.audio.core.inference import BaseInference
|
| 32 |
+
|
| 33 |
+
PipelineModel = Union[Text, Mapping]
|
| 34 |
+
|
| 35 |
+
class ONNXWeSpeakerPretrainedSpeakerEmbedding(BaseInference):
|
| 36 |
+
"""Pretrained WeSpeaker speaker embedding
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
embedding : str
|
| 41 |
+
Path to WeSpeaker pretrained speaker embedding
|
| 42 |
+
device : torch.device, optional
|
| 43 |
+
Device
|
| 44 |
+
|
| 45 |
+
Usage
|
| 46 |
+
-----
|
| 47 |
+
>>> get_embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM")
|
| 48 |
+
>>> assert waveforms.ndim == 3
|
| 49 |
+
>>> batch_size, num_channels, num_samples = waveforms.shape
|
| 50 |
+
>>> assert num_channels == 1
|
| 51 |
+
>>> embeddings = get_embedding(waveforms)
|
| 52 |
+
>>> assert embeddings.ndim == 2
|
| 53 |
+
>>> assert embeddings.shape[0] == batch_size
|
| 54 |
+
|
| 55 |
+
>>> assert binary_masks.ndim == 1
|
| 56 |
+
>>> assert binary_masks.shape[0] == batch_size
|
| 57 |
+
>>> embeddings = get_embedding(waveforms, masks=binary_masks)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
embedding: Text = "hbredin/wespeaker-voxceleb-resnet34-LM",
|
| 63 |
+
# device: Optional[torch.device] = None,
|
| 64 |
+
args = None,
|
| 65 |
+
emb_path = None
|
| 66 |
+
):
|
| 67 |
+
# if not ONNX_IS_AVAILABLE:
|
| 68 |
+
# raise ImportError(
|
| 69 |
+
# f"'onnxruntime' must be installed to use '{embedding}' embeddings."
|
| 70 |
+
# )
|
| 71 |
+
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
# if not Path(embedding).exists():
|
| 75 |
+
# try:
|
| 76 |
+
# embedding = hf_hub_download(
|
| 77 |
+
# repo_id=embedding,
|
| 78 |
+
# filename="speaker-embedding.onnx",
|
| 79 |
+
# )
|
| 80 |
+
# except RepositoryNotFoundError:
|
| 81 |
+
# raise ValueError(
|
| 82 |
+
# f"Could not find '{embedding}' on huggingface.co nor on local disk."
|
| 83 |
+
# )
|
| 84 |
+
|
| 85 |
+
self.embedding = embedding
|
| 86 |
+
|
| 87 |
+
if args.onnx:
|
| 88 |
+
import onnxruntime as ort
|
| 89 |
+
#print("use onnx runtime")
|
| 90 |
+
providers = ["CPUExecutionProvider", ("CUDAExecutionProvider",{"cudnn_conv_algo_search": "DEFAULT"})]
|
| 91 |
+
|
| 92 |
+
sess_options = ort.SessionOptions()
|
| 93 |
+
sess_options.inter_op_num_threads = 1
|
| 94 |
+
sess_options.intra_op_num_threads = 1
|
| 95 |
+
self.session_ = ort.InferenceSession(
|
| 96 |
+
embedding, sess_options=sess_options, providers=providers
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
#print("use ailia")
|
| 100 |
+
|
| 101 |
+
self.session_ = ailia.Net(emb_path, weight=embedding, env_id=args.env_id)
|
| 102 |
+
|
| 103 |
+
self.args = args
|
| 104 |
+
|
| 105 |
+
@cached_property
|
| 106 |
+
def sample_rate(self) -> int:
|
| 107 |
+
return 16000
|
| 108 |
+
|
| 109 |
+
@cached_property
|
| 110 |
+
def dimension(self) -> int:
|
| 111 |
+
dummy_waveforms = np.random.rand(1, 1, 16000)
|
| 112 |
+
features = self.compute_fbank(dummy_waveforms)
|
| 113 |
+
|
| 114 |
+
if self.args.onnx:
|
| 115 |
+
embeddings = self.session_.run(output_names=["embs"], input_feed={"feats": features}
|
| 116 |
+
)[0]
|
| 117 |
+
else:
|
| 118 |
+
embeddings = self.session_.predict([features])[0]
|
| 119 |
+
|
| 120 |
+
_, dimension = embeddings.shape
|
| 121 |
+
return dimension
|
| 122 |
+
|
| 123 |
+
@cached_property
|
| 124 |
+
def metric(self) -> str:
|
| 125 |
+
return "cosine"
|
| 126 |
+
|
| 127 |
+
@cached_property
|
| 128 |
+
def min_num_samples(self) -> int:
|
| 129 |
+
lower, upper = 2, round(0.5 * self.sample_rate)
|
| 130 |
+
middle = (lower + upper) // 2
|
| 131 |
+
while lower + 1 < upper:
|
| 132 |
+
try:
|
| 133 |
+
features = self.compute_fbank(np.random.randn(1, 1, middle))
|
| 134 |
+
|
| 135 |
+
except AssertionError:
|
| 136 |
+
lower = middle
|
| 137 |
+
middle = (lower + upper) // 2
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
if self.args.onnx:
|
| 141 |
+
embeddings = self.session_.run(output_names=["embs"], input_feed={"feats": features})[0]
|
| 142 |
+
else:
|
| 143 |
+
embeddings = self.session_.predict([features])[0]
|
| 144 |
+
|
| 145 |
+
if np.any(np.isnan(embeddings)):
|
| 146 |
+
lower = middle
|
| 147 |
+
else:
|
| 148 |
+
upper = middle
|
| 149 |
+
middle = (lower + upper) // 2
|
| 150 |
+
|
| 151 |
+
return upper
|
| 152 |
+
|
| 153 |
+
@cached_property
|
| 154 |
+
def min_num_frames(self) -> int:
|
| 155 |
+
return self.compute_fbank(np.random.randn(1, 1, self.min_num_samples)).shape[1]
|
| 156 |
+
|
| 157 |
+
def compute_fbank(
|
| 158 |
+
self,
|
| 159 |
+
waveforms: np.ndarray,
|
| 160 |
+
num_mel_bins: int = 80,
|
| 161 |
+
frame_length: int = 25,
|
| 162 |
+
frame_shift: int = 10,
|
| 163 |
+
dither: float = 0.0,
|
| 164 |
+
) -> np.ndarray:
|
| 165 |
+
"""Extract fbank features
|
| 166 |
+
|
| 167 |
+
Parameters
|
| 168 |
+
----------
|
| 169 |
+
waveforms : (batch_size, num_channels, num_samples)
|
| 170 |
+
|
| 171 |
+
Returns
|
| 172 |
+
-------
|
| 173 |
+
fbank : (batch_size, num_frames, num_mel_bins)
|
| 174 |
+
|
| 175 |
+
Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
waveforms = waveforms * (1 << 15)
|
| 179 |
+
|
| 180 |
+
### ここで少しずれる ###
|
| 181 |
+
features_numpy = np.stack([compute_fbank_feats(
|
| 182 |
+
waveform=waveform[0],
|
| 183 |
+
num_mel_bins=num_mel_bins,
|
| 184 |
+
frame_length=frame_length,
|
| 185 |
+
frame_shift=frame_shift,
|
| 186 |
+
dither=dither,
|
| 187 |
+
sample_frequency=self.sample_rate,
|
| 188 |
+
window_type="hamming",
|
| 189 |
+
use_energy=False,
|
| 190 |
+
)for waveform in waveforms])
|
| 191 |
+
### ここで少しずれる ###
|
| 192 |
+
|
| 193 |
+
features = features_numpy.astype(np.float32)
|
| 194 |
+
|
| 195 |
+
return features - np.mean(features, axis=1, keepdims=True)
|
| 196 |
+
|
| 197 |
+
def __call__(
|
| 198 |
+
self, waveforms: np.ndarray, masks: Optional[np.ndarray] = None
|
| 199 |
+
) -> np.ndarray:
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
Parameters
|
| 203 |
+
----------
|
| 204 |
+
waveforms : (batch_size, num_channels, num_samples)
|
| 205 |
+
Only num_channels == 1 is supported.
|
| 206 |
+
masks : (batch_size, num_samples), optional
|
| 207 |
+
|
| 208 |
+
Returns
|
| 209 |
+
-------
|
| 210 |
+
embeddings : (batch_size, dimension)
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
batch_size, num_channels, num_samples = waveforms.shape
|
| 215 |
+
assert num_channels == 1
|
| 216 |
+
|
| 217 |
+
features = self.compute_fbank(waveforms)
|
| 218 |
+
_, num_frames, _ = features.shape
|
| 219 |
+
|
| 220 |
+
batch_size_masks, _ = masks.shape
|
| 221 |
+
assert batch_size == batch_size_masks
|
| 222 |
+
|
| 223 |
+
def interpolate_numpy(input_array, size):
|
| 224 |
+
output_array = np.zeros((input_array.shape[0],size))
|
| 225 |
+
|
| 226 |
+
for i in range(output_array.shape[0]):
|
| 227 |
+
for j in range(output_array.shape[1]):
|
| 228 |
+
ii = int(np.floor(i * input_array.shape[0] / output_array.shape[0]))
|
| 229 |
+
jj = int(np.floor(j * input_array.shape[1] / output_array.shape[1]))
|
| 230 |
+
output_array[i, j] = input_array[ii, jj]
|
| 231 |
+
return output_array
|
| 232 |
+
|
| 233 |
+
imasks = interpolate_numpy(masks,size=num_frames)
|
| 234 |
+
imasks = imasks > 0.5
|
| 235 |
+
|
| 236 |
+
embeddings = np.NAN * np.zeros((batch_size, self.dimension))
|
| 237 |
+
|
| 238 |
+
for f, (feature, imask) in enumerate(zip(features, imasks)):
|
| 239 |
+
masked_feature = feature[imask]
|
| 240 |
+
if masked_feature.shape[0] < self.min_num_frames:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
if self.args.onnx:
|
| 244 |
+
embeddings[f] = self.session_.run(output_names=["embs"],input_feed={"feats": masked_feature[None]},)[0][0]
|
| 245 |
+
else:
|
| 246 |
+
embeddings[f] = self.session_.predict([masked_feature[None]])[0][0]
|
| 247 |
+
|
| 248 |
+
return embeddings
|
| 249 |
+
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2022- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
from .diarization import SpeakerDiarizationMixin
|
| 24 |
+
# from .getter import (
|
| 25 |
+
# PipelineAugmentation,
|
| 26 |
+
# PipelineInference,
|
| 27 |
+
# PipelineModel,
|
| 28 |
+
# get_augmentation,
|
| 29 |
+
# get_devices,
|
| 30 |
+
# get_inference,
|
| 31 |
+
# get_model,
|
| 32 |
+
# )
|
| 33 |
+
# from .oracle import oracle_segmentation
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
"SpeakerDiarizationMixin",
|
| 37 |
+
]
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2022- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
from typing import Dict, Mapping, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature
|
| 27 |
+
from pyannote_audio_utils.core.utils.types import Label
|
| 28 |
+
from pyannote_audio_utils.metrics.diarization import DiarizationErrorRate
|
| 29 |
+
|
| 30 |
+
from pyannote_audio_utils.audio.core.inference import Inference
|
| 31 |
+
from pyannote_audio_utils.audio.utils.signal import Binarize
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# TODO: move to dedicated module
|
| 35 |
+
class SpeakerDiarizationMixin:
|
| 36 |
+
"""Defines a bunch of methods common to speaker diarization pipelines"""
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def set_num_speakers(
|
| 40 |
+
num_speakers: Optional[int] = None,
|
| 41 |
+
min_speakers: Optional[int] = None,
|
| 42 |
+
max_speakers: Optional[int] = None,
|
| 43 |
+
):
|
| 44 |
+
"""Validate number of speakers
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
num_speakers : int, optional
|
| 49 |
+
Number of speakers.
|
| 50 |
+
min_speakers : int, optional
|
| 51 |
+
Minimum number of speakers.
|
| 52 |
+
max_speakers : int, optional
|
| 53 |
+
Maximum number of speakers.
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
num_speakers : int or None
|
| 58 |
+
min_speakers : int
|
| 59 |
+
max_speakers : int or np.inf
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# override {min|max}_num_speakers by num_speakers when available
|
| 63 |
+
min_speakers = num_speakers or min_speakers or 1
|
| 64 |
+
max_speakers = num_speakers or max_speakers or np.inf
|
| 65 |
+
|
| 66 |
+
if min_speakers > max_speakers:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"min_speakers must be smaller than (or equal to) max_speakers "
|
| 69 |
+
f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})."
|
| 70 |
+
)
|
| 71 |
+
if min_speakers == max_speakers:
|
| 72 |
+
num_speakers = min_speakers
|
| 73 |
+
|
| 74 |
+
return num_speakers, min_speakers, max_speakers
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def optimal_mapping(
|
| 78 |
+
reference: Union[Mapping, Annotation],
|
| 79 |
+
hypothesis: Annotation,
|
| 80 |
+
return_mapping: bool = False,
|
| 81 |
+
) -> Union[Annotation, Tuple[Annotation, Dict[Label, Label]]]:
|
| 82 |
+
"""Find the optimal bijective mapping between reference and hypothesis labels
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
reference : Annotation or Mapping
|
| 87 |
+
Reference annotation. Can be an Annotation instance or
|
| 88 |
+
a mapping with an "annotation" key.
|
| 89 |
+
hypothesis : Annotation
|
| 90 |
+
Hypothesized annotation.
|
| 91 |
+
return_mapping : bool, optional
|
| 92 |
+
Return the label mapping itself along with the mapped annotation. Defaults to False.
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
mapped : Annotation
|
| 97 |
+
Hypothesis mapped to reference speakers.
|
| 98 |
+
mapping : dict, optional
|
| 99 |
+
Mapping between hypothesis (key) and reference (value) labels
|
| 100 |
+
Only returned if `return_mapping` is True.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
if isinstance(reference, Mapping):
|
| 104 |
+
reference = reference["annotation"]
|
| 105 |
+
annotated = reference["annotated"] if "annotated" in reference else None
|
| 106 |
+
else:
|
| 107 |
+
annotated = None
|
| 108 |
+
|
| 109 |
+
mapping = DiarizationErrorRate().optimal_mapping(
|
| 110 |
+
reference, hypothesis, uem=annotated
|
| 111 |
+
)
|
| 112 |
+
mapped_hypothesis = hypothesis.rename_labels(mapping=mapping)
|
| 113 |
+
|
| 114 |
+
if return_mapping:
|
| 115 |
+
return mapped_hypothesis, mapping
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
return mapped_hypothesis
|
| 119 |
+
|
| 120 |
+
# TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count)
|
| 121 |
+
@staticmethod
|
| 122 |
+
def speaker_count(
|
| 123 |
+
binarized_segmentations: SlidingWindowFeature,
|
| 124 |
+
frames: SlidingWindow,
|
| 125 |
+
warm_up: Tuple[float, float] = (0.1, 0.1),
|
| 126 |
+
) -> SlidingWindowFeature:
|
| 127 |
+
"""Estimate frame-level number of instantaneous speakers
|
| 128 |
+
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
binarized_segmentations : SlidingWindowFeature
|
| 132 |
+
(num_chunks, num_frames, num_classes)-shaped binarized scores.
|
| 133 |
+
warm_up : (float, float) tuple, optional
|
| 134 |
+
Left/right warm up ratio of chunk duration.
|
| 135 |
+
Defaults to (0.1, 0.1), i.e. 10% on both sides.
|
| 136 |
+
frames : SlidingWindow
|
| 137 |
+
Frames resolution. Defaults to estimate it automatically based on
|
| 138 |
+
`segmentations` shape and chunk size. Providing the exact frame
|
| 139 |
+
resolution (when known) leads to better temporal precision.
|
| 140 |
+
|
| 141 |
+
Returns
|
| 142 |
+
-------
|
| 143 |
+
count : SlidingWindowFeature
|
| 144 |
+
(num_frames, 1)-shaped instantaneous speaker count
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
trimmed = Inference.trim(binarized_segmentations, warm_up=warm_up)
|
| 148 |
+
|
| 149 |
+
count = Inference.aggregate(
|
| 150 |
+
np.sum(trimmed, axis=-1, keepdims=True),
|
| 151 |
+
frames,
|
| 152 |
+
hamming=False,
|
| 153 |
+
missing=0.0,
|
| 154 |
+
skip_average=False,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
count.data = np.rint(count.data).astype(np.uint8)
|
| 158 |
+
|
| 159 |
+
return count
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def to_annotation(
|
| 163 |
+
discrete_diarization: SlidingWindowFeature,
|
| 164 |
+
min_duration_on: float = 0.0,
|
| 165 |
+
min_duration_off: float = 0.0,
|
| 166 |
+
) -> Annotation:
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
discrete_diarization : SlidingWindowFeature
|
| 172 |
+
(num_frames, num_speakers)-shaped discrete diarization
|
| 173 |
+
min_duration_on : float, optional
|
| 174 |
+
Defaults to 0.
|
| 175 |
+
min_duration_off : float, optional
|
| 176 |
+
Defaults to 0.
|
| 177 |
+
|
| 178 |
+
Returns
|
| 179 |
+
-------
|
| 180 |
+
continuous_diarization : Annotation
|
| 181 |
+
Continuous diarization, with speaker labels as integers,
|
| 182 |
+
corresponding to the speaker indices in the discrete diarization.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
binarize = Binarize(
|
| 186 |
+
onset=0.5,
|
| 187 |
+
offset=0.5,
|
| 188 |
+
min_duration_on=min_duration_on,
|
| 189 |
+
min_duration_off=min_duration_off,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return binarize(discrete_diarization).rename_tracks(generator="string")
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def to_diarization(
|
| 196 |
+
segmentations: SlidingWindowFeature,
|
| 197 |
+
count: SlidingWindowFeature,
|
| 198 |
+
) -> SlidingWindowFeature:
|
| 199 |
+
"""Build diarization out of preprocessed segmentation and precomputed speaker count
|
| 200 |
+
|
| 201 |
+
Parameters
|
| 202 |
+
----------
|
| 203 |
+
segmentations : SlidingWindowFeature
|
| 204 |
+
(num_chunks, num_frames, num_speakers)-shaped segmentations
|
| 205 |
+
count : SlidingWindow_feature
|
| 206 |
+
(num_frames, 1)-shaped speaker count
|
| 207 |
+
|
| 208 |
+
Returns
|
| 209 |
+
-------
|
| 210 |
+
discrete_diarization : SlidingWindowFeature
|
| 211 |
+
Discrete (0s and 1s) diarization.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
# TODO: investigate alternative aggregation
|
| 215 |
+
activations = Inference.aggregate(
|
| 216 |
+
segmentations,
|
| 217 |
+
count.sliding_window,
|
| 218 |
+
hamming=False,
|
| 219 |
+
missing=0.0,
|
| 220 |
+
skip_average=True,
|
| 221 |
+
)
|
| 222 |
+
# shape is (num_frames, num_speakers)
|
| 223 |
+
|
| 224 |
+
_, num_speakers = activations.data.shape
|
| 225 |
+
max_speakers_per_frame = np.max(count.data)
|
| 226 |
+
if num_speakers < max_speakers_per_frame:
|
| 227 |
+
activations.data = np.pad(
|
| 228 |
+
activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers))
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
extent = activations.extent & count.extent
|
| 232 |
+
activations = activations.crop(extent, return_data=False)
|
| 233 |
+
count = count.crop(extent, return_data=False)
|
| 234 |
+
|
| 235 |
+
sorted_speakers = np.argsort(-activations, axis=-1)
|
| 236 |
+
binary = np.zeros_like(activations.data)
|
| 237 |
+
|
| 238 |
+
for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)):
|
| 239 |
+
for i in range(c.item()):
|
| 240 |
+
binary[t, speakers[i]] = 1.0
|
| 241 |
+
|
| 242 |
+
return SlidingWindowFeature(binary, activations.sliding_window)
|
| 243 |
+
|
| 244 |
+
def classes(self):
|
| 245 |
+
speaker = 0
|
| 246 |
+
while True:
|
| 247 |
+
yield f"SPEAKER_{speaker:02d}"
|
| 248 |
+
speaker += 1
|
ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/yuyq96/kaldifeat
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# ---------- feature-window ----------
|
| 7 |
+
|
| 8 |
+
def sliding_window(x, window_size, window_shift):
|
| 9 |
+
shape = x.shape[:-1] + (x.shape[-1] - window_size + 1, window_size)
|
| 10 |
+
strides = x.strides + (x.strides[-1],)
|
| 11 |
+
return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)[::window_shift]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def func_num_frames(num_samples, window_size, window_shift, snip_edges):
|
| 15 |
+
if snip_edges:
|
| 16 |
+
if num_samples < window_size:
|
| 17 |
+
return 0
|
| 18 |
+
else:
|
| 19 |
+
return 1 + ((num_samples - window_size) // window_shift)
|
| 20 |
+
else:
|
| 21 |
+
return (num_samples + (window_shift // 2)) // window_shift
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def func_dither(waveform, dither_value):
|
| 25 |
+
if dither_value == 0.0:
|
| 26 |
+
return waveform
|
| 27 |
+
waveform += np.random.normal(size=waveform.shape).astype(waveform.dtype) * dither_value
|
| 28 |
+
return waveform
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def func_remove_dc_offset(waveform):
|
| 32 |
+
return waveform - np.mean(waveform)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def func_log_energy(waveform):
|
| 36 |
+
return np.log(np.dot(waveform, waveform).clip(min=np.finfo(waveform.dtype).eps))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def func_preemphasis(waveform, preemph_coeff):
|
| 40 |
+
if preemph_coeff == 0.0:
|
| 41 |
+
return waveform
|
| 42 |
+
assert 0 < preemph_coeff <= 1
|
| 43 |
+
waveform[1:] -= preemph_coeff * waveform[:-1]
|
| 44 |
+
waveform[0] -= preemph_coeff * waveform[0]
|
| 45 |
+
return waveform
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def sine(M):
|
| 49 |
+
if M < 1:
|
| 50 |
+
return np.array([])
|
| 51 |
+
if M == 1:
|
| 52 |
+
return np.ones(1, float)
|
| 53 |
+
n = np.arange(0, M)
|
| 54 |
+
return np.sin(np.pi * n / (M - 1))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def povey(M):
|
| 58 |
+
if M < 1:
|
| 59 |
+
return np.array([])
|
| 60 |
+
if M == 1:
|
| 61 |
+
return np.ones(1, float)
|
| 62 |
+
n = np.arange(0, M)
|
| 63 |
+
return (0.5 - 0.5 * np.cos(2.0 * np.pi * n / (M - 1))) ** 0.85
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def feature_window_function(window_type, window_size, blackman_coeff):
|
| 67 |
+
assert window_size > 0
|
| 68 |
+
if window_type == 'hanning':
|
| 69 |
+
return np.hanning(window_size)
|
| 70 |
+
elif window_type == 'sine':
|
| 71 |
+
return sine(window_size)
|
| 72 |
+
elif window_type == 'hamming':
|
| 73 |
+
return np.hamming(window_size)
|
| 74 |
+
elif window_type == 'povey':
|
| 75 |
+
return povey(window_size)
|
| 76 |
+
elif window_type == 'rectangular':
|
| 77 |
+
return np.ones(window_size)
|
| 78 |
+
elif window_type == 'blackman':
|
| 79 |
+
window_func = np.blackman(window_size)
|
| 80 |
+
if blackman_coeff == 0.42:
|
| 81 |
+
return window_func
|
| 82 |
+
else:
|
| 83 |
+
return window_func - 0.42 + blackman_coeff
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError('Invalid window type {}'.format(window_type))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def process_window(window, dither, remove_dc_offset, preemphasis_coefficient, window_function, raw_energy):
|
| 89 |
+
if dither != 0.0:
|
| 90 |
+
window = func_dither(window, dither)
|
| 91 |
+
if remove_dc_offset:
|
| 92 |
+
window = func_remove_dc_offset(window)
|
| 93 |
+
if raw_energy:
|
| 94 |
+
log_energy = func_log_energy(window)
|
| 95 |
+
if preemphasis_coefficient != 0.0:
|
| 96 |
+
window = func_preemphasis(window, preemphasis_coefficient)
|
| 97 |
+
window *= window_function
|
| 98 |
+
if not raw_energy:
|
| 99 |
+
log_energy = func_log_energy(window)
|
| 100 |
+
return window, log_energy
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def extract_window(waveform, blackman_coeff, dither, window_size, window_shift,
|
| 104 |
+
preemphasis_coefficient, raw_energy, remove_dc_offset,
|
| 105 |
+
snip_edges, window_type, dtype):
|
| 106 |
+
num_samples = len(waveform)
|
| 107 |
+
num_frames = func_num_frames(num_samples, window_size, window_shift, snip_edges)
|
| 108 |
+
num_samples_ = (num_frames - 1) * window_shift + window_size
|
| 109 |
+
if snip_edges:
|
| 110 |
+
waveform = waveform[:num_samples_]
|
| 111 |
+
else:
|
| 112 |
+
offset = window_shift // 2 - window_size // 2
|
| 113 |
+
waveform = np.concatenate([
|
| 114 |
+
waveform[-offset - 1::-1],
|
| 115 |
+
waveform,
|
| 116 |
+
waveform[:-(offset + num_samples_ - num_samples + 1):-1]
|
| 117 |
+
])
|
| 118 |
+
frames = sliding_window(waveform, window_size=window_size, window_shift=window_shift)
|
| 119 |
+
frames = frames.astype(dtype)
|
| 120 |
+
log_enery = np.empty(frames.shape[0], dtype=dtype)
|
| 121 |
+
for i in range(frames.shape[0]):
|
| 122 |
+
frames[i], log_enery[i] = process_window(
|
| 123 |
+
window=frames[i],
|
| 124 |
+
dither=dither,
|
| 125 |
+
remove_dc_offset=remove_dc_offset,
|
| 126 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 127 |
+
window_function=feature_window_function(
|
| 128 |
+
window_type=window_type,
|
| 129 |
+
window_size=window_size,
|
| 130 |
+
blackman_coeff=blackman_coeff
|
| 131 |
+
).astype(dtype),
|
| 132 |
+
raw_energy=raw_energy
|
| 133 |
+
)
|
| 134 |
+
return frames, log_enery
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------- feature-window ----------
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------- feature-functions ----------
|
| 141 |
+
|
| 142 |
+
def compute_spectrum(frames, n):
|
| 143 |
+
complex_spec = np.fft.rfft(frames, n)
|
| 144 |
+
return np.absolute(complex_spec)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def compute_power_spectrum(frames, n):
|
| 148 |
+
return np.square(compute_spectrum(frames, n))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------- feature-functions ----------
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ---------- mel-computations ----------
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def mel_scale(freq):
|
| 158 |
+
return 1127.0 * np.log(1.0 + freq / 700.0)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def compute_mel_banks(num_bins, sample_frequency, low_freq, high_freq, n):
|
| 162 |
+
""" Compute Mel banks.
|
| 163 |
+
|
| 164 |
+
:param num_bins: Number of triangular mel-frequency bins
|
| 165 |
+
:param sample_frequency: Waveform data sample frequency
|
| 166 |
+
:param low_freq: Low cutoff frequency for mel bins
|
| 167 |
+
:param high_freq: High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 168 |
+
:param n: Window size
|
| 169 |
+
:return: Mel banks.
|
| 170 |
+
"""
|
| 171 |
+
assert num_bins >= 3, 'Must have at least 3 mel bins'
|
| 172 |
+
num_fft_bins = n // 2
|
| 173 |
+
|
| 174 |
+
nyquist = 0.5 * sample_frequency
|
| 175 |
+
if high_freq <= 0:
|
| 176 |
+
high_freq = nyquist + high_freq
|
| 177 |
+
assert 0 <= low_freq < high_freq <= nyquist
|
| 178 |
+
|
| 179 |
+
fft_bin_width = sample_frequency / n
|
| 180 |
+
|
| 181 |
+
mel_low_freq = mel_scale(low_freq)
|
| 182 |
+
mel_high_freq = mel_scale(high_freq)
|
| 183 |
+
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
| 184 |
+
|
| 185 |
+
mel_banks = np.zeros([num_bins, num_fft_bins + 1])
|
| 186 |
+
for i in range(num_bins):
|
| 187 |
+
left_mel = mel_low_freq + mel_freq_delta * i
|
| 188 |
+
center_mel = left_mel + mel_freq_delta
|
| 189 |
+
right_mel = center_mel + mel_freq_delta
|
| 190 |
+
for j in range(num_fft_bins):
|
| 191 |
+
mel = mel_scale(fft_bin_width * j)
|
| 192 |
+
if left_mel < mel < right_mel:
|
| 193 |
+
if mel <= center_mel:
|
| 194 |
+
mel_banks[i, j] = (mel - left_mel) / (center_mel - left_mel)
|
| 195 |
+
else:
|
| 196 |
+
mel_banks[i, j] = (right_mel - mel) / (right_mel - center_mel)
|
| 197 |
+
return mel_banks
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ---------- mel-computations ----------
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# ---------- compute-fbank-feats ----------
|
| 204 |
+
|
| 205 |
+
def compute_fbank_feats(
|
| 206 |
+
waveform,
|
| 207 |
+
blackman_coeff=0.42,
|
| 208 |
+
dither=1.0,
|
| 209 |
+
energy_floor=0.0,
|
| 210 |
+
frame_length=25,
|
| 211 |
+
frame_shift=10,
|
| 212 |
+
high_freq=0,
|
| 213 |
+
low_freq=20,
|
| 214 |
+
num_mel_bins=23,
|
| 215 |
+
preemphasis_coefficient=0.97,
|
| 216 |
+
raw_energy=True,
|
| 217 |
+
remove_dc_offset=True,
|
| 218 |
+
round_to_power_of_two=True,
|
| 219 |
+
sample_frequency=16000,
|
| 220 |
+
snip_edges=True,
|
| 221 |
+
use_energy=False,
|
| 222 |
+
use_log_fbank=True,
|
| 223 |
+
use_power=True,
|
| 224 |
+
window_type='povey',
|
| 225 |
+
dtype=np.float32):
|
| 226 |
+
""" Compute (log) Mel filter bank energies
|
| 227 |
+
|
| 228 |
+
:param waveform: Input waveform.
|
| 229 |
+
:param blackman_coeff: Constant coefficient for generalized Blackman window. (float, default = 0.42)
|
| 230 |
+
:param dither: Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
|
| 231 |
+
:param energy_floor: Floor on energy (absolute, not relative) in FBANK computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0)
|
| 232 |
+
:param frame_length: Frame length in milliseconds (float, default = 25)
|
| 233 |
+
:param frame_shift: Frame shift in milliseconds (float, default = 10)
|
| 234 |
+
:param high_freq: High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
|
| 235 |
+
:param low_freq: Low cutoff frequency for mel bins (float, default = 20)
|
| 236 |
+
:param num_mel_bins: Number of triangular mel-frequency bins (int, default = 23)
|
| 237 |
+
:param preemphasis_coefficient: Coefficient for use in signal preemphasis (float, default = 0.97)
|
| 238 |
+
:param raw_energy: If true, compute energy before preemphasis and windowing (bool, default = true)
|
| 239 |
+
:param remove_dc_offset: Subtract mean from waveform on each frame (bool, default = true)
|
| 240 |
+
:param round_to_power_of_two: If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
|
| 241 |
+
:param sample_frequency: Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
|
| 242 |
+
:param snip_edges: If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
|
| 243 |
+
:param use_energy: Add an extra energy output. (bool, default = false)
|
| 244 |
+
:param use_log_fbank: If true, produce log-filterbank, else produce linear. (bool, default = true)
|
| 245 |
+
:param use_power: If true, use power, else use magnitude. (bool, default = true)
|
| 246 |
+
:param window_type: Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"sine"|"blackmann") (string, default = "povey")
|
| 247 |
+
:param dtype: Type of array (np.float32|np.float64) (dtype or string, default=np.float32)
|
| 248 |
+
:return: (Log) Mel filter bank energies.
|
| 249 |
+
"""
|
| 250 |
+
window_size = int(frame_length * sample_frequency * 0.001)
|
| 251 |
+
window_shift = int(frame_shift * sample_frequency * 0.001)
|
| 252 |
+
frames, log_energy = extract_window(
|
| 253 |
+
waveform=waveform,
|
| 254 |
+
blackman_coeff=blackman_coeff,
|
| 255 |
+
dither=dither,
|
| 256 |
+
window_size=window_size,
|
| 257 |
+
window_shift=window_shift,
|
| 258 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 259 |
+
raw_energy=raw_energy,
|
| 260 |
+
remove_dc_offset=remove_dc_offset,
|
| 261 |
+
snip_edges=snip_edges,
|
| 262 |
+
window_type=window_type,
|
| 263 |
+
dtype=dtype
|
| 264 |
+
)
|
| 265 |
+
if round_to_power_of_two:
|
| 266 |
+
n = 1
|
| 267 |
+
while n < window_size:
|
| 268 |
+
n *= 2
|
| 269 |
+
else:
|
| 270 |
+
n = window_size
|
| 271 |
+
if use_power:
|
| 272 |
+
spectrum = compute_power_spectrum(frames, n)
|
| 273 |
+
else:
|
| 274 |
+
spectrum = compute_spectrum(frames, n)
|
| 275 |
+
mel_banks = compute_mel_banks(
|
| 276 |
+
num_bins=num_mel_bins,
|
| 277 |
+
sample_frequency=sample_frequency,
|
| 278 |
+
low_freq=low_freq,
|
| 279 |
+
high_freq=high_freq,
|
| 280 |
+
n=n
|
| 281 |
+
).astype(dtype)
|
| 282 |
+
feat = np.dot(spectrum, mel_banks.T)
|
| 283 |
+
if use_log_fbank:
|
| 284 |
+
feat = np.log(feat.clip(min=np.finfo(dtype).eps))
|
| 285 |
+
if use_energy:
|
| 286 |
+
if energy_floor > 0.0:
|
| 287 |
+
log_energy.clip(min=np.math.log(energy_floor))
|
| 288 |
+
return feat, log_energy
|
| 289 |
+
return feat
|
| 290 |
+
|
| 291 |
+
# ---------- compute-fbank-feats ----------
|
ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from typing import Any, Callable, Tuple, Union
|
| 25 |
+
|
| 26 |
+
from pyannote_audio_utils.audio.core.task import Specifications
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def map_with_specifications(
|
| 30 |
+
specifications: Union[Specifications, Tuple[Specifications]],
|
| 31 |
+
func: Callable,
|
| 32 |
+
*iterables,
|
| 33 |
+
) -> Union[Any, Tuple[Any]]:
|
| 34 |
+
"""Compute the function using arguments from each of the iterables
|
| 35 |
+
|
| 36 |
+
Returns a tuple if provided `specifications` is a tuple,
|
| 37 |
+
otherwise returns the function return value.
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
specifications : (tuple of) Specifications
|
| 42 |
+
Specifications or tuple of specifications
|
| 43 |
+
func : callable
|
| 44 |
+
Function called for each specification with
|
| 45 |
+
`func(*iterables[i], specifications=specifications[i])`
|
| 46 |
+
*iterables :
|
| 47 |
+
List of iterables with same length as `specifications`.
|
| 48 |
+
|
| 49 |
+
Returns
|
| 50 |
+
-------
|
| 51 |
+
output : (tuple of) `func` return value(s)
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
if isinstance(specifications, Specifications):
|
| 55 |
+
return func(*iterables, specifications=specifications)
|
| 56 |
+
|
| 57 |
+
return tuple(
|
| 58 |
+
func(*i, specifications=s) for s, *i in zip(specifications, *iterables)
|
| 59 |
+
)
|
ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023- CNRS
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# AUTHORS
|
| 24 |
+
# Hervé BREDIN - https://herve.niderb.fr
|
| 25 |
+
# Alexis PLAQUET
|
| 26 |
+
|
| 27 |
+
from functools import cached_property
|
| 28 |
+
from itertools import combinations
|
| 29 |
+
|
| 30 |
+
import scipy.special
|
| 31 |
+
import numpy as np
|
| 32 |
+
|
| 33 |
+
class Powerset():
|
| 34 |
+
"""Powerset to multilabel conversion, and back.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
num_classes : int
|
| 39 |
+
Number of regular classes.
|
| 40 |
+
max_set_size : int
|
| 41 |
+
Maximum number of classes in each set.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, num_classes: int, max_set_size: int):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.num_classes = num_classes
|
| 47 |
+
self.max_set_size = max_set_size
|
| 48 |
+
self.mapping = self.build_mapping()
|
| 49 |
+
self.cardinality = self.build_cardinality()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@cached_property
|
| 54 |
+
def num_powerset_classes(self) -> int:
|
| 55 |
+
# compute number of subsets of size at most "max_set_size"
|
| 56 |
+
# e.g. with num_classes = 3 and max_set_size = 2:
|
| 57 |
+
# {}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}
|
| 58 |
+
return int(
|
| 59 |
+
sum(
|
| 60 |
+
scipy.special.binom(self.num_classes, i)
|
| 61 |
+
for i in range(0, self.max_set_size + 1)
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def build_mapping(self) -> np.ndarray:
|
| 66 |
+
"""Compute powerset to regular mapping
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
mapping : (num_powerset_classes, num_classes) torch.Tensor
|
| 71 |
+
mapping[i, j] == 1 if jth regular class is a member of ith powerset class
|
| 72 |
+
mapping[i, j] == 0 otherwise
|
| 73 |
+
|
| 74 |
+
Example
|
| 75 |
+
-------
|
| 76 |
+
With num_classes == 3 and max_set_size == 2, returns
|
| 77 |
+
|
| 78 |
+
[0, 0, 0] # none
|
| 79 |
+
[1, 0, 0] # class #1
|
| 80 |
+
[0, 1, 0] # class #2
|
| 81 |
+
[0, 0, 1] # class #3
|
| 82 |
+
[1, 1, 0] # classes #1 and #2
|
| 83 |
+
[1, 0, 1] # classes #1 and #3
|
| 84 |
+
[0, 1, 1] # classes #2 and #3
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
mapping = np.zeros((self.num_powerset_classes, self.num_classes))
|
| 88 |
+
|
| 89 |
+
powerset_k = 0
|
| 90 |
+
for set_size in range(0, self.max_set_size + 1):
|
| 91 |
+
for current_set in combinations(range(self.num_classes), set_size):
|
| 92 |
+
mapping[powerset_k, current_set] = 1
|
| 93 |
+
powerset_k += 1
|
| 94 |
+
|
| 95 |
+
return mapping
|
| 96 |
+
|
| 97 |
+
def build_cardinality(self) -> np.ndarray:
|
| 98 |
+
"""Compute size of each powerset class"""
|
| 99 |
+
return np.sum(self.mapping, axis=1)
|
| 100 |
+
|
| 101 |
+
def to_multilabel(self, powerset: np.ndarray, soft: bool = False) -> np.ndarray:
|
| 102 |
+
"""Convert predictions from powerset to multi-label
|
| 103 |
+
|
| 104 |
+
Parameter
|
| 105 |
+
---------
|
| 106 |
+
powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor
|
| 107 |
+
Soft predictions in "powerset" space.
|
| 108 |
+
soft : bool, optional
|
| 109 |
+
Return soft multi-label predictions. Defaults to False (i.e. hard predictions)
|
| 110 |
+
Assumes that `powerset` are "logits" (not "probabilities").
|
| 111 |
+
|
| 112 |
+
Returns
|
| 113 |
+
-------
|
| 114 |
+
multi_label : (batch_size, num_frames, num_classes) torch.Tensor
|
| 115 |
+
Predictions in "multi-label" space.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
powerset_probs = np.identity(self.num_powerset_classes)[np.argmax(powerset, axis=-1)]
|
| 119 |
+
return np.matmul(powerset_probs, self.mapping)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def __call__(self, powerset: np.ndarray, soft: bool = False) -> np.ndarray:
|
| 123 |
+
"""Alias for `to_multilabel`"""
|
| 124 |
+
|
| 125 |
+
return self.to_multilabel(powerset, soft=soft)
|
ailia-models/code/pyannote_audio_utils/audio/utils/signal.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
#
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) 2016-2021 CNRS
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
# Signal processing
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
from functools import singledispatch
|
| 34 |
+
from itertools import zip_longest
|
| 35 |
+
from typing import Optional, Union
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
import scipy.signal
|
| 39 |
+
from pyannote_audio_utils.core import Annotation, Segment, SlidingWindowFeature, Timeline
|
| 40 |
+
from pyannote_audio_utils.core.utils.generators import pairwise
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@singledispatch
|
| 44 |
+
def binarize(
|
| 45 |
+
scores,
|
| 46 |
+
onset: float = 0.5,
|
| 47 |
+
offset: Optional[float] = None,
|
| 48 |
+
initial_state: Optional[Union[bool, np.ndarray]] = None,
|
| 49 |
+
):
|
| 50 |
+
"""(Batch) hysteresis thresholding
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
scores : numpy.ndarray or SlidingWindowFeature
|
| 55 |
+
(num_chunks, num_frames, num_classes)- or (num_frames, num_classes)-shaped scores.
|
| 56 |
+
onset : float, optional
|
| 57 |
+
Onset threshold. Defaults to 0.5.
|
| 58 |
+
offset : float, optional
|
| 59 |
+
Offset threshold. Defaults to `onset`.
|
| 60 |
+
initial_state : np.ndarray or bool, optional
|
| 61 |
+
Initial state.
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
binarized : same as scores
|
| 66 |
+
Binarized scores with same shape and type as scores.
|
| 67 |
+
|
| 68 |
+
Reference
|
| 69 |
+
---------
|
| 70 |
+
https://stackoverflow.com/questions/23289976/how-to-find-zero-crossings-with-hysteresis
|
| 71 |
+
"""
|
| 72 |
+
raise NotImplementedError(
|
| 73 |
+
"scores must be of type numpy.ndarray or SlidingWindowFeatures"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@binarize.register
|
| 78 |
+
def binarize_ndarray(
|
| 79 |
+
scores: np.ndarray,
|
| 80 |
+
onset: float = 0.5,
|
| 81 |
+
offset: Optional[float] = None,
|
| 82 |
+
initial_state: Optional[Union[bool, np.ndarray]] = None,
|
| 83 |
+
):
|
| 84 |
+
"""(Batch) hysteresis thresholding
|
| 85 |
+
|
| 86 |
+
Parameters
|
| 87 |
+
----------
|
| 88 |
+
scores : numpy.ndarray
|
| 89 |
+
(num_frames, num_classes)-shaped scores.
|
| 90 |
+
onset : float, optional
|
| 91 |
+
Onset threshold. Defaults to 0.5.
|
| 92 |
+
offset : float, optional
|
| 93 |
+
Offset threshold. Defaults to `onset`.
|
| 94 |
+
initial_state : np.ndarray or bool, optional
|
| 95 |
+
Initial state.
|
| 96 |
+
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
binarized : same as scores
|
| 100 |
+
Binarized scores with same shape and type as scores.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
offset = offset or onset
|
| 104 |
+
|
| 105 |
+
batch_size, num_frames = scores.shape
|
| 106 |
+
|
| 107 |
+
scores = np.nan_to_num(scores)
|
| 108 |
+
|
| 109 |
+
if initial_state is None:
|
| 110 |
+
initial_state = scores[:, 0] >= 0.5 * (onset + offset)
|
| 111 |
+
|
| 112 |
+
elif isinstance(initial_state, bool):
|
| 113 |
+
initial_state = initial_state * np.ones((batch_size,), dtype=bool)
|
| 114 |
+
|
| 115 |
+
elif isinstance(initial_state, np.ndarray):
|
| 116 |
+
assert initial_state.shape == (batch_size,)
|
| 117 |
+
assert initial_state.dtype == bool
|
| 118 |
+
|
| 119 |
+
initial_state = np.tile(initial_state, (num_frames, 1)).T
|
| 120 |
+
|
| 121 |
+
on = scores > onset
|
| 122 |
+
off_or_on = (scores < offset) | on
|
| 123 |
+
|
| 124 |
+
# indices of frames for which the on/off state is well-defined
|
| 125 |
+
well_defined_idx = np.array(
|
| 126 |
+
list(zip_longest(*[np.nonzero(oon)[0] for oon in off_or_on], fillvalue=-1))
|
| 127 |
+
).T
|
| 128 |
+
|
| 129 |
+
# corner case where well_defined_idx is empty
|
| 130 |
+
if not well_defined_idx.size:
|
| 131 |
+
return np.zeros_like(scores, dtype=bool) | initial_state
|
| 132 |
+
|
| 133 |
+
# points to the index of the previous well-defined frame
|
| 134 |
+
same_as = np.cumsum(off_or_on, axis=1)
|
| 135 |
+
|
| 136 |
+
samples = np.tile(np.arange(batch_size), (num_frames, 1)).T
|
| 137 |
+
|
| 138 |
+
return np.where(
|
| 139 |
+
same_as, on[samples, well_defined_idx[samples, same_as - 1]], initial_state
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@binarize.register
|
| 144 |
+
def binarize_swf(
|
| 145 |
+
scores: SlidingWindowFeature,
|
| 146 |
+
onset: float = 0.5,
|
| 147 |
+
offset: Optional[float] = None,
|
| 148 |
+
initial_state: Optional[bool] = None,
|
| 149 |
+
):
|
| 150 |
+
"""(Batch) hysteresis thresholding
|
| 151 |
+
|
| 152 |
+
Parameters
|
| 153 |
+
----------
|
| 154 |
+
scores : SlidingWindowFeature
|
| 155 |
+
(num_chunks, num_frames, num_classes)- or (num_frames, num_classes)-shaped scores.
|
| 156 |
+
onset : float, optional
|
| 157 |
+
Onset threshold. Defaults to 0.5.
|
| 158 |
+
offset : float, optional
|
| 159 |
+
Offset threshold. Defaults to `onset`.
|
| 160 |
+
initial_state : np.ndarray or bool, optional
|
| 161 |
+
Initial state.
|
| 162 |
+
|
| 163 |
+
Returns
|
| 164 |
+
-------
|
| 165 |
+
binarized : same as scores
|
| 166 |
+
Binarized scores with same shape and type as scores.
|
| 167 |
+
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
offset = offset or onset
|
| 171 |
+
|
| 172 |
+
if scores.data.ndim == 2:
|
| 173 |
+
num_frames, num_classes = scores.data.shape
|
| 174 |
+
data = scores.data.transpose()
|
| 175 |
+
binarized = binarize(
|
| 176 |
+
data, onset=onset, offset=offset, initial_state=initial_state
|
| 177 |
+
)
|
| 178 |
+
return SlidingWindowFeature(
|
| 179 |
+
1.0
|
| 180 |
+
* binarized.transpose(),
|
| 181 |
+
scores.sliding_window,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
elif scores.data.ndim == 3:
|
| 185 |
+
num_chunks, num_frames, num_classes = scores.data.shape
|
| 186 |
+
data = scores.data.reshape([-1, num_classes])
|
| 187 |
+
binarized = binarize(
|
| 188 |
+
data, onset=onset, offset=offset, initial_state=initial_state
|
| 189 |
+
)
|
| 190 |
+
return SlidingWindowFeature(
|
| 191 |
+
1.0
|
| 192 |
+
* binarized.reshape([num_chunks, num_frames, num_classes]),
|
| 193 |
+
scores.sliding_window,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
else:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
"Shape of scores must be (num_chunks, num_frames, num_classes) or (num_frames, num_classes)."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Binarize:
|
| 203 |
+
"""Binarize detection scores using hysteresis thresholding
|
| 204 |
+
|
| 205 |
+
Parameters
|
| 206 |
+
----------
|
| 207 |
+
onset : float, optional
|
| 208 |
+
Onset threshold. Defaults to 0.5.
|
| 209 |
+
offset : float, optional
|
| 210 |
+
Offset threshold. Defaults to `onset`.
|
| 211 |
+
min_duration_on : float, optional
|
| 212 |
+
Remove active regions shorter than that many seconds. Defaults to 0s.
|
| 213 |
+
min_duration_off : float, optional
|
| 214 |
+
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
| 215 |
+
pad_onset : float, optional
|
| 216 |
+
Extend active regions by moving their start time by that many seconds.
|
| 217 |
+
Defaults to 0s.
|
| 218 |
+
pad_offset : float, optional
|
| 219 |
+
Extend active regions by moving their end time by that many seconds.
|
| 220 |
+
Defaults to 0s.
|
| 221 |
+
|
| 222 |
+
Reference
|
| 223 |
+
---------
|
| 224 |
+
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
| 225 |
+
RNN-based Voice Activity Detection", InterSpeech 2015.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
onset: float = 0.5,
|
| 231 |
+
offset: Optional[float] = None,
|
| 232 |
+
min_duration_on: float = 0.0,
|
| 233 |
+
min_duration_off: float = 0.0,
|
| 234 |
+
pad_onset: float = 0.0,
|
| 235 |
+
pad_offset: float = 0.0,
|
| 236 |
+
):
|
| 237 |
+
|
| 238 |
+
super().__init__()
|
| 239 |
+
|
| 240 |
+
self.onset = onset
|
| 241 |
+
self.offset = offset or onset
|
| 242 |
+
|
| 243 |
+
self.pad_onset = pad_onset
|
| 244 |
+
self.pad_offset = pad_offset
|
| 245 |
+
|
| 246 |
+
self.min_duration_on = min_duration_on
|
| 247 |
+
self.min_duration_off = min_duration_off
|
| 248 |
+
|
| 249 |
+
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
| 250 |
+
"""Binarize detection scores
|
| 251 |
+
|
| 252 |
+
Parameters
|
| 253 |
+
----------
|
| 254 |
+
scores : SlidingWindowFeature
|
| 255 |
+
Detection scores.
|
| 256 |
+
|
| 257 |
+
Returns
|
| 258 |
+
-------
|
| 259 |
+
active : Annotation
|
| 260 |
+
Binarized scores.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
num_frames, num_classes = scores.data.shape
|
| 264 |
+
frames = scores.sliding_window
|
| 265 |
+
timestamps = [frames[i].middle for i in range(num_frames)]
|
| 266 |
+
|
| 267 |
+
# annotation meant to store 'active' regions
|
| 268 |
+
active = Annotation()
|
| 269 |
+
|
| 270 |
+
for k, k_scores in enumerate(scores.data.T):
|
| 271 |
+
|
| 272 |
+
label = k if scores.labels is None else scores.labels[k]
|
| 273 |
+
|
| 274 |
+
# initial state
|
| 275 |
+
start = timestamps[0]
|
| 276 |
+
is_active = k_scores[0] > self.onset
|
| 277 |
+
|
| 278 |
+
for t, y in zip(timestamps[1:], k_scores[1:]):
|
| 279 |
+
|
| 280 |
+
# currently active
|
| 281 |
+
if is_active:
|
| 282 |
+
# switching from active to inactive
|
| 283 |
+
if y < self.offset:
|
| 284 |
+
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
| 285 |
+
active[region, k] = label
|
| 286 |
+
start = t
|
| 287 |
+
is_active = False
|
| 288 |
+
|
| 289 |
+
# currently inactive
|
| 290 |
+
else:
|
| 291 |
+
# switching from inactive to active
|
| 292 |
+
if y > self.onset:
|
| 293 |
+
start = t
|
| 294 |
+
is_active = True
|
| 295 |
+
|
| 296 |
+
# if active at the end, add final region
|
| 297 |
+
if is_active:
|
| 298 |
+
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
| 299 |
+
active[region, k] = label
|
| 300 |
+
|
| 301 |
+
# because of padding, some active regions might be overlapping: merge them.
|
| 302 |
+
# also: fill same speaker gaps shorter than min_duration_off
|
| 303 |
+
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
| 304 |
+
active = active.support(collar=self.min_duration_off)
|
| 305 |
+
|
| 306 |
+
# remove tracks shorter than min_duration_on
|
| 307 |
+
if self.min_duration_on > 0:
|
| 308 |
+
for segment, track in list(active.itertracks()):
|
| 309 |
+
if segment.duration < self.min_duration_on:
|
| 310 |
+
del active[segment, track]
|
| 311 |
+
|
| 312 |
+
return active
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class Peak:
|
| 316 |
+
"""Peak detection
|
| 317 |
+
|
| 318 |
+
Parameters
|
| 319 |
+
----------
|
| 320 |
+
alpha : float, optional
|
| 321 |
+
Peak threshold. Defaults to 0.5
|
| 322 |
+
min_duration : float, optional
|
| 323 |
+
Minimum elapsed time between two consecutive peaks. Defaults to 1 second.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
alpha: float = 0.5,
|
| 329 |
+
min_duration: float = 1.0,
|
| 330 |
+
):
|
| 331 |
+
super(Peak, self).__init__()
|
| 332 |
+
self.alpha = alpha
|
| 333 |
+
self.min_duration = min_duration
|
| 334 |
+
|
| 335 |
+
def __call__(self, scores: SlidingWindowFeature):
|
| 336 |
+
"""Peak detection
|
| 337 |
+
|
| 338 |
+
Parameter
|
| 339 |
+
---------
|
| 340 |
+
scores : SlidingWindowFeature
|
| 341 |
+
Detection scores.
|
| 342 |
+
|
| 343 |
+
Returns
|
| 344 |
+
-------
|
| 345 |
+
segmentation : Timeline
|
| 346 |
+
Partition.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
if scores.dimension != 1:
|
| 350 |
+
raise ValueError("Peak expects one-dimensional scores.")
|
| 351 |
+
|
| 352 |
+
num_frames = len(scores)
|
| 353 |
+
frames = scores.sliding_window
|
| 354 |
+
|
| 355 |
+
precision = frames.step
|
| 356 |
+
order = max(1, int(np.rint(self.min_duration / precision)))
|
| 357 |
+
indices = scipy.signal.argrelmax(scores[:], order=order)[0]
|
| 358 |
+
|
| 359 |
+
peak_time = np.array(
|
| 360 |
+
[frames[i].middle for i in indices if scores[i] > self.alpha]
|
| 361 |
+
)
|
| 362 |
+
boundaries = np.hstack([[frames[0].start], peak_time, [frames[num_frames].end]])
|
| 363 |
+
|
| 364 |
+
segmentation = Timeline()
|
| 365 |
+
for i, (start, end) in enumerate(pairwise(boundaries)):
|
| 366 |
+
segment = Segment(start, end)
|
| 367 |
+
segmentation.add(segment)
|
| 368 |
+
|
| 369 |
+
return segmentation
|
ailia-models/code/pyannote_audio_utils/audio/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '3.1.1'
|
ailia-models/code/pyannote_audio_utils/core/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from ._version import get_versions
|
| 31 |
+
__version__ = get_versions()['version']
|
| 32 |
+
del get_versions
|
| 33 |
+
|
| 34 |
+
PYANNOTE_URI = 'uri'
|
| 35 |
+
PYANNOTE_MODALITY = 'modality'
|
| 36 |
+
PYANNOTE_SEGMENT = 'segment'
|
| 37 |
+
PYANNOTE_TRACK = 'track'
|
| 38 |
+
PYANNOTE_LABEL = 'label'
|
| 39 |
+
PYANNOTE_SCORE = 'score'
|
| 40 |
+
PYANNOTE_IDENTITY = 'identity'
|
| 41 |
+
|
| 42 |
+
from .segment import Segment, SlidingWindow
|
| 43 |
+
from .timeline import Timeline
|
| 44 |
+
from .annotation import Annotation
|
| 45 |
+
from .feature import SlidingWindowFeature
|
| 46 |
+
|
| 47 |
+
Segment.set_precision()
|
| 48 |
+
|
ailia-models/code/pyannote_audio_utils/core/_version.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# This file was generated by 'versioneer.py' (0.15) from
|
| 3 |
+
# revision-control system data, or from the parent directory name of an
|
| 4 |
+
# unpacked source archive. Distribution tarballs contain a pre-generated copy
|
| 5 |
+
# of this file.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
version_json = '''
|
| 10 |
+
{
|
| 11 |
+
"dirty": false,
|
| 12 |
+
"error": null,
|
| 13 |
+
"full-revisionid": "4b0fd5302d8fa3ba249b42d3ab7b4cb51ee61ba2",
|
| 14 |
+
"version": "5.0.0"
|
| 15 |
+
}
|
| 16 |
+
''' # END VERSION_JSON
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_versions():
|
| 20 |
+
return json.loads(version_json)
|
ailia-models/code/pyannote_audio_utils/core/annotation.py
ADDED
|
@@ -0,0 +1,1551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2021 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
# Paul LERNER
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
##########
|
| 32 |
+
Annotation
|
| 33 |
+
##########
|
| 34 |
+
|
| 35 |
+
.. plot:: pyplots/annotation.py
|
| 36 |
+
|
| 37 |
+
:class:`pyannote.core.Annotation` instances are ordered sets of non-empty
|
| 38 |
+
tracks:
|
| 39 |
+
|
| 40 |
+
- ordered, because segments are sorted by start time (and end time in case of tie)
|
| 41 |
+
- set, because one cannot add twice the same track
|
| 42 |
+
- non-empty, because one cannot add empty track
|
| 43 |
+
|
| 44 |
+
A track is a (support, name) pair where `support` is a Segment instance,
|
| 45 |
+
and `name` is an additional identifier so that it is possible to add multiple
|
| 46 |
+
tracks with the same support.
|
| 47 |
+
|
| 48 |
+
To define the annotation depicted above:
|
| 49 |
+
|
| 50 |
+
.. code-block:: ipython
|
| 51 |
+
|
| 52 |
+
In [1]: from pyannote.core import Annotation, Segment
|
| 53 |
+
|
| 54 |
+
In [6]: annotation = Annotation()
|
| 55 |
+
...: annotation[Segment(1, 5)] = 'Carol'
|
| 56 |
+
...: annotation[Segment(6, 8)] = 'Bob'
|
| 57 |
+
...: annotation[Segment(12, 18)] = 'Carol'
|
| 58 |
+
...: annotation[Segment(7, 20)] = 'Alice'
|
| 59 |
+
...:
|
| 60 |
+
|
| 61 |
+
which is actually a shortcut for
|
| 62 |
+
|
| 63 |
+
.. code-block:: ipython
|
| 64 |
+
|
| 65 |
+
In [6]: annotation = Annotation()
|
| 66 |
+
...: annotation[Segment(1, 5), '_'] = 'Carol'
|
| 67 |
+
...: annotation[Segment(6, 8), '_'] = 'Bob'
|
| 68 |
+
...: annotation[Segment(12, 18), '_'] = 'Carol'
|
| 69 |
+
...: annotation[Segment(7, 20), '_'] = 'Alice'
|
| 70 |
+
...:
|
| 71 |
+
|
| 72 |
+
where all tracks share the same (default) name ``'_'``.
|
| 73 |
+
|
| 74 |
+
In case two tracks share the same support, use a different track name:
|
| 75 |
+
|
| 76 |
+
.. code-block:: ipython
|
| 77 |
+
|
| 78 |
+
In [6]: annotation = Annotation(uri='my_video_file', modality='speaker')
|
| 79 |
+
...: annotation[Segment(1, 5), 1] = 'Carol' # track name = 1
|
| 80 |
+
...: annotation[Segment(1, 5), 2] = 'Bob' # track name = 2
|
| 81 |
+
...: annotation[Segment(12, 18)] = 'Carol'
|
| 82 |
+
...:
|
| 83 |
+
|
| 84 |
+
The track name does not have to be unique over the whole set of tracks.
|
| 85 |
+
|
| 86 |
+
.. note::
|
| 87 |
+
|
| 88 |
+
The optional *uri* and *modality* keywords argument can be used to remember
|
| 89 |
+
which document and modality (e.g. speaker or face) it describes.
|
| 90 |
+
|
| 91 |
+
Several convenient methods are available. Here are a few examples:
|
| 92 |
+
|
| 93 |
+
.. code-block:: ipython
|
| 94 |
+
|
| 95 |
+
In [9]: annotation.labels() # sorted list of labels
|
| 96 |
+
Out[9]: ['Bob', 'Carol']
|
| 97 |
+
|
| 98 |
+
In [10]: annotation.chart() # label duration chart
|
| 99 |
+
Out[10]: [('Carol', 10), ('Bob', 4)]
|
| 100 |
+
|
| 101 |
+
In [11]: list(annotation.itertracks())
|
| 102 |
+
Out[11]: [(<Segment(1, 5)>, 1), (<Segment(1, 5)>, 2), (<Segment(12, 18)>, u'_')]
|
| 103 |
+
|
| 104 |
+
In [12]: annotation.label_timeline('Carol')
|
| 105 |
+
Out[12]: <Timeline(uri=my_video_file, segments=[<Segment(1, 5)>, <Segment(12, 18)>])>
|
| 106 |
+
|
| 107 |
+
See :class:`pyannote.core.Annotation` for the complete reference.
|
| 108 |
+
"""
|
| 109 |
+
import itertools
|
| 110 |
+
import warnings
|
| 111 |
+
from collections import defaultdict
|
| 112 |
+
from typing import (
|
| 113 |
+
Hashable,
|
| 114 |
+
Optional,
|
| 115 |
+
Dict,
|
| 116 |
+
Union,
|
| 117 |
+
Iterable,
|
| 118 |
+
List,
|
| 119 |
+
Set,
|
| 120 |
+
TextIO,
|
| 121 |
+
Tuple,
|
| 122 |
+
Iterator,
|
| 123 |
+
Text,
|
| 124 |
+
TYPE_CHECKING,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
import numpy as np
|
| 128 |
+
from sortedcontainers import SortedDict
|
| 129 |
+
|
| 130 |
+
from . import (
|
| 131 |
+
PYANNOTE_SEGMENT,
|
| 132 |
+
PYANNOTE_TRACK,
|
| 133 |
+
PYANNOTE_LABEL,
|
| 134 |
+
)
|
| 135 |
+
from .segment import Segment, SlidingWindow
|
| 136 |
+
from .timeline import Timeline
|
| 137 |
+
from .feature import SlidingWindowFeature
|
| 138 |
+
from .utils.generators import string_generator, int_generator
|
| 139 |
+
from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode
|
| 140 |
+
|
| 141 |
+
if TYPE_CHECKING:
|
| 142 |
+
import pandas as pd
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Annotation:
|
| 146 |
+
"""Annotation
|
| 147 |
+
|
| 148 |
+
Parameters
|
| 149 |
+
----------
|
| 150 |
+
uri : string, optional
|
| 151 |
+
name of annotated resource (e.g. audio or video file)
|
| 152 |
+
modality : string, optional
|
| 153 |
+
name of annotated modality
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
-------
|
| 157 |
+
annotation : Annotation
|
| 158 |
+
New annotation
|
| 159 |
+
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
@classmethod
|
| 163 |
+
def from_df(
|
| 164 |
+
cls,
|
| 165 |
+
df: "pd.DataFrame",
|
| 166 |
+
uri: Optional[str] = None,
|
| 167 |
+
modality: Optional[str] = None,
|
| 168 |
+
) -> "Annotation":
|
| 169 |
+
|
| 170 |
+
df = df[[PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL]]
|
| 171 |
+
return Annotation.from_records(df.itertuples(index=False), uri, modality)
|
| 172 |
+
|
| 173 |
+
def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None):
|
| 174 |
+
|
| 175 |
+
self._uri: Optional[str] = uri
|
| 176 |
+
self.modality: Optional[str] = modality
|
| 177 |
+
|
| 178 |
+
# sorted dictionary
|
| 179 |
+
# keys: annotated segments
|
| 180 |
+
# values: {track: label} dictionary
|
| 181 |
+
self._tracks: Dict[Segment, Dict[TrackName, Label]] = SortedDict()
|
| 182 |
+
|
| 183 |
+
# dictionary
|
| 184 |
+
# key: label
|
| 185 |
+
# value: timeline
|
| 186 |
+
self._labels: Dict[Label, Timeline] = {}
|
| 187 |
+
self._labelNeedsUpdate: Dict[Label, bool] = {}
|
| 188 |
+
|
| 189 |
+
# timeline meant to store all annotated segments
|
| 190 |
+
self._timeline: Timeline = None
|
| 191 |
+
self._timelineNeedsUpdate: bool = True
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def uri(self):
|
| 195 |
+
return self._uri
|
| 196 |
+
|
| 197 |
+
@uri.setter
|
| 198 |
+
def uri(self, uri: str):
|
| 199 |
+
# update uri for all internal timelines
|
| 200 |
+
for label in self.labels():
|
| 201 |
+
timeline = self.label_timeline(label, copy=False)
|
| 202 |
+
timeline.uri = uri
|
| 203 |
+
timeline = self.get_timeline(copy=False)
|
| 204 |
+
timeline.uri = uri
|
| 205 |
+
self._uri = uri
|
| 206 |
+
|
| 207 |
+
def _updateLabels(self):
|
| 208 |
+
|
| 209 |
+
# list of labels that needs to be updated
|
| 210 |
+
update = set(
|
| 211 |
+
label for label, update in self._labelNeedsUpdate.items() if update
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# accumulate segments for updated labels
|
| 215 |
+
_segments = {label: [] for label in update}
|
| 216 |
+
for segment, track, label in self.itertracks(yield_label=True):
|
| 217 |
+
if label in update:
|
| 218 |
+
_segments[label].append(segment)
|
| 219 |
+
|
| 220 |
+
# create timeline with accumulated segments for updated labels
|
| 221 |
+
for label in update:
|
| 222 |
+
if _segments[label]:
|
| 223 |
+
self._labels[label] = Timeline(segments=_segments[label], uri=self.uri)
|
| 224 |
+
self._labelNeedsUpdate[label] = False
|
| 225 |
+
else:
|
| 226 |
+
self._labels.pop(label, None)
|
| 227 |
+
self._labelNeedsUpdate.pop(label, None)
|
| 228 |
+
|
| 229 |
+
def __len__(self):
|
| 230 |
+
"""Number of segments
|
| 231 |
+
|
| 232 |
+
>>> len(annotation) # annotation contains three segments
|
| 233 |
+
3
|
| 234 |
+
"""
|
| 235 |
+
return len(self._tracks)
|
| 236 |
+
|
| 237 |
+
def __nonzero__(self):
|
| 238 |
+
return self.__bool__()
|
| 239 |
+
|
| 240 |
+
def __bool__(self):
|
| 241 |
+
"""Emptiness
|
| 242 |
+
|
| 243 |
+
>>> if annotation:
|
| 244 |
+
... # annotation is not empty
|
| 245 |
+
... else:
|
| 246 |
+
... # annotation is empty
|
| 247 |
+
"""
|
| 248 |
+
return len(self._tracks) > 0
|
| 249 |
+
|
| 250 |
+
def itersegments(self):
|
| 251 |
+
"""Iterate over segments (in chronological order)
|
| 252 |
+
|
| 253 |
+
>>> for segment in annotation.itersegments():
|
| 254 |
+
... # do something with the segment
|
| 255 |
+
|
| 256 |
+
See also
|
| 257 |
+
--------
|
| 258 |
+
:class:`pyannote.core.Segment` describes how segments are sorted.
|
| 259 |
+
"""
|
| 260 |
+
return iter(self._tracks)
|
| 261 |
+
|
| 262 |
+
def itertracks(
|
| 263 |
+
self, yield_label: bool = False
|
| 264 |
+
) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]:
|
| 265 |
+
"""Iterate over tracks (in chronological order)
|
| 266 |
+
|
| 267 |
+
Parameters
|
| 268 |
+
----------
|
| 269 |
+
yield_label : bool, optional
|
| 270 |
+
When True, yield (segment, track, label) tuples, such that
|
| 271 |
+
annotation[segment, track] == label. Defaults to yielding
|
| 272 |
+
(segment, track) tuple.
|
| 273 |
+
|
| 274 |
+
Examples
|
| 275 |
+
--------
|
| 276 |
+
|
| 277 |
+
>>> for segment, track in annotation.itertracks():
|
| 278 |
+
... # do something with the track
|
| 279 |
+
|
| 280 |
+
>>> for segment, track, label in annotation.itertracks(yield_label=True):
|
| 281 |
+
... # do something with the track and its label
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
for segment, tracks in self._tracks.items():
|
| 285 |
+
for track, lbl in sorted(
|
| 286 |
+
tracks.items(), key=lambda tl: (str(tl[0]), str(tl[1]))
|
| 287 |
+
):
|
| 288 |
+
if yield_label:
|
| 289 |
+
yield segment, track, lbl
|
| 290 |
+
else:
|
| 291 |
+
yield segment, track
|
| 292 |
+
|
| 293 |
+
def _updateTimeline(self):
|
| 294 |
+
self._timeline = Timeline(segments=self._tracks, uri=self.uri)
|
| 295 |
+
self._timelineNeedsUpdate = False
|
| 296 |
+
|
| 297 |
+
def get_timeline(self, copy: bool = True) -> Timeline:
|
| 298 |
+
"""Get timeline made of all annotated segments
|
| 299 |
+
|
| 300 |
+
Parameters
|
| 301 |
+
----------
|
| 302 |
+
copy : bool, optional
|
| 303 |
+
Defaults (True) to returning a copy of the internal timeline.
|
| 304 |
+
Set to False to return the actual internal timeline (faster).
|
| 305 |
+
|
| 306 |
+
Returns
|
| 307 |
+
-------
|
| 308 |
+
timeline : Timeline
|
| 309 |
+
Timeline made of all annotated segments.
|
| 310 |
+
|
| 311 |
+
Note
|
| 312 |
+
----
|
| 313 |
+
In case copy is set to False, be careful **not** to modify the returned
|
| 314 |
+
timeline, as it may lead to weird subsequent behavior of the annotation
|
| 315 |
+
instance.
|
| 316 |
+
|
| 317 |
+
"""
|
| 318 |
+
if self._timelineNeedsUpdate:
|
| 319 |
+
self._updateTimeline()
|
| 320 |
+
if copy:
|
| 321 |
+
return self._timeline.copy()
|
| 322 |
+
return self._timeline
|
| 323 |
+
|
| 324 |
+
def __eq__(self, other: "Annotation"):
|
| 325 |
+
"""Equality
|
| 326 |
+
|
| 327 |
+
>>> annotation == other
|
| 328 |
+
|
| 329 |
+
Two annotations are equal if and only if their tracks and associated
|
| 330 |
+
labels are equal.
|
| 331 |
+
"""
|
| 332 |
+
pairOfTracks = itertools.zip_longest(
|
| 333 |
+
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
|
| 334 |
+
)
|
| 335 |
+
return all(t1 == t2 for t1, t2 in pairOfTracks)
|
| 336 |
+
|
| 337 |
+
def __ne__(self, other: "Annotation"):
|
| 338 |
+
"""Inequality"""
|
| 339 |
+
pairOfTracks = itertools.zip_longest(
|
| 340 |
+
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
return any(t1 != t2 for t1, t2 in pairOfTracks)
|
| 344 |
+
|
| 345 |
+
def __contains__(self, included: Union[Segment, Timeline]):
|
| 346 |
+
"""Inclusion
|
| 347 |
+
|
| 348 |
+
Check whether every segment of `included` does exist in annotation.
|
| 349 |
+
|
| 350 |
+
Parameters
|
| 351 |
+
----------
|
| 352 |
+
included : Segment or Timeline
|
| 353 |
+
Segment or timeline being checked for inclusion
|
| 354 |
+
|
| 355 |
+
Returns
|
| 356 |
+
-------
|
| 357 |
+
contains : bool
|
| 358 |
+
True if every segment in `included` exists in timeline,
|
| 359 |
+
False otherwise
|
| 360 |
+
|
| 361 |
+
"""
|
| 362 |
+
return included in self.get_timeline(copy=False)
|
| 363 |
+
|
| 364 |
+
def _iter_rttm(self) -> Iterator[Text]:
|
| 365 |
+
"""Generate lines for an RTTM file for this annotation
|
| 366 |
+
|
| 367 |
+
Returns
|
| 368 |
+
-------
|
| 369 |
+
iterator: Iterator[str]
|
| 370 |
+
An iterator over RTTM text lines
|
| 371 |
+
"""
|
| 372 |
+
uri = self.uri if self.uri else "<NA>"
|
| 373 |
+
if isinstance(uri, Text) and " " in uri:
|
| 374 |
+
msg = (
|
| 375 |
+
f"Space-separated RTTM file format does not allow file URIs "
|
| 376 |
+
f'containing spaces (got: "{uri}").'
|
| 377 |
+
)
|
| 378 |
+
raise ValueError(msg)
|
| 379 |
+
for segment, _, label in self.itertracks(yield_label=True):
|
| 380 |
+
if isinstance(label, Text) and " " in label:
|
| 381 |
+
msg = (
|
| 382 |
+
f"Space-separated RTTM file format does not allow labels "
|
| 383 |
+
f'containing spaces (got: "{label}").'
|
| 384 |
+
)
|
| 385 |
+
raise ValueError(msg)
|
| 386 |
+
yield (
|
| 387 |
+
f"SPEAKER {uri} 1 {segment.start:.3f} {segment.duration:.3f} "
|
| 388 |
+
f"<NA> <NA> {label} <NA> <NA>\n"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def to_rttm(self) -> Text:
|
| 392 |
+
"""Serialize annotation as a string using RTTM format
|
| 393 |
+
|
| 394 |
+
Returns
|
| 395 |
+
-------
|
| 396 |
+
serialized: str
|
| 397 |
+
RTTM string
|
| 398 |
+
"""
|
| 399 |
+
return "".join([line for line in self._iter_rttm()])
|
| 400 |
+
|
| 401 |
+
def write_rttm(self, file: TextIO):
|
| 402 |
+
"""Dump annotation to file using RTTM format
|
| 403 |
+
|
| 404 |
+
Parameters
|
| 405 |
+
----------
|
| 406 |
+
file : file object
|
| 407 |
+
|
| 408 |
+
Usage
|
| 409 |
+
-----
|
| 410 |
+
>>> with open('file.rttm', 'w') as file:
|
| 411 |
+
... annotation.write_rttm(file)
|
| 412 |
+
"""
|
| 413 |
+
for line in self._iter_rttm():
|
| 414 |
+
file.write(line)
|
| 415 |
+
|
| 416 |
+
def _iter_lab(self) -> Iterator[Text]:
|
| 417 |
+
"""Generate lines for a LAB file for this annotation
|
| 418 |
+
|
| 419 |
+
Returns
|
| 420 |
+
-------
|
| 421 |
+
iterator: Iterator[str]
|
| 422 |
+
An iterator over LAB text lines
|
| 423 |
+
"""
|
| 424 |
+
for segment, _, label in self.itertracks(yield_label=True):
|
| 425 |
+
if isinstance(label, Text) and " " in label:
|
| 426 |
+
msg = (
|
| 427 |
+
f"Space-separated LAB file format does not allow labels "
|
| 428 |
+
f'containing spaces (got: "{label}").'
|
| 429 |
+
)
|
| 430 |
+
raise ValueError(msg)
|
| 431 |
+
yield f"{segment.start:.3f} {segment.start + segment.duration:.3f} {label}\n"
|
| 432 |
+
|
| 433 |
+
def to_lab(self) -> Text:
|
| 434 |
+
"""Serialize annotation as a string using LAB format
|
| 435 |
+
|
| 436 |
+
Returns
|
| 437 |
+
-------
|
| 438 |
+
serialized: str
|
| 439 |
+
LAB string
|
| 440 |
+
"""
|
| 441 |
+
return "".join([line for line in self._iter_lab()])
|
| 442 |
+
|
| 443 |
+
def write_lab(self, file: TextIO):
|
| 444 |
+
"""Dump annotation to file using LAB format
|
| 445 |
+
|
| 446 |
+
Parameters
|
| 447 |
+
----------
|
| 448 |
+
file : file object
|
| 449 |
+
|
| 450 |
+
Usage
|
| 451 |
+
-----
|
| 452 |
+
>>> with open('file.lab', 'w') as file:
|
| 453 |
+
... annotation.write_lab(file)
|
| 454 |
+
"""
|
| 455 |
+
for line in self._iter_lab():
|
| 456 |
+
file.write(line)
|
| 457 |
+
|
| 458 |
+
def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation":
|
| 459 |
+
"""Crop annotation to new support
|
| 460 |
+
|
| 461 |
+
Parameters
|
| 462 |
+
----------
|
| 463 |
+
support : Segment or Timeline
|
| 464 |
+
If `support` is a `Timeline`, its support is used.
|
| 465 |
+
mode : {'strict', 'loose', 'intersection'}, optional
|
| 466 |
+
Controls how segments that are not fully included in `support` are
|
| 467 |
+
handled. 'strict' mode only keeps fully included segments. 'loose'
|
| 468 |
+
mode keeps any intersecting segment. 'intersection' mode keeps any
|
| 469 |
+
intersecting segment but replace them by their actual intersection.
|
| 470 |
+
|
| 471 |
+
Returns
|
| 472 |
+
-------
|
| 473 |
+
cropped : Annotation
|
| 474 |
+
Cropped annotation
|
| 475 |
+
|
| 476 |
+
Note
|
| 477 |
+
----
|
| 478 |
+
In 'intersection' mode, the best is done to keep the track names
|
| 479 |
+
unchanged. However, in some cases where two original segments are
|
| 480 |
+
cropped into the same resulting segments, conflicting track names are
|
| 481 |
+
modified to make sure no track is lost.
|
| 482 |
+
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
# TODO speed things up by working directly with annotation internals
|
| 486 |
+
|
| 487 |
+
if isinstance(support, Segment):
|
| 488 |
+
support = Timeline(segments=[support], uri=self.uri)
|
| 489 |
+
return self.crop(support, mode=mode)
|
| 490 |
+
|
| 491 |
+
elif isinstance(support, Timeline):
|
| 492 |
+
|
| 493 |
+
# if 'support' is a `Timeline`, we use its support
|
| 494 |
+
support = support.support()
|
| 495 |
+
cropped = self.__class__(uri=self.uri, modality=self.modality)
|
| 496 |
+
|
| 497 |
+
if mode == "loose":
|
| 498 |
+
|
| 499 |
+
_tracks = {}
|
| 500 |
+
_labels = set([])
|
| 501 |
+
|
| 502 |
+
for segment, _ in self.get_timeline(copy=False).co_iter(support):
|
| 503 |
+
tracks = dict(self._tracks[segment])
|
| 504 |
+
_tracks[segment] = tracks
|
| 505 |
+
_labels.update(tracks.values())
|
| 506 |
+
|
| 507 |
+
cropped._tracks = SortedDict(_tracks)
|
| 508 |
+
|
| 509 |
+
cropped._labelNeedsUpdate = {label: True for label in _labels}
|
| 510 |
+
cropped._labels = {label: None for label in _labels}
|
| 511 |
+
|
| 512 |
+
cropped._timelineNeedsUpdate = True
|
| 513 |
+
cropped._timeline = None
|
| 514 |
+
|
| 515 |
+
return cropped
|
| 516 |
+
|
| 517 |
+
elif mode == "strict":
|
| 518 |
+
|
| 519 |
+
_tracks = {}
|
| 520 |
+
_labels = set([])
|
| 521 |
+
|
| 522 |
+
for segment, other_segment in self.get_timeline(copy=False).co_iter(
|
| 523 |
+
support
|
| 524 |
+
):
|
| 525 |
+
|
| 526 |
+
if segment not in other_segment:
|
| 527 |
+
continue
|
| 528 |
+
|
| 529 |
+
tracks = dict(self._tracks[segment])
|
| 530 |
+
_tracks[segment] = tracks
|
| 531 |
+
_labels.update(tracks.values())
|
| 532 |
+
|
| 533 |
+
cropped._tracks = SortedDict(_tracks)
|
| 534 |
+
|
| 535 |
+
cropped._labelNeedsUpdate = {label: True for label in _labels}
|
| 536 |
+
cropped._labels = {label: None for label in _labels}
|
| 537 |
+
|
| 538 |
+
cropped._timelineNeedsUpdate = True
|
| 539 |
+
cropped._timeline = None
|
| 540 |
+
|
| 541 |
+
return cropped
|
| 542 |
+
|
| 543 |
+
elif mode == "intersection":
|
| 544 |
+
|
| 545 |
+
for segment, other_segment in self.get_timeline(copy=False).co_iter(
|
| 546 |
+
support
|
| 547 |
+
):
|
| 548 |
+
|
| 549 |
+
intersection = segment & other_segment
|
| 550 |
+
for track, label in self._tracks[segment].items():
|
| 551 |
+
track = cropped.new_track(intersection, candidate=track)
|
| 552 |
+
cropped[intersection, track] = label
|
| 553 |
+
|
| 554 |
+
return cropped
|
| 555 |
+
|
| 556 |
+
else:
|
| 557 |
+
raise NotImplementedError("unsupported mode: '%s'" % mode)
|
| 558 |
+
|
| 559 |
+
def extrude(
|
| 560 |
+
self, removed: Support, mode: CropMode = "intersection"
|
| 561 |
+
) -> "Annotation":
|
| 562 |
+
"""Remove segments that overlap `removed` support.
|
| 563 |
+
|
| 564 |
+
A simple illustration:
|
| 565 |
+
|
| 566 |
+
annotation
|
| 567 |
+
A |------| |------|
|
| 568 |
+
B |----------|
|
| 569 |
+
C |--------------| |------|
|
| 570 |
+
|
| 571 |
+
removed `Timeline`
|
| 572 |
+
|-------| |-----------|
|
| 573 |
+
|
| 574 |
+
extruded Annotation with mode="intersection"
|
| 575 |
+
B |---|
|
| 576 |
+
C |--| |------|
|
| 577 |
+
|
| 578 |
+
extruded Annotation with mode="loose"
|
| 579 |
+
C |------|
|
| 580 |
+
|
| 581 |
+
extruded Annotation with mode="strict"
|
| 582 |
+
A |------|
|
| 583 |
+
B |----------|
|
| 584 |
+
C |--------------| |------|
|
| 585 |
+
|
| 586 |
+
Parameters
|
| 587 |
+
----------
|
| 588 |
+
removed : Segment or Timeline
|
| 589 |
+
If `support` is a `Timeline`, its support is used.
|
| 590 |
+
mode : {'strict', 'loose', 'intersection'}, optional
|
| 591 |
+
Controls how segments that are not fully included in `removed` are
|
| 592 |
+
handled. 'strict' mode only removes fully included segments. 'loose'
|
| 593 |
+
mode removes any intersecting segment. 'intersection' mode removes
|
| 594 |
+
the overlapping part of any intersecting segment.
|
| 595 |
+
|
| 596 |
+
Returns
|
| 597 |
+
-------
|
| 598 |
+
extruded : Annotation
|
| 599 |
+
Extruded annotation
|
| 600 |
+
|
| 601 |
+
Note
|
| 602 |
+
----
|
| 603 |
+
In 'intersection' mode, the best is done to keep the track names
|
| 604 |
+
unchanged. However, in some cases where two original segments are
|
| 605 |
+
cropped into the same resulting segments, conflicting track names are
|
| 606 |
+
modified to make sure no track is lost.
|
| 607 |
+
|
| 608 |
+
"""
|
| 609 |
+
if isinstance(removed, Segment):
|
| 610 |
+
removed = Timeline([removed])
|
| 611 |
+
|
| 612 |
+
extent_tl = Timeline([self.get_timeline().extent()], uri=self.uri)
|
| 613 |
+
truncating_support = removed.gaps(support=extent_tl)
|
| 614 |
+
# loose for truncate means strict for crop and vice-versa
|
| 615 |
+
if mode == "loose":
|
| 616 |
+
mode = "strict"
|
| 617 |
+
elif mode == "strict":
|
| 618 |
+
mode = "loose"
|
| 619 |
+
return self.crop(truncating_support, mode=mode)
|
| 620 |
+
|
| 621 |
+
def get_overlap(self, labels: Optional[Iterable[Label]] = None) -> "Timeline":
|
| 622 |
+
"""Get overlapping parts of the annotation.
|
| 623 |
+
|
| 624 |
+
A simple illustration:
|
| 625 |
+
|
| 626 |
+
annotation
|
| 627 |
+
A |------| |------| |----|
|
| 628 |
+
B |--| |-----| |----------|
|
| 629 |
+
C |--------------| |------|
|
| 630 |
+
|
| 631 |
+
annotation.get_overlap()
|
| 632 |
+
|------| |-----| |--------|
|
| 633 |
+
|
| 634 |
+
annotation.get_overlap(for_labels=["A", "B"])
|
| 635 |
+
|--| |--| |----|
|
| 636 |
+
|
| 637 |
+
Parameters
|
| 638 |
+
----------
|
| 639 |
+
labels : optional list of labels
|
| 640 |
+
Labels for which to consider the overlap
|
| 641 |
+
|
| 642 |
+
Returns
|
| 643 |
+
-------
|
| 644 |
+
overlap : `pyannote.core.Timeline`
|
| 645 |
+
Timeline of the overlaps.
|
| 646 |
+
"""
|
| 647 |
+
if labels:
|
| 648 |
+
annotation = self.subset(labels)
|
| 649 |
+
else:
|
| 650 |
+
annotation = self
|
| 651 |
+
|
| 652 |
+
overlaps_tl = Timeline(uri=annotation.uri)
|
| 653 |
+
for (s1, t1), (s2, t2) in annotation.co_iter(annotation):
|
| 654 |
+
# if labels are the same for the two segments, skipping
|
| 655 |
+
if self[s1, t1] == self[s2, t2]:
|
| 656 |
+
continue
|
| 657 |
+
overlaps_tl.add(s1 & s2)
|
| 658 |
+
return overlaps_tl.support()
|
| 659 |
+
|
| 660 |
+
def get_tracks(self, segment: Segment) -> Set[TrackName]:
|
| 661 |
+
"""Query tracks by segment
|
| 662 |
+
|
| 663 |
+
Parameters
|
| 664 |
+
----------
|
| 665 |
+
segment : Segment
|
| 666 |
+
Query
|
| 667 |
+
|
| 668 |
+
Returns
|
| 669 |
+
-------
|
| 670 |
+
tracks : set
|
| 671 |
+
Set of tracks
|
| 672 |
+
|
| 673 |
+
Note
|
| 674 |
+
----
|
| 675 |
+
This will return an empty set if segment does not exist.
|
| 676 |
+
"""
|
| 677 |
+
return set(self._tracks.get(segment, {}).keys())
|
| 678 |
+
|
| 679 |
+
def has_track(self, segment: Segment, track: TrackName) -> bool:
|
| 680 |
+
"""Check whether a given track exists
|
| 681 |
+
|
| 682 |
+
Parameters
|
| 683 |
+
----------
|
| 684 |
+
segment : Segment
|
| 685 |
+
Query segment
|
| 686 |
+
track :
|
| 687 |
+
Query track
|
| 688 |
+
|
| 689 |
+
Returns
|
| 690 |
+
-------
|
| 691 |
+
exists : bool
|
| 692 |
+
True if track exists for segment
|
| 693 |
+
"""
|
| 694 |
+
return track in self._tracks.get(segment, {})
|
| 695 |
+
|
| 696 |
+
def copy(self) -> "Annotation":
|
| 697 |
+
"""Get a copy of the annotation
|
| 698 |
+
|
| 699 |
+
Returns
|
| 700 |
+
-------
|
| 701 |
+
annotation : Annotation
|
| 702 |
+
Copy of the annotation
|
| 703 |
+
"""
|
| 704 |
+
|
| 705 |
+
# create new empty annotation
|
| 706 |
+
copied = self.__class__(uri=self.uri, modality=self.modality)
|
| 707 |
+
|
| 708 |
+
# deep copy internal track dictionary
|
| 709 |
+
_tracks, _labels = [], set([])
|
| 710 |
+
for key, value in self._tracks.items():
|
| 711 |
+
_labels.update(value.values())
|
| 712 |
+
_tracks.append((key, dict(value)))
|
| 713 |
+
|
| 714 |
+
copied._tracks = SortedDict(_tracks)
|
| 715 |
+
|
| 716 |
+
copied._labels = {label: None for label in _labels}
|
| 717 |
+
copied._labelNeedsUpdate = {label: True for label in _labels}
|
| 718 |
+
|
| 719 |
+
copied._timeline = None
|
| 720 |
+
copied._timelineNeedsUpdate = True
|
| 721 |
+
|
| 722 |
+
return copied
|
| 723 |
+
|
| 724 |
+
def new_track(
|
| 725 |
+
self,
|
| 726 |
+
segment: Segment,
|
| 727 |
+
candidate: Optional[TrackName] = None,
|
| 728 |
+
prefix: Optional[str] = None,
|
| 729 |
+
) -> TrackName:
|
| 730 |
+
"""Generate a new track name for given segment
|
| 731 |
+
|
| 732 |
+
Ensures that the returned track name does not already
|
| 733 |
+
exist for the given segment.
|
| 734 |
+
|
| 735 |
+
Parameters
|
| 736 |
+
----------
|
| 737 |
+
segment : Segment
|
| 738 |
+
Segment for which a new track name is generated.
|
| 739 |
+
candidate : any valid track name, optional
|
| 740 |
+
When provided, try this candidate name first.
|
| 741 |
+
prefix : str, optional
|
| 742 |
+
Track name prefix. Defaults to the empty string ''.
|
| 743 |
+
|
| 744 |
+
Returns
|
| 745 |
+
-------
|
| 746 |
+
name : str
|
| 747 |
+
New track name
|
| 748 |
+
"""
|
| 749 |
+
|
| 750 |
+
# obtain list of existing tracks for segment
|
| 751 |
+
existing_tracks = set(self._tracks.get(segment, {}))
|
| 752 |
+
|
| 753 |
+
# if candidate is provided, check whether it already exists
|
| 754 |
+
# in case it does not, use it
|
| 755 |
+
if (candidate is not None) and (candidate not in existing_tracks):
|
| 756 |
+
return candidate
|
| 757 |
+
|
| 758 |
+
# no candidate was provided or the provided candidate already exists
|
| 759 |
+
# we need to create a brand new one
|
| 760 |
+
|
| 761 |
+
# by default (if prefix is not provided), use ''
|
| 762 |
+
if prefix is None:
|
| 763 |
+
prefix = ""
|
| 764 |
+
|
| 765 |
+
# find first non-existing track name for segment
|
| 766 |
+
# eg. if '0' exists, try '1', then '2', ...
|
| 767 |
+
count = 0
|
| 768 |
+
while ("%s%d" % (prefix, count)) in existing_tracks:
|
| 769 |
+
count += 1
|
| 770 |
+
|
| 771 |
+
# return first non-existing track name
|
| 772 |
+
return "%s%d" % (prefix, count)
|
| 773 |
+
|
| 774 |
+
def __str__(self):
|
| 775 |
+
"""Human-friendly representation"""
|
| 776 |
+
# TODO: use pandas.DataFrame
|
| 777 |
+
return "\n".join(
|
| 778 |
+
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
def __delitem__(self, key: Key):
|
| 782 |
+
"""Delete one track
|
| 783 |
+
|
| 784 |
+
>>> del annotation[segment, track]
|
| 785 |
+
|
| 786 |
+
Delete all tracks of a segment
|
| 787 |
+
|
| 788 |
+
>>> del annotation[segment]
|
| 789 |
+
"""
|
| 790 |
+
|
| 791 |
+
# del annotation[segment]
|
| 792 |
+
if isinstance(key, Segment):
|
| 793 |
+
|
| 794 |
+
# Pop segment out of dictionary
|
| 795 |
+
# and get corresponding tracks
|
| 796 |
+
# Raises KeyError if segment does not exist
|
| 797 |
+
tracks = self._tracks.pop(key)
|
| 798 |
+
|
| 799 |
+
# mark timeline as modified
|
| 800 |
+
self._timelineNeedsUpdate = True
|
| 801 |
+
|
| 802 |
+
# mark every label in tracks as modified
|
| 803 |
+
for track, label in tracks.items():
|
| 804 |
+
self._labelNeedsUpdate[label] = True
|
| 805 |
+
|
| 806 |
+
# del annotation[segment, track]
|
| 807 |
+
elif isinstance(key, tuple) and len(key) == 2:
|
| 808 |
+
|
| 809 |
+
# get segment tracks as dictionary
|
| 810 |
+
# if segment does not exist, get empty dictionary
|
| 811 |
+
# Raises KeyError if segment does not exist
|
| 812 |
+
tracks = self._tracks[key[0]]
|
| 813 |
+
|
| 814 |
+
# pop track out of tracks dictionary
|
| 815 |
+
# and get corresponding label
|
| 816 |
+
# Raises KeyError if track does not exist
|
| 817 |
+
label = tracks.pop(key[1])
|
| 818 |
+
|
| 819 |
+
# mark label as modified
|
| 820 |
+
self._labelNeedsUpdate[label] = True
|
| 821 |
+
|
| 822 |
+
# if tracks dictionary is now empty,
|
| 823 |
+
# remove segment as well
|
| 824 |
+
if not tracks:
|
| 825 |
+
self._tracks.pop(key[0])
|
| 826 |
+
self._timelineNeedsUpdate = True
|
| 827 |
+
|
| 828 |
+
else:
|
| 829 |
+
raise NotImplementedError(
|
| 830 |
+
"Deletion only works with Segment or (Segment, track) keys."
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# label = annotation[segment, track]
|
| 834 |
+
def __getitem__(self, key: Key) -> Label:
|
| 835 |
+
"""Get track label
|
| 836 |
+
|
| 837 |
+
>>> label = annotation[segment, track]
|
| 838 |
+
|
| 839 |
+
Note
|
| 840 |
+
----
|
| 841 |
+
``annotation[segment]`` is equivalent to ``annotation[segment, '_']``
|
| 842 |
+
|
| 843 |
+
"""
|
| 844 |
+
|
| 845 |
+
if isinstance(key, Segment):
|
| 846 |
+
key = (key, "_")
|
| 847 |
+
|
| 848 |
+
return self._tracks[key[0]][key[1]]
|
| 849 |
+
|
| 850 |
+
# annotation[segment, track] = label
|
| 851 |
+
def __setitem__(self, key: Key, label: Label):
|
| 852 |
+
"""Add new or update existing track
|
| 853 |
+
|
| 854 |
+
>>> annotation[segment, track] = label
|
| 855 |
+
|
| 856 |
+
If (segment, track) does not exist, it is added.
|
| 857 |
+
If (segment, track) already exists, it is updated.
|
| 858 |
+
|
| 859 |
+
Note
|
| 860 |
+
----
|
| 861 |
+
``annotation[segment] = label`` is equivalent to ``annotation[segment, '_'] = label``
|
| 862 |
+
|
| 863 |
+
Note
|
| 864 |
+
----
|
| 865 |
+
If `segment` is empty, it does nothing.
|
| 866 |
+
"""
|
| 867 |
+
|
| 868 |
+
if isinstance(key, Segment):
|
| 869 |
+
key = (key, "_")
|
| 870 |
+
|
| 871 |
+
segment, track = key
|
| 872 |
+
|
| 873 |
+
# do not add empty track
|
| 874 |
+
if not segment:
|
| 875 |
+
return
|
| 876 |
+
|
| 877 |
+
# in case we create a new segment
|
| 878 |
+
# mark timeline as modified
|
| 879 |
+
if segment not in self._tracks:
|
| 880 |
+
self._tracks[segment] = {}
|
| 881 |
+
self._timelineNeedsUpdate = True
|
| 882 |
+
|
| 883 |
+
# in case we modify an existing track
|
| 884 |
+
# mark old label as modified
|
| 885 |
+
if track in self._tracks[segment]:
|
| 886 |
+
old_label = self._tracks[segment][track]
|
| 887 |
+
self._labelNeedsUpdate[old_label] = True
|
| 888 |
+
|
| 889 |
+
# mark new label as modified
|
| 890 |
+
self._tracks[segment][track] = label
|
| 891 |
+
self._labelNeedsUpdate[label] = True
|
| 892 |
+
|
| 893 |
+
def empty(self) -> "Annotation":
|
| 894 |
+
"""Return an empty copy
|
| 895 |
+
|
| 896 |
+
Returns
|
| 897 |
+
-------
|
| 898 |
+
empty : Annotation
|
| 899 |
+
Empty annotation using the same 'uri' and 'modality' attributes.
|
| 900 |
+
|
| 901 |
+
"""
|
| 902 |
+
return self.__class__(uri=self.uri, modality=self.modality)
|
| 903 |
+
|
| 904 |
+
def labels(self) -> List[Label]:
|
| 905 |
+
"""Get sorted list of labels
|
| 906 |
+
|
| 907 |
+
Returns
|
| 908 |
+
-------
|
| 909 |
+
labels : list
|
| 910 |
+
Sorted list of labels
|
| 911 |
+
"""
|
| 912 |
+
if any([lnu for lnu in self._labelNeedsUpdate.values()]):
|
| 913 |
+
self._updateLabels()
|
| 914 |
+
return sorted(self._labels, key=str)
|
| 915 |
+
|
| 916 |
+
def get_labels(
|
| 917 |
+
self, segment: Segment, unique: bool = True
|
| 918 |
+
) -> Union[Set[Label], List[Label]]:
|
| 919 |
+
"""Query labels by segment
|
| 920 |
+
|
| 921 |
+
Parameters
|
| 922 |
+
----------
|
| 923 |
+
segment : Segment
|
| 924 |
+
Query
|
| 925 |
+
unique : bool, optional
|
| 926 |
+
When False, return the list of (possibly repeated) labels.
|
| 927 |
+
Defaults to returning the set of labels.
|
| 928 |
+
|
| 929 |
+
Returns
|
| 930 |
+
-------
|
| 931 |
+
labels : set or list
|
| 932 |
+
Set (resp. list) of labels for `segment` if it exists, empty set (resp. list) otherwise
|
| 933 |
+
if unique (resp. if not unique).
|
| 934 |
+
|
| 935 |
+
Examples
|
| 936 |
+
--------
|
| 937 |
+
>>> annotation = Annotation()
|
| 938 |
+
>>> segment = Segment(0, 2)
|
| 939 |
+
>>> annotation[segment, 'speaker1'] = 'Bernard'
|
| 940 |
+
>>> annotation[segment, 'speaker2'] = 'John'
|
| 941 |
+
>>> print sorted(annotation.get_labels(segment))
|
| 942 |
+
set(['Bernard', 'John'])
|
| 943 |
+
>>> print annotation.get_labels(Segment(1, 2))
|
| 944 |
+
set([])
|
| 945 |
+
|
| 946 |
+
"""
|
| 947 |
+
|
| 948 |
+
labels = self._tracks.get(segment, {}).values()
|
| 949 |
+
|
| 950 |
+
if unique:
|
| 951 |
+
return set(labels)
|
| 952 |
+
|
| 953 |
+
return list(labels)
|
| 954 |
+
|
| 955 |
+
def subset(self, labels: Iterable[Label], invert: bool = False) -> "Annotation":
|
| 956 |
+
"""Filter annotation by labels
|
| 957 |
+
|
| 958 |
+
Parameters
|
| 959 |
+
----------
|
| 960 |
+
labels : iterable
|
| 961 |
+
List of filtered labels
|
| 962 |
+
invert : bool, optional
|
| 963 |
+
If invert is True, extract all but requested labels
|
| 964 |
+
|
| 965 |
+
Returns
|
| 966 |
+
-------
|
| 967 |
+
filtered : Annotation
|
| 968 |
+
Filtered annotation
|
| 969 |
+
"""
|
| 970 |
+
|
| 971 |
+
labels = set(labels)
|
| 972 |
+
|
| 973 |
+
if invert:
|
| 974 |
+
labels = set(self.labels()) - labels
|
| 975 |
+
else:
|
| 976 |
+
labels = labels & set(self.labels())
|
| 977 |
+
|
| 978 |
+
sub = self.__class__(uri=self.uri, modality=self.modality)
|
| 979 |
+
|
| 980 |
+
_tracks, _labels = {}, set([])
|
| 981 |
+
for segment, tracks in self._tracks.items():
|
| 982 |
+
sub_tracks = {
|
| 983 |
+
track: label for track, label in tracks.items() if label in labels
|
| 984 |
+
}
|
| 985 |
+
if sub_tracks:
|
| 986 |
+
_tracks[segment] = sub_tracks
|
| 987 |
+
_labels.update(sub_tracks.values())
|
| 988 |
+
|
| 989 |
+
sub._tracks = SortedDict(_tracks)
|
| 990 |
+
|
| 991 |
+
sub._labelNeedsUpdate = {label: True for label in _labels}
|
| 992 |
+
sub._labels = {label: None for label in _labels}
|
| 993 |
+
|
| 994 |
+
sub._timelineNeedsUpdate = True
|
| 995 |
+
sub._timeline = None
|
| 996 |
+
|
| 997 |
+
return sub
|
| 998 |
+
|
| 999 |
+
def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation":
|
| 1000 |
+
"""Add every track of an existing annotation (in place)
|
| 1001 |
+
|
| 1002 |
+
Parameters
|
| 1003 |
+
----------
|
| 1004 |
+
annotation : Annotation
|
| 1005 |
+
Annotation whose tracks are being added
|
| 1006 |
+
copy : bool, optional
|
| 1007 |
+
Return a copy of the annotation. Defaults to updating the
|
| 1008 |
+
annotation in-place.
|
| 1009 |
+
|
| 1010 |
+
Returns
|
| 1011 |
+
-------
|
| 1012 |
+
self : Annotation
|
| 1013 |
+
Updated annotation
|
| 1014 |
+
|
| 1015 |
+
Note
|
| 1016 |
+
----
|
| 1017 |
+
Existing tracks are updated with the new label.
|
| 1018 |
+
"""
|
| 1019 |
+
|
| 1020 |
+
result = self.copy() if copy else self
|
| 1021 |
+
|
| 1022 |
+
# TODO speed things up by working directly with annotation internals
|
| 1023 |
+
for segment, track, label in annotation.itertracks(yield_label=True):
|
| 1024 |
+
result[segment, track] = label
|
| 1025 |
+
|
| 1026 |
+
return result
|
| 1027 |
+
|
| 1028 |
+
def label_timeline(self, label: Label, copy: bool = True) -> Timeline:
|
| 1029 |
+
"""Query segments by label
|
| 1030 |
+
|
| 1031 |
+
Parameters
|
| 1032 |
+
----------
|
| 1033 |
+
label : object
|
| 1034 |
+
Query
|
| 1035 |
+
copy : bool, optional
|
| 1036 |
+
Defaults (True) to returning a copy of the internal timeline.
|
| 1037 |
+
Set to False to return the actual internal timeline (faster).
|
| 1038 |
+
|
| 1039 |
+
Returns
|
| 1040 |
+
-------
|
| 1041 |
+
timeline : Timeline
|
| 1042 |
+
Timeline made of all segments for which at least one track is
|
| 1043 |
+
annotated as label
|
| 1044 |
+
|
| 1045 |
+
Note
|
| 1046 |
+
----
|
| 1047 |
+
If label does not exist, this will return an empty timeline.
|
| 1048 |
+
|
| 1049 |
+
Note
|
| 1050 |
+
----
|
| 1051 |
+
In case copy is set to False, be careful **not** to modify the returned
|
| 1052 |
+
timeline, as it may lead to weird subsequent behavior of the annotation
|
| 1053 |
+
instance.
|
| 1054 |
+
|
| 1055 |
+
"""
|
| 1056 |
+
if label not in self.labels():
|
| 1057 |
+
return Timeline(uri=self.uri)
|
| 1058 |
+
|
| 1059 |
+
if self._labelNeedsUpdate[label]:
|
| 1060 |
+
self._updateLabels()
|
| 1061 |
+
|
| 1062 |
+
if copy:
|
| 1063 |
+
return self._labels[label].copy()
|
| 1064 |
+
|
| 1065 |
+
return self._labels[label]
|
| 1066 |
+
|
| 1067 |
+
def label_support(self, label: Label) -> Timeline:
|
| 1068 |
+
"""Label support
|
| 1069 |
+
|
| 1070 |
+
Equivalent to ``Annotation.label_timeline(label).support()``
|
| 1071 |
+
|
| 1072 |
+
Parameters
|
| 1073 |
+
----------
|
| 1074 |
+
label : object
|
| 1075 |
+
Query
|
| 1076 |
+
|
| 1077 |
+
Returns
|
| 1078 |
+
-------
|
| 1079 |
+
support : Timeline
|
| 1080 |
+
Label support
|
| 1081 |
+
|
| 1082 |
+
See also
|
| 1083 |
+
--------
|
| 1084 |
+
:func:`~pyannote.core.Annotation.label_timeline`
|
| 1085 |
+
:func:`~pyannote.core.Timeline.support`
|
| 1086 |
+
|
| 1087 |
+
"""
|
| 1088 |
+
return self.label_timeline(label, copy=False).support()
|
| 1089 |
+
|
| 1090 |
+
def label_duration(self, label: Label) -> float:
|
| 1091 |
+
"""Label duration
|
| 1092 |
+
|
| 1093 |
+
Equivalent to ``Annotation.label_timeline(label).duration()``
|
| 1094 |
+
|
| 1095 |
+
Parameters
|
| 1096 |
+
----------
|
| 1097 |
+
label : object
|
| 1098 |
+
Query
|
| 1099 |
+
|
| 1100 |
+
Returns
|
| 1101 |
+
-------
|
| 1102 |
+
duration : float
|
| 1103 |
+
Duration, in seconds.
|
| 1104 |
+
|
| 1105 |
+
See also
|
| 1106 |
+
--------
|
| 1107 |
+
:func:`~pyannote.core.Annotation.label_timeline`
|
| 1108 |
+
:func:`~pyannote.core.Timeline.duration`
|
| 1109 |
+
|
| 1110 |
+
"""
|
| 1111 |
+
|
| 1112 |
+
return self.label_timeline(label, copy=False).duration()
|
| 1113 |
+
|
| 1114 |
+
def chart(self, percent: bool = False) -> List[Tuple[Label, float]]:
|
| 1115 |
+
"""Get labels chart (from longest to shortest duration)
|
| 1116 |
+
|
| 1117 |
+
Parameters
|
| 1118 |
+
----------
|
| 1119 |
+
percent : bool, optional
|
| 1120 |
+
Return list of (label, percentage) tuples.
|
| 1121 |
+
Defaults to returning list of (label, duration) tuples.
|
| 1122 |
+
|
| 1123 |
+
Returns
|
| 1124 |
+
-------
|
| 1125 |
+
chart : list
|
| 1126 |
+
List of (label, duration), sorted by duration in decreasing order.
|
| 1127 |
+
"""
|
| 1128 |
+
|
| 1129 |
+
chart = sorted(
|
| 1130 |
+
((L, self.label_duration(L)) for L in self.labels()),
|
| 1131 |
+
key=lambda x: x[1],
|
| 1132 |
+
reverse=True,
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
if percent:
|
| 1136 |
+
total = np.sum([duration for _, duration in chart])
|
| 1137 |
+
chart = [(label, duration / total) for (label, duration) in chart]
|
| 1138 |
+
|
| 1139 |
+
return chart
|
| 1140 |
+
|
| 1141 |
+
def argmax(self, support: Optional[Support] = None) -> Optional[Label]:
|
| 1142 |
+
"""Get label with longest duration
|
| 1143 |
+
|
| 1144 |
+
Parameters
|
| 1145 |
+
----------
|
| 1146 |
+
support : Segment or Timeline, optional
|
| 1147 |
+
Find label with longest duration within provided support.
|
| 1148 |
+
Defaults to whole extent.
|
| 1149 |
+
|
| 1150 |
+
Returns
|
| 1151 |
+
-------
|
| 1152 |
+
label : any existing label or None
|
| 1153 |
+
Label with longest intersection
|
| 1154 |
+
|
| 1155 |
+
Examples
|
| 1156 |
+
--------
|
| 1157 |
+
>>> annotation = Annotation(modality='speaker')
|
| 1158 |
+
>>> annotation[Segment(0, 10), 'speaker1'] = 'Alice'
|
| 1159 |
+
>>> annotation[Segment(8, 20), 'speaker1'] = 'Bob'
|
| 1160 |
+
>>> print "%s is such a talker!" % annotation.argmax()
|
| 1161 |
+
Bob is such a talker!
|
| 1162 |
+
>>> segment = Segment(22, 23)
|
| 1163 |
+
>>> if not annotation.argmax(support):
|
| 1164 |
+
... print "No label intersecting %s" % segment
|
| 1165 |
+
No label intersection [22 --> 23]
|
| 1166 |
+
|
| 1167 |
+
"""
|
| 1168 |
+
|
| 1169 |
+
cropped = self
|
| 1170 |
+
if support is not None:
|
| 1171 |
+
cropped = cropped.crop(support, mode="intersection")
|
| 1172 |
+
|
| 1173 |
+
if not cropped:
|
| 1174 |
+
return None
|
| 1175 |
+
|
| 1176 |
+
return max(
|
| 1177 |
+
((_, cropped.label_duration(_)) for _ in cropped.labels()),
|
| 1178 |
+
key=lambda x: x[1],
|
| 1179 |
+
)[0]
|
| 1180 |
+
|
| 1181 |
+
def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
|
| 1182 |
+
"""Rename all tracks
|
| 1183 |
+
|
| 1184 |
+
Parameters
|
| 1185 |
+
----------
|
| 1186 |
+
generator : 'string', 'int', or iterable, optional
|
| 1187 |
+
If 'string' (default) rename tracks to 'A', 'B', 'C', etc.
|
| 1188 |
+
If 'int', rename tracks to 0, 1, 2, etc.
|
| 1189 |
+
If iterable, use it to generate track names.
|
| 1190 |
+
|
| 1191 |
+
Returns
|
| 1192 |
+
-------
|
| 1193 |
+
renamed : Annotation
|
| 1194 |
+
Copy of the original annotation where tracks are renamed.
|
| 1195 |
+
|
| 1196 |
+
Example
|
| 1197 |
+
-------
|
| 1198 |
+
>>> annotation = Annotation()
|
| 1199 |
+
>>> annotation[Segment(0, 1), 'a'] = 'a'
|
| 1200 |
+
>>> annotation[Segment(0, 1), 'b'] = 'b'
|
| 1201 |
+
>>> annotation[Segment(1, 2), 'a'] = 'a'
|
| 1202 |
+
>>> annotation[Segment(1, 3), 'c'] = 'c'
|
| 1203 |
+
>>> print(annotation)
|
| 1204 |
+
[ 00:00:00.000 --> 00:00:01.000] a a
|
| 1205 |
+
[ 00:00:00.000 --> 00:00:01.000] b b
|
| 1206 |
+
[ 00:00:01.000 --> 00:00:02.000] a a
|
| 1207 |
+
[ 00:00:01.000 --> 00:00:03.000] c c
|
| 1208 |
+
>>> print(annotation.rename_tracks(generator='int'))
|
| 1209 |
+
[ 00:00:00.000 --> 00:00:01.000] 0 a
|
| 1210 |
+
[ 00:00:00.000 --> 00:00:01.000] 1 b
|
| 1211 |
+
[ 00:00:01.000 --> 00:00:02.000] 2 a
|
| 1212 |
+
[ 00:00:01.000 --> 00:00:03.000] 3 c
|
| 1213 |
+
"""
|
| 1214 |
+
|
| 1215 |
+
renamed = self.__class__(uri=self.uri, modality=self.modality)
|
| 1216 |
+
|
| 1217 |
+
if generator == "string":
|
| 1218 |
+
generator = string_generator()
|
| 1219 |
+
elif generator == "int":
|
| 1220 |
+
generator = int_generator()
|
| 1221 |
+
|
| 1222 |
+
# TODO speed things up by working directly with annotation internals
|
| 1223 |
+
for s, _, label in self.itertracks(yield_label=True):
|
| 1224 |
+
renamed[s, next(generator)] = label
|
| 1225 |
+
return renamed
|
| 1226 |
+
|
| 1227 |
+
def rename_labels(
|
| 1228 |
+
self,
|
| 1229 |
+
mapping: Optional[Dict] = None,
|
| 1230 |
+
generator: LabelGenerator = "string",
|
| 1231 |
+
copy: bool = True,
|
| 1232 |
+
) -> "Annotation":
|
| 1233 |
+
"""Rename labels
|
| 1234 |
+
|
| 1235 |
+
Parameters
|
| 1236 |
+
----------
|
| 1237 |
+
mapping : dict, optional
|
| 1238 |
+
{old_name: new_name} mapping dictionary.
|
| 1239 |
+
generator : 'string', 'int' or iterable, optional
|
| 1240 |
+
If 'string' (default) rename label to 'A', 'B', 'C', ... If 'int',
|
| 1241 |
+
rename to 0, 1, 2, etc. If iterable, use it to generate labels.
|
| 1242 |
+
copy : bool, optional
|
| 1243 |
+
Set to True to return a copy of the annotation. Set to False to
|
| 1244 |
+
update the annotation in-place. Defaults to True.
|
| 1245 |
+
|
| 1246 |
+
Returns
|
| 1247 |
+
-------
|
| 1248 |
+
renamed : Annotation
|
| 1249 |
+
Annotation where labels have been renamed
|
| 1250 |
+
|
| 1251 |
+
Note
|
| 1252 |
+
----
|
| 1253 |
+
Unmapped labels are kept unchanged.
|
| 1254 |
+
|
| 1255 |
+
Note
|
| 1256 |
+
----
|
| 1257 |
+
Parameter `generator` has no effect when `mapping` is provided.
|
| 1258 |
+
|
| 1259 |
+
"""
|
| 1260 |
+
|
| 1261 |
+
if mapping is None:
|
| 1262 |
+
if generator == "string":
|
| 1263 |
+
generator = string_generator()
|
| 1264 |
+
elif generator == "int":
|
| 1265 |
+
generator = int_generator()
|
| 1266 |
+
# generate mapping
|
| 1267 |
+
mapping = {label: next(generator) for label in self.labels()}
|
| 1268 |
+
|
| 1269 |
+
renamed = self.copy() if copy else self
|
| 1270 |
+
|
| 1271 |
+
for old_label, new_label in mapping.items():
|
| 1272 |
+
renamed._labelNeedsUpdate[old_label] = True
|
| 1273 |
+
renamed._labelNeedsUpdate[new_label] = True
|
| 1274 |
+
|
| 1275 |
+
for segment, tracks in self._tracks.items():
|
| 1276 |
+
new_tracks = {
|
| 1277 |
+
track: mapping.get(label, label) for track, label in tracks.items()
|
| 1278 |
+
}
|
| 1279 |
+
renamed._tracks[segment] = new_tracks
|
| 1280 |
+
|
| 1281 |
+
return renamed
|
| 1282 |
+
|
| 1283 |
+
def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
|
| 1284 |
+
"""Relabel tracks
|
| 1285 |
+
|
| 1286 |
+
Create a new annotation where each track has a unique label.
|
| 1287 |
+
|
| 1288 |
+
Parameters
|
| 1289 |
+
----------
|
| 1290 |
+
generator : 'string', 'int' or iterable, optional
|
| 1291 |
+
If 'string' (default) relabel tracks to 'A', 'B', 'C', ... If 'int'
|
| 1292 |
+
relabel to 0, 1, 2, ... If iterable, use it to generate labels.
|
| 1293 |
+
|
| 1294 |
+
Returns
|
| 1295 |
+
-------
|
| 1296 |
+
renamed : Annotation
|
| 1297 |
+
New annotation with relabeled tracks.
|
| 1298 |
+
"""
|
| 1299 |
+
|
| 1300 |
+
if generator == "string":
|
| 1301 |
+
generator = string_generator()
|
| 1302 |
+
elif generator == "int":
|
| 1303 |
+
generator = int_generator()
|
| 1304 |
+
|
| 1305 |
+
relabeled = self.empty()
|
| 1306 |
+
for s, t, _ in self.itertracks(yield_label=True):
|
| 1307 |
+
relabeled[s, t] = next(generator)
|
| 1308 |
+
|
| 1309 |
+
return relabeled
|
| 1310 |
+
|
| 1311 |
+
def support(self, collar: float = 0.0) -> "Annotation":
|
| 1312 |
+
"""Annotation support
|
| 1313 |
+
|
| 1314 |
+
The support of an annotation is an annotation where contiguous tracks
|
| 1315 |
+
with same label are merged into one unique covering track.
|
| 1316 |
+
|
| 1317 |
+
A picture is worth a thousand words::
|
| 1318 |
+
|
| 1319 |
+
collar
|
| 1320 |
+
|---|
|
| 1321 |
+
|
| 1322 |
+
annotation
|
| 1323 |
+
|--A--| |--A--| |-B-|
|
| 1324 |
+
|-B-| |--C--| |----B-----|
|
| 1325 |
+
|
| 1326 |
+
annotation.support(collar)
|
| 1327 |
+
|------A------| |------B------|
|
| 1328 |
+
|-B-| |--C--|
|
| 1329 |
+
|
| 1330 |
+
Parameters
|
| 1331 |
+
----------
|
| 1332 |
+
collar : float, optional
|
| 1333 |
+
Merge tracks with same label and separated by less than `collar`
|
| 1334 |
+
seconds. This is why 'A' tracks are merged in above figure.
|
| 1335 |
+
Defaults to 0.
|
| 1336 |
+
|
| 1337 |
+
Returns
|
| 1338 |
+
-------
|
| 1339 |
+
support : Annotation
|
| 1340 |
+
Annotation support
|
| 1341 |
+
|
| 1342 |
+
Note
|
| 1343 |
+
----
|
| 1344 |
+
Track names are lost in the process.
|
| 1345 |
+
"""
|
| 1346 |
+
|
| 1347 |
+
generator = string_generator()
|
| 1348 |
+
|
| 1349 |
+
# initialize an empty annotation
|
| 1350 |
+
# with same uri and modality as original
|
| 1351 |
+
support = self.empty()
|
| 1352 |
+
for label in self.labels():
|
| 1353 |
+
|
| 1354 |
+
# get timeline for current label
|
| 1355 |
+
timeline = self.label_timeline(label, copy=True)
|
| 1356 |
+
|
| 1357 |
+
# fill the gaps shorter than collar
|
| 1358 |
+
timeline = timeline.support(collar)
|
| 1359 |
+
|
| 1360 |
+
# reconstruct annotation with merged tracks
|
| 1361 |
+
for segment in timeline.support():
|
| 1362 |
+
support[segment, next(generator)] = label
|
| 1363 |
+
|
| 1364 |
+
return support
|
| 1365 |
+
|
| 1366 |
+
def co_iter(
|
| 1367 |
+
self, other: "Annotation"
|
| 1368 |
+
) -> Iterator[Tuple[Tuple[Segment, TrackName], Tuple[Segment, TrackName]]]:
|
| 1369 |
+
"""Iterate over pairs of intersecting tracks
|
| 1370 |
+
|
| 1371 |
+
Parameters
|
| 1372 |
+
----------
|
| 1373 |
+
other : Annotation
|
| 1374 |
+
Second annotation
|
| 1375 |
+
|
| 1376 |
+
Returns
|
| 1377 |
+
-------
|
| 1378 |
+
iterable : (Segment, object), (Segment, object) iterable
|
| 1379 |
+
Yields pairs of intersecting tracks, in chronological (then
|
| 1380 |
+
alphabetical) order.
|
| 1381 |
+
|
| 1382 |
+
See also
|
| 1383 |
+
--------
|
| 1384 |
+
:func:`~pyannote.core.Timeline.co_iter`
|
| 1385 |
+
|
| 1386 |
+
"""
|
| 1387 |
+
timeline = self.get_timeline(copy=False)
|
| 1388 |
+
other_timeline = other.get_timeline(copy=False)
|
| 1389 |
+
for s, S in timeline.co_iter(other_timeline):
|
| 1390 |
+
tracks = sorted(self.get_tracks(s), key=str)
|
| 1391 |
+
other_tracks = sorted(other.get_tracks(S), key=str)
|
| 1392 |
+
for t, T in itertools.product(tracks, other_tracks):
|
| 1393 |
+
yield (s, t), (S, T)
|
| 1394 |
+
|
| 1395 |
+
def __mul__(self, other: "Annotation") -> np.ndarray:
|
| 1396 |
+
"""Cooccurrence (or confusion) matrix
|
| 1397 |
+
|
| 1398 |
+
>>> matrix = annotation * other
|
| 1399 |
+
|
| 1400 |
+
Parameters
|
| 1401 |
+
----------
|
| 1402 |
+
other : Annotation
|
| 1403 |
+
Second annotation
|
| 1404 |
+
|
| 1405 |
+
Returns
|
| 1406 |
+
-------
|
| 1407 |
+
cooccurrence : (n_self, n_other) np.ndarray
|
| 1408 |
+
Cooccurrence matrix where `n_self` (resp. `n_other`) is the number
|
| 1409 |
+
of labels in `self` (resp. `other`).
|
| 1410 |
+
"""
|
| 1411 |
+
|
| 1412 |
+
if not isinstance(other, Annotation):
|
| 1413 |
+
raise TypeError(
|
| 1414 |
+
"computing cooccurrence matrix only works with Annotation " "instances."
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
i_labels = self.labels()
|
| 1418 |
+
j_labels = other.labels()
|
| 1419 |
+
|
| 1420 |
+
I = {label: i for i, label in enumerate(i_labels)}
|
| 1421 |
+
J = {label: j for j, label in enumerate(j_labels)}
|
| 1422 |
+
|
| 1423 |
+
matrix = np.zeros((len(I), len(J)))
|
| 1424 |
+
|
| 1425 |
+
# iterate over intersecting tracks and accumulate durations
|
| 1426 |
+
for (segment, track), (other_segment, other_track) in self.co_iter(other):
|
| 1427 |
+
i = I[self[segment, track]]
|
| 1428 |
+
j = J[other[other_segment, other_track]]
|
| 1429 |
+
duration = (segment & other_segment).duration
|
| 1430 |
+
matrix[i, j] += duration
|
| 1431 |
+
|
| 1432 |
+
return matrix
|
| 1433 |
+
|
| 1434 |
+
def discretize(
|
| 1435 |
+
self,
|
| 1436 |
+
support: Optional[Segment] = None,
|
| 1437 |
+
resolution: Union[float, SlidingWindow] = 0.01,
|
| 1438 |
+
labels: Optional[List[Hashable]] = None,
|
| 1439 |
+
duration: Optional[float] = None,
|
| 1440 |
+
):
|
| 1441 |
+
"""Discretize
|
| 1442 |
+
|
| 1443 |
+
Parameters
|
| 1444 |
+
----------
|
| 1445 |
+
support : Segment, optional
|
| 1446 |
+
Part of annotation to discretize.
|
| 1447 |
+
Defaults to annotation full extent.
|
| 1448 |
+
resolution : float or SlidingWindow, optional
|
| 1449 |
+
Defaults to 10ms frames.
|
| 1450 |
+
labels : list of labels, optional
|
| 1451 |
+
Defaults to self.labels()
|
| 1452 |
+
duration : float, optional
|
| 1453 |
+
Overrides support duration and ensures that the number of
|
| 1454 |
+
returned frames is fixed (which might otherwise not be the case
|
| 1455 |
+
because of rounding errors).
|
| 1456 |
+
|
| 1457 |
+
Returns
|
| 1458 |
+
-------
|
| 1459 |
+
discretized : SlidingWindowFeature
|
| 1460 |
+
(num_frames, num_labels)-shaped binary features.
|
| 1461 |
+
"""
|
| 1462 |
+
|
| 1463 |
+
if support is None:
|
| 1464 |
+
support = self.get_timeline().extent()
|
| 1465 |
+
start_time, end_time = support
|
| 1466 |
+
|
| 1467 |
+
cropped = self.crop(support, mode="intersection")
|
| 1468 |
+
|
| 1469 |
+
if labels is None:
|
| 1470 |
+
labels = cropped.labels()
|
| 1471 |
+
|
| 1472 |
+
if isinstance(resolution, SlidingWindow):
|
| 1473 |
+
resolution = SlidingWindow(
|
| 1474 |
+
start=start_time, step=resolution.step, duration=resolution.duration
|
| 1475 |
+
)
|
| 1476 |
+
else:
|
| 1477 |
+
resolution = SlidingWindow(
|
| 1478 |
+
start=start_time, step=resolution, duration=resolution
|
| 1479 |
+
)
|
| 1480 |
+
|
| 1481 |
+
start_frame = resolution.closest_frame(start_time)
|
| 1482 |
+
if duration is None:
|
| 1483 |
+
end_frame = resolution.closest_frame(end_time)
|
| 1484 |
+
num_frames = end_frame - start_frame
|
| 1485 |
+
else:
|
| 1486 |
+
num_frames = int(round(duration / resolution.step))
|
| 1487 |
+
|
| 1488 |
+
data = np.zeros((num_frames, len(labels)), dtype=np.uint8)
|
| 1489 |
+
for k, label in enumerate(labels):
|
| 1490 |
+
segments = cropped.label_timeline(label)
|
| 1491 |
+
for start, stop in resolution.crop(
|
| 1492 |
+
segments, mode="center", return_ranges=True
|
| 1493 |
+
):
|
| 1494 |
+
data[max(0, start) : min(stop, num_frames), k] += 1
|
| 1495 |
+
data = np.minimum(data, 1, out=data)
|
| 1496 |
+
|
| 1497 |
+
return SlidingWindowFeature(data, resolution, labels=labels)
|
| 1498 |
+
|
| 1499 |
+
@classmethod
|
| 1500 |
+
def from_records(
|
| 1501 |
+
cls,
|
| 1502 |
+
records: Iterator[Tuple[Segment, TrackName, Label]],
|
| 1503 |
+
uri: Optional[str] = None,
|
| 1504 |
+
modality: Optional[str] = None,
|
| 1505 |
+
) -> "Annotation":
|
| 1506 |
+
"""Annotation
|
| 1507 |
+
|
| 1508 |
+
Parameters
|
| 1509 |
+
----------
|
| 1510 |
+
records : iterator of tuples
|
| 1511 |
+
(segment, track, label) tuples
|
| 1512 |
+
uri : string, optional
|
| 1513 |
+
name of annotated resource (e.g. audio or video file)
|
| 1514 |
+
modality : string, optional
|
| 1515 |
+
name of annotated modality
|
| 1516 |
+
|
| 1517 |
+
Returns
|
| 1518 |
+
-------
|
| 1519 |
+
annotation : Annotation
|
| 1520 |
+
New annotation
|
| 1521 |
+
|
| 1522 |
+
"""
|
| 1523 |
+
annotation = cls(uri=uri, modality=modality)
|
| 1524 |
+
tracks = defaultdict(dict)
|
| 1525 |
+
labels = set()
|
| 1526 |
+
for segment, track, label in records:
|
| 1527 |
+
tracks[segment][track] = label
|
| 1528 |
+
labels.add(label)
|
| 1529 |
+
annotation._tracks = SortedDict(tracks)
|
| 1530 |
+
annotation._labels = {label: None for label in labels}
|
| 1531 |
+
annotation._labelNeedsUpdate = {label: True for label in annotation._labels}
|
| 1532 |
+
annotation._timeline = None
|
| 1533 |
+
annotation._timelineNeedsUpdate = True
|
| 1534 |
+
|
| 1535 |
+
return annotation
|
| 1536 |
+
|
| 1537 |
+
def _repr_png_(self):
|
| 1538 |
+
"""IPython notebook support
|
| 1539 |
+
|
| 1540 |
+
See also
|
| 1541 |
+
--------
|
| 1542 |
+
:mod:`pyannote.core.notebook`
|
| 1543 |
+
"""
|
| 1544 |
+
from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
|
| 1545 |
+
|
| 1546 |
+
if not MATPLOTLIB_IS_AVAILABLE:
|
| 1547 |
+
warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
|
| 1548 |
+
return None
|
| 1549 |
+
|
| 1550 |
+
from .notebook import repr_annotation
|
| 1551 |
+
return repr_annotation(self)
|
ailia-models/code/pyannote_audio_utils/core/feature.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
########
|
| 32 |
+
Features
|
| 33 |
+
########
|
| 34 |
+
|
| 35 |
+
See :class:`pyannote_audio_utils.core.SlidingWindowFeature` for the complete reference.
|
| 36 |
+
"""
|
| 37 |
+
import numbers
|
| 38 |
+
import warnings
|
| 39 |
+
from typing import Tuple, Optional, Union, Iterator, List, Text
|
| 40 |
+
|
| 41 |
+
import numpy as np
|
| 42 |
+
|
| 43 |
+
from pyannote_audio_utils.core.utils.types import Alignment
|
| 44 |
+
from .segment import Segment
|
| 45 |
+
from .segment import SlidingWindow
|
| 46 |
+
from .timeline import Timeline
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin):
|
| 50 |
+
"""Periodic feature vectors
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
data : (n_frames, n_features) numpy array
|
| 55 |
+
sliding_window : SlidingWindow
|
| 56 |
+
labels : list, optional
|
| 57 |
+
Textual description of each dimension.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None
|
| 62 |
+
):
|
| 63 |
+
self.sliding_window: SlidingWindow = sliding_window
|
| 64 |
+
self.data = data
|
| 65 |
+
self.labels = labels
|
| 66 |
+
self.__i: int = -1
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
"""Number of feature vectors"""
|
| 70 |
+
return self.data.shape[0]
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def extent(self):
|
| 74 |
+
return self.sliding_window.range_to_segment(0, len(self))
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def dimension(self):
|
| 78 |
+
"""Dimension of feature vectors"""
|
| 79 |
+
return self.data.shape[1]
|
| 80 |
+
|
| 81 |
+
def getNumber(self):
|
| 82 |
+
warnings.warn("This is deprecated in favor of `__len__`", DeprecationWarning)
|
| 83 |
+
return self.data.shape[0]
|
| 84 |
+
|
| 85 |
+
def getDimension(self):
|
| 86 |
+
warnings.warn(
|
| 87 |
+
"This is deprecated in favor of `dimension` property", DeprecationWarning
|
| 88 |
+
)
|
| 89 |
+
return self.dimension
|
| 90 |
+
|
| 91 |
+
def getExtent(self):
|
| 92 |
+
warnings.warn(
|
| 93 |
+
"This is deprecated in favor of `extent` property", DeprecationWarning
|
| 94 |
+
)
|
| 95 |
+
return self.extent
|
| 96 |
+
|
| 97 |
+
def __getitem__(self, i: int) -> np.ndarray:
|
| 98 |
+
"""Get ith feature vector"""
|
| 99 |
+
return self.data[i]
|
| 100 |
+
|
| 101 |
+
def __iter__(self):
|
| 102 |
+
self.__i = -1
|
| 103 |
+
return self
|
| 104 |
+
|
| 105 |
+
def __next__(self) -> Tuple[Segment, np.ndarray]:
|
| 106 |
+
self.__i += 1
|
| 107 |
+
try:
|
| 108 |
+
return self.sliding_window[self.__i], self.data[self.__i]
|
| 109 |
+
except IndexError as e:
|
| 110 |
+
raise StopIteration()
|
| 111 |
+
|
| 112 |
+
def next(self):
|
| 113 |
+
return self.__next__()
|
| 114 |
+
|
| 115 |
+
def iterfeatures(
|
| 116 |
+
self, window: Optional[bool] = False
|
| 117 |
+
) -> Iterator[Union[Tuple[np.ndarray, Segment], np.ndarray]]:
|
| 118 |
+
"""Feature vector iterator
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
window : bool, optional
|
| 123 |
+
When True, yield both feature vector and corresponding window.
|
| 124 |
+
Default is to only yield feature vector
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
n_samples = self.data.shape[0]
|
| 128 |
+
for i in range(n_samples):
|
| 129 |
+
if window:
|
| 130 |
+
yield self.data[i], self.sliding_window[i]
|
| 131 |
+
else:
|
| 132 |
+
yield self.data[i]
|
| 133 |
+
|
| 134 |
+
def crop(
|
| 135 |
+
self,
|
| 136 |
+
focus: Union[Segment, Timeline],
|
| 137 |
+
mode: Alignment = "loose",
|
| 138 |
+
fixed: Optional[float] = None,
|
| 139 |
+
return_data: bool = True,
|
| 140 |
+
) -> Union[np.ndarray, "SlidingWindowFeature"]:
|
| 141 |
+
"""Extract frames
|
| 142 |
+
|
| 143 |
+
Parameters
|
| 144 |
+
----------
|
| 145 |
+
focus : Segment or Timeline
|
| 146 |
+
mode : {'loose', 'strict', 'center'}, optional
|
| 147 |
+
In 'strict' mode, only frames fully included in 'focus' support are
|
| 148 |
+
returned. In 'loose' mode, any intersecting frames are returned. In
|
| 149 |
+
'center' mode, first and last frames are chosen to be the ones
|
| 150 |
+
whose centers are the closest to 'focus' start and end times.
|
| 151 |
+
Defaults to 'loose'.
|
| 152 |
+
fixed : float, optional
|
| 153 |
+
Overrides `Segment` 'focus' duration and ensures that the number of
|
| 154 |
+
returned frames is fixed (which might otherwise not be the case
|
| 155 |
+
because of rounding errors).
|
| 156 |
+
return_data : bool, optional
|
| 157 |
+
Return a numpy array (default). For `Segment` 'focus', setting it
|
| 158 |
+
to False will return a `SlidingWindowFeature` instance.
|
| 159 |
+
|
| 160 |
+
Returns
|
| 161 |
+
-------
|
| 162 |
+
data : `numpy.ndarray` or `SlidingWindowFeature`
|
| 163 |
+
Frame features.
|
| 164 |
+
|
| 165 |
+
See also
|
| 166 |
+
--------
|
| 167 |
+
SlidingWindow.crop
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
if (not return_data) and (not isinstance(focus, Segment)):
|
| 172 |
+
msg = (
|
| 173 |
+
'"focus" must be a "Segment" instance when "return_data"'
|
| 174 |
+
"is set to False."
|
| 175 |
+
)
|
| 176 |
+
raise ValueError(msg)
|
| 177 |
+
|
| 178 |
+
if (not return_data) and (fixed is not None):
|
| 179 |
+
msg = '"fixed" cannot be set when "return_data" is set to False.'
|
| 180 |
+
raise ValueError(msg)
|
| 181 |
+
|
| 182 |
+
ranges = self.sliding_window.crop(
|
| 183 |
+
focus, mode=mode, fixed=fixed, return_ranges=True
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# total number of samples in features
|
| 187 |
+
n_samples = self.data.shape[0]
|
| 188 |
+
|
| 189 |
+
# 1 for vector features (e.g. MFCC in pyannote_audio_utils.audio)
|
| 190 |
+
# 2 for matrix features (e.g. grey-level frames in pyannote_audio_utils.video)
|
| 191 |
+
# 3 for 3rd order tensor (e.g. RBG frames in pyannote_audio_utils.video)
|
| 192 |
+
n_dimensions = len(self.data.shape) - 1
|
| 193 |
+
|
| 194 |
+
# clip ranges
|
| 195 |
+
clipped_ranges, repeat_first, repeat_last = [], 0, 0
|
| 196 |
+
for start, end in ranges:
|
| 197 |
+
# count number of requested samples before first sample
|
| 198 |
+
repeat_first += min(end, 0) - min(start, 0)
|
| 199 |
+
# count number of requested samples after last sample
|
| 200 |
+
repeat_last += max(end, n_samples) - max(start, n_samples)
|
| 201 |
+
# if all requested samples are out of bounds, skip
|
| 202 |
+
if end < 0 or start >= n_samples:
|
| 203 |
+
continue
|
| 204 |
+
else:
|
| 205 |
+
# keep track of non-empty clipped ranges
|
| 206 |
+
clipped_ranges += [[max(start, 0), min(end, n_samples)]]
|
| 207 |
+
|
| 208 |
+
if clipped_ranges:
|
| 209 |
+
data = np.vstack([self.data[start:end, :] for start, end in clipped_ranges])
|
| 210 |
+
else:
|
| 211 |
+
# if all ranges are out of bounds, just return empty data
|
| 212 |
+
shape = (0,) + self.data.shape[1:]
|
| 213 |
+
data = np.empty(shape)
|
| 214 |
+
|
| 215 |
+
# corner case when "fixed" duration cropping is requested:
|
| 216 |
+
# correct number of samples even with out-of-bounds indices
|
| 217 |
+
if fixed is not None:
|
| 218 |
+
data = np.vstack(
|
| 219 |
+
[
|
| 220 |
+
# repeat first sample as many times as needed
|
| 221 |
+
np.tile(self.data[0], (repeat_first,) + (1,) * n_dimensions),
|
| 222 |
+
data,
|
| 223 |
+
# repeat last sample as many times as needed
|
| 224 |
+
np.tile(
|
| 225 |
+
self.data[n_samples - 1], (repeat_last,) + (1,) * n_dimensions
|
| 226 |
+
),
|
| 227 |
+
]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# return data
|
| 231 |
+
if return_data:
|
| 232 |
+
return data
|
| 233 |
+
|
| 234 |
+
# wrap data in a SlidingWindowFeature and return
|
| 235 |
+
sliding_window = SlidingWindow(
|
| 236 |
+
start=self.sliding_window[clipped_ranges[0][0]].start,
|
| 237 |
+
duration=self.sliding_window.duration,
|
| 238 |
+
step=self.sliding_window.step,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return SlidingWindowFeature(data, sliding_window, labels=self.labels)
|
| 242 |
+
|
| 243 |
+
def _repr_png_(self):
|
| 244 |
+
from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
|
| 245 |
+
|
| 246 |
+
if not MATPLOTLIB_IS_AVAILABLE:
|
| 247 |
+
warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
from .notebook import repr_feature
|
| 251 |
+
|
| 252 |
+
return repr_feature(self)
|
| 253 |
+
|
| 254 |
+
_HANDLED_TYPES = (np.ndarray, numbers.Number)
|
| 255 |
+
|
| 256 |
+
def __array__(self) -> np.ndarray:
|
| 257 |
+
return self.data
|
| 258 |
+
|
| 259 |
+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
| 260 |
+
out = kwargs.get("out", ())
|
| 261 |
+
for x in inputs + out:
|
| 262 |
+
# Only support operations with instances of _HANDLED_TYPES.
|
| 263 |
+
# Use SlidingWindowFeature instead of type(self) for isinstance to
|
| 264 |
+
# allow subclasses that don't override __array_ufunc__ to
|
| 265 |
+
# handle SlidingWindowFeature objects.
|
| 266 |
+
if not isinstance(x, self._HANDLED_TYPES + (SlidingWindowFeature,)):
|
| 267 |
+
return NotImplemented
|
| 268 |
+
|
| 269 |
+
# Defer to the implementation of the ufunc on unwrapped values.
|
| 270 |
+
inputs = tuple(
|
| 271 |
+
x.data if isinstance(x, SlidingWindowFeature) else x for x in inputs
|
| 272 |
+
)
|
| 273 |
+
if out:
|
| 274 |
+
kwargs["out"] = tuple(
|
| 275 |
+
x.data if isinstance(x, SlidingWindowFeature) else x for x in out
|
| 276 |
+
)
|
| 277 |
+
data = getattr(ufunc, method)(*inputs, **kwargs)
|
| 278 |
+
|
| 279 |
+
if type(data) is tuple:
|
| 280 |
+
# multiple return values
|
| 281 |
+
return tuple(
|
| 282 |
+
type(self)(x, self.sliding_window, labels=self.labels) for x in data
|
| 283 |
+
)
|
| 284 |
+
elif method == "at":
|
| 285 |
+
# no return value
|
| 286 |
+
return None
|
| 287 |
+
else:
|
| 288 |
+
# one return value
|
| 289 |
+
return type(self)(data, self.sliding_window, labels=self.labels)
|
| 290 |
+
|
| 291 |
+
def align(self, to: "SlidingWindowFeature") -> "SlidingWindowFeature":
|
| 292 |
+
"""Align features by linear temporal interpolation
|
| 293 |
+
|
| 294 |
+
Parameters
|
| 295 |
+
----------
|
| 296 |
+
to : SlidingWindowFeature
|
| 297 |
+
Features to align with.
|
| 298 |
+
|
| 299 |
+
Returns
|
| 300 |
+
-------
|
| 301 |
+
aligned : SlidingWindowFeature
|
| 302 |
+
Aligned features
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
old_start = self.sliding_window.start
|
| 306 |
+
old_step = self.sliding_window.step
|
| 307 |
+
old_duration = self.sliding_window.duration
|
| 308 |
+
old_samples = len(self)
|
| 309 |
+
old_t = old_start + 0.5 * old_duration + np.arange(old_samples) * old_step
|
| 310 |
+
|
| 311 |
+
new_start = to.sliding_window.start
|
| 312 |
+
new_step = to.sliding_window.step
|
| 313 |
+
new_duration = to.sliding_window.duration
|
| 314 |
+
new_samples = len(to)
|
| 315 |
+
new_t = new_start + 0.5 * new_duration + np.arange(new_samples) * new_step
|
| 316 |
+
|
| 317 |
+
new_data = np.hstack(
|
| 318 |
+
[
|
| 319 |
+
np.interp(new_t, old_t, old_data)[:, np.newaxis]
|
| 320 |
+
for old_data in self.data.T
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
return SlidingWindowFeature(new_data, to.sliding_window, labels=self.labels)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
import doctest
|
| 328 |
+
|
| 329 |
+
doctest.testmod()
|
ailia-models/code/pyannote_audio_utils/core/notebook.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
#############
|
| 31 |
+
Visualization
|
| 32 |
+
#############
|
| 33 |
+
|
| 34 |
+
:class:`pyannote.core.Segment`, :class:`pyannote.core.Timeline`,
|
| 35 |
+
:class:`pyannote.core.Annotation` and :class:`pyannote.core.SlidingWindowFeature`
|
| 36 |
+
instances can be directly visualized in notebooks.
|
| 37 |
+
|
| 38 |
+
You will however need to install ``pytannote.core``'s additional dependencies
|
| 39 |
+
for notebook representations (namely, matplotlib):
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
.. code-block:: bash
|
| 43 |
+
|
| 44 |
+
pip install pyannote.core[notebook]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
Segments
|
| 48 |
+
--------
|
| 49 |
+
|
| 50 |
+
.. code-block:: ipython
|
| 51 |
+
|
| 52 |
+
In [1]: from pyannote.core import Segment
|
| 53 |
+
|
| 54 |
+
In [2]: segment = Segment(start=5, end=15)
|
| 55 |
+
....: segment
|
| 56 |
+
|
| 57 |
+
.. plot:: pyplots/segment.py
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
Timelines
|
| 61 |
+
---------
|
| 62 |
+
|
| 63 |
+
.. code-block:: ipython
|
| 64 |
+
|
| 65 |
+
In [25]: from pyannote.core import Timeline, Segment
|
| 66 |
+
|
| 67 |
+
In [26]: timeline = Timeline()
|
| 68 |
+
....: timeline.add(Segment(1, 5))
|
| 69 |
+
....: timeline.add(Segment(6, 8))
|
| 70 |
+
....: timeline.add(Segment(12, 18))
|
| 71 |
+
....: timeline.add(Segment(7, 20))
|
| 72 |
+
....: timeline
|
| 73 |
+
|
| 74 |
+
.. plot:: pyplots/timeline.py
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
Annotations
|
| 78 |
+
-----------
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
.. code-block:: ipython
|
| 82 |
+
|
| 83 |
+
In [1]: from pyannote.core import Annotation, Segment
|
| 84 |
+
|
| 85 |
+
In [6]: annotation = Annotation()
|
| 86 |
+
...: annotation[Segment(1, 5)] = 'Carol'
|
| 87 |
+
...: annotation[Segment(6, 8)] = 'Bob'
|
| 88 |
+
...: annotation[Segment(12, 18)] = 'Carol'
|
| 89 |
+
...: annotation[Segment(7, 20)] = 'Alice'
|
| 90 |
+
...: annotation
|
| 91 |
+
|
| 92 |
+
.. plot:: pyplots/annotation.py
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
from typing import Iterable, Dict, Optional
|
| 96 |
+
|
| 97 |
+
from .utils.types import Label, LabelStyle, Resource
|
| 98 |
+
|
| 99 |
+
# try:
|
| 100 |
+
# from IPython.core.pylabtools import print_figure
|
| 101 |
+
# except Exception as e:
|
| 102 |
+
# pass
|
| 103 |
+
import numpy as np
|
| 104 |
+
from itertools import cycle, product, groupby
|
| 105 |
+
from .segment import Segment, SlidingWindow
|
| 106 |
+
from .timeline import Timeline
|
| 107 |
+
from .annotation import Annotation
|
| 108 |
+
from .feature import SlidingWindowFeature
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
import matplotlib
|
| 112 |
+
except ImportError:
|
| 113 |
+
MATPLOTLIB_IS_AVAILABLE = False
|
| 114 |
+
else:
|
| 115 |
+
MATPLOTLIB_IS_AVAILABLE = True
|
| 116 |
+
|
| 117 |
+
MATPLOTLIB_WARNING = (
|
| 118 |
+
"Couldn't import matplotlib to render the vizualization "
|
| 119 |
+
"for object {klass}. To enable, install the required dependencies "
|
| 120 |
+
"with 'pip install pyannore.core[notebook]'"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Notebook:
|
| 125 |
+
def __init__(self):
|
| 126 |
+
self.reset()
|
| 127 |
+
|
| 128 |
+
def reset(self):
|
| 129 |
+
from matplotlib.cm import get_cmap
|
| 130 |
+
|
| 131 |
+
linewidth = [3, 1]
|
| 132 |
+
linestyle = ["solid", "dashed", "dotted"]
|
| 133 |
+
|
| 134 |
+
cm = get_cmap("Set1")
|
| 135 |
+
colors = [cm(1.0 * i / 8) for i in range(9)]
|
| 136 |
+
|
| 137 |
+
self._style_generator = cycle(product(linestyle, linewidth, colors))
|
| 138 |
+
self._style: Dict[Optional[Label], LabelStyle] = {
|
| 139 |
+
None: ("solid", 1, (0.0, 0.0, 0.0))
|
| 140 |
+
}
|
| 141 |
+
del self.crop
|
| 142 |
+
del self.width
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def crop(self):
|
| 146 |
+
"""The crop property."""
|
| 147 |
+
return self._crop
|
| 148 |
+
|
| 149 |
+
@crop.setter
|
| 150 |
+
def crop(self, segment: Segment):
|
| 151 |
+
self._crop = segment
|
| 152 |
+
|
| 153 |
+
@crop.deleter
|
| 154 |
+
def crop(self):
|
| 155 |
+
self._crop = None
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def width(self):
|
| 159 |
+
"""The width property"""
|
| 160 |
+
return self._width
|
| 161 |
+
|
| 162 |
+
@width.setter
|
| 163 |
+
def width(self, value: int):
|
| 164 |
+
self._width = value
|
| 165 |
+
|
| 166 |
+
@width.deleter
|
| 167 |
+
def width(self):
|
| 168 |
+
self._width = 20
|
| 169 |
+
|
| 170 |
+
def __getitem__(self, label: Label) -> LabelStyle:
|
| 171 |
+
"""Get line style for a given label"""
|
| 172 |
+
if label not in self._style:
|
| 173 |
+
self._style[label] = next(self._style_generator)
|
| 174 |
+
return self._style[label]
|
| 175 |
+
|
| 176 |
+
def setup(self, ax=None, ylim=(0, 1), yaxis=False, time=True):
|
| 177 |
+
import matplotlib.pyplot as plt
|
| 178 |
+
|
| 179 |
+
if ax is None:
|
| 180 |
+
ax = plt.gca()
|
| 181 |
+
ax.set_xlim(self.crop)
|
| 182 |
+
if time:
|
| 183 |
+
ax.set_xlabel("Time")
|
| 184 |
+
else:
|
| 185 |
+
ax.set_xticklabels([])
|
| 186 |
+
ax.set_ylim(ylim)
|
| 187 |
+
ax.axes.get_yaxis().set_visible(yaxis)
|
| 188 |
+
return ax
|
| 189 |
+
|
| 190 |
+
def draw_segment(self, ax, segment: Segment, y, label=None, boundaries=True):
|
| 191 |
+
|
| 192 |
+
# do nothing if segment is empty
|
| 193 |
+
if not segment:
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
linestyle, linewidth, color = self[label]
|
| 197 |
+
|
| 198 |
+
# draw segment
|
| 199 |
+
ax.hlines(
|
| 200 |
+
y,
|
| 201 |
+
segment.start,
|
| 202 |
+
segment.end,
|
| 203 |
+
color,
|
| 204 |
+
linewidth=linewidth,
|
| 205 |
+
linestyle=linestyle,
|
| 206 |
+
label=label,
|
| 207 |
+
)
|
| 208 |
+
if boundaries:
|
| 209 |
+
ax.vlines(
|
| 210 |
+
segment.start, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid"
|
| 211 |
+
)
|
| 212 |
+
ax.vlines(
|
| 213 |
+
segment.end, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if label is None:
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
def get_y(self, segments: Iterable[Segment]) -> np.ndarray:
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
Parameters
|
| 223 |
+
----------
|
| 224 |
+
segments : Iterable
|
| 225 |
+
`Segment` iterable (sorted)
|
| 226 |
+
|
| 227 |
+
Returns
|
| 228 |
+
-------
|
| 229 |
+
y : np.array
|
| 230 |
+
y coordinates of each segment
|
| 231 |
+
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
# up_to stores the largest end time
|
| 235 |
+
# displayed in each line (at the current iteration)
|
| 236 |
+
# (at the beginning, there is only one empty line)
|
| 237 |
+
up_to = [-np.inf]
|
| 238 |
+
|
| 239 |
+
# y[k] indicates on which line to display kth segment
|
| 240 |
+
y = []
|
| 241 |
+
|
| 242 |
+
for segment in segments:
|
| 243 |
+
# so far, we do not know which line to use
|
| 244 |
+
found = False
|
| 245 |
+
# try each line until we find one that is ok
|
| 246 |
+
for i, u in enumerate(up_to):
|
| 247 |
+
# if segment starts after the previous one
|
| 248 |
+
# on the same line, then we add it to the line
|
| 249 |
+
if segment.start >= u:
|
| 250 |
+
found = True
|
| 251 |
+
y.append(i)
|
| 252 |
+
up_to[i] = segment.end
|
| 253 |
+
break
|
| 254 |
+
# in case we went out of lines, create a new one
|
| 255 |
+
if not found:
|
| 256 |
+
y.append(len(up_to))
|
| 257 |
+
up_to.append(segment.end)
|
| 258 |
+
|
| 259 |
+
# from line numbers to actual y coordinates
|
| 260 |
+
y = 1.0 - 1.0 / (len(up_to) + 1) * (1 + np.array(y))
|
| 261 |
+
|
| 262 |
+
return y
|
| 263 |
+
|
| 264 |
+
def __call__(self, resource: Resource, time: bool = True, legend: bool = True):
|
| 265 |
+
|
| 266 |
+
if isinstance(resource, Segment):
|
| 267 |
+
self.plot_segment(resource, time=time)
|
| 268 |
+
|
| 269 |
+
elif isinstance(resource, Timeline):
|
| 270 |
+
self.plot_timeline(resource, time=time)
|
| 271 |
+
|
| 272 |
+
elif isinstance(resource, Annotation):
|
| 273 |
+
self.plot_annotation(resource, time=time, legend=legend)
|
| 274 |
+
|
| 275 |
+
elif isinstance(resource, SlidingWindowFeature):
|
| 276 |
+
self.plot_feature(resource, time=time)
|
| 277 |
+
|
| 278 |
+
def plot_segment(self, segment, ax=None, time=True):
|
| 279 |
+
|
| 280 |
+
if not self.crop:
|
| 281 |
+
self.crop = segment
|
| 282 |
+
|
| 283 |
+
ax = self.setup(ax=ax, time=time)
|
| 284 |
+
self.draw_segment(ax, segment, 0.5)
|
| 285 |
+
|
| 286 |
+
def plot_timeline(self, timeline: Timeline, ax=None, time=True):
|
| 287 |
+
|
| 288 |
+
if not self.crop and timeline:
|
| 289 |
+
self.crop = timeline.extent()
|
| 290 |
+
|
| 291 |
+
cropped = timeline.crop(self.crop, mode="loose")
|
| 292 |
+
|
| 293 |
+
ax = self.setup(ax=ax, time=time)
|
| 294 |
+
|
| 295 |
+
for segment, y in zip(cropped, self.get_y(cropped)):
|
| 296 |
+
self.draw_segment(ax, segment, y)
|
| 297 |
+
|
| 298 |
+
# ax.set_aspect(3. / self.crop.duration)
|
| 299 |
+
|
| 300 |
+
def plot_annotation(self, annotation: Annotation, ax=None, time=True, legend=True):
|
| 301 |
+
|
| 302 |
+
if not self.crop:
|
| 303 |
+
self.crop = annotation.get_timeline(copy=False).extent()
|
| 304 |
+
|
| 305 |
+
cropped = annotation.crop(self.crop, mode="intersection")
|
| 306 |
+
labels = cropped.labels()
|
| 307 |
+
segments = [s for s, _ in cropped.itertracks()]
|
| 308 |
+
|
| 309 |
+
ax = self.setup(ax=ax, time=time)
|
| 310 |
+
|
| 311 |
+
for (segment, track, label), y in zip(
|
| 312 |
+
cropped.itertracks(yield_label=True), self.get_y(segments)
|
| 313 |
+
):
|
| 314 |
+
self.draw_segment(ax, segment, y, label=label)
|
| 315 |
+
|
| 316 |
+
if legend:
|
| 317 |
+
H, L = ax.get_legend_handles_labels()
|
| 318 |
+
|
| 319 |
+
# corner case when no segment is visible
|
| 320 |
+
if not H:
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
# this gets exactly one legend handle and one legend label per label
|
| 324 |
+
# (avoids repeated legends for repeated tracks with same label)
|
| 325 |
+
HL = groupby(
|
| 326 |
+
sorted(zip(H, L), key=lambda h_l: h_l[1]), key=lambda h_l: h_l[1]
|
| 327 |
+
)
|
| 328 |
+
H, L = zip(*list((next(h_l)[0], l) for l, h_l in HL))
|
| 329 |
+
ax.legend(
|
| 330 |
+
H,
|
| 331 |
+
L,
|
| 332 |
+
bbox_to_anchor=(0, 1),
|
| 333 |
+
loc=3,
|
| 334 |
+
ncol=5,
|
| 335 |
+
borderaxespad=0.0,
|
| 336 |
+
frameon=False,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def plot_feature(
|
| 340 |
+
self, feature: SlidingWindowFeature, ax=None, time=True, ylim=None
|
| 341 |
+
):
|
| 342 |
+
|
| 343 |
+
if not self.crop:
|
| 344 |
+
self.crop = feature.getExtent()
|
| 345 |
+
|
| 346 |
+
window = feature.sliding_window
|
| 347 |
+
n, dimension = feature.data.shape
|
| 348 |
+
((start, stop),) = window.crop(self.crop, mode="loose", return_ranges=True)
|
| 349 |
+
xlim = (window[start].middle, window[stop].middle)
|
| 350 |
+
|
| 351 |
+
start = max(0, start)
|
| 352 |
+
stop = min(stop, n)
|
| 353 |
+
t = window[0].middle + window.step * np.arange(start, stop)
|
| 354 |
+
data = feature[start:stop]
|
| 355 |
+
|
| 356 |
+
if ylim is None:
|
| 357 |
+
m = np.nanmin(data)
|
| 358 |
+
M = np.nanmax(data)
|
| 359 |
+
ylim = (m - 0.1 * (M - m), M + 0.1 * (M - m))
|
| 360 |
+
|
| 361 |
+
ax = self.setup(ax=ax, yaxis=False, ylim=ylim, time=time)
|
| 362 |
+
ax.plot(t, data)
|
| 363 |
+
ax.set_xlim(xlim)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
notebook = Notebook()
|
| 367 |
+
|
| 368 |
+
def repr_segment(segment: Segment):
|
| 369 |
+
"""Get `png` data for `segment`"""
|
| 370 |
+
import matplotlib.pyplot as plt
|
| 371 |
+
|
| 372 |
+
figsize = plt.rcParams["figure.figsize"]
|
| 373 |
+
plt.rcParams["figure.figsize"] = (notebook.width, 1)
|
| 374 |
+
fig, ax = plt.subplots()
|
| 375 |
+
notebook.plot_segment(segment, ax=ax)
|
| 376 |
+
# data = print_figure(fig, "png")
|
| 377 |
+
plt.savefig('./output')
|
| 378 |
+
plt.close(fig)
|
| 379 |
+
plt.rcParams["figure.figsize"] = figsize
|
| 380 |
+
return
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def repr_timeline(timeline: Timeline):
|
| 384 |
+
"""Get `png` data for `timeline`"""
|
| 385 |
+
import matplotlib.pyplot as plt
|
| 386 |
+
breakpoint()
|
| 387 |
+
figsize = plt.rcParams["figure.figsize"]
|
| 388 |
+
plt.rcParams["figure.figsize"] = (notebook.width, 1)
|
| 389 |
+
fig, ax = plt.subplots()
|
| 390 |
+
notebook.plot_timeline(timeline, ax=ax)
|
| 391 |
+
# data = print_figure(fig, "png")
|
| 392 |
+
plt.savefig('./output')
|
| 393 |
+
plt.cla(fig)
|
| 394 |
+
plt.rcParams["figure.figsize"] = figsize
|
| 395 |
+
return
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def repr_annotation(annotation: Annotation):
|
| 399 |
+
"""Get `png` data for `annotation`"""
|
| 400 |
+
import matplotlib.pyplot as plt
|
| 401 |
+
|
| 402 |
+
figsize = plt.rcParams["figure.figsize"]
|
| 403 |
+
plt.rcParams["figure.figsize"] = (notebook.width, 2)
|
| 404 |
+
fig, ax = plt.subplots()
|
| 405 |
+
notebook.plot_annotation(annotation, ax=ax)
|
| 406 |
+
# data = print_figure(fig, "png")
|
| 407 |
+
plt.savefig('./output')
|
| 408 |
+
plt.close(fig)
|
| 409 |
+
plt.rcParams["figure.figsize"] = figsize
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def repr_feature(feature: SlidingWindowFeature):
|
| 414 |
+
"""Get `png` data for `feature`"""
|
| 415 |
+
import matplotlib.pyplot as plt
|
| 416 |
+
|
| 417 |
+
figsize = plt.rcParams["figure.figsize"]
|
| 418 |
+
|
| 419 |
+
if feature.data.ndim == 2:
|
| 420 |
+
|
| 421 |
+
plt.rcParams["figure.figsize"] = (notebook.width, 2)
|
| 422 |
+
fig, ax = plt.subplots()
|
| 423 |
+
notebook.plot_feature(feature, ax=ax)
|
| 424 |
+
# data = print_figure(fig, "png")
|
| 425 |
+
plt.savefig('./output')
|
| 426 |
+
plt.close(fig)
|
| 427 |
+
|
| 428 |
+
elif feature.data.ndim == 3:
|
| 429 |
+
|
| 430 |
+
num_chunks = len(feature)
|
| 431 |
+
|
| 432 |
+
if notebook.crop is None:
|
| 433 |
+
notebook.crop = Segment(
|
| 434 |
+
start=feature.sliding_window.start,
|
| 435 |
+
end=feature.sliding_window[num_chunks - 1].end,
|
| 436 |
+
)
|
| 437 |
+
else:
|
| 438 |
+
feature = feature.crop(notebook.crop, mode="loose", return_data=False)
|
| 439 |
+
|
| 440 |
+
num_overlap = (
|
| 441 |
+
round(feature.sliding_window.duration // feature.sliding_window.step) + 1
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
num_overlap = min(num_chunks, num_overlap)
|
| 445 |
+
|
| 446 |
+
plt.rcParams["figure.figsize"] = (notebook.width, 1.5 * num_overlap)
|
| 447 |
+
|
| 448 |
+
fig, axes = plt.subplots(nrows=num_overlap, ncols=1,)
|
| 449 |
+
mini, maxi = np.nanmin(feature.data), np.nanmax(feature.data)
|
| 450 |
+
ylim = (mini - 0.2 * (maxi - mini), maxi + 0.2 * (maxi - mini))
|
| 451 |
+
for c, (window, data) in enumerate(feature):
|
| 452 |
+
ax = axes[c % num_overlap]
|
| 453 |
+
step = duration = window.duration / len(data)
|
| 454 |
+
frames = SlidingWindow(start=window.start, step=step, duration=duration)
|
| 455 |
+
window_feature = SlidingWindowFeature(data, frames, labels=feature.labels)
|
| 456 |
+
notebook.plot_feature(
|
| 457 |
+
window_feature,
|
| 458 |
+
ax=ax,
|
| 459 |
+
time=c % num_overlap == (num_overlap - 1),
|
| 460 |
+
ylim=ylim,
|
| 461 |
+
)
|
| 462 |
+
ax.set_prop_cycle(None)
|
| 463 |
+
# data = print_figure(fig, "png")
|
| 464 |
+
plt.savefig('./output')
|
| 465 |
+
plt.close(fig)
|
| 466 |
+
|
| 467 |
+
plt.rcParams["figure.figsize"] = figsize
|
| 468 |
+
return
|
ailia-models/code/pyannote_audio_utils/core/segment.py
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2021 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
#######
|
| 31 |
+
Segment
|
| 32 |
+
#######
|
| 33 |
+
|
| 34 |
+
.. plot:: pyplots/segment.py
|
| 35 |
+
|
| 36 |
+
:class:`pyannote.core.Segment` instances describe temporal fragments (*e.g.* of an audio file). The segment depicted above can be defined like that:
|
| 37 |
+
|
| 38 |
+
.. code-block:: ipython
|
| 39 |
+
|
| 40 |
+
In [1]: from pyannote.core import Segment
|
| 41 |
+
|
| 42 |
+
In [2]: segment = Segment(start=5, end=15)
|
| 43 |
+
|
| 44 |
+
In [3]: print(segment)
|
| 45 |
+
|
| 46 |
+
It is nothing more than 2-tuples augmented with several useful methods and properties:
|
| 47 |
+
|
| 48 |
+
.. code-block:: ipython
|
| 49 |
+
|
| 50 |
+
In [4]: start, end = segment
|
| 51 |
+
|
| 52 |
+
In [5]: start
|
| 53 |
+
|
| 54 |
+
In [6]: segment.end
|
| 55 |
+
|
| 56 |
+
In [7]: segment.duration # duration (read-only)
|
| 57 |
+
|
| 58 |
+
In [8]: segment.middle # middle (read-only)
|
| 59 |
+
|
| 60 |
+
In [9]: segment & Segment(3, 12) # intersection
|
| 61 |
+
|
| 62 |
+
In [10]: segment | Segment(3, 12) # union
|
| 63 |
+
|
| 64 |
+
In [11]: segment.overlaps(3) # does segment overlap time t=3?
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
Use `Segment.set_precision(ndigits)` to automatically round start and end timestamps to `ndigits` precision after the decimal point.
|
| 68 |
+
To ensure consistency between `Segment` instances, it is recommended to call this method only once, right after importing `pyannote.core.Segment`.
|
| 69 |
+
|
| 70 |
+
.. code-block:: ipython
|
| 71 |
+
|
| 72 |
+
In [12]: Segment(1/1000, 330/1000) == Segment(1/1000, 90/1000+240/1000)
|
| 73 |
+
Out[12]: False
|
| 74 |
+
|
| 75 |
+
In [13]: Segment.set_precision(ndigits=4)
|
| 76 |
+
|
| 77 |
+
In [14]: Segment(1/1000, 330/1000) == Segment(1/1000, 90/1000+240/1000)
|
| 78 |
+
Out[14]: True
|
| 79 |
+
|
| 80 |
+
See :class:`pyannote.core.Segment` for the complete reference.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
import warnings
|
| 84 |
+
from typing import Union, Optional, Tuple, List, Iterator, Iterable
|
| 85 |
+
|
| 86 |
+
from .utils.types import Alignment
|
| 87 |
+
|
| 88 |
+
import numpy as np
|
| 89 |
+
from dataclasses import dataclass
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# setting 'frozen' to True makes it hashable and immutable
|
| 93 |
+
@dataclass(frozen=True, order=True)
|
| 94 |
+
class Segment:
|
| 95 |
+
"""
|
| 96 |
+
Time interval
|
| 97 |
+
|
| 98 |
+
Parameters
|
| 99 |
+
----------
|
| 100 |
+
start : float
|
| 101 |
+
interval start time, in seconds.
|
| 102 |
+
end : float
|
| 103 |
+
interval end time, in seconds.
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
Segments can be compared and sorted using the standard operators:
|
| 107 |
+
|
| 108 |
+
>>> Segment(0, 1) == Segment(0, 1.)
|
| 109 |
+
True
|
| 110 |
+
>>> Segment(0, 1) != Segment(3, 4)
|
| 111 |
+
True
|
| 112 |
+
>>> Segment(0, 1) < Segment(2, 3)
|
| 113 |
+
True
|
| 114 |
+
>>> Segment(0, 1) < Segment(0, 2)
|
| 115 |
+
True
|
| 116 |
+
>>> Segment(1, 2) < Segment(0, 3)
|
| 117 |
+
False
|
| 118 |
+
|
| 119 |
+
Note
|
| 120 |
+
----
|
| 121 |
+
A segment is smaller than another segment if one of these two conditions is verified:
|
| 122 |
+
|
| 123 |
+
- `segment.start < other_segment.start`
|
| 124 |
+
- `segment.start == other_segment.start` and `segment.end < other_segment.end`
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
start: float = 0.0
|
| 128 |
+
end: float = 0.0
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def set_precision(ndigits: Optional[int] = None):
|
| 132 |
+
"""Automatically round start and end timestamps to `ndigits` precision after the decimal point
|
| 133 |
+
|
| 134 |
+
To ensure consistency between `Segment` instances, it is recommended to call this method only
|
| 135 |
+
once, right after importing `pyannote.core.Segment`.
|
| 136 |
+
|
| 137 |
+
Usage
|
| 138 |
+
-----
|
| 139 |
+
>>> from pyannote.core import Segment
|
| 140 |
+
>>> Segment.set_precision(2)
|
| 141 |
+
>>> Segment(1/3, 2/3)
|
| 142 |
+
<Segment(0.33, 0.67)>
|
| 143 |
+
"""
|
| 144 |
+
global AUTO_ROUND_TIME
|
| 145 |
+
global SEGMENT_PRECISION
|
| 146 |
+
|
| 147 |
+
if ndigits is None:
|
| 148 |
+
# backward compatibility
|
| 149 |
+
AUTO_ROUND_TIME = False
|
| 150 |
+
# 1 μs (one microsecond)
|
| 151 |
+
SEGMENT_PRECISION = 1e-6
|
| 152 |
+
else:
|
| 153 |
+
AUTO_ROUND_TIME = True
|
| 154 |
+
SEGMENT_PRECISION = 10 ** (-ndigits)
|
| 155 |
+
|
| 156 |
+
def __bool__(self):
|
| 157 |
+
"""Emptiness
|
| 158 |
+
|
| 159 |
+
>>> if segment:
|
| 160 |
+
... # segment is not empty.
|
| 161 |
+
... else:
|
| 162 |
+
... # segment is empty.
|
| 163 |
+
|
| 164 |
+
Note
|
| 165 |
+
----
|
| 166 |
+
A segment is considered empty if its end time is smaller than its
|
| 167 |
+
start time, or its duration is smaller than 1μs.
|
| 168 |
+
"""
|
| 169 |
+
return bool((self.end - self.start) > SEGMENT_PRECISION)
|
| 170 |
+
|
| 171 |
+
def __post_init__(self):
|
| 172 |
+
"""Round start and end up to SEGMENT_PRECISION precision (when required)"""
|
| 173 |
+
if AUTO_ROUND_TIME:
|
| 174 |
+
object.__setattr__(self, 'start', int(self.start / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
|
| 175 |
+
object.__setattr__(self, 'end', int(self.end / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def duration(self) -> float:
|
| 179 |
+
"""Segment duration (read-only)"""
|
| 180 |
+
return self.end - self.start if self else 0.
|
| 181 |
+
|
| 182 |
+
@property
|
| 183 |
+
def middle(self) -> float:
|
| 184 |
+
"""Segment mid-time (read-only)"""
|
| 185 |
+
return .5 * (self.start + self.end)
|
| 186 |
+
|
| 187 |
+
def __iter__(self) -> Iterator[float]:
|
| 188 |
+
"""Unpack segment boundaries
|
| 189 |
+
>>> segment = Segment(start, end)
|
| 190 |
+
>>> start, end = segment
|
| 191 |
+
"""
|
| 192 |
+
yield self.start
|
| 193 |
+
yield self.end
|
| 194 |
+
|
| 195 |
+
def copy(self) -> 'Segment':
|
| 196 |
+
"""Get a copy of the segment
|
| 197 |
+
|
| 198 |
+
Returns
|
| 199 |
+
-------
|
| 200 |
+
copy : Segment
|
| 201 |
+
Copy of the segment.
|
| 202 |
+
"""
|
| 203 |
+
return Segment(start=self.start, end=self.end)
|
| 204 |
+
|
| 205 |
+
# ------------------------------------------------------- #
|
| 206 |
+
# Inclusion (in), intersection (&), union (|) and gap (^) #
|
| 207 |
+
# ------------------------------------------------------- #
|
| 208 |
+
|
| 209 |
+
def __contains__(self, other: 'Segment'):
|
| 210 |
+
"""Inclusion
|
| 211 |
+
|
| 212 |
+
>>> segment = Segment(start=0, end=10)
|
| 213 |
+
>>> Segment(start=3, end=10) in segment:
|
| 214 |
+
True
|
| 215 |
+
>>> Segment(start=5, end=15) in segment:
|
| 216 |
+
False
|
| 217 |
+
"""
|
| 218 |
+
return (self.start <= other.start) and (self.end >= other.end)
|
| 219 |
+
|
| 220 |
+
def __and__(self, other):
|
| 221 |
+
"""Intersection
|
| 222 |
+
|
| 223 |
+
>>> segment = Segment(0, 10)
|
| 224 |
+
>>> other_segment = Segment(5, 15)
|
| 225 |
+
>>> segment & other_segment
|
| 226 |
+
<Segment(5, 10)>
|
| 227 |
+
|
| 228 |
+
Note
|
| 229 |
+
----
|
| 230 |
+
When the intersection is empty, an empty segment is returned:
|
| 231 |
+
|
| 232 |
+
>>> segment = Segment(0, 10)
|
| 233 |
+
>>> other_segment = Segment(15, 20)
|
| 234 |
+
>>> intersection = segment & other_segment
|
| 235 |
+
>>> if not intersection:
|
| 236 |
+
... # intersection is empty.
|
| 237 |
+
"""
|
| 238 |
+
start = max(self.start, other.start)
|
| 239 |
+
end = min(self.end, other.end)
|
| 240 |
+
return Segment(start=start, end=end)
|
| 241 |
+
|
| 242 |
+
def intersects(self, other: 'Segment') -> bool:
|
| 243 |
+
"""Check whether two segments intersect each other
|
| 244 |
+
|
| 245 |
+
Parameters
|
| 246 |
+
----------
|
| 247 |
+
other : Segment
|
| 248 |
+
Other segment
|
| 249 |
+
|
| 250 |
+
Returns
|
| 251 |
+
-------
|
| 252 |
+
intersect : bool
|
| 253 |
+
True if segments intersect, False otherwise
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
return (self.start < other.start and
|
| 257 |
+
other.start < self.end - SEGMENT_PRECISION) or \
|
| 258 |
+
(self.start > other.start and
|
| 259 |
+
self.start < other.end - SEGMENT_PRECISION) or \
|
| 260 |
+
(self.start == other.start)
|
| 261 |
+
|
| 262 |
+
def overlaps(self, t: float) -> bool:
|
| 263 |
+
"""Check if segment overlaps a given time
|
| 264 |
+
|
| 265 |
+
Parameters
|
| 266 |
+
----------
|
| 267 |
+
t : float
|
| 268 |
+
Time, in seconds.
|
| 269 |
+
|
| 270 |
+
Returns
|
| 271 |
+
-------
|
| 272 |
+
overlap: bool
|
| 273 |
+
True if segment overlaps time t, False otherwise.
|
| 274 |
+
"""
|
| 275 |
+
return self.start <= t and self.end >= t
|
| 276 |
+
|
| 277 |
+
def __or__(self, other: 'Segment') -> 'Segment':
|
| 278 |
+
"""Union
|
| 279 |
+
|
| 280 |
+
>>> segment = Segment(0, 10)
|
| 281 |
+
>>> other_segment = Segment(5, 15)
|
| 282 |
+
>>> segment | other_segment
|
| 283 |
+
<Segment(0, 15)>
|
| 284 |
+
|
| 285 |
+
Note
|
| 286 |
+
----
|
| 287 |
+
When a gap exists between the segment, their union covers the gap as well:
|
| 288 |
+
|
| 289 |
+
>>> segment = Segment(0, 10)
|
| 290 |
+
>>> other_segment = Segment(15, 20)
|
| 291 |
+
>>> segment | other_segment
|
| 292 |
+
<Segment(0, 20)
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
# if segment is empty, union is the other one
|
| 296 |
+
if not self:
|
| 297 |
+
return other
|
| 298 |
+
# if other one is empty, union is self
|
| 299 |
+
if not other:
|
| 300 |
+
return self
|
| 301 |
+
|
| 302 |
+
# otherwise, do what's meant to be...
|
| 303 |
+
start = min(self.start, other.start)
|
| 304 |
+
end = max(self.end, other.end)
|
| 305 |
+
return Segment(start=start, end=end)
|
| 306 |
+
|
| 307 |
+
def __xor__(self, other: 'Segment') -> 'Segment':
|
| 308 |
+
"""Gap
|
| 309 |
+
|
| 310 |
+
>>> segment = Segment(0, 10)
|
| 311 |
+
>>> other_segment = Segment(15, 20)
|
| 312 |
+
>>> segment ^ other_segment
|
| 313 |
+
<Segment(10, 15)
|
| 314 |
+
|
| 315 |
+
Note
|
| 316 |
+
----
|
| 317 |
+
The gap between a segment and an empty segment is not defined.
|
| 318 |
+
|
| 319 |
+
>>> segment = Segment(0, 10)
|
| 320 |
+
>>> empty_segment = Segment(11, 11)
|
| 321 |
+
>>> segment ^ empty_segment
|
| 322 |
+
ValueError: The gap between a segment and an empty segment is not defined.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
# if segment is empty, xor is not defined
|
| 326 |
+
if (not self) or (not other):
|
| 327 |
+
raise ValueError(
|
| 328 |
+
'The gap between a segment and an empty segment '
|
| 329 |
+
'is not defined.')
|
| 330 |
+
|
| 331 |
+
start = min(self.end, other.end)
|
| 332 |
+
end = max(self.start, other.start)
|
| 333 |
+
return Segment(start=start, end=end)
|
| 334 |
+
|
| 335 |
+
def _str_helper(self, seconds: float) -> str:
|
| 336 |
+
from datetime import timedelta
|
| 337 |
+
negative = seconds < 0
|
| 338 |
+
seconds = abs(seconds)
|
| 339 |
+
td = timedelta(seconds=seconds)
|
| 340 |
+
seconds = td.seconds + 86400 * td.days
|
| 341 |
+
microseconds = td.microseconds
|
| 342 |
+
hours, remainder = divmod(seconds, 3600)
|
| 343 |
+
minutes, seconds = divmod(remainder, 60)
|
| 344 |
+
return '%s%02d:%02d:%02d.%03d' % (
|
| 345 |
+
'-' if negative else ' ', hours, minutes,
|
| 346 |
+
seconds, microseconds / 1000)
|
| 347 |
+
|
| 348 |
+
def __str__(self):
|
| 349 |
+
"""Human-readable representation
|
| 350 |
+
|
| 351 |
+
>>> print(Segment(1337, 1337 + 0.42))
|
| 352 |
+
[ 00:22:17.000 --> 00:22:17.420]
|
| 353 |
+
|
| 354 |
+
Note
|
| 355 |
+
----
|
| 356 |
+
Empty segments are printed as "[]"
|
| 357 |
+
"""
|
| 358 |
+
if self:
|
| 359 |
+
return '[%s --> %s]' % (self._str_helper(self.start),
|
| 360 |
+
self._str_helper(self.end))
|
| 361 |
+
return '[]'
|
| 362 |
+
|
| 363 |
+
def __repr__(self):
|
| 364 |
+
"""Computer-readable representation
|
| 365 |
+
|
| 366 |
+
>>> Segment(1337, 1337 + 0.42)
|
| 367 |
+
<Segment(1337, 1337.42)>
|
| 368 |
+
"""
|
| 369 |
+
return '<Segment(%g, %g)>' % (self.start, self.end)
|
| 370 |
+
|
| 371 |
+
def _repr_png_(self):
|
| 372 |
+
"""IPython notebook support
|
| 373 |
+
|
| 374 |
+
See also
|
| 375 |
+
--------
|
| 376 |
+
:mod:`pyannote.core.notebook`
|
| 377 |
+
"""
|
| 378 |
+
from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
|
| 379 |
+
if not MATPLOTLIB_IS_AVAILABLE:
|
| 380 |
+
warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
from .notebook import repr_segment
|
| 384 |
+
try:
|
| 385 |
+
return repr_segment(self)
|
| 386 |
+
except ImportError:
|
| 387 |
+
warnings.warn(
|
| 388 |
+
f"Couldn't import matplotlib to render the vizualization for object {self}. To enable, install the required dependencies with 'pip install pyannore.core[notebook]'")
|
| 389 |
+
return None
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class SlidingWindow:
|
| 393 |
+
"""Sliding window
|
| 394 |
+
|
| 395 |
+
Parameters
|
| 396 |
+
----------
|
| 397 |
+
duration : float > 0, optional
|
| 398 |
+
Window duration, in seconds. Default is 30 ms.
|
| 399 |
+
step : float > 0, optional
|
| 400 |
+
Step between two consecutive position, in seconds. Default is 10 ms.
|
| 401 |
+
start : float, optional
|
| 402 |
+
First start position of window, in seconds. Default is 0.
|
| 403 |
+
end : float > `start`, optional
|
| 404 |
+
Default is infinity (ie. window keeps sliding forever)
|
| 405 |
+
|
| 406 |
+
Examples
|
| 407 |
+
--------
|
| 408 |
+
|
| 409 |
+
>>> sw = SlidingWindow(duration, step, start)
|
| 410 |
+
>>> frame_range = (a, b)
|
| 411 |
+
>>> frame_range == sw.toFrameRange(sw.toSegment(*frame_range))
|
| 412 |
+
... True
|
| 413 |
+
|
| 414 |
+
>>> segment = Segment(A, B)
|
| 415 |
+
>>> new_segment = sw.toSegment(*sw.toFrameRange(segment))
|
| 416 |
+
>>> abs(segment) - abs(segment & new_segment) < .5 * sw.step
|
| 417 |
+
|
| 418 |
+
>>> sw = SlidingWindow(end=0.1)
|
| 419 |
+
>>> print(next(sw))
|
| 420 |
+
[ 00:00:00.000 --> 00:00:00.030]
|
| 421 |
+
>>> print(next(sw))
|
| 422 |
+
[ 00:00:00.010 --> 00:00:00.040]
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(self, duration=0.030, step=0.010, start=0.000, end=None):
|
| 426 |
+
|
| 427 |
+
# duration must be a float > 0
|
| 428 |
+
if duration <= 0:
|
| 429 |
+
raise ValueError("'duration' must be a float > 0.")
|
| 430 |
+
self.__duration = duration
|
| 431 |
+
|
| 432 |
+
# step must be a float > 0
|
| 433 |
+
if step <= 0:
|
| 434 |
+
raise ValueError("'step' must be a float > 0.")
|
| 435 |
+
self.__step: float = step
|
| 436 |
+
|
| 437 |
+
# start must be a float.
|
| 438 |
+
self.__start: float = start
|
| 439 |
+
|
| 440 |
+
# if end is not provided, set it to infinity
|
| 441 |
+
if end is None:
|
| 442 |
+
self.__end: float = np.inf
|
| 443 |
+
else:
|
| 444 |
+
# end must be greater than start
|
| 445 |
+
if end <= start:
|
| 446 |
+
raise ValueError("'end' must be greater than 'start'.")
|
| 447 |
+
self.__end: float = end
|
| 448 |
+
|
| 449 |
+
# current index of iterator
|
| 450 |
+
self.__i: int = -1
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def start(self) -> float:
|
| 454 |
+
"""Sliding window start time in seconds."""
|
| 455 |
+
return self.__start
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def end(self) -> float:
|
| 459 |
+
"""Sliding window end time in seconds."""
|
| 460 |
+
return self.__end
|
| 461 |
+
|
| 462 |
+
@property
|
| 463 |
+
def step(self) -> float:
|
| 464 |
+
"""Sliding window step in seconds."""
|
| 465 |
+
return self.__step
|
| 466 |
+
|
| 467 |
+
@property
|
| 468 |
+
def duration(self) -> float:
|
| 469 |
+
"""Sliding window duration in seconds."""
|
| 470 |
+
return self.__duration
|
| 471 |
+
|
| 472 |
+
def closest_frame(self, t: float) -> int:
|
| 473 |
+
"""Closest frame to timestamp.
|
| 474 |
+
|
| 475 |
+
Parameters
|
| 476 |
+
----------
|
| 477 |
+
t : float
|
| 478 |
+
Timestamp, in seconds.
|
| 479 |
+
|
| 480 |
+
Returns
|
| 481 |
+
-------
|
| 482 |
+
index : int
|
| 483 |
+
Index of frame whose middle is the closest to `timestamp`
|
| 484 |
+
|
| 485 |
+
"""
|
| 486 |
+
return int(np.rint(
|
| 487 |
+
(t - self.__start - .5 * self.__duration) / self.__step
|
| 488 |
+
))
|
| 489 |
+
|
| 490 |
+
def samples(self, from_duration: float, mode: Alignment = 'strict') -> int:
|
| 491 |
+
"""Number of frames
|
| 492 |
+
|
| 493 |
+
Parameters
|
| 494 |
+
----------
|
| 495 |
+
from_duration : float
|
| 496 |
+
Duration in seconds.
|
| 497 |
+
mode : {'strict', 'loose', 'center'}
|
| 498 |
+
In 'strict' mode, computes the maximum number of consecutive frames
|
| 499 |
+
that can be fitted into a segment with duration `from_duration`.
|
| 500 |
+
In 'loose' mode, computes the maximum number of consecutive frames
|
| 501 |
+
intersecting a segment with duration `from_duration`.
|
| 502 |
+
In 'center' mode, computes the average number of consecutive frames
|
| 503 |
+
where the first one is centered on the start time and the last one
|
| 504 |
+
is centered on the end time of a segment with duration
|
| 505 |
+
`from_duration`.
|
| 506 |
+
|
| 507 |
+
"""
|
| 508 |
+
if mode == 'strict':
|
| 509 |
+
return int(np.floor((from_duration - self.duration) / self.step)) + 1
|
| 510 |
+
|
| 511 |
+
elif mode == 'loose':
|
| 512 |
+
return int(np.floor((from_duration + self.duration) / self.step))
|
| 513 |
+
|
| 514 |
+
elif mode == 'center':
|
| 515 |
+
return int(np.rint((from_duration / self.step)))
|
| 516 |
+
|
| 517 |
+
def crop(self, focus: Union[Segment, 'Timeline'],
|
| 518 |
+
mode: Alignment = 'loose',
|
| 519 |
+
fixed: Optional[float] = None,
|
| 520 |
+
return_ranges: Optional[bool] = False) -> \
|
| 521 |
+
Union[np.ndarray, List[List[int]]]:
|
| 522 |
+
"""Crop sliding window
|
| 523 |
+
|
| 524 |
+
Parameters
|
| 525 |
+
----------
|
| 526 |
+
focus : `Segment` or `Timeline`
|
| 527 |
+
mode : {'strict', 'loose', 'center'}, optional
|
| 528 |
+
In 'strict' mode, only indices of segments fully included in
|
| 529 |
+
'focus' support are returned. In 'loose' mode, indices of any
|
| 530 |
+
intersecting segments are returned. In 'center' mode, first and
|
| 531 |
+
last positions are chosen to be the positions whose centers are the
|
| 532 |
+
closest to 'focus' start and end times. Defaults to 'loose'.
|
| 533 |
+
fixed : float, optional
|
| 534 |
+
Overrides `Segment` 'focus' duration and ensures that the number of
|
| 535 |
+
returned frames is fixed (which might otherwise not be the case
|
| 536 |
+
because of rounding erros).
|
| 537 |
+
return_ranges : bool, optional
|
| 538 |
+
Return as list of ranges. Defaults to indices numpy array.
|
| 539 |
+
|
| 540 |
+
Returns
|
| 541 |
+
-------
|
| 542 |
+
indices : np.array (or list of ranges)
|
| 543 |
+
Array of unique indices of matching segments
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
from .timeline import Timeline
|
| 547 |
+
|
| 548 |
+
if not isinstance(focus, (Segment, Timeline)):
|
| 549 |
+
msg = '"focus" must be a `Segment` or `Timeline` instance.'
|
| 550 |
+
raise TypeError(msg)
|
| 551 |
+
|
| 552 |
+
if isinstance(focus, Timeline):
|
| 553 |
+
|
| 554 |
+
if fixed is not None:
|
| 555 |
+
msg = "'fixed' is not supported with `Timeline` 'focus'."
|
| 556 |
+
raise ValueError(msg)
|
| 557 |
+
|
| 558 |
+
if return_ranges:
|
| 559 |
+
ranges = []
|
| 560 |
+
|
| 561 |
+
for i, s in enumerate(focus.support()):
|
| 562 |
+
rng = self.crop(s, mode=mode, fixed=fixed,
|
| 563 |
+
return_ranges=True)
|
| 564 |
+
|
| 565 |
+
# if first or disjoint segment, add it
|
| 566 |
+
if i == 0 or rng[0][0] > ranges[-1][1]:
|
| 567 |
+
ranges += rng
|
| 568 |
+
|
| 569 |
+
# if overlapping segment, update last range
|
| 570 |
+
else:
|
| 571 |
+
ranges[-1][1] = rng[0][1]
|
| 572 |
+
|
| 573 |
+
return ranges
|
| 574 |
+
|
| 575 |
+
# concatenate all indices
|
| 576 |
+
indices = np.hstack([
|
| 577 |
+
self.crop(s, mode=mode, fixed=fixed, return_ranges=False)
|
| 578 |
+
for s in focus.support()])
|
| 579 |
+
|
| 580 |
+
# remove duplicate indices
|
| 581 |
+
return np.unique(indices)
|
| 582 |
+
|
| 583 |
+
# 'focus' is a `Segment` instance
|
| 584 |
+
|
| 585 |
+
if mode == 'loose':
|
| 586 |
+
|
| 587 |
+
# find smallest integer i such that
|
| 588 |
+
# self.start + i x self.step + self.duration >= focus.start
|
| 589 |
+
i_ = (focus.start - self.duration - self.start) / self.step
|
| 590 |
+
i = int(np.ceil(i_))
|
| 591 |
+
|
| 592 |
+
if fixed is None:
|
| 593 |
+
# find largest integer j such that
|
| 594 |
+
# self.start + j x self.step <= focus.end
|
| 595 |
+
j_ = (focus.end - self.start) / self.step
|
| 596 |
+
j = int(np.floor(j_))
|
| 597 |
+
rng = (i, j + 1)
|
| 598 |
+
|
| 599 |
+
else:
|
| 600 |
+
n = self.samples(fixed, mode='loose')
|
| 601 |
+
rng = (i, i + n)
|
| 602 |
+
|
| 603 |
+
elif mode == 'strict':
|
| 604 |
+
|
| 605 |
+
# find smallest integer i such that
|
| 606 |
+
# self.start + i x self.step >= focus.start
|
| 607 |
+
i_ = (focus.start - self.start) / self.step
|
| 608 |
+
i = int(np.ceil(i_))
|
| 609 |
+
|
| 610 |
+
if fixed is None:
|
| 611 |
+
|
| 612 |
+
# find largest integer j such that
|
| 613 |
+
# self.start + j x self.step + self.duration <= focus.end
|
| 614 |
+
j_ = (focus.end - self.duration - self.start) / self.step
|
| 615 |
+
j = int(np.floor(j_))
|
| 616 |
+
rng = (i, j + 1)
|
| 617 |
+
|
| 618 |
+
else:
|
| 619 |
+
n = self.samples(fixed, mode='strict')
|
| 620 |
+
rng = (i, i + n)
|
| 621 |
+
|
| 622 |
+
elif mode == 'center':
|
| 623 |
+
|
| 624 |
+
# find window position whose center is the closest to focus.start
|
| 625 |
+
i = self.closest_frame(focus.start)
|
| 626 |
+
|
| 627 |
+
if fixed is None:
|
| 628 |
+
# find window position whose center is the closest to focus.end
|
| 629 |
+
j = self.closest_frame(focus.end)
|
| 630 |
+
rng = (i, j + 1)
|
| 631 |
+
else:
|
| 632 |
+
n = self.samples(fixed, mode='center')
|
| 633 |
+
rng = (i, i + n)
|
| 634 |
+
|
| 635 |
+
else:
|
| 636 |
+
msg = "'mode' must be one of {'loose', 'strict', 'center'}."
|
| 637 |
+
raise ValueError(msg)
|
| 638 |
+
|
| 639 |
+
if return_ranges:
|
| 640 |
+
return [list(rng)]
|
| 641 |
+
|
| 642 |
+
return np.array(range(*rng), dtype=np.int64)
|
| 643 |
+
|
| 644 |
+
def segmentToRange(self, segment: Segment) -> Tuple[int, int]:
|
| 645 |
+
warnings.warn("Deprecated in favor of `segment_to_range`",
|
| 646 |
+
DeprecationWarning)
|
| 647 |
+
return self.segment_to_range(segment)
|
| 648 |
+
|
| 649 |
+
def segment_to_range(self, segment: Segment) -> Tuple[int, int]:
|
| 650 |
+
"""Convert segment to 0-indexed frame range
|
| 651 |
+
|
| 652 |
+
Parameters
|
| 653 |
+
----------
|
| 654 |
+
segment : Segment
|
| 655 |
+
|
| 656 |
+
Returns
|
| 657 |
+
-------
|
| 658 |
+
i0 : int
|
| 659 |
+
Index of first frame
|
| 660 |
+
n : int
|
| 661 |
+
Number of frames
|
| 662 |
+
|
| 663 |
+
Examples
|
| 664 |
+
--------
|
| 665 |
+
|
| 666 |
+
>>> window = SlidingWindow()
|
| 667 |
+
>>> print window.segment_to_range(Segment(10, 15))
|
| 668 |
+
i0, n
|
| 669 |
+
|
| 670 |
+
"""
|
| 671 |
+
# find closest frame to segment start
|
| 672 |
+
i0 = self.closest_frame(segment.start)
|
| 673 |
+
|
| 674 |
+
# number of steps to cover segment duration
|
| 675 |
+
n = int(segment.duration / self.step) + 1
|
| 676 |
+
|
| 677 |
+
return i0, n
|
| 678 |
+
|
| 679 |
+
def rangeToSegment(self, i0: int, n: int) -> Segment:
|
| 680 |
+
warnings.warn("This is deprecated in favor of `range_to_segment`",
|
| 681 |
+
DeprecationWarning)
|
| 682 |
+
return self.range_to_segment(i0, n)
|
| 683 |
+
|
| 684 |
+
def range_to_segment(self, i0: int, n: int) -> Segment:
|
| 685 |
+
"""Convert 0-indexed frame range to segment
|
| 686 |
+
|
| 687 |
+
Each frame represents a unique segment of duration 'step', centered on
|
| 688 |
+
the middle of the frame.
|
| 689 |
+
|
| 690 |
+
The very first frame (i0 = 0) is the exception. It is extended to the
|
| 691 |
+
sliding window start time.
|
| 692 |
+
|
| 693 |
+
Parameters
|
| 694 |
+
----------
|
| 695 |
+
i0 : int
|
| 696 |
+
Index of first frame
|
| 697 |
+
n : int
|
| 698 |
+
Number of frames
|
| 699 |
+
|
| 700 |
+
Returns
|
| 701 |
+
-------
|
| 702 |
+
segment : Segment
|
| 703 |
+
|
| 704 |
+
Examples
|
| 705 |
+
--------
|
| 706 |
+
|
| 707 |
+
>>> window = SlidingWindow()
|
| 708 |
+
>>> print window.range_to_segment(3, 2)
|
| 709 |
+
[ --> ]
|
| 710 |
+
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
# frame start time
|
| 714 |
+
# start = self.start + i0 * self.step
|
| 715 |
+
# frame middle time
|
| 716 |
+
# start += .5 * self.duration
|
| 717 |
+
# subframe start time
|
| 718 |
+
# start -= .5 * self.step
|
| 719 |
+
start = self.__start + (i0 - .5) * self.__step + .5 * self.__duration
|
| 720 |
+
duration = n * self.__step
|
| 721 |
+
end = start + duration
|
| 722 |
+
|
| 723 |
+
# extend segment to the beginning of the timeline
|
| 724 |
+
if i0 == 0:
|
| 725 |
+
start = self.start
|
| 726 |
+
|
| 727 |
+
return Segment(start, end)
|
| 728 |
+
|
| 729 |
+
def samplesToDuration(self, nSamples: int) -> float:
|
| 730 |
+
warnings.warn("This is deprecated in favor of `samples_to_duration`",
|
| 731 |
+
DeprecationWarning)
|
| 732 |
+
return self.samples_to_duration(nSamples)
|
| 733 |
+
|
| 734 |
+
def samples_to_duration(self, n_samples: int) -> float:
|
| 735 |
+
"""Returns duration of samples"""
|
| 736 |
+
return self.range_to_segment(0, n_samples).duration
|
| 737 |
+
|
| 738 |
+
def durationToSamples(self, duration: float) -> int:
|
| 739 |
+
warnings.warn("This is deprecated in favor of `duration_to_samples`",
|
| 740 |
+
DeprecationWarning)
|
| 741 |
+
return self.duration_to_samples(duration)
|
| 742 |
+
|
| 743 |
+
def duration_to_samples(self, duration: float) -> int:
|
| 744 |
+
"""Returns samples in duration"""
|
| 745 |
+
return self.segment_to_range(Segment(0, duration))[1]
|
| 746 |
+
|
| 747 |
+
def __getitem__(self, i: int) -> Segment:
|
| 748 |
+
"""
|
| 749 |
+
Parameters
|
| 750 |
+
----------
|
| 751 |
+
i : int
|
| 752 |
+
Index of sliding window position
|
| 753 |
+
|
| 754 |
+
Returns
|
| 755 |
+
-------
|
| 756 |
+
segment : :class:`Segment`
|
| 757 |
+
Sliding window at ith position
|
| 758 |
+
|
| 759 |
+
"""
|
| 760 |
+
|
| 761 |
+
# window start time at ith position
|
| 762 |
+
start = self.__start + i * self.__step
|
| 763 |
+
|
| 764 |
+
# in case segment starts after the end,
|
| 765 |
+
# return an empty segment
|
| 766 |
+
if start >= self.__end:
|
| 767 |
+
return None
|
| 768 |
+
|
| 769 |
+
return Segment(start=start, end=start + self.__duration)
|
| 770 |
+
|
| 771 |
+
def next(self) -> Segment:
|
| 772 |
+
return self.__next__()
|
| 773 |
+
|
| 774 |
+
def __next__(self) -> Segment:
|
| 775 |
+
self.__i += 1
|
| 776 |
+
window = self[self.__i]
|
| 777 |
+
|
| 778 |
+
if window:
|
| 779 |
+
return window
|
| 780 |
+
else:
|
| 781 |
+
raise StopIteration()
|
| 782 |
+
|
| 783 |
+
def __iter__(self) -> 'SlidingWindow':
|
| 784 |
+
"""Sliding window iterator
|
| 785 |
+
|
| 786 |
+
Use expression 'for segment in sliding_window'
|
| 787 |
+
|
| 788 |
+
Examples
|
| 789 |
+
--------
|
| 790 |
+
|
| 791 |
+
>>> window = SlidingWindow(end=0.1)
|
| 792 |
+
>>> for segment in window:
|
| 793 |
+
... print(segment)
|
| 794 |
+
[ 00:00:00.000 --> 00:00:00.030]
|
| 795 |
+
[ 00:00:00.010 --> 00:00:00.040]
|
| 796 |
+
[ 00:00:00.020 --> 00:00:00.050]
|
| 797 |
+
[ 00:00:00.030 --> 00:00:00.060]
|
| 798 |
+
[ 00:00:00.040 --> 00:00:00.070]
|
| 799 |
+
[ 00:00:00.050 --> 00:00:00.080]
|
| 800 |
+
[ 00:00:00.060 --> 00:00:00.090]
|
| 801 |
+
[ 00:00:00.070 --> 00:00:00.100]
|
| 802 |
+
[ 00:00:00.080 --> 00:00:00.110]
|
| 803 |
+
[ 00:00:00.090 --> 00:00:00.120]
|
| 804 |
+
"""
|
| 805 |
+
|
| 806 |
+
# reset iterator index
|
| 807 |
+
self.__i = -1
|
| 808 |
+
return self
|
| 809 |
+
|
| 810 |
+
def __len__(self) -> int:
|
| 811 |
+
"""Number of positions
|
| 812 |
+
|
| 813 |
+
Equivalent to len([segment for segment in window])
|
| 814 |
+
|
| 815 |
+
Returns
|
| 816 |
+
-------
|
| 817 |
+
length : int
|
| 818 |
+
Number of positions taken by the sliding window
|
| 819 |
+
(from start times to end times)
|
| 820 |
+
|
| 821 |
+
"""
|
| 822 |
+
if np.isinf(self.__end):
|
| 823 |
+
raise ValueError('infinite sliding window.')
|
| 824 |
+
|
| 825 |
+
# start looking for last position
|
| 826 |
+
# based on frame closest to the end
|
| 827 |
+
i = self.closest_frame(self.__end)
|
| 828 |
+
|
| 829 |
+
while (self[i]):
|
| 830 |
+
i += 1
|
| 831 |
+
length = i
|
| 832 |
+
|
| 833 |
+
return length
|
| 834 |
+
|
| 835 |
+
def copy(self) -> 'SlidingWindow':
|
| 836 |
+
"""Duplicate sliding window"""
|
| 837 |
+
duration = self.duration
|
| 838 |
+
step = self.step
|
| 839 |
+
start = self.start
|
| 840 |
+
end = self.end
|
| 841 |
+
sliding_window = self.__class__(
|
| 842 |
+
duration=duration, step=step, start=start, end=end
|
| 843 |
+
)
|
| 844 |
+
return sliding_window
|
| 845 |
+
|
| 846 |
+
def __call__(self,
|
| 847 |
+
support: Union[Segment, 'Timeline'],
|
| 848 |
+
align_last: bool = False) -> Iterable[Segment]:
|
| 849 |
+
"""Slide window over support
|
| 850 |
+
|
| 851 |
+
Parameter
|
| 852 |
+
---------
|
| 853 |
+
support : Segment or Timeline
|
| 854 |
+
Support on which to slide the window.
|
| 855 |
+
align_last : bool, optional
|
| 856 |
+
Yield a final segment so that it aligns exactly with end of support.
|
| 857 |
+
|
| 858 |
+
Yields
|
| 859 |
+
------
|
| 860 |
+
chunk : Segment
|
| 861 |
+
|
| 862 |
+
Example
|
| 863 |
+
-------
|
| 864 |
+
>>> window = SlidingWindow(duration=2., step=1.)
|
| 865 |
+
>>> for chunk in window(Segment(3, 7.5)):
|
| 866 |
+
... print(tuple(chunk))
|
| 867 |
+
(3.0, 5.0)
|
| 868 |
+
(4.0, 6.0)
|
| 869 |
+
(5.0, 7.0)
|
| 870 |
+
>>> for chunk in window(Segment(3, 7.5), align_last=True):
|
| 871 |
+
... print(tuple(chunk))
|
| 872 |
+
(3.0, 5.0)
|
| 873 |
+
(4.0, 6.0)
|
| 874 |
+
(5.0, 7.0)
|
| 875 |
+
(5.5, 7.5)
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
from pyannote.core import Timeline
|
| 879 |
+
if isinstance(support, Timeline):
|
| 880 |
+
segments = support
|
| 881 |
+
|
| 882 |
+
elif isinstance(support, Segment):
|
| 883 |
+
segments = Timeline(segments=[support])
|
| 884 |
+
|
| 885 |
+
else:
|
| 886 |
+
msg = (
|
| 887 |
+
f'"support" must be either a Segment or a Timeline '
|
| 888 |
+
f'instance (is {type(support)})'
|
| 889 |
+
)
|
| 890 |
+
raise TypeError(msg)
|
| 891 |
+
|
| 892 |
+
for segment in segments:
|
| 893 |
+
|
| 894 |
+
if segment.duration < self.duration:
|
| 895 |
+
continue
|
| 896 |
+
|
| 897 |
+
window = SlidingWindow(duration=self.duration,
|
| 898 |
+
step=self.step,
|
| 899 |
+
start=segment.start,
|
| 900 |
+
end=segment.end)
|
| 901 |
+
|
| 902 |
+
for s in window:
|
| 903 |
+
# ugly hack to account for floating point imprecision
|
| 904 |
+
if s in segment:
|
| 905 |
+
yield s
|
| 906 |
+
last = s
|
| 907 |
+
|
| 908 |
+
if align_last and last.end < segment.end:
|
| 909 |
+
yield Segment(start=segment.end - self.duration,
|
| 910 |
+
end=segment.end)
|
ailia-models/code/pyannote_audio_utils/core/timeline.py
ADDED
|
@@ -0,0 +1,1126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2020 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
# Grant JENKS - http://www.grantjenks.com/
|
| 29 |
+
# Paul LERNER
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
########
|
| 33 |
+
Timeline
|
| 34 |
+
########
|
| 35 |
+
|
| 36 |
+
.. plot:: pyplots/timeline.py
|
| 37 |
+
|
| 38 |
+
:class:`pyannote.core.Timeline` instances are ordered sets of non-empty
|
| 39 |
+
segments:
|
| 40 |
+
|
| 41 |
+
- ordered, because segments are sorted by start time (and end time in case of tie)
|
| 42 |
+
- set, because one cannot add twice the same segment
|
| 43 |
+
- non-empty, because one cannot add empty segments (*i.e.* start >= end)
|
| 44 |
+
|
| 45 |
+
There are two ways to define the timeline depicted above:
|
| 46 |
+
|
| 47 |
+
.. code-block:: ipython
|
| 48 |
+
|
| 49 |
+
In [25]: from pyannote.core import Timeline, Segment
|
| 50 |
+
|
| 51 |
+
In [26]: timeline = Timeline()
|
| 52 |
+
....: timeline.add(Segment(1, 5))
|
| 53 |
+
....: timeline.add(Segment(6, 8))
|
| 54 |
+
....: timeline.add(Segment(12, 18))
|
| 55 |
+
....: timeline.add(Segment(7, 20))
|
| 56 |
+
....:
|
| 57 |
+
|
| 58 |
+
In [27]: segments = [Segment(1, 5), Segment(6, 8), Segment(12, 18), Segment(7, 20)]
|
| 59 |
+
....: timeline = Timeline(segments=segments, uri='my_audio_file') # faster
|
| 60 |
+
....:
|
| 61 |
+
|
| 62 |
+
In [9]: for segment in timeline:
|
| 63 |
+
...: print(segment)
|
| 64 |
+
...:
|
| 65 |
+
[ 00:00:01.000 --> 00:00:05.000]
|
| 66 |
+
[ 00:00:06.000 --> 00:00:08.000]
|
| 67 |
+
[ 00:00:07.000 --> 00:00:20.000]
|
| 68 |
+
[ 00:00:12.000 --> 00:00:18.000]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
.. note::
|
| 72 |
+
|
| 73 |
+
The optional *uri* keyword argument can be used to remember which document it describes.
|
| 74 |
+
|
| 75 |
+
Several convenient methods are available. Here are a few examples:
|
| 76 |
+
|
| 77 |
+
.. code-block:: ipython
|
| 78 |
+
|
| 79 |
+
In [3]: timeline.extent() # extent
|
| 80 |
+
Out[3]: <Segment(1, 20)>
|
| 81 |
+
|
| 82 |
+
In [5]: timeline.support() # support
|
| 83 |
+
Out[5]: <Timeline(uri=my_audio_file, segments=[<Segment(1, 5)>, <Segment(6, 20)>])>
|
| 84 |
+
|
| 85 |
+
In [6]: timeline.duration() # support duration
|
| 86 |
+
Out[6]: 18
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
See :class:`pyannote.core.Timeline` for the complete reference.
|
| 90 |
+
"""
|
| 91 |
+
import warnings
|
| 92 |
+
from typing import (Optional, Iterable, List, Union, Callable,
|
| 93 |
+
TextIO, Tuple, TYPE_CHECKING, Iterator, Dict, Text)
|
| 94 |
+
|
| 95 |
+
from sortedcontainers import SortedList
|
| 96 |
+
|
| 97 |
+
from . import PYANNOTE_SEGMENT
|
| 98 |
+
from .segment import Segment
|
| 99 |
+
from .utils.types import Support, Label, CropMode
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# this is a moderately ugly way to import `Annotation` to the namespace
|
| 103 |
+
# without causing some circular imports :
|
| 104 |
+
# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
|
| 105 |
+
if TYPE_CHECKING:
|
| 106 |
+
from .annotation import Annotation
|
| 107 |
+
import pandas as pd
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# =====================================================================
|
| 111 |
+
# Timeline class
|
| 112 |
+
# =====================================================================
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Timeline:
|
| 116 |
+
"""
|
| 117 |
+
Ordered set of segments.
|
| 118 |
+
|
| 119 |
+
A timeline can be seen as an ordered set of non-empty segments (Segment).
|
| 120 |
+
Segments can overlap -- though adding an already exisiting segment to a
|
| 121 |
+
timeline does nothing.
|
| 122 |
+
|
| 123 |
+
Parameters
|
| 124 |
+
----------
|
| 125 |
+
segments : Segment iterator, optional
|
| 126 |
+
initial set of (non-empty) segments
|
| 127 |
+
uri : string, optional
|
| 128 |
+
name of segmented resource
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
timeline : Timeline
|
| 133 |
+
New timeline
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_df(cls, df: 'pd.DataFrame', uri: Optional[str] = None) -> 'Timeline':
|
| 138 |
+
segments = list(df[PYANNOTE_SEGMENT])
|
| 139 |
+
timeline = cls(segments=segments, uri=uri)
|
| 140 |
+
return timeline
|
| 141 |
+
|
| 142 |
+
def __init__(self,
|
| 143 |
+
segments: Optional[Iterable[Segment]] = None,
|
| 144 |
+
uri: str = None):
|
| 145 |
+
if segments is None:
|
| 146 |
+
segments = ()
|
| 147 |
+
|
| 148 |
+
# set of segments (used for checking inclusion)
|
| 149 |
+
# Store only non-empty Segments.
|
| 150 |
+
segments_set = set([segment for segment in segments if segment])
|
| 151 |
+
|
| 152 |
+
self.segments_set_ = segments_set
|
| 153 |
+
|
| 154 |
+
# sorted list of segments (used for sorted iteration)
|
| 155 |
+
self.segments_list_ = SortedList(segments_set)
|
| 156 |
+
|
| 157 |
+
# sorted list of (possibly redundant) segment boundaries
|
| 158 |
+
boundaries = (boundary for segment in segments_set for boundary in segment)
|
| 159 |
+
self.segments_boundaries_ = SortedList(boundaries)
|
| 160 |
+
|
| 161 |
+
# path to (or any identifier of) segmented resource
|
| 162 |
+
self.uri: str = uri
|
| 163 |
+
|
| 164 |
+
def __len__(self):
|
| 165 |
+
"""Number of segments
|
| 166 |
+
|
| 167 |
+
>>> len(timeline) # timeline contains three segments
|
| 168 |
+
3
|
| 169 |
+
"""
|
| 170 |
+
return len(self.segments_set_)
|
| 171 |
+
|
| 172 |
+
def __nonzero__(self):
|
| 173 |
+
return self.__bool__()
|
| 174 |
+
|
| 175 |
+
def __bool__(self):
|
| 176 |
+
"""Emptiness
|
| 177 |
+
|
| 178 |
+
>>> if timeline:
|
| 179 |
+
... # timeline is not empty
|
| 180 |
+
... else:
|
| 181 |
+
... # timeline is empty
|
| 182 |
+
"""
|
| 183 |
+
return len(self.segments_set_) > 0
|
| 184 |
+
|
| 185 |
+
def __iter__(self) -> Iterable[Segment]:
|
| 186 |
+
"""Iterate over segments (in chronological order)
|
| 187 |
+
|
| 188 |
+
>>> for segment in timeline:
|
| 189 |
+
... # do something with the segment
|
| 190 |
+
|
| 191 |
+
See also
|
| 192 |
+
--------
|
| 193 |
+
:class:`pyannote.core.Segment` describes how segments are sorted.
|
| 194 |
+
"""
|
| 195 |
+
return iter(self.segments_list_)
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, k: int) -> Segment:
|
| 198 |
+
"""Get segment by index (in chronological order)
|
| 199 |
+
|
| 200 |
+
>>> first_segment = timeline[0]
|
| 201 |
+
>>> penultimate_segment = timeline[-2]
|
| 202 |
+
"""
|
| 203 |
+
return self.segments_list_[k]
|
| 204 |
+
|
| 205 |
+
def __eq__(self, other: 'Timeline'):
|
| 206 |
+
"""Equality
|
| 207 |
+
|
| 208 |
+
Two timelines are equal if and only if their segments are equal.
|
| 209 |
+
|
| 210 |
+
>>> timeline1 = Timeline([Segment(0, 1), Segment(2, 3)])
|
| 211 |
+
>>> timeline2 = Timeline([Segment(2, 3), Segment(0, 1)])
|
| 212 |
+
>>> timeline3 = Timeline([Segment(2, 3)])
|
| 213 |
+
>>> timeline1 == timeline2
|
| 214 |
+
True
|
| 215 |
+
>>> timeline1 == timeline3
|
| 216 |
+
False
|
| 217 |
+
"""
|
| 218 |
+
return self.segments_set_ == other.segments_set_
|
| 219 |
+
|
| 220 |
+
def __ne__(self, other: 'Timeline'):
|
| 221 |
+
"""Inequality"""
|
| 222 |
+
return self.segments_set_ != other.segments_set_
|
| 223 |
+
|
| 224 |
+
def index(self, segment: Segment) -> int:
|
| 225 |
+
"""Get index of (existing) segment
|
| 226 |
+
|
| 227 |
+
Parameters
|
| 228 |
+
----------
|
| 229 |
+
segment : Segment
|
| 230 |
+
Segment that is being looked for.
|
| 231 |
+
|
| 232 |
+
Returns
|
| 233 |
+
-------
|
| 234 |
+
position : int
|
| 235 |
+
Index of `segment` in timeline
|
| 236 |
+
|
| 237 |
+
Raises
|
| 238 |
+
------
|
| 239 |
+
ValueError if `segment` is not present.
|
| 240 |
+
"""
|
| 241 |
+
return self.segments_list_.index(segment)
|
| 242 |
+
|
| 243 |
+
def add(self, segment: Segment) -> 'Timeline':
|
| 244 |
+
"""Add a segment (in place)
|
| 245 |
+
|
| 246 |
+
Parameters
|
| 247 |
+
----------
|
| 248 |
+
segment : Segment
|
| 249 |
+
Segment that is being added
|
| 250 |
+
|
| 251 |
+
Returns
|
| 252 |
+
-------
|
| 253 |
+
self : Timeline
|
| 254 |
+
Updated timeline.
|
| 255 |
+
|
| 256 |
+
Note
|
| 257 |
+
----
|
| 258 |
+
If the timeline already contains this segment, it will not be added
|
| 259 |
+
again, as a timeline is meant to be a **set** of segments (not a list).
|
| 260 |
+
|
| 261 |
+
If the segment is empty, it will not be added either, as a timeline
|
| 262 |
+
only contains non-empty segments.
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
segments_set_ = self.segments_set_
|
| 266 |
+
if segment in segments_set_ or not segment:
|
| 267 |
+
return self
|
| 268 |
+
|
| 269 |
+
segments_set_.add(segment)
|
| 270 |
+
|
| 271 |
+
self.segments_list_.add(segment)
|
| 272 |
+
|
| 273 |
+
segments_boundaries_ = self.segments_boundaries_
|
| 274 |
+
segments_boundaries_.add(segment.start)
|
| 275 |
+
segments_boundaries_.add(segment.end)
|
| 276 |
+
|
| 277 |
+
return self
|
| 278 |
+
|
| 279 |
+
def remove(self, segment: Segment) -> 'Timeline':
|
| 280 |
+
"""Remove a segment (in place)
|
| 281 |
+
|
| 282 |
+
Parameters
|
| 283 |
+
----------
|
| 284 |
+
segment : Segment
|
| 285 |
+
Segment that is being removed
|
| 286 |
+
|
| 287 |
+
Returns
|
| 288 |
+
-------
|
| 289 |
+
self : Timeline
|
| 290 |
+
Updated timeline.
|
| 291 |
+
|
| 292 |
+
Note
|
| 293 |
+
----
|
| 294 |
+
If the timeline does not contain this segment, this does nothing
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
segments_set_ = self.segments_set_
|
| 298 |
+
if segment not in segments_set_:
|
| 299 |
+
return self
|
| 300 |
+
|
| 301 |
+
segments_set_.remove(segment)
|
| 302 |
+
|
| 303 |
+
self.segments_list_.remove(segment)
|
| 304 |
+
|
| 305 |
+
segments_boundaries_ = self.segments_boundaries_
|
| 306 |
+
segments_boundaries_.remove(segment.start)
|
| 307 |
+
segments_boundaries_.remove(segment.end)
|
| 308 |
+
|
| 309 |
+
return self
|
| 310 |
+
|
| 311 |
+
def discard(self, segment: Segment) -> 'Timeline':
|
| 312 |
+
"""Same as `remove`
|
| 313 |
+
|
| 314 |
+
See also
|
| 315 |
+
--------
|
| 316 |
+
:func:`pyannote.core.Timeline.remove`
|
| 317 |
+
"""
|
| 318 |
+
return self.remove(segment)
|
| 319 |
+
|
| 320 |
+
def __ior__(self, timeline: 'Timeline') -> 'Timeline':
|
| 321 |
+
return self.update(timeline)
|
| 322 |
+
|
| 323 |
+
def update(self, timeline: Segment) -> 'Timeline':
|
| 324 |
+
"""Add every segments of an existing timeline (in place)
|
| 325 |
+
|
| 326 |
+
Parameters
|
| 327 |
+
----------
|
| 328 |
+
timeline : Timeline
|
| 329 |
+
Timeline whose segments are being added
|
| 330 |
+
|
| 331 |
+
Returns
|
| 332 |
+
-------
|
| 333 |
+
self : Timeline
|
| 334 |
+
Updated timeline
|
| 335 |
+
|
| 336 |
+
Note
|
| 337 |
+
----
|
| 338 |
+
Only segments that do not already exist will be added, as a timeline is
|
| 339 |
+
meant to be a **set** of segments (not a list).
|
| 340 |
+
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
segments_set = self.segments_set_
|
| 344 |
+
|
| 345 |
+
segments_set |= timeline.segments_set_
|
| 346 |
+
|
| 347 |
+
# sorted list of segments (used for sorted iteration)
|
| 348 |
+
self.segments_list_ = SortedList(segments_set)
|
| 349 |
+
|
| 350 |
+
# sorted list of (possibly redundant) segment boundaries
|
| 351 |
+
boundaries = (boundary for segment in segments_set for boundary in segment)
|
| 352 |
+
self.segments_boundaries_ = SortedList(boundaries)
|
| 353 |
+
|
| 354 |
+
return self
|
| 355 |
+
|
| 356 |
+
def __or__(self, timeline: 'Timeline') -> 'Timeline':
|
| 357 |
+
return self.union(timeline)
|
| 358 |
+
|
| 359 |
+
def union(self, timeline: 'Timeline') -> 'Timeline':
|
| 360 |
+
"""Create new timeline made of union of segments
|
| 361 |
+
|
| 362 |
+
Parameters
|
| 363 |
+
----------
|
| 364 |
+
timeline : Timeline
|
| 365 |
+
Timeline whose segments are being added
|
| 366 |
+
|
| 367 |
+
Returns
|
| 368 |
+
-------
|
| 369 |
+
union : Timeline
|
| 370 |
+
New timeline containing the union of both timelines.
|
| 371 |
+
|
| 372 |
+
Note
|
| 373 |
+
----
|
| 374 |
+
This does the same as timeline.update(...) except it returns a new
|
| 375 |
+
timeline, and the original one is not modified.
|
| 376 |
+
"""
|
| 377 |
+
segments = self.segments_set_ | timeline.segments_set_
|
| 378 |
+
return Timeline(segments=segments, uri=self.uri)
|
| 379 |
+
|
| 380 |
+
def co_iter(self, other: 'Timeline') -> Iterator[Tuple[Segment, Segment]]:
|
| 381 |
+
"""Iterate over pairs of intersecting segments
|
| 382 |
+
|
| 383 |
+
>>> timeline1 = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)])
|
| 384 |
+
>>> timeline2 = Timeline([Segment(1, 3), Segment(3, 5)])
|
| 385 |
+
>>> for segment1, segment2 in timeline1.co_iter(timeline2):
|
| 386 |
+
... print(segment1, segment2)
|
| 387 |
+
(<Segment(0, 2)>, <Segment(1, 3)>)
|
| 388 |
+
(<Segment(1, 2)>, <Segment(1, 3)>)
|
| 389 |
+
(<Segment(3, 4)>, <Segment(3, 5)>)
|
| 390 |
+
|
| 391 |
+
Parameters
|
| 392 |
+
----------
|
| 393 |
+
other : Timeline
|
| 394 |
+
Second timeline
|
| 395 |
+
|
| 396 |
+
Returns
|
| 397 |
+
-------
|
| 398 |
+
iterable : (Segment, Segment) iterable
|
| 399 |
+
Yields pairs of intersecting segments in chronological order.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
for segment in self.segments_list_:
|
| 403 |
+
|
| 404 |
+
# iterate over segments that starts before 'segment' ends
|
| 405 |
+
temp = Segment(start=segment.end, end=segment.end)
|
| 406 |
+
for other_segment in other.segments_list_.irange(maximum=temp):
|
| 407 |
+
if segment.intersects(other_segment):
|
| 408 |
+
yield segment, other_segment
|
| 409 |
+
|
| 410 |
+
def crop_iter(self,
|
| 411 |
+
support: Support,
|
| 412 |
+
mode: CropMode = 'intersection',
|
| 413 |
+
returns_mapping: bool = False) \
|
| 414 |
+
-> Iterator[Union[Tuple[Segment, Segment], Segment]]:
|
| 415 |
+
"""Like `crop` but returns a segment iterator instead
|
| 416 |
+
|
| 417 |
+
See also
|
| 418 |
+
--------
|
| 419 |
+
:func:`pyannote.core.Timeline.crop`
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
if mode not in {'loose', 'strict', 'intersection'}:
|
| 423 |
+
raise ValueError("Mode must be one of 'loose', 'strict', or "
|
| 424 |
+
"'intersection'.")
|
| 425 |
+
|
| 426 |
+
if not isinstance(support, (Segment, Timeline)):
|
| 427 |
+
raise TypeError("Support must be a Segment or a Timeline.")
|
| 428 |
+
|
| 429 |
+
if isinstance(support, Segment):
|
| 430 |
+
# corner case where "support" is empty
|
| 431 |
+
if support:
|
| 432 |
+
segments = [support]
|
| 433 |
+
else:
|
| 434 |
+
segments = []
|
| 435 |
+
|
| 436 |
+
support = Timeline(segments=segments, uri=self.uri)
|
| 437 |
+
for yielded in self.crop_iter(support, mode=mode,
|
| 438 |
+
returns_mapping=returns_mapping):
|
| 439 |
+
yield yielded
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
# if 'support' is a `Timeline`, we use its support
|
| 443 |
+
support = support.support()
|
| 444 |
+
|
| 445 |
+
# loose mode
|
| 446 |
+
if mode == 'loose':
|
| 447 |
+
for segment, _ in self.co_iter(support):
|
| 448 |
+
yield segment
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
# strict mode
|
| 452 |
+
if mode == 'strict':
|
| 453 |
+
for segment, other_segment in self.co_iter(support):
|
| 454 |
+
if segment in other_segment:
|
| 455 |
+
yield segment
|
| 456 |
+
return
|
| 457 |
+
|
| 458 |
+
# intersection mode
|
| 459 |
+
for segment, other_segment in self.co_iter(support):
|
| 460 |
+
mapped_to = segment & other_segment
|
| 461 |
+
if not mapped_to:
|
| 462 |
+
continue
|
| 463 |
+
if returns_mapping:
|
| 464 |
+
yield segment, mapped_to
|
| 465 |
+
else:
|
| 466 |
+
yield mapped_to
|
| 467 |
+
|
| 468 |
+
def crop(self,
|
| 469 |
+
support: Support,
|
| 470 |
+
mode: CropMode = 'intersection',
|
| 471 |
+
returns_mapping: bool = False) \
|
| 472 |
+
-> Union['Timeline', Tuple['Timeline', Dict[Segment, Segment]]]:
|
| 473 |
+
"""Crop timeline to new support
|
| 474 |
+
|
| 475 |
+
Parameters
|
| 476 |
+
----------
|
| 477 |
+
support : Segment or Timeline
|
| 478 |
+
If `support` is a `Timeline`, its support is used.
|
| 479 |
+
mode : {'strict', 'loose', 'intersection'}, optional
|
| 480 |
+
Controls how segments that are not fully included in `support` are
|
| 481 |
+
handled. 'strict' mode only keeps fully included segments. 'loose'
|
| 482 |
+
mode keeps any intersecting segment. 'intersection' mode keeps any
|
| 483 |
+
intersecting segment but replace them by their actual intersection.
|
| 484 |
+
returns_mapping : bool, optional
|
| 485 |
+
In 'intersection' mode, return a dictionary whose keys are segments
|
| 486 |
+
of the cropped timeline, and values are list of the original
|
| 487 |
+
segments that were cropped. Defaults to False.
|
| 488 |
+
|
| 489 |
+
Returns
|
| 490 |
+
-------
|
| 491 |
+
cropped : Timeline
|
| 492 |
+
Cropped timeline
|
| 493 |
+
mapping : dict
|
| 494 |
+
When 'returns_mapping' is True, dictionary whose keys are segments
|
| 495 |
+
of 'cropped', and values are lists of corresponding original
|
| 496 |
+
segments.
|
| 497 |
+
|
| 498 |
+
Examples
|
| 499 |
+
--------
|
| 500 |
+
|
| 501 |
+
>>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)])
|
| 502 |
+
>>> timeline.crop(Segment(1, 3))
|
| 503 |
+
<Timeline(uri=None, segments=[<Segment(1, 2)>])>
|
| 504 |
+
|
| 505 |
+
>>> timeline.crop(Segment(1, 3), mode='loose')
|
| 506 |
+
<Timeline(uri=None, segments=[<Segment(0, 2)>, <Segment(1, 2)>])>
|
| 507 |
+
|
| 508 |
+
>>> timeline.crop(Segment(1, 3), mode='strict')
|
| 509 |
+
<Timeline(uri=None, segments=[<Segment(1, 2)>])>
|
| 510 |
+
|
| 511 |
+
>>> cropped, mapping = timeline.crop(Segment(1, 3), returns_mapping=True)
|
| 512 |
+
>>> print(mapping)
|
| 513 |
+
{<Segment(1, 2)>: [<Segment(0, 2)>, <Segment(1, 2)>]}
|
| 514 |
+
|
| 515 |
+
"""
|
| 516 |
+
|
| 517 |
+
if mode == 'intersection' and returns_mapping:
|
| 518 |
+
segments, mapping = [], {}
|
| 519 |
+
for segment, mapped_to in self.crop_iter(support,
|
| 520 |
+
mode='intersection',
|
| 521 |
+
returns_mapping=True):
|
| 522 |
+
segments.append(mapped_to)
|
| 523 |
+
mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment]
|
| 524 |
+
return Timeline(segments=segments, uri=self.uri), mapping
|
| 525 |
+
|
| 526 |
+
return Timeline(segments=self.crop_iter(support, mode=mode),
|
| 527 |
+
uri=self.uri)
|
| 528 |
+
|
| 529 |
+
def overlapping(self, t: float) -> List[Segment]:
|
| 530 |
+
"""Get list of segments overlapping `t`
|
| 531 |
+
|
| 532 |
+
Parameters
|
| 533 |
+
----------
|
| 534 |
+
t : float
|
| 535 |
+
Timestamp, in seconds.
|
| 536 |
+
|
| 537 |
+
Returns
|
| 538 |
+
-------
|
| 539 |
+
segments : list
|
| 540 |
+
List of all segments of timeline containing time t
|
| 541 |
+
"""
|
| 542 |
+
return list(self.overlapping_iter(t))
|
| 543 |
+
|
| 544 |
+
def overlapping_iter(self, t: float) -> Iterator[Segment]:
|
| 545 |
+
"""Like `overlapping` but returns a segment iterator instead
|
| 546 |
+
|
| 547 |
+
See also
|
| 548 |
+
--------
|
| 549 |
+
:func:`pyannote.core.Timeline.overlapping`
|
| 550 |
+
"""
|
| 551 |
+
segment = Segment(start=t, end=t)
|
| 552 |
+
for segment in self.segments_list_.irange(maximum=segment):
|
| 553 |
+
if segment.overlaps(t):
|
| 554 |
+
yield segment
|
| 555 |
+
|
| 556 |
+
def get_overlap(self) -> 'Timeline':
|
| 557 |
+
"""Get overlapping parts of the timeline.
|
| 558 |
+
|
| 559 |
+
A simple illustration:
|
| 560 |
+
|
| 561 |
+
timeline
|
| 562 |
+
|------| |------| |----|
|
| 563 |
+
|--| |-----| |----------|
|
| 564 |
+
|
| 565 |
+
timeline.get_overlap()
|
| 566 |
+
|--| |---| |----|
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
Returns
|
| 570 |
+
-------
|
| 571 |
+
overlap : `pyannote.core.Timeline`
|
| 572 |
+
Timeline of the overlaps.
|
| 573 |
+
"""
|
| 574 |
+
overlaps_tl = Timeline(uri=self.uri)
|
| 575 |
+
for s1, s2 in self.co_iter(self):
|
| 576 |
+
if s1 == s2:
|
| 577 |
+
continue
|
| 578 |
+
overlaps_tl.add(s1 & s2)
|
| 579 |
+
return overlaps_tl.support()
|
| 580 |
+
|
| 581 |
+
def extrude(self,
|
| 582 |
+
removed: Support,
|
| 583 |
+
mode: CropMode = 'intersection') -> 'Timeline':
|
| 584 |
+
"""Remove segments that overlap `removed` support.
|
| 585 |
+
|
| 586 |
+
Parameters
|
| 587 |
+
----------
|
| 588 |
+
removed : Segment or Timeline
|
| 589 |
+
If `support` is a `Timeline`, its support is used.
|
| 590 |
+
mode : {'strict', 'loose', 'intersection'}, optional
|
| 591 |
+
Controls how segments that are not fully included in `removed` are
|
| 592 |
+
handled. 'strict' mode only removes fully included segments. 'loose'
|
| 593 |
+
mode removes any intersecting segment. 'intersection' mode removes
|
| 594 |
+
the overlapping part of any intersecting segment.
|
| 595 |
+
|
| 596 |
+
Returns
|
| 597 |
+
-------
|
| 598 |
+
extruded : Timeline
|
| 599 |
+
Extruded timeline
|
| 600 |
+
|
| 601 |
+
Examples
|
| 602 |
+
--------
|
| 603 |
+
|
| 604 |
+
>>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 5)])
|
| 605 |
+
>>> timeline.extrude(Segment(1, 2))
|
| 606 |
+
<Timeline(uri=None, segments=[<Segment(0, 1)>, <Segment(3, 5)>])>
|
| 607 |
+
|
| 608 |
+
>>> timeline.extrude(Segment(1, 3), mode='loose')
|
| 609 |
+
<Timeline(uri=None, segments=[<Segment(3, 5)>])>
|
| 610 |
+
|
| 611 |
+
>>> timeline.extrude(Segment(1, 3), mode='strict')
|
| 612 |
+
<Timeline(uri=None, segments=[<Segment(0, 2)>, <Segment(3, 5)>])>
|
| 613 |
+
|
| 614 |
+
"""
|
| 615 |
+
if isinstance(removed, Segment):
|
| 616 |
+
removed = Timeline([removed])
|
| 617 |
+
|
| 618 |
+
extent_tl = Timeline([self.extent()], uri=self.uri)
|
| 619 |
+
truncating_support = removed.gaps(support=extent_tl)
|
| 620 |
+
# loose for truncate means strict for crop and vice-versa
|
| 621 |
+
if mode == "loose":
|
| 622 |
+
mode = "strict"
|
| 623 |
+
elif mode == "strict":
|
| 624 |
+
mode = "loose"
|
| 625 |
+
return self.crop(truncating_support, mode=mode)
|
| 626 |
+
|
| 627 |
+
def __str__(self):
|
| 628 |
+
"""Human-readable representation
|
| 629 |
+
|
| 630 |
+
>>> timeline = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)])
|
| 631 |
+
>>> print(timeline)
|
| 632 |
+
[[ 00:00:00.000 --> 00:00:10.000]
|
| 633 |
+
[ 00:00:01.000 --> 00:00:13.370]]
|
| 634 |
+
|
| 635 |
+
"""
|
| 636 |
+
|
| 637 |
+
n = len(self.segments_list_)
|
| 638 |
+
string = "["
|
| 639 |
+
for i, segment in enumerate(self.segments_list_):
|
| 640 |
+
string += str(segment)
|
| 641 |
+
string += "\n " if i + 1 < n else ""
|
| 642 |
+
string += "]"
|
| 643 |
+
return string
|
| 644 |
+
|
| 645 |
+
def __repr__(self):
|
| 646 |
+
"""Computer-readable representation
|
| 647 |
+
|
| 648 |
+
>>> Timeline(segments=[Segment(0, 10), Segment(1, 13.37)])
|
| 649 |
+
<Timeline(uri=None, segments=[<Segment(0, 10)>, <Segment(1, 13.37)>])>
|
| 650 |
+
|
| 651 |
+
"""
|
| 652 |
+
|
| 653 |
+
return "<Timeline(uri=%s, segments=%s)>" % (self.uri,
|
| 654 |
+
list(self.segments_list_))
|
| 655 |
+
|
| 656 |
+
def __contains__(self, included: Union[Segment, 'Timeline']):
|
| 657 |
+
"""Inclusion
|
| 658 |
+
|
| 659 |
+
Check whether every segment of `included` does exist in timeline.
|
| 660 |
+
|
| 661 |
+
Parameters
|
| 662 |
+
----------
|
| 663 |
+
included : Segment or Timeline
|
| 664 |
+
Segment or timeline being checked for inclusion
|
| 665 |
+
|
| 666 |
+
Returns
|
| 667 |
+
-------
|
| 668 |
+
contains : bool
|
| 669 |
+
True if every segment in `included` exists in timeline,
|
| 670 |
+
False otherwise
|
| 671 |
+
|
| 672 |
+
Examples
|
| 673 |
+
--------
|
| 674 |
+
>>> timeline1 = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)])
|
| 675 |
+
>>> timeline2 = Timeline(segments=[Segment(0, 10)])
|
| 676 |
+
>>> timeline1 in timeline2
|
| 677 |
+
False
|
| 678 |
+
>>> timeline2 in timeline1
|
| 679 |
+
>>> Segment(1, 13.37) in timeline1
|
| 680 |
+
True
|
| 681 |
+
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
if isinstance(included, Segment):
|
| 685 |
+
return included in self.segments_set_
|
| 686 |
+
|
| 687 |
+
elif isinstance(included, Timeline):
|
| 688 |
+
return self.segments_set_.issuperset(included.segments_set_)
|
| 689 |
+
|
| 690 |
+
else:
|
| 691 |
+
raise TypeError(
|
| 692 |
+
'Checking for inclusion only supports Segment and '
|
| 693 |
+
'Timeline instances')
|
| 694 |
+
|
| 695 |
+
def empty(self) -> 'Timeline':
|
| 696 |
+
"""Return an empty copy
|
| 697 |
+
|
| 698 |
+
Returns
|
| 699 |
+
-------
|
| 700 |
+
empty : Timeline
|
| 701 |
+
Empty timeline using the same 'uri' attribute.
|
| 702 |
+
|
| 703 |
+
"""
|
| 704 |
+
return Timeline(uri=self.uri)
|
| 705 |
+
|
| 706 |
+
def covers(self, other: 'Timeline') -> bool:
|
| 707 |
+
"""Check whether other timeline is fully covered by the timeline
|
| 708 |
+
|
| 709 |
+
Parameter
|
| 710 |
+
---------
|
| 711 |
+
other : Timeline
|
| 712 |
+
Second timeline
|
| 713 |
+
|
| 714 |
+
Returns
|
| 715 |
+
-------
|
| 716 |
+
covers : bool
|
| 717 |
+
True if timeline covers "other" timeline entirely. False if at least
|
| 718 |
+
one segment of "other" is not fully covered by timeline
|
| 719 |
+
"""
|
| 720 |
+
|
| 721 |
+
# compute gaps within "other" extent
|
| 722 |
+
# this is where we should look for possible faulty segments
|
| 723 |
+
gaps = self.gaps(support=other.extent())
|
| 724 |
+
|
| 725 |
+
# if at least one gap intersects with a segment from "other",
|
| 726 |
+
# "self" does not cover "other" entirely --> return False
|
| 727 |
+
for _ in gaps.co_iter(other):
|
| 728 |
+
return False
|
| 729 |
+
|
| 730 |
+
# if no gap intersects with a segment from "other",
|
| 731 |
+
# "self" covers "other" entirely --> return True
|
| 732 |
+
return True
|
| 733 |
+
|
| 734 |
+
def copy(self, segment_func: Optional[Callable[[Segment], Segment]] = None) \
|
| 735 |
+
-> 'Timeline':
|
| 736 |
+
"""Get a copy of the timeline
|
| 737 |
+
|
| 738 |
+
If `segment_func` is provided, it is applied to each segment first.
|
| 739 |
+
|
| 740 |
+
Parameters
|
| 741 |
+
----------
|
| 742 |
+
segment_func : callable, optional
|
| 743 |
+
Callable that takes a segment as input, and returns a segment.
|
| 744 |
+
Defaults to identity function (segment_func(segment) = segment)
|
| 745 |
+
|
| 746 |
+
Returns
|
| 747 |
+
-------
|
| 748 |
+
timeline : Timeline
|
| 749 |
+
Copy of the timeline
|
| 750 |
+
|
| 751 |
+
"""
|
| 752 |
+
|
| 753 |
+
# if segment_func is not provided
|
| 754 |
+
# just add every segment
|
| 755 |
+
if segment_func is None:
|
| 756 |
+
return Timeline(segments=self.segments_list_, uri=self.uri)
|
| 757 |
+
|
| 758 |
+
# if is provided
|
| 759 |
+
# apply it to each segment before adding them
|
| 760 |
+
return Timeline(segments=[segment_func(s) for s in self.segments_list_],
|
| 761 |
+
uri=self.uri)
|
| 762 |
+
|
| 763 |
+
def extent(self) -> Segment:
|
| 764 |
+
"""Extent
|
| 765 |
+
|
| 766 |
+
The extent of a timeline is the segment of minimum duration that
|
| 767 |
+
contains every segments of the timeline. It is unique, by definition.
|
| 768 |
+
The extent of an empty timeline is an empty segment.
|
| 769 |
+
|
| 770 |
+
A picture is worth a thousand words::
|
| 771 |
+
|
| 772 |
+
timeline
|
| 773 |
+
|------| |------| |----|
|
| 774 |
+
|--| |-----| |----------|
|
| 775 |
+
|
| 776 |
+
timeline.extent()
|
| 777 |
+
|--------------------------------|
|
| 778 |
+
|
| 779 |
+
Returns
|
| 780 |
+
-------
|
| 781 |
+
extent : Segment
|
| 782 |
+
Timeline extent
|
| 783 |
+
|
| 784 |
+
Examples
|
| 785 |
+
--------
|
| 786 |
+
>>> timeline = Timeline(segments=[Segment(0, 1), Segment(9, 10)])
|
| 787 |
+
>>> timeline.extent()
|
| 788 |
+
<Segment(0, 10)>
|
| 789 |
+
|
| 790 |
+
"""
|
| 791 |
+
if self.segments_set_:
|
| 792 |
+
segments_boundaries_ = self.segments_boundaries_
|
| 793 |
+
start = segments_boundaries_[0]
|
| 794 |
+
end = segments_boundaries_[-1]
|
| 795 |
+
return Segment(start=start, end=end)
|
| 796 |
+
|
| 797 |
+
return Segment(start=0.0, end=0.0)
|
| 798 |
+
|
| 799 |
+
def support_iter(self, collar: float = 0.0) -> Iterator[Segment]:
|
| 800 |
+
"""Like `support` but returns a segment generator instead
|
| 801 |
+
|
| 802 |
+
See also
|
| 803 |
+
--------
|
| 804 |
+
:func:`pyannote.core.Timeline.support`
|
| 805 |
+
"""
|
| 806 |
+
|
| 807 |
+
# The support of an empty timeline is an empty timeline.
|
| 808 |
+
if not self:
|
| 809 |
+
return
|
| 810 |
+
|
| 811 |
+
# Principle:
|
| 812 |
+
# * gather all segments with no gap between them
|
| 813 |
+
# * add one segment per resulting group (their union |)
|
| 814 |
+
# Note:
|
| 815 |
+
# Since segments are kept sorted internally,
|
| 816 |
+
# there is no need to perform an exhaustive segment clustering.
|
| 817 |
+
# We just have to consider them in their natural order.
|
| 818 |
+
|
| 819 |
+
# Initialize new support segment
|
| 820 |
+
# as very first segment of the timeline
|
| 821 |
+
new_segment = self.segments_list_[0]
|
| 822 |
+
|
| 823 |
+
for segment in self:
|
| 824 |
+
|
| 825 |
+
# If there is no gap between new support segment and next segment
|
| 826 |
+
# OR there is a gap with duration < collar seconds,
|
| 827 |
+
possible_gap = segment ^ new_segment
|
| 828 |
+
if not possible_gap or possible_gap.duration < collar:
|
| 829 |
+
# Extend new support segment using next segment
|
| 830 |
+
new_segment |= segment
|
| 831 |
+
|
| 832 |
+
# If there actually is a gap and the gap duration >= collar
|
| 833 |
+
# seconds,
|
| 834 |
+
else:
|
| 835 |
+
yield new_segment
|
| 836 |
+
|
| 837 |
+
# Initialize new support segment as next segment
|
| 838 |
+
# (right after the gap)
|
| 839 |
+
new_segment = segment
|
| 840 |
+
|
| 841 |
+
# Add new segment to the timeline support
|
| 842 |
+
yield new_segment
|
| 843 |
+
|
| 844 |
+
def support(self, collar: float = 0.) -> 'Timeline':
|
| 845 |
+
"""Timeline support
|
| 846 |
+
|
| 847 |
+
The support of a timeline is the timeline with the minimum number of
|
| 848 |
+
segments with exactly the same time span as the original timeline. It
|
| 849 |
+
is (by definition) unique and does not contain any overlapping
|
| 850 |
+
segments.
|
| 851 |
+
|
| 852 |
+
A picture is worth a thousand words::
|
| 853 |
+
|
| 854 |
+
collar
|
| 855 |
+
|---|
|
| 856 |
+
|
| 857 |
+
timeline
|
| 858 |
+
|------| |------| |----|
|
| 859 |
+
|--| |-----| |----------|
|
| 860 |
+
|
| 861 |
+
timeline.support()
|
| 862 |
+
|------| |--------| |----------|
|
| 863 |
+
|
| 864 |
+
timeline.support(collar)
|
| 865 |
+
|------------------| |----------|
|
| 866 |
+
|
| 867 |
+
Parameters
|
| 868 |
+
----------
|
| 869 |
+
collar : float, optional
|
| 870 |
+
Merge separated by less than `collar` seconds. This is why there
|
| 871 |
+
are only two segments in the final timeline in the above figure.
|
| 872 |
+
Defaults to 0.
|
| 873 |
+
|
| 874 |
+
Returns
|
| 875 |
+
-------
|
| 876 |
+
support : Timeline
|
| 877 |
+
Timeline support
|
| 878 |
+
"""
|
| 879 |
+
return Timeline(segments=self.support_iter(collar), uri=self.uri)
|
| 880 |
+
|
| 881 |
+
def duration(self) -> float:
|
| 882 |
+
"""Timeline duration
|
| 883 |
+
|
| 884 |
+
The timeline duration is the sum of the durations of the segments
|
| 885 |
+
in the timeline support.
|
| 886 |
+
|
| 887 |
+
Returns
|
| 888 |
+
-------
|
| 889 |
+
duration : float
|
| 890 |
+
Duration of timeline support, in seconds.
|
| 891 |
+
"""
|
| 892 |
+
|
| 893 |
+
# The timeline duration is the sum of the durations
|
| 894 |
+
# of the segments in the timeline support.
|
| 895 |
+
return sum(s.duration for s in self.support_iter())
|
| 896 |
+
|
| 897 |
+
def gaps_iter(self, support: Optional[Support] = None) -> Iterator[Segment]:
|
| 898 |
+
"""Like `gaps` but returns a segment generator instead
|
| 899 |
+
|
| 900 |
+
See also
|
| 901 |
+
--------
|
| 902 |
+
:func:`pyannote.core.Timeline.gaps`
|
| 903 |
+
|
| 904 |
+
"""
|
| 905 |
+
|
| 906 |
+
if support is None:
|
| 907 |
+
support = self.extent()
|
| 908 |
+
|
| 909 |
+
if not isinstance(support, (Segment, Timeline)):
|
| 910 |
+
raise TypeError("unsupported operand type(s) for -':"
|
| 911 |
+
"%s and Timeline." % type(support).__name__)
|
| 912 |
+
|
| 913 |
+
# segment support
|
| 914 |
+
if isinstance(support, Segment):
|
| 915 |
+
|
| 916 |
+
# `end` is meant to store the end time of former segment
|
| 917 |
+
# initialize it with beginning of provided segment `support`
|
| 918 |
+
end = support.start
|
| 919 |
+
|
| 920 |
+
# support on the intersection of timeline and provided segment
|
| 921 |
+
for segment in self.crop(support, mode='intersection').support():
|
| 922 |
+
|
| 923 |
+
# add gap between each pair of consecutive segments
|
| 924 |
+
# if there is no gap, segment is empty, therefore not added
|
| 925 |
+
gap = Segment(start=end, end=segment.start)
|
| 926 |
+
if gap:
|
| 927 |
+
yield gap
|
| 928 |
+
|
| 929 |
+
# keep track of the end of former segment
|
| 930 |
+
end = segment.end
|
| 931 |
+
|
| 932 |
+
# add final gap (if not empty)
|
| 933 |
+
gap = Segment(start=end, end=support.end)
|
| 934 |
+
if gap:
|
| 935 |
+
yield gap
|
| 936 |
+
|
| 937 |
+
# timeline support
|
| 938 |
+
elif isinstance(support, Timeline):
|
| 939 |
+
|
| 940 |
+
# yield gaps for every segment in support of provided timeline
|
| 941 |
+
for segment in support.support():
|
| 942 |
+
for gap in self.gaps_iter(support=segment):
|
| 943 |
+
yield gap
|
| 944 |
+
|
| 945 |
+
def gaps(self, support: Optional[Support] = None) \
|
| 946 |
+
-> 'Timeline':
|
| 947 |
+
"""Gaps
|
| 948 |
+
|
| 949 |
+
A picture is worth a thousand words::
|
| 950 |
+
|
| 951 |
+
timeline
|
| 952 |
+
|------| |------| |----|
|
| 953 |
+
|--| |-----| |----------|
|
| 954 |
+
|
| 955 |
+
timeline.gaps()
|
| 956 |
+
|--| |--|
|
| 957 |
+
|
| 958 |
+
Parameters
|
| 959 |
+
----------
|
| 960 |
+
support : None, Segment or Timeline
|
| 961 |
+
Support in which gaps are looked for. Defaults to timeline extent
|
| 962 |
+
|
| 963 |
+
Returns
|
| 964 |
+
-------
|
| 965 |
+
gaps : Timeline
|
| 966 |
+
Timeline made of all gaps from original timeline, and delimited
|
| 967 |
+
by provided support
|
| 968 |
+
|
| 969 |
+
See also
|
| 970 |
+
--------
|
| 971 |
+
:func:`pyannote.core.Timeline.extent`
|
| 972 |
+
|
| 973 |
+
"""
|
| 974 |
+
return Timeline(segments=self.gaps_iter(support=support),
|
| 975 |
+
uri=self.uri)
|
| 976 |
+
|
| 977 |
+
def segmentation(self) -> 'Timeline':
|
| 978 |
+
"""Segmentation
|
| 979 |
+
|
| 980 |
+
Create the unique timeline with same support and same set of segment
|
| 981 |
+
boundaries as original timeline, but with no overlapping segments.
|
| 982 |
+
|
| 983 |
+
A picture is worth a thousand words::
|
| 984 |
+
|
| 985 |
+
timeline
|
| 986 |
+
|------| |------| |----|
|
| 987 |
+
|--| |-----| |----------|
|
| 988 |
+
|
| 989 |
+
timeline.segmentation()
|
| 990 |
+
|-|--|-| |-|---|--| |--|----|--|
|
| 991 |
+
|
| 992 |
+
Returns
|
| 993 |
+
-------
|
| 994 |
+
timeline : Timeline
|
| 995 |
+
(unique) timeline with same support and same set of segment
|
| 996 |
+
boundaries as original timeline, but with no overlapping segments.
|
| 997 |
+
"""
|
| 998 |
+
# COMPLEXITY: O(n)
|
| 999 |
+
support = self.support()
|
| 1000 |
+
|
| 1001 |
+
# COMPLEXITY: O(n.log n)
|
| 1002 |
+
# get all boundaries (sorted)
|
| 1003 |
+
# |------| |------| |----|
|
| 1004 |
+
# |--| |-----| |----------|
|
| 1005 |
+
# becomes
|
| 1006 |
+
# | | | | | | | | | | | |
|
| 1007 |
+
timestamps = set([])
|
| 1008 |
+
for (start, end) in self:
|
| 1009 |
+
timestamps.add(start)
|
| 1010 |
+
timestamps.add(end)
|
| 1011 |
+
timestamps = sorted(timestamps)
|
| 1012 |
+
|
| 1013 |
+
# create new partition timeline
|
| 1014 |
+
# | | | | | | | | | | | |
|
| 1015 |
+
# becomes
|
| 1016 |
+
# |-|--|-| |-|---|--| |--|----|--|
|
| 1017 |
+
|
| 1018 |
+
# start with an empty copy
|
| 1019 |
+
timeline = Timeline(uri=self.uri)
|
| 1020 |
+
|
| 1021 |
+
if len(timestamps) == 0:
|
| 1022 |
+
return Timeline(uri=self.uri)
|
| 1023 |
+
|
| 1024 |
+
segments = []
|
| 1025 |
+
start = timestamps[0]
|
| 1026 |
+
for end in timestamps[1:]:
|
| 1027 |
+
# only add segments that are covered by original timeline
|
| 1028 |
+
segment = Segment(start=start, end=end)
|
| 1029 |
+
if segment and support.overlapping(segment.middle):
|
| 1030 |
+
segments.append(segment)
|
| 1031 |
+
# next segment...
|
| 1032 |
+
start = end
|
| 1033 |
+
|
| 1034 |
+
return Timeline(segments=segments, uri=self.uri)
|
| 1035 |
+
|
| 1036 |
+
def to_annotation(self,
|
| 1037 |
+
generator: Union[str, Iterable[Label], None, None] = 'string',
|
| 1038 |
+
modality: Optional[str] = None) \
|
| 1039 |
+
-> 'Annotation':
|
| 1040 |
+
"""Turn timeline into an annotation
|
| 1041 |
+
|
| 1042 |
+
Each segment is labeled by a unique label.
|
| 1043 |
+
|
| 1044 |
+
Parameters
|
| 1045 |
+
----------
|
| 1046 |
+
generator : 'string', 'int', or iterable, optional
|
| 1047 |
+
If 'string' (default) generate string labels. If 'int', generate
|
| 1048 |
+
integer labels. If iterable, use it to generate labels.
|
| 1049 |
+
modality : str, optional
|
| 1050 |
+
|
| 1051 |
+
Returns
|
| 1052 |
+
-------
|
| 1053 |
+
annotation : Annotation
|
| 1054 |
+
Annotation
|
| 1055 |
+
"""
|
| 1056 |
+
|
| 1057 |
+
from .annotation import Annotation
|
| 1058 |
+
annotation = Annotation(uri=self.uri, modality=modality)
|
| 1059 |
+
if generator == 'string':
|
| 1060 |
+
from .utils.generators import string_generator
|
| 1061 |
+
generator = string_generator()
|
| 1062 |
+
elif generator == 'int':
|
| 1063 |
+
from .utils.generators import int_generator
|
| 1064 |
+
generator = int_generator()
|
| 1065 |
+
|
| 1066 |
+
for segment in self:
|
| 1067 |
+
annotation[segment] = next(generator)
|
| 1068 |
+
|
| 1069 |
+
return annotation
|
| 1070 |
+
|
| 1071 |
+
def _iter_uem(self) -> Iterator[Text]:
|
| 1072 |
+
"""Generate lines for a UEM file for this timeline
|
| 1073 |
+
|
| 1074 |
+
Returns
|
| 1075 |
+
-------
|
| 1076 |
+
iterator: Iterator[str]
|
| 1077 |
+
An iterator over UEM text lines
|
| 1078 |
+
"""
|
| 1079 |
+
uri = self.uri if self.uri else "<NA>"
|
| 1080 |
+
if isinstance(uri, Text) and ' ' in uri:
|
| 1081 |
+
msg = (f'Space-separated UEM file format does not allow file URIs '
|
| 1082 |
+
f'containing spaces (got: "{uri}").')
|
| 1083 |
+
raise ValueError(msg)
|
| 1084 |
+
for segment in self:
|
| 1085 |
+
yield f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n"
|
| 1086 |
+
|
| 1087 |
+
def to_uem(self) -> Text:
|
| 1088 |
+
"""Serialize timeline as a string using UEM format
|
| 1089 |
+
|
| 1090 |
+
Returns
|
| 1091 |
+
-------
|
| 1092 |
+
serialized: str
|
| 1093 |
+
UEM string
|
| 1094 |
+
"""
|
| 1095 |
+
return "".join([line for line in self._iter_uem()])
|
| 1096 |
+
|
| 1097 |
+
def write_uem(self, file: TextIO):
|
| 1098 |
+
"""Dump timeline to file using UEM format
|
| 1099 |
+
|
| 1100 |
+
Parameters
|
| 1101 |
+
----------
|
| 1102 |
+
file : file object
|
| 1103 |
+
|
| 1104 |
+
Usage
|
| 1105 |
+
-----
|
| 1106 |
+
>>> with open('file.uem', 'w') as file:
|
| 1107 |
+
... timeline.write_uem(file)
|
| 1108 |
+
"""
|
| 1109 |
+
for line in self._iter_uem():
|
| 1110 |
+
file.write(line)
|
| 1111 |
+
|
| 1112 |
+
def _repr_png_(self):
|
| 1113 |
+
"""IPython notebook support
|
| 1114 |
+
|
| 1115 |
+
See also
|
| 1116 |
+
--------
|
| 1117 |
+
:mod:`pyannote.core.notebook`
|
| 1118 |
+
"""
|
| 1119 |
+
|
| 1120 |
+
from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
|
| 1121 |
+
if not MATPLOTLIB_IS_AVAILABLE:
|
| 1122 |
+
warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
|
| 1123 |
+
return None
|
| 1124 |
+
|
| 1125 |
+
from .notebook import repr_timeline
|
| 1126 |
+
return repr_timeline(self)
|
ailia-models/code/pyannote_audio_utils/core/utils/generators.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2014-2018 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
import itertools
|
| 30 |
+
from string import ascii_uppercase
|
| 31 |
+
from typing import Iterable, Union, List, Set, Optional, Iterator
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pairwise(iterable: Iterable):
|
| 35 |
+
"""s -> (s0,s1), (s1,s2), (s2, s3), ..."""
|
| 36 |
+
a, b = itertools.tee(iterable)
|
| 37 |
+
next(b, None)
|
| 38 |
+
return zip(a, b)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def string_generator(skip: Optional[Union[List, Set]] = None) \
|
| 42 |
+
-> Iterator[str]:
|
| 43 |
+
"""Label generator
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
skip : list or set
|
| 48 |
+
List of labels that must be skipped.
|
| 49 |
+
This option is useful in case you want to make sure generated labels
|
| 50 |
+
are different from a pre-existing set of labels.
|
| 51 |
+
|
| 52 |
+
Usage
|
| 53 |
+
-----
|
| 54 |
+
t = string_generator()
|
| 55 |
+
next(t) -> 'A' # start with 1-letter labels
|
| 56 |
+
... # from A to Z
|
| 57 |
+
next(t) -> 'Z'
|
| 58 |
+
next(t) -> 'AA' # then 2-letters labels
|
| 59 |
+
next(t) -> 'AB' # from AA to ZZ
|
| 60 |
+
...
|
| 61 |
+
next(t) -> 'ZY'
|
| 62 |
+
next(t) -> 'ZZ'
|
| 63 |
+
next(t) -> 'AAA' # then 3-letters labels
|
| 64 |
+
... # (you get the idea)
|
| 65 |
+
"""
|
| 66 |
+
if skip is None:
|
| 67 |
+
skip = list()
|
| 68 |
+
|
| 69 |
+
# label length
|
| 70 |
+
r = 1
|
| 71 |
+
|
| 72 |
+
# infinite loop
|
| 73 |
+
while True:
|
| 74 |
+
|
| 75 |
+
# generate labels with current length
|
| 76 |
+
for c in itertools.product(ascii_uppercase, repeat=r):
|
| 77 |
+
if c in skip:
|
| 78 |
+
continue
|
| 79 |
+
yield ''.join(c)
|
| 80 |
+
|
| 81 |
+
# increment label length when all possibilities are exhausted
|
| 82 |
+
r = r + 1
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def int_generator() -> Iterator[int]:
|
| 86 |
+
i = 0
|
| 87 |
+
while True:
|
| 88 |
+
yield i
|
| 89 |
+
i = i + 1
|
ailia-models/code/pyannote_audio_utils/core/utils/types.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Hashable, Union, Tuple, Iterator, Literal
|
| 2 |
+
|
| 3 |
+
Label = Hashable
|
| 4 |
+
Support = Union['Segment', 'Timeline']
|
| 5 |
+
LabelGeneratorMode = Literal['int', 'string']
|
| 6 |
+
LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]]
|
| 7 |
+
TrackName = Union[str, int]
|
| 8 |
+
Key = Union['Segment', Tuple['Segment', TrackName]]
|
| 9 |
+
Resource = Union['Segment', 'Timeline', 'SlidingWindowFeature',
|
| 10 |
+
'Annotation']
|
| 11 |
+
CropMode = Literal['intersection', 'loose', 'strict']
|
| 12 |
+
Alignment = Literal['center', 'loose', 'strict']
|
| 13 |
+
LabelStyle = Tuple[str, int, Tuple[float, float, float]]
|
ailia-models/code/pyannote_audio_utils/database/__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2016- CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
# Alexis PLAQUET
|
| 29 |
+
|
| 30 |
+
"""pyannote.database"""
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from typing import Optional
|
| 34 |
+
import warnings
|
| 35 |
+
|
| 36 |
+
# from .registry import registry, LoadingMode
|
| 37 |
+
|
| 38 |
+
# from .database import Database
|
| 39 |
+
|
| 40 |
+
from .protocol.protocol import Protocol
|
| 41 |
+
from .protocol.protocol import ProtocolFile
|
| 42 |
+
from .protocol.protocol import Subset
|
| 43 |
+
from .protocol.protocol import Preprocessors
|
| 44 |
+
|
| 45 |
+
# from .file_finder import FileFinder
|
| 46 |
+
# from .util import get_annotated
|
| 47 |
+
# from .util import get_unique_identifier
|
| 48 |
+
# from .util import get_label_identifier
|
| 49 |
+
|
| 50 |
+
# from ._version import get_versions
|
| 51 |
+
#
|
| 52 |
+
|
| 53 |
+
# __version__ = get_versions()["version"]
|
| 54 |
+
# del get_versions
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_protocol(name, preprocessors: Optional[Preprocessors] = None) -> Protocol:
|
| 58 |
+
"""Get protocol by full name
|
| 59 |
+
|
| 60 |
+
name : str
|
| 61 |
+
Protocol full name (e.g. "Etape.SpeakerDiarization.TV")
|
| 62 |
+
preprocessors : dict or (key, preprocessor) iterable
|
| 63 |
+
When provided, each protocol item (dictionary) are preprocessed, such
|
| 64 |
+
that item[key] = preprocessor(item). In case 'preprocessor' is not
|
| 65 |
+
callable, it should be a string containing placeholder for item keys
|
| 66 |
+
(e.g. {'audio': '/path/to/{uri}.wav'})
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
protocol : Protocol
|
| 71 |
+
Protocol instance
|
| 72 |
+
"""
|
| 73 |
+
warnings.warn(
|
| 74 |
+
"`get_protocol` has been deprecated in favor of `pyannote.database.registry.get_protocol`.",
|
| 75 |
+
DeprecationWarning)
|
| 76 |
+
return registry.get_protocol(name, preprocessors=preprocessors)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
__all__ = [
|
| 80 |
+
"registry",
|
| 81 |
+
"get_protocol",
|
| 82 |
+
"LoadingMode",
|
| 83 |
+
"Database",
|
| 84 |
+
"Protocol",
|
| 85 |
+
"ProtocolFile",
|
| 86 |
+
"Subset",
|
| 87 |
+
"FileFinder",
|
| 88 |
+
"get_annotated",
|
| 89 |
+
"get_unique_identifier",
|
| 90 |
+
"get_label_identifier",
|
| 91 |
+
]
|
ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2016- CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
from .protocol import Protocol
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
"Protocol",
|
| 33 |
+
]
|
| 34 |
+
|
ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2016-2020 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
#########
|
| 31 |
+
Protocols
|
| 32 |
+
#########
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import warnings
|
| 37 |
+
import collections
|
| 38 |
+
import threading
|
| 39 |
+
import itertools
|
| 40 |
+
from typing import Union, Dict, Iterator, Callable, Any, Text, Optional
|
| 41 |
+
|
| 42 |
+
# try:
|
| 43 |
+
from typing import Literal
|
| 44 |
+
# except ImportError:
|
| 45 |
+
# from typing_extensions import Literal
|
| 46 |
+
|
| 47 |
+
Subset = Literal["train", "development", "test"]
|
| 48 |
+
LEGACY_SUBSET_MAPPING = {"train": "trn", "development": "dev", "test": "tst"}
|
| 49 |
+
Scope = Literal["file", "database", "global"]
|
| 50 |
+
|
| 51 |
+
Preprocessor = Callable[["ProtocolFile"], Any]
|
| 52 |
+
Preprocessors = Dict[Text, Preprocessor]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ProtocolFile(collections.abc.MutableMapping):
|
| 56 |
+
"""Protocol file with lazy preprocessors
|
| 57 |
+
|
| 58 |
+
This is a dict-like data structure where some values may depend on other
|
| 59 |
+
values, and are only computed if/when requested. Once computed, they are
|
| 60 |
+
cached and never recomputed again.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
precomputed : dict
|
| 65 |
+
Regular dictionary with precomputed values
|
| 66 |
+
lazy : dict, optional
|
| 67 |
+
Dictionary describing how lazy value needs to be computed.
|
| 68 |
+
Values are callable expecting a dictionary as input and returning the
|
| 69 |
+
computed value.
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, precomputed: Union[Dict, "ProtocolFile"], lazy: Dict = None):
|
| 74 |
+
|
| 75 |
+
if lazy is None:
|
| 76 |
+
lazy = dict()
|
| 77 |
+
|
| 78 |
+
if isinstance(precomputed, ProtocolFile):
|
| 79 |
+
# when 'precomputed' is a ProtocolFile, it may already contain lazy keys.
|
| 80 |
+
|
| 81 |
+
# we use 'precomputed' precomputed keys as precomputed keys
|
| 82 |
+
self._store: Dict = abs(precomputed)
|
| 83 |
+
|
| 84 |
+
# we handle the corner case where the intersection of 'precomputed' lazy keys
|
| 85 |
+
# and 'lazy' keys is not empty. this is currently achieved by "unlazying" the
|
| 86 |
+
# 'precomputed' one (which is probably not the most efficient solution).
|
| 87 |
+
for key in set(precomputed.lazy) & set(lazy):
|
| 88 |
+
self._store[key] = precomputed[key]
|
| 89 |
+
|
| 90 |
+
# we use the union of 'precomputed' lazy keys and provided 'lazy' keys as lazy keys
|
| 91 |
+
compound_lazy = dict(precomputed.lazy)
|
| 92 |
+
compound_lazy.update(lazy)
|
| 93 |
+
self.lazy: Dict = compound_lazy
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
# when 'precomputed' is a Dict, we use it directly as precomputed keys
|
| 97 |
+
# and 'lazy' as lazy keys.
|
| 98 |
+
self._store = dict(precomputed)
|
| 99 |
+
self.lazy = dict(lazy)
|
| 100 |
+
|
| 101 |
+
# re-entrant lock used below to make ProtocolFile thread-safe
|
| 102 |
+
self.lock_ = threading.RLock()
|
| 103 |
+
|
| 104 |
+
# this is needed to avoid infinite recursion
|
| 105 |
+
# when a key is both in precomputed and lazy.
|
| 106 |
+
# keys with evaluating_ > 0 are currently being evaluated
|
| 107 |
+
# and therefore should be taken from precomputed
|
| 108 |
+
self.evaluating_ = collections.Counter()
|
| 109 |
+
|
| 110 |
+
# since RLock is not pickable, remove it before pickling...
|
| 111 |
+
def __getstate__(self):
|
| 112 |
+
d = dict(self.__dict__)
|
| 113 |
+
del d["lock_"]
|
| 114 |
+
return d
|
| 115 |
+
|
| 116 |
+
# ... and add it back when unpickling
|
| 117 |
+
def __setstate__(self, d):
|
| 118 |
+
self.__dict__.update(d)
|
| 119 |
+
self.lock_ = threading.RLock()
|
| 120 |
+
|
| 121 |
+
def __abs__(self):
|
| 122 |
+
with self.lock_:
|
| 123 |
+
return dict(self._store)
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, key):
|
| 126 |
+
with self.lock_:
|
| 127 |
+
|
| 128 |
+
if key in self.lazy and self.evaluating_[key] == 0:
|
| 129 |
+
|
| 130 |
+
# mark lazy key as being evaluated
|
| 131 |
+
self.evaluating_.update([key])
|
| 132 |
+
|
| 133 |
+
# apply preprocessor once and remove it
|
| 134 |
+
value = self.lazy[key](self)
|
| 135 |
+
del self.lazy[key]
|
| 136 |
+
|
| 137 |
+
# warn the user when a precomputed key is modified
|
| 138 |
+
if key in self._store and value != self._store[key]:
|
| 139 |
+
msg = 'Existing precomputed key "{key}" has been modified by a preprocessor.'
|
| 140 |
+
warnings.warn(msg.format(key=key))
|
| 141 |
+
|
| 142 |
+
# store the output of the lazy computation
|
| 143 |
+
# so that it is available for future access
|
| 144 |
+
self._store[key] = value
|
| 145 |
+
|
| 146 |
+
# lazy evaluation is finished for key
|
| 147 |
+
self.evaluating_.subtract([key])
|
| 148 |
+
|
| 149 |
+
return self._store[key]
|
| 150 |
+
|
| 151 |
+
def __setitem__(self, key, value):
|
| 152 |
+
with self.lock_:
|
| 153 |
+
|
| 154 |
+
if key in self.lazy:
|
| 155 |
+
del self.lazy[key]
|
| 156 |
+
|
| 157 |
+
self._store[key] = value
|
| 158 |
+
|
| 159 |
+
def __delitem__(self, key):
|
| 160 |
+
with self.lock_:
|
| 161 |
+
|
| 162 |
+
if key in self.lazy:
|
| 163 |
+
del self.lazy[key]
|
| 164 |
+
|
| 165 |
+
del self._store[key]
|
| 166 |
+
|
| 167 |
+
def __iter__(self):
|
| 168 |
+
with self.lock_:
|
| 169 |
+
|
| 170 |
+
store_keys = list(self._store)
|
| 171 |
+
for key in store_keys:
|
| 172 |
+
yield key
|
| 173 |
+
|
| 174 |
+
lazy_keys = list(self.lazy)
|
| 175 |
+
for key in lazy_keys:
|
| 176 |
+
if key in self._store:
|
| 177 |
+
continue
|
| 178 |
+
yield key
|
| 179 |
+
|
| 180 |
+
def __len__(self):
|
| 181 |
+
with self.lock_:
|
| 182 |
+
return len(set(self._store) | set(self.lazy))
|
| 183 |
+
|
| 184 |
+
def files(self) -> Iterator["ProtocolFile"]:
|
| 185 |
+
"""Iterate over all files
|
| 186 |
+
|
| 187 |
+
When `current_file` refers to only one file,
|
| 188 |
+
yield it and return.
|
| 189 |
+
When `current_file` refers to a list of file (i.e. 'uri' is a list),
|
| 190 |
+
yield each file separately.
|
| 191 |
+
|
| 192 |
+
Examples
|
| 193 |
+
--------
|
| 194 |
+
>>> current_file = ProtocolFile({
|
| 195 |
+
... 'uri': 'my_uri',
|
| 196 |
+
... 'database': 'my_database'})
|
| 197 |
+
>>> for file in current_file.files():
|
| 198 |
+
... print(file['uri'], file['database'])
|
| 199 |
+
my_uri my_database
|
| 200 |
+
|
| 201 |
+
>>> current_file = {
|
| 202 |
+
... 'uri': ['my_uri1', 'my_uri2', 'my_uri3'],
|
| 203 |
+
... 'database': 'my_database'}
|
| 204 |
+
>>> for file in current_file.files():
|
| 205 |
+
... print(file['uri'], file['database'])
|
| 206 |
+
my_uri1 my_database
|
| 207 |
+
my_uri2 my_database
|
| 208 |
+
my_uri3 my_database
|
| 209 |
+
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
uris = self["uri"]
|
| 213 |
+
if not isinstance(uris, list):
|
| 214 |
+
yield self
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
n_uris = len(uris)
|
| 218 |
+
|
| 219 |
+
# iterate over precomputed keys and make sure
|
| 220 |
+
|
| 221 |
+
precomputed = {"uri": uris}
|
| 222 |
+
for key, value in abs(self).items():
|
| 223 |
+
|
| 224 |
+
if key == "uri":
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
if not isinstance(value, list):
|
| 228 |
+
precomputed[key] = itertools.repeat(value)
|
| 229 |
+
|
| 230 |
+
else:
|
| 231 |
+
if len(value) != n_uris:
|
| 232 |
+
msg = (
|
| 233 |
+
f'Mismatch between number of "uris" ({n_uris}) '
|
| 234 |
+
f'and number of "{key}" ({len(value)}).'
|
| 235 |
+
)
|
| 236 |
+
raise ValueError(msg)
|
| 237 |
+
precomputed[key] = value
|
| 238 |
+
|
| 239 |
+
keys = list(precomputed.keys())
|
| 240 |
+
for values in zip(*precomputed.values()):
|
| 241 |
+
precomputed_one = dict(zip(keys, values))
|
| 242 |
+
yield ProtocolFile(precomputed_one, self.lazy)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class Protocol:
|
| 246 |
+
"""Experimental protocol
|
| 247 |
+
|
| 248 |
+
An experimental protocol usually defines three subsets: a training subset,
|
| 249 |
+
a development subset, and a test subset.
|
| 250 |
+
|
| 251 |
+
An experimental protocol can be defined programmatically by creating a
|
| 252 |
+
class that inherits from Protocol and implements at least
|
| 253 |
+
one of `train_iter`, `development_iter` and `test_iter` methods:
|
| 254 |
+
|
| 255 |
+
>>> class MyProtocol(Protocol):
|
| 256 |
+
... def train_iter(self) -> Iterator[Dict]:
|
| 257 |
+
... yield {"uri": "filename1", "any_other_key": "..."}
|
| 258 |
+
... yield {"uri": "filename2", "any_other_key": "..."}
|
| 259 |
+
|
| 260 |
+
`{subset}_iter` should return an iterator of dictionnaries with
|
| 261 |
+
- "uri" key (mandatory) that provides a unique file identifier (usually
|
| 262 |
+
the filename),
|
| 263 |
+
- any other key that the protocol may provide.
|
| 264 |
+
|
| 265 |
+
It can then be used in Python like this:
|
| 266 |
+
|
| 267 |
+
>>> protocol = MyProtocol()
|
| 268 |
+
>>> for file in protocol.train():
|
| 269 |
+
... print(file["uri"])
|
| 270 |
+
filename1
|
| 271 |
+
filename2
|
| 272 |
+
|
| 273 |
+
An experimental protocol can also be defined using `pyannote_audio_utils.database`
|
| 274 |
+
configuration file, whose (configurable) path defaults to "~/database.yml".
|
| 275 |
+
|
| 276 |
+
~~~ Content of ~/database.yml ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 277 |
+
Protocols:
|
| 278 |
+
MyDatabase:
|
| 279 |
+
Protocol:
|
| 280 |
+
MyProtocol:
|
| 281 |
+
train:
|
| 282 |
+
uri: /path/to/collection.lst
|
| 283 |
+
any_other_key: ... # see custom loader documentation
|
| 284 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 285 |
+
|
| 286 |
+
where "/path/to/collection.lst" contains the list of identifiers of the
|
| 287 |
+
files in the collection:
|
| 288 |
+
|
| 289 |
+
~~~ Content of "/path/to/collection.lst ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 290 |
+
filename1
|
| 291 |
+
filename2
|
| 292 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 293 |
+
|
| 294 |
+
It can then be used in Python like this:
|
| 295 |
+
|
| 296 |
+
>>> from pyannote_audio_utils.database import registry
|
| 297 |
+
>>> protocol = registry.get_protocol('MyDatabase.Protocol.MyProtocol')
|
| 298 |
+
>>> for file in protocol.train():
|
| 299 |
+
... print(file["uri"])
|
| 300 |
+
filename1
|
| 301 |
+
filename2
|
| 302 |
+
|
| 303 |
+
This class is usually inherited from, but can be used directly.
|
| 304 |
+
|
| 305 |
+
Parameters
|
| 306 |
+
----------
|
| 307 |
+
preprocessors : dict
|
| 308 |
+
Preprocess protocol files so that `file[key] = preprocessors[key](file)`
|
| 309 |
+
for each key in `preprocessors`. In case `preprocessors[key]` is not
|
| 310 |
+
callable, it should be a string containing placeholders for `file` keys
|
| 311 |
+
(e.g. {'audio': '/path/to/{uri}.wav'})
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
def __init__(self, preprocessors: Optional[Preprocessors] = None):
|
| 315 |
+
super().__init__()
|
| 316 |
+
|
| 317 |
+
if preprocessors is None:
|
| 318 |
+
preprocessors = dict()
|
| 319 |
+
|
| 320 |
+
self.preprocessors = dict()
|
| 321 |
+
for key, preprocessor in preprocessors.items():
|
| 322 |
+
|
| 323 |
+
if callable(preprocessor):
|
| 324 |
+
self.preprocessors[key] = preprocessor
|
| 325 |
+
|
| 326 |
+
# when `preprocessor` is not callable, it should be a string
|
| 327 |
+
# containing placeholder for item key (e.g. '/path/to/{uri}.wav')
|
| 328 |
+
elif isinstance(preprocessor, str):
|
| 329 |
+
preprocessor_copy = str(preprocessor)
|
| 330 |
+
|
| 331 |
+
def func(current_file):
|
| 332 |
+
return preprocessor_copy.format(**current_file)
|
| 333 |
+
|
| 334 |
+
self.preprocessors[key] = func
|
| 335 |
+
|
| 336 |
+
else:
|
| 337 |
+
msg = f'"{key}" preprocessor is neither a callable nor a string.'
|
| 338 |
+
raise ValueError(msg)
|
| 339 |
+
|
| 340 |
+
def preprocess(self, current_file: Union[Dict, ProtocolFile]) -> ProtocolFile:
|
| 341 |
+
return ProtocolFile(current_file, lazy=self.preprocessors)
|
| 342 |
+
|
| 343 |
+
def __str__(self):
|
| 344 |
+
return self.__doc__
|
| 345 |
+
|
| 346 |
+
def train_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
|
| 347 |
+
"""Iterate over files in the training subset"""
|
| 348 |
+
raise NotImplementedError()
|
| 349 |
+
|
| 350 |
+
def development_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
|
| 351 |
+
"""Iterate over files in the development subset"""
|
| 352 |
+
raise NotImplementedError()
|
| 353 |
+
|
| 354 |
+
def test_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
|
| 355 |
+
"""Iterate over files in the test subset"""
|
| 356 |
+
raise NotImplementedError()
|
| 357 |
+
|
| 358 |
+
def subset_helper(self, subset: Subset) -> Iterator[ProtocolFile]:
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
files = getattr(self, f"{subset}_iter")()
|
| 362 |
+
except (AttributeError, NotImplementedError):
|
| 363 |
+
# previous pyannote_audio_utils.database versions used `trn_iter` instead of
|
| 364 |
+
# `train_iter`, `dev_iter` instead of `development_iter`, and
|
| 365 |
+
# `tst_iter` instead of `test_iter`. therefore, we use the legacy
|
| 366 |
+
# version when it is available (and the new one is not).
|
| 367 |
+
subset_legacy = LEGACY_SUBSET_MAPPING[subset]
|
| 368 |
+
try:
|
| 369 |
+
files = getattr(self, f"{subset_legacy}_iter")()
|
| 370 |
+
except AttributeError:
|
| 371 |
+
msg = f"Protocol does not implement a {subset} subset."
|
| 372 |
+
raise NotImplementedError(msg)
|
| 373 |
+
|
| 374 |
+
for file in files:
|
| 375 |
+
yield self.preprocess(file)
|
| 376 |
+
|
| 377 |
+
def train(self) -> Iterator[ProtocolFile]:
|
| 378 |
+
return self.subset_helper("train")
|
| 379 |
+
|
| 380 |
+
def development(self) -> Iterator[ProtocolFile]:
|
| 381 |
+
return self.subset_helper("development")
|
| 382 |
+
|
| 383 |
+
def test(self) -> Iterator[ProtocolFile]:
|
| 384 |
+
return self.subset_helper("test")
|
| 385 |
+
|
| 386 |
+
def files(self) -> Iterator[ProtocolFile]:
|
| 387 |
+
"""Iterate over all files in `protocol`"""
|
| 388 |
+
|
| 389 |
+
# imported here to avoid circular imports
|
| 390 |
+
from pyannote_audio_utils.database.util import get_unique_identifier
|
| 391 |
+
|
| 392 |
+
yielded_uris = set()
|
| 393 |
+
|
| 394 |
+
for method in [
|
| 395 |
+
"development",
|
| 396 |
+
"development_enrolment",
|
| 397 |
+
"development_trial",
|
| 398 |
+
"test",
|
| 399 |
+
"test_enrolment",
|
| 400 |
+
"test_trial",
|
| 401 |
+
"train",
|
| 402 |
+
"train_enrolment",
|
| 403 |
+
"train_trial",
|
| 404 |
+
]:
|
| 405 |
+
|
| 406 |
+
if not hasattr(self, method):
|
| 407 |
+
continue
|
| 408 |
+
|
| 409 |
+
def iterate():
|
| 410 |
+
try:
|
| 411 |
+
for file in getattr(self, method)():
|
| 412 |
+
yield file
|
| 413 |
+
except (AttributeError, NotImplementedError):
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
for current_file in iterate():
|
| 417 |
+
|
| 418 |
+
# skip "files" that do not contain a "uri" entry.
|
| 419 |
+
# this happens for speaker verification trials that contain
|
| 420 |
+
# two nested files "file1" and "file2"
|
| 421 |
+
# see https://github.com/pyannote_audio_utils/pyannote_audio_utils-db-voxceleb/issues/4
|
| 422 |
+
if "uri" not in current_file:
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
for current_file_ in current_file.files():
|
| 426 |
+
|
| 427 |
+
# corner case when the same file is yielded several times
|
| 428 |
+
uri = get_unique_identifier(current_file_)
|
| 429 |
+
if uri in yielded_uris:
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
yield current_file_
|
| 433 |
+
|
| 434 |
+
yielded_uris.add(uri)
|
ailia-models/code/pyannote_audio_utils/database/util.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2016-2020 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
import warnings
|
| 31 |
+
import pandas as pd
|
| 32 |
+
from pyannote_audio_utils.core import Segment, Timeline, Annotation
|
| 33 |
+
|
| 34 |
+
from typing import Text
|
| 35 |
+
|
| 36 |
+
DatabaseName = Text
|
| 37 |
+
PathTemplate = Text
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_unique_identifier(item):
|
| 41 |
+
"""Return unique item identifier
|
| 42 |
+
|
| 43 |
+
The complete format is {database}/{uri}_{channel}:
|
| 44 |
+
* prefixed by "{database}/" only when `item` has a 'database' key.
|
| 45 |
+
* suffixed by "_{channel}" only when `item` has a 'channel' key.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
item : dict
|
| 50 |
+
Item as yielded by pyannote_audio_utils.database protocols
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
identifier : str
|
| 55 |
+
Unique item identifier
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
IDENTIFIER = ""
|
| 59 |
+
|
| 60 |
+
# {database}/{uri}_{channel}
|
| 61 |
+
database = item.get("database", None)
|
| 62 |
+
if database is not None:
|
| 63 |
+
IDENTIFIER += f"{database}/"
|
| 64 |
+
IDENTIFIER += item["uri"]
|
| 65 |
+
channel = item.get("channel", None)
|
| 66 |
+
if channel is not None:
|
| 67 |
+
IDENTIFIER += f"_{channel:d}"
|
| 68 |
+
|
| 69 |
+
return IDENTIFIER
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# This function is used in custom.py
|
| 73 |
+
def get_annotated(current_file):
|
| 74 |
+
"""Get part of the file that is annotated.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
current_file : `dict`
|
| 79 |
+
File generated by a `pyannote_audio_utils.database` protocol.
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
annotated : `pyannote_audio_utils.core.Timeline`
|
| 84 |
+
Part of the file that is annotated. Defaults to
|
| 85 |
+
`current_file["annotated"]`. When it does not exist, try to use the
|
| 86 |
+
full audio extent. When that fails, use "annotation" extent.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
# if protocol provides 'annotated' key, use it
|
| 90 |
+
if "annotated" in current_file:
|
| 91 |
+
annotated = current_file["annotated"]
|
| 92 |
+
return annotated
|
| 93 |
+
|
| 94 |
+
# if it does not, but does provide 'audio' key
|
| 95 |
+
# try and use wav duration
|
| 96 |
+
|
| 97 |
+
if "duration" in current_file:
|
| 98 |
+
try:
|
| 99 |
+
duration = current_file["duration"]
|
| 100 |
+
except ImportError:
|
| 101 |
+
pass
|
| 102 |
+
else:
|
| 103 |
+
annotated = Timeline([Segment(0, duration)])
|
| 104 |
+
msg = '"annotated" was approximated by [0, audio duration].'
|
| 105 |
+
warnings.warn(msg)
|
| 106 |
+
return annotated
|
| 107 |
+
|
| 108 |
+
extent = current_file["annotation"].get_timeline().extent()
|
| 109 |
+
annotated = Timeline([extent])
|
| 110 |
+
|
| 111 |
+
msg = (
|
| 112 |
+
'"annotated" was approximated by "annotation" extent. '
|
| 113 |
+
'Please provide "annotated" directly, or at the very '
|
| 114 |
+
'least, use a "duration" preprocessor.'
|
| 115 |
+
)
|
| 116 |
+
warnings.warn(msg)
|
| 117 |
+
|
| 118 |
+
return annotated
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_label_identifier(label, current_file):
|
| 122 |
+
"""Return unique label identifier
|
| 123 |
+
|
| 124 |
+
Parameters
|
| 125 |
+
----------
|
| 126 |
+
label : str
|
| 127 |
+
Database-internal label
|
| 128 |
+
current_file
|
| 129 |
+
Yielded by pyannote_audio_utils.database protocols
|
| 130 |
+
|
| 131 |
+
Returns
|
| 132 |
+
-------
|
| 133 |
+
unique_label : str
|
| 134 |
+
Global label
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
# TODO. when the "true" name of a person is used,
|
| 138 |
+
# do not preprend database name.
|
| 139 |
+
database = current_file["database"]
|
| 140 |
+
return database + "|" + label
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_rttm(file_rttm, keep_type="SPEAKER"):
|
| 144 |
+
"""Load RTTM file
|
| 145 |
+
|
| 146 |
+
Parameter
|
| 147 |
+
---------
|
| 148 |
+
file_rttm : `str`
|
| 149 |
+
Path to RTTM file.
|
| 150 |
+
keep_type : str, optional
|
| 151 |
+
Only keep lines with this type (field #1 in RTTM specs).
|
| 152 |
+
Defaults to "SPEAKER".
|
| 153 |
+
|
| 154 |
+
Returns
|
| 155 |
+
-------
|
| 156 |
+
annotations : `dict`
|
| 157 |
+
Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
names = [
|
| 161 |
+
"type",
|
| 162 |
+
"uri",
|
| 163 |
+
"NA2",
|
| 164 |
+
"start",
|
| 165 |
+
"duration",
|
| 166 |
+
"NA3",
|
| 167 |
+
"NA4",
|
| 168 |
+
"speaker",
|
| 169 |
+
"NA5",
|
| 170 |
+
"NA6",
|
| 171 |
+
]
|
| 172 |
+
dtype = {"uri": str, "start": float, "duration": float, "speaker": str}
|
| 173 |
+
data = pd.read_csv(
|
| 174 |
+
file_rttm,
|
| 175 |
+
names=names,
|
| 176 |
+
dtype=dtype,
|
| 177 |
+
# delim_whitespace=True,
|
| 178 |
+
sep='\s+',
|
| 179 |
+
keep_default_na=True,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
annotations = dict()
|
| 183 |
+
for uri, turns in data.groupby("uri"):
|
| 184 |
+
annotation = Annotation(uri=uri)
|
| 185 |
+
for i, turn in turns.iterrows():
|
| 186 |
+
if turn.type != keep_type:
|
| 187 |
+
continue
|
| 188 |
+
segment = Segment(turn.start, turn.start + turn.duration)
|
| 189 |
+
annotation[segment, i] = turn.speaker
|
| 190 |
+
annotations[uri] = annotation
|
| 191 |
+
|
| 192 |
+
return annotations
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_stm(file_stm):
|
| 196 |
+
"""Load STM file (speaker-info only)
|
| 197 |
+
|
| 198 |
+
Parameter
|
| 199 |
+
---------
|
| 200 |
+
file_stm : str
|
| 201 |
+
Path to STM file
|
| 202 |
+
|
| 203 |
+
Returns
|
| 204 |
+
-------
|
| 205 |
+
annotations : `dict`
|
| 206 |
+
Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
dtype = {"uri": str, "speaker": str, "start": float, "end": float}
|
| 210 |
+
data = pd.read_csv(
|
| 211 |
+
file_stm,
|
| 212 |
+
# delim_whitespace=True,
|
| 213 |
+
sep='\s+',
|
| 214 |
+
usecols=[0, 2, 3, 4],
|
| 215 |
+
dtype=dtype,
|
| 216 |
+
names=list(dtype),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
annotations = dict()
|
| 220 |
+
for uri, turns in data.groupby("uri"):
|
| 221 |
+
annotation = Annotation(uri=uri)
|
| 222 |
+
for i, turn in turns.iterrows():
|
| 223 |
+
segment = Segment(turn.start, turn.end)
|
| 224 |
+
annotation[segment, i] = turn.speaker
|
| 225 |
+
annotations[uri] = annotation
|
| 226 |
+
|
| 227 |
+
return annotations
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def load_mdtm(file_mdtm):
|
| 231 |
+
"""Load MDTM file
|
| 232 |
+
|
| 233 |
+
Parameter
|
| 234 |
+
---------
|
| 235 |
+
file_mdtm : `str`
|
| 236 |
+
Path to MDTM file.
|
| 237 |
+
|
| 238 |
+
Returns
|
| 239 |
+
-------
|
| 240 |
+
annotations : `dict`
|
| 241 |
+
Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
names = ["uri", "NA1", "start", "duration", "NA2", "NA3", "NA4", "speaker"]
|
| 245 |
+
dtype = {"uri": str, "start": float, "duration": float, "speaker": str}
|
| 246 |
+
data = pd.read_csv(
|
| 247 |
+
file_mdtm,
|
| 248 |
+
names=names,
|
| 249 |
+
dtype=dtype,
|
| 250 |
+
# delim_whitespace=True,
|
| 251 |
+
sep='\s+',
|
| 252 |
+
keep_default_na=False,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
annotations = dict()
|
| 256 |
+
for uri, turns in data.groupby("uri"):
|
| 257 |
+
annotation = Annotation(uri=uri)
|
| 258 |
+
for i, turn in turns.iterrows():
|
| 259 |
+
segment = Segment(turn.start, turn.start + turn.duration)
|
| 260 |
+
annotation[segment, i] = turn.speaker
|
| 261 |
+
annotations[uri] = annotation
|
| 262 |
+
|
| 263 |
+
return annotations
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def load_uem(file_uem):
|
| 267 |
+
"""Load UEM file
|
| 268 |
+
|
| 269 |
+
Parameter
|
| 270 |
+
---------
|
| 271 |
+
file_uem : `str`
|
| 272 |
+
Path to UEM file.
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
timelines : `dict`
|
| 277 |
+
Evaluation map as a {uri: pyannote_audio_utils.core.Timeline} dictionary.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
names = ["uri", "NA1", "start", "end"]
|
| 281 |
+
dtype = {"uri": str, "start": float, "end": float}
|
| 282 |
+
data = pd.read_csv(file_uem, names=names, dtype=dtype, sep='\s+',)
|
| 283 |
+
|
| 284 |
+
timelines = dict()
|
| 285 |
+
for uri, parts in data.groupby("uri"):
|
| 286 |
+
segments = [Segment(part.start, part.end) for i, part in parts.iterrows()]
|
| 287 |
+
timelines[uri] = Timeline(segments=segments, uri=uri)
|
| 288 |
+
|
| 289 |
+
return timelines
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def load_lab(path, uri: str = None) -> Annotation:
|
| 293 |
+
"""Load LAB file
|
| 294 |
+
|
| 295 |
+
Parameter
|
| 296 |
+
---------
|
| 297 |
+
file_lab : `str`
|
| 298 |
+
Path to LAB file
|
| 299 |
+
|
| 300 |
+
Returns
|
| 301 |
+
-------
|
| 302 |
+
data : `pyannote_audio_utils.core.Annotation`
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
names = ["start", "end", "label"]
|
| 306 |
+
dtype = {"start": float, "end": float, "label": str}
|
| 307 |
+
data = pd.read_csv(path, names=names, dtype=dtype, sep='\s+',)
|
| 308 |
+
|
| 309 |
+
annotation = Annotation(uri=uri)
|
| 310 |
+
for i, turn in data.iterrows():
|
| 311 |
+
segment = Segment(turn.start, turn.end)
|
| 312 |
+
annotation[segment, i] = turn.label
|
| 313 |
+
|
| 314 |
+
return annotation
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def load_lst(file_lst):
|
| 318 |
+
"""Load LST file
|
| 319 |
+
|
| 320 |
+
LST files provide a list of URIs (one line per URI)
|
| 321 |
+
|
| 322 |
+
Parameter
|
| 323 |
+
---------
|
| 324 |
+
file_lst : `str`
|
| 325 |
+
Path to LST file.
|
| 326 |
+
|
| 327 |
+
Returns
|
| 328 |
+
-------
|
| 329 |
+
uris : `list`
|
| 330 |
+
List or uris
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
with open(file_lst, mode="r") as fp:
|
| 334 |
+
lines = fp.readlines()
|
| 335 |
+
return [line.strip() for line in lines]
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def load_mapping(mapping_txt):
|
| 339 |
+
"""Load mapping file
|
| 340 |
+
|
| 341 |
+
Parameter
|
| 342 |
+
---------
|
| 343 |
+
mapping_txt : `str`
|
| 344 |
+
Path to mapping file
|
| 345 |
+
|
| 346 |
+
Returns
|
| 347 |
+
-------
|
| 348 |
+
mapping : `dict`
|
| 349 |
+
{1st field: 2nd field} dictionary
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
with open(mapping_txt, mode="r") as fp:
|
| 353 |
+
lines = fp.readlines()
|
| 354 |
+
|
| 355 |
+
mapping = dict()
|
| 356 |
+
for line in lines:
|
| 357 |
+
key, value, *left = line.strip().split()
|
| 358 |
+
mapping[key] = value
|
| 359 |
+
|
| 360 |
+
return mapping
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class LabelMapper(object):
|
| 364 |
+
"""Label mapper for use as pyannote_audio_utils.database preprocessor
|
| 365 |
+
|
| 366 |
+
Parameters
|
| 367 |
+
----------
|
| 368 |
+
mapping : `dict`
|
| 369 |
+
Mapping dictionary as used in `Annotation.rename_labels()`.
|
| 370 |
+
keep_missing : `bool`, optional
|
| 371 |
+
In case a label has no mapping, a `ValueError` will be raised.
|
| 372 |
+
Set "keep_missing" to True to keep those labels unchanged instead.
|
| 373 |
+
|
| 374 |
+
Usage
|
| 375 |
+
-----
|
| 376 |
+
>>> mapping = {'Hadrien': 'MAL', 'Marvin': 'MAL',
|
| 377 |
+
... 'Wassim': 'CHI', 'Herve': 'GOD'}
|
| 378 |
+
>>> preprocessors = {'annotation': LabelMapper(mapping=mapping)}
|
| 379 |
+
>>> protocol = registry.get_protocol('AMI.SpeakerDiarization.MixHeadset',
|
| 380 |
+
preprocessors=preprocessors)
|
| 381 |
+
|
| 382 |
+
"""
|
| 383 |
+
|
| 384 |
+
def __init__(self, mapping, keep_missing=False):
|
| 385 |
+
self.mapping = mapping
|
| 386 |
+
self.keep_missing = keep_missing
|
| 387 |
+
|
| 388 |
+
def __call__(self, current_file):
|
| 389 |
+
|
| 390 |
+
if not self.keep_missing:
|
| 391 |
+
missing = set(current_file["annotation"].labels()) - set(self.mapping)
|
| 392 |
+
if missing and not self.keep_missing:
|
| 393 |
+
label = missing.pop()
|
| 394 |
+
msg = (
|
| 395 |
+
f'No mapping found for label "{label}". Set "keep_missing" '
|
| 396 |
+
f"to True to keep labels with no mapping."
|
| 397 |
+
)
|
| 398 |
+
raise ValueError(msg)
|
| 399 |
+
|
| 400 |
+
return current_file["annotation"].rename_labels(mapping=self.mapping)
|
ailia-models/code/pyannote_audio_utils/metrics/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2012-2021 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
from ._version import get_versions
|
| 30 |
+
from .base import f_measure
|
| 31 |
+
|
| 32 |
+
__version__ = get_versions()["version"]
|
| 33 |
+
del get_versions
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = ["f_measure"]
|
ailia-models/code/pyannote_audio_utils/metrics/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# This file was generated by 'versioneer.py' (0.15) from
|
| 3 |
+
# revision-control system data, or from the parent directory name of an
|
| 4 |
+
# unpacked source archive. Distribution tarballs contain a pre-generated copy
|
| 5 |
+
# of this file.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
version_json = '''
|
| 11 |
+
{
|
| 12 |
+
"dirty": false,
|
| 13 |
+
"error": null,
|
| 14 |
+
"full-revisionid": "babbd1c68adc50c0e2199676c7ae741194c520da",
|
| 15 |
+
"version": "3.2.1"
|
| 16 |
+
}
|
| 17 |
+
''' # END VERSION_JSON
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_versions():
|
| 21 |
+
return json.loads(version_json)
|
ailia-models/code/pyannote_audio_utils/metrics/base.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2012- CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
from typing import List, Union, Optional, Set, Tuple
|
| 29 |
+
|
| 30 |
+
import warnings
|
| 31 |
+
import numpy as np
|
| 32 |
+
import pandas as pd
|
| 33 |
+
import scipy.stats
|
| 34 |
+
from pyannote_audio_utils.core import Annotation, Timeline
|
| 35 |
+
|
| 36 |
+
from pyannote_audio_utils.metrics.types import Details, MetricComponents
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BaseMetric:
|
| 40 |
+
"""
|
| 41 |
+
:class:`BaseMetric` is the base class for most pyannote_audio_utils evaluation metrics.
|
| 42 |
+
|
| 43 |
+
Attributes
|
| 44 |
+
----------
|
| 45 |
+
name : str
|
| 46 |
+
Human-readable name of the metric (eg. 'diarization error rate')
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def metric_name(cls) -> str:
|
| 51 |
+
raise NotImplementedError(
|
| 52 |
+
cls.__name__ + " is missing a 'metric_name' class method. "
|
| 53 |
+
"It should return the name of the metric as string."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def metric_components(cls) -> MetricComponents:
|
| 58 |
+
raise NotImplementedError(
|
| 59 |
+
cls.__name__ + " is missing a 'metric_components' class method. "
|
| 60 |
+
"It should return the list of names of metric components."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def __init__(self, **kwargs):
|
| 64 |
+
super(BaseMetric, self).__init__()
|
| 65 |
+
self.metric_name_ = self.__class__.metric_name()
|
| 66 |
+
self.components_: Set[str] = set(self.__class__.metric_components())
|
| 67 |
+
self.reset()
|
| 68 |
+
|
| 69 |
+
def init_components(self):
|
| 70 |
+
return {value: 0.0 for value in self.components_}
|
| 71 |
+
|
| 72 |
+
def reset(self):
|
| 73 |
+
"""Reset accumulated components and metric values"""
|
| 74 |
+
self.accumulated_: Details = dict()
|
| 75 |
+
self.results_: List = list()
|
| 76 |
+
for value in self.components_:
|
| 77 |
+
self.accumulated_[value] = 0.0
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def name(self):
|
| 81 |
+
"""Metric name."""
|
| 82 |
+
return self.metric_name()
|
| 83 |
+
|
| 84 |
+
# TODO: use joblib/locky to allow parallel processing?
|
| 85 |
+
# TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...)
|
| 86 |
+
|
| 87 |
+
def __call__(self, reference: Union[Timeline, Annotation],
|
| 88 |
+
hypothesis: Union[Timeline, Annotation],
|
| 89 |
+
detailed: bool = False, uri: Optional[str] = None, **kwargs):
|
| 90 |
+
"""Compute metric value and accumulate components
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
reference : type depends on the metric
|
| 95 |
+
Manual `reference`
|
| 96 |
+
hypothesis : type depends on the metric
|
| 97 |
+
Evaluated `hypothesis`
|
| 98 |
+
uri : optional
|
| 99 |
+
Override uri.
|
| 100 |
+
detailed : bool, optional
|
| 101 |
+
By default (False), return metric value only.
|
| 102 |
+
Set `detailed` to True to return dictionary where keys are
|
| 103 |
+
components names and values are component values
|
| 104 |
+
|
| 105 |
+
Returns
|
| 106 |
+
-------
|
| 107 |
+
value : float (if `detailed` is False)
|
| 108 |
+
Metric value
|
| 109 |
+
components : dict (if `detailed` is True)
|
| 110 |
+
`components` updated with metric value
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
# compute metric components
|
| 114 |
+
components = self.compute_components(reference, hypothesis, **kwargs)
|
| 115 |
+
|
| 116 |
+
# compute rate based on components
|
| 117 |
+
components[self.metric_name_] = self.compute_metric(components)
|
| 118 |
+
|
| 119 |
+
# keep track of this computation
|
| 120 |
+
uri = uri or getattr(reference, "uri", "NA")
|
| 121 |
+
self.results_.append((uri, components))
|
| 122 |
+
|
| 123 |
+
# accumulate components
|
| 124 |
+
for name in self.components_:
|
| 125 |
+
self.accumulated_[name] += components[name]
|
| 126 |
+
|
| 127 |
+
if detailed:
|
| 128 |
+
return components
|
| 129 |
+
|
| 130 |
+
return components[self.metric_name_]
|
| 131 |
+
|
| 132 |
+
def report(self, display: bool = False) -> pd.DataFrame:
|
| 133 |
+
"""Evaluation report
|
| 134 |
+
|
| 135 |
+
Parameters
|
| 136 |
+
----------
|
| 137 |
+
display : bool, optional
|
| 138 |
+
Set to True to print the report to stdout.
|
| 139 |
+
|
| 140 |
+
Returns
|
| 141 |
+
-------
|
| 142 |
+
report : pandas.DataFrame
|
| 143 |
+
Dataframe with one column per metric component, one row per
|
| 144 |
+
evaluated item, and one final row for accumulated results.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
report = []
|
| 148 |
+
uris = []
|
| 149 |
+
|
| 150 |
+
percent = "total" in self.metric_components()
|
| 151 |
+
|
| 152 |
+
for uri, components in self.results_:
|
| 153 |
+
row = {}
|
| 154 |
+
if percent:
|
| 155 |
+
total = components["total"]
|
| 156 |
+
for key, value in components.items():
|
| 157 |
+
if key == self.name:
|
| 158 |
+
row[key, "%"] = 100 * value
|
| 159 |
+
elif key == "total":
|
| 160 |
+
row[key, ""] = value
|
| 161 |
+
else:
|
| 162 |
+
row[key, ""] = value
|
| 163 |
+
if percent:
|
| 164 |
+
if total > 0:
|
| 165 |
+
row[key, "%"] = 100 * value / total
|
| 166 |
+
else:
|
| 167 |
+
row[key, "%"] = np.NaN
|
| 168 |
+
|
| 169 |
+
report.append(row)
|
| 170 |
+
uris.append(uri)
|
| 171 |
+
|
| 172 |
+
row = {}
|
| 173 |
+
components = self.accumulated_
|
| 174 |
+
|
| 175 |
+
if percent:
|
| 176 |
+
total = components["total"]
|
| 177 |
+
|
| 178 |
+
for key, value in components.items():
|
| 179 |
+
if key == self.name:
|
| 180 |
+
row[key, "%"] = 100 * value
|
| 181 |
+
elif key == "total":
|
| 182 |
+
row[key, ""] = value
|
| 183 |
+
else:
|
| 184 |
+
row[key, ""] = value
|
| 185 |
+
if percent:
|
| 186 |
+
if total > 0:
|
| 187 |
+
row[key, "%"] = 100 * value / total
|
| 188 |
+
else:
|
| 189 |
+
row[key, "%"] = np.NaN
|
| 190 |
+
|
| 191 |
+
row[self.name, "%"] = 100 * abs(self)
|
| 192 |
+
report.append(row)
|
| 193 |
+
uris.append("TOTAL")
|
| 194 |
+
|
| 195 |
+
df = pd.DataFrame(report)
|
| 196 |
+
|
| 197 |
+
df["item"] = uris
|
| 198 |
+
df = df.set_index("item")
|
| 199 |
+
|
| 200 |
+
df.columns = pd.MultiIndex.from_tuples(df.columns)
|
| 201 |
+
|
| 202 |
+
df = df[[self.name] + self.metric_components()]
|
| 203 |
+
|
| 204 |
+
if display:
|
| 205 |
+
print(
|
| 206 |
+
df.to_string(
|
| 207 |
+
index=True,
|
| 208 |
+
sparsify=False,
|
| 209 |
+
justify="right",
|
| 210 |
+
float_format=lambda f: "{0:.2f}".format(f),
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return df
|
| 215 |
+
|
| 216 |
+
def __str__(self):
|
| 217 |
+
report = self.report(display=False)
|
| 218 |
+
return report.to_string(
|
| 219 |
+
sparsify=False, float_format=lambda f: "{0:.2f}".format(f)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def __abs__(self):
|
| 223 |
+
"""Compute metric value from accumulated components"""
|
| 224 |
+
return self.compute_metric(self.accumulated_)
|
| 225 |
+
|
| 226 |
+
def __getitem__(self, component: str) -> Union[float, Details]:
|
| 227 |
+
"""Get value of accumulated `component`.
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
component : str
|
| 232 |
+
Name of `component`
|
| 233 |
+
|
| 234 |
+
Returns
|
| 235 |
+
-------
|
| 236 |
+
value : type depends on the metric
|
| 237 |
+
Value of accumulated `component`
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
if component == slice(None, None, None):
|
| 241 |
+
return dict(self.accumulated_)
|
| 242 |
+
else:
|
| 243 |
+
return self.accumulated_[component]
|
| 244 |
+
|
| 245 |
+
def __iter__(self):
|
| 246 |
+
"""Iterator over the accumulated (uri, value)"""
|
| 247 |
+
for uri, component in self.results_:
|
| 248 |
+
yield uri, component
|
| 249 |
+
|
| 250 |
+
def compute_components(self,
|
| 251 |
+
reference: Union[Timeline, Annotation],
|
| 252 |
+
hypothesis: Union[Timeline, Annotation],
|
| 253 |
+
**kwargs) -> Details:
|
| 254 |
+
"""Compute metric components
|
| 255 |
+
|
| 256 |
+
Parameters
|
| 257 |
+
----------
|
| 258 |
+
reference : type depends on the metric
|
| 259 |
+
Manual `reference`
|
| 260 |
+
hypothesis : same as `reference`
|
| 261 |
+
Evaluated `hypothesis`
|
| 262 |
+
|
| 263 |
+
Returns
|
| 264 |
+
-------
|
| 265 |
+
components : dict
|
| 266 |
+
Dictionary where keys are component names and values are component
|
| 267 |
+
values
|
| 268 |
+
|
| 269 |
+
"""
|
| 270 |
+
raise NotImplementedError(
|
| 271 |
+
self.__class__.__name__ + " is missing a 'compute_components' method."
|
| 272 |
+
"It should return a dictionary where keys are component names "
|
| 273 |
+
"and values are component values."
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def compute_metric(self, components: Details):
|
| 277 |
+
"""Compute metric value from computed `components`
|
| 278 |
+
|
| 279 |
+
Parameters
|
| 280 |
+
----------
|
| 281 |
+
components : dict
|
| 282 |
+
Dictionary where keys are components names and values are component
|
| 283 |
+
values
|
| 284 |
+
|
| 285 |
+
Returns
|
| 286 |
+
-------
|
| 287 |
+
value : type depends on the metric
|
| 288 |
+
Metric value
|
| 289 |
+
"""
|
| 290 |
+
raise NotImplementedError(
|
| 291 |
+
self.__class__.__name__ + " is missing a 'compute_metric' method. "
|
| 292 |
+
"It should return the actual value of the metric based "
|
| 293 |
+
"on the precomputed component dictionary given as input."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def confidence_interval(self, alpha: float = 0.9) \
|
| 297 |
+
-> Tuple[float, Tuple[float, float]]:
|
| 298 |
+
"""Compute confidence interval on accumulated metric values
|
| 299 |
+
|
| 300 |
+
Parameters
|
| 301 |
+
----------
|
| 302 |
+
alpha : float, optional
|
| 303 |
+
Probability that the returned confidence interval contains
|
| 304 |
+
the true metric value.
|
| 305 |
+
|
| 306 |
+
Returns
|
| 307 |
+
-------
|
| 308 |
+
(center, (lower, upper))
|
| 309 |
+
with center the mean of the conditional pdf of the metric value
|
| 310 |
+
and (lower, upper) is a confidence interval centered on the median,
|
| 311 |
+
containing the estimate to a probability alpha.
|
| 312 |
+
|
| 313 |
+
See Also:
|
| 314 |
+
---------
|
| 315 |
+
scipy.stats.bayes_mvs
|
| 316 |
+
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
values = [r[self.metric_name_] for _, r in self.results_]
|
| 320 |
+
|
| 321 |
+
if len(values) == 0:
|
| 322 |
+
raise ValueError("Please evaluate a bunch of files before computing confidence interval.")
|
| 323 |
+
|
| 324 |
+
elif len(values) == 1:
|
| 325 |
+
warnings.warn("Cannot compute a reliable confidence interval out of just one file.")
|
| 326 |
+
center = lower = upper = values[0]
|
| 327 |
+
return center, (lower, upper)
|
| 328 |
+
|
| 329 |
+
else:
|
| 330 |
+
return scipy.stats.bayes_mvs(values, alpha=alpha)[0]
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
PRECISION_NAME = "precision"
|
| 334 |
+
PRECISION_RETRIEVED = "# retrieved"
|
| 335 |
+
PRECISION_RELEVANT_RETRIEVED = "# relevant retrieved"
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class Precision(BaseMetric):
|
| 339 |
+
"""
|
| 340 |
+
:class:`Precision` is a base class for precision-like evaluation metrics.
|
| 341 |
+
|
| 342 |
+
It defines two components '# retrieved' and '# relevant retrieved' and the
|
| 343 |
+
compute_metric() method to compute the actual precision:
|
| 344 |
+
|
| 345 |
+
Precision = # retrieved / # relevant retrieved
|
| 346 |
+
|
| 347 |
+
Inheriting classes must implement compute_components().
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
@classmethod
|
| 351 |
+
def metric_name(cls):
|
| 352 |
+
return PRECISION_NAME
|
| 353 |
+
|
| 354 |
+
@classmethod
|
| 355 |
+
def metric_components(cls) -> MetricComponents:
|
| 356 |
+
return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED]
|
| 357 |
+
|
| 358 |
+
def compute_metric(self, components: Details) -> float:
|
| 359 |
+
"""Compute precision from `components`"""
|
| 360 |
+
numerator = components[PRECISION_RELEVANT_RETRIEVED]
|
| 361 |
+
denominator = components[PRECISION_RETRIEVED]
|
| 362 |
+
if denominator == 0.0:
|
| 363 |
+
if numerator == 0:
|
| 364 |
+
return 1.0
|
| 365 |
+
else:
|
| 366 |
+
raise ValueError("")
|
| 367 |
+
else:
|
| 368 |
+
return numerator / denominator
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
RECALL_NAME = "recall"
|
| 372 |
+
RECALL_RELEVANT = "# relevant"
|
| 373 |
+
RECALL_RELEVANT_RETRIEVED = "# relevant retrieved"
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class Recall(BaseMetric):
|
| 377 |
+
"""
|
| 378 |
+
:class:`Recall` is a base class for recall-like evaluation metrics.
|
| 379 |
+
|
| 380 |
+
It defines two components '# relevant' and '# relevant retrieved' and the
|
| 381 |
+
compute_metric() method to compute the actual recall:
|
| 382 |
+
|
| 383 |
+
Recall = # relevant retrieved / # relevant
|
| 384 |
+
|
| 385 |
+
Inheriting classes must implement compute_components().
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
@classmethod
|
| 389 |
+
def metric_name(cls):
|
| 390 |
+
return RECALL_NAME
|
| 391 |
+
|
| 392 |
+
@classmethod
|
| 393 |
+
def metric_components(cls) -> MetricComponents:
|
| 394 |
+
return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED]
|
| 395 |
+
|
| 396 |
+
def compute_metric(self, components: Details) -> float:
|
| 397 |
+
"""Compute recall from `components`"""
|
| 398 |
+
numerator = components[RECALL_RELEVANT_RETRIEVED]
|
| 399 |
+
denominator = components[RECALL_RELEVANT]
|
| 400 |
+
if denominator == 0.0:
|
| 401 |
+
if numerator == 0:
|
| 402 |
+
return 1.0
|
| 403 |
+
else:
|
| 404 |
+
raise ValueError("")
|
| 405 |
+
else:
|
| 406 |
+
return numerator / denominator
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def f_measure(precision: float, recall: float, beta=1.0) -> float:
|
| 410 |
+
"""Compute f-measure
|
| 411 |
+
|
| 412 |
+
f-measure is defined as follows:
|
| 413 |
+
F(P, R, b) = (1+b²).P.R / (b².P + R)
|
| 414 |
+
|
| 415 |
+
where P is `precision`, R is `recall` and b is `beta`
|
| 416 |
+
"""
|
| 417 |
+
if precision + recall == 0.0:
|
| 418 |
+
return 0
|
| 419 |
+
return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
|
ailia-models/code/pyannote_audio_utils/metrics/diarization.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2012-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
"""Metrics for diarization"""
|
| 30 |
+
from typing import Optional, Dict, TYPE_CHECKING
|
| 31 |
+
|
| 32 |
+
from pyannote_audio_utils.core import Annotation, Timeline
|
| 33 |
+
from pyannote_audio_utils.core.utils.types import Label
|
| 34 |
+
|
| 35 |
+
from .identification import IdentificationErrorRate
|
| 36 |
+
from .matcher import HungarianMapper
|
| 37 |
+
from .types import Details, MetricComponents
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
# TODO: can't we put these as class attributes?
|
| 43 |
+
DER_NAME = 'diarization error rate'
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DiarizationErrorRate(IdentificationErrorRate):
|
| 47 |
+
"""Diarization error rate
|
| 48 |
+
|
| 49 |
+
First, the optimal mapping between reference and hypothesis labels
|
| 50 |
+
is obtained using the Hungarian algorithm. Then, the actual diarization
|
| 51 |
+
error rate is computed as the identification error rate with each hypothesis
|
| 52 |
+
label translated into the corresponding reference label.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
collar : float, optional
|
| 57 |
+
Duration (in seconds) of collars removed from evaluation around
|
| 58 |
+
boundaries of reference segments.
|
| 59 |
+
skip_overlap : bool, optional
|
| 60 |
+
Set to True to not evaluate overlap regions.
|
| 61 |
+
Defaults to False (i.e. keep overlap regions).
|
| 62 |
+
|
| 63 |
+
Usage
|
| 64 |
+
-----
|
| 65 |
+
|
| 66 |
+
* Diarization error rate between `reference` and `hypothesis` annotations
|
| 67 |
+
|
| 68 |
+
>>> metric = DiarizationErrorRate()
|
| 69 |
+
>>> reference = Annotation(...) # doctest: +SKIP
|
| 70 |
+
>>> hypothesis = Annotation(...) # doctest: +SKIP
|
| 71 |
+
>>> value = metric(reference, hypothesis) # doctest: +SKIP
|
| 72 |
+
|
| 73 |
+
* Compute global diarization error rate and confidence interval
|
| 74 |
+
over multiple documents
|
| 75 |
+
|
| 76 |
+
>>> for reference, hypothesis in ... # doctest: +SKIP
|
| 77 |
+
... metric(reference, hypothesis) # doctest: +SKIP
|
| 78 |
+
>>> global_value = abs(metric) # doctest: +SKIP
|
| 79 |
+
>>> mean, (lower, upper) = metric.confidence_interval() # doctest: +SKIP
|
| 80 |
+
|
| 81 |
+
* Get diarization error rate detailed components
|
| 82 |
+
|
| 83 |
+
>>> components = metric(reference, hypothesis, detailed=True) #doctest +SKIP
|
| 84 |
+
|
| 85 |
+
* Get accumulated components
|
| 86 |
+
|
| 87 |
+
>>> components = metric[:] # doctest: +SKIP
|
| 88 |
+
>>> metric['confusion'] # doctest: +SKIP
|
| 89 |
+
|
| 90 |
+
See Also
|
| 91 |
+
--------
|
| 92 |
+
:class:`pyannote_audio_utils.metric.base.BaseMetric`: details on accumulation
|
| 93 |
+
:class:`pyannote_audio_utils.metric.identification.IdentificationErrorRate`: identification error rate
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def metric_name(cls) -> str:
|
| 99 |
+
return DER_NAME
|
| 100 |
+
|
| 101 |
+
def __init__(self, collar: float = 0.0, skip_overlap: bool = False,
|
| 102 |
+
**kwargs):
|
| 103 |
+
super().__init__(collar=collar, skip_overlap=skip_overlap, **kwargs)
|
| 104 |
+
self.mapper_ = HungarianMapper()
|
| 105 |
+
|
| 106 |
+
def optimal_mapping(self,
|
| 107 |
+
reference: Annotation,
|
| 108 |
+
hypothesis: Annotation,
|
| 109 |
+
uem: Optional[Timeline] = None) -> Dict[Label, Label]:
|
| 110 |
+
"""Optimal label mapping
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
reference : Annotation
|
| 115 |
+
hypothesis : Annotation
|
| 116 |
+
Reference and hypothesis diarization
|
| 117 |
+
uem : Timeline
|
| 118 |
+
Evaluation map
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
mapping : dict
|
| 123 |
+
Mapping between hypothesis (key) and reference (value) labels
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
# NOTE that this 'uemification' will not be called when
|
| 127 |
+
# 'optimal_mapping' is called from 'compute_components' as it
|
| 128 |
+
# has already been done in 'compute_components'
|
| 129 |
+
if uem:
|
| 130 |
+
reference, hypothesis = self.uemify(reference, hypothesis, uem=uem)
|
| 131 |
+
|
| 132 |
+
# call hungarian mapper
|
| 133 |
+
return self.mapper_(hypothesis, reference)
|
| 134 |
+
|
| 135 |
+
def compute_components(self,
|
| 136 |
+
reference: Annotation,
|
| 137 |
+
hypothesis: Annotation,
|
| 138 |
+
uem: Optional[Timeline] = None,
|
| 139 |
+
**kwargs) -> Details:
|
| 140 |
+
# crop reference and hypothesis to evaluated regions (uem)
|
| 141 |
+
# remove collars around reference segment boundaries
|
| 142 |
+
# remove overlap regions (if requested)
|
| 143 |
+
reference, hypothesis, uem = self.uemify(
|
| 144 |
+
reference, hypothesis, uem=uem,
|
| 145 |
+
collar=self.collar, skip_overlap=self.skip_overlap,
|
| 146 |
+
returns_uem=True)
|
| 147 |
+
# NOTE that this 'uemification' must be done here because it
|
| 148 |
+
# might have an impact on the search for the optimal mapping.
|
| 149 |
+
|
| 150 |
+
# make sure reference only contains string labels ('A', 'B', ...)
|
| 151 |
+
reference = reference.rename_labels(generator='string')
|
| 152 |
+
|
| 153 |
+
# make sure hypothesis only contains integer labels (1, 2, ...)
|
| 154 |
+
hypothesis = hypothesis.rename_labels(generator='int')
|
| 155 |
+
|
| 156 |
+
# optimal (int --> str) mapping
|
| 157 |
+
mapping = self.optimal_mapping(reference, hypothesis)
|
| 158 |
+
|
| 159 |
+
# compute identification error rate based on mapped hypothesis
|
| 160 |
+
# NOTE that collar is set to 0.0 because 'uemify' has already
|
| 161 |
+
# been applied (same reason for setting skip_overlap to False)
|
| 162 |
+
mapped = hypothesis.rename_labels(mapping=mapping)
|
| 163 |
+
return super(DiarizationErrorRate, self) \
|
| 164 |
+
.compute_components(reference, mapped, uem=uem,
|
| 165 |
+
collar=0.0, skip_overlap=False,
|
| 166 |
+
**kwargs)
|
| 167 |
+
|
ailia-models/code/pyannote_audio_utils/metrics/identification.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2012-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
from pyannote_audio_utils.core import Annotation, Timeline
|
| 31 |
+
|
| 32 |
+
from .base import BaseMetric
|
| 33 |
+
from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED
|
| 34 |
+
from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED
|
| 35 |
+
from .matcher import LabelMatcher, \
|
| 36 |
+
MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \
|
| 37 |
+
MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM
|
| 38 |
+
from .types import MetricComponents, Details
|
| 39 |
+
from .utils import UEMSupportMixin
|
| 40 |
+
|
| 41 |
+
# TODO: can't we put these as class attributes?
|
| 42 |
+
IER_TOTAL = MATCH_TOTAL
|
| 43 |
+
IER_CORRECT = MATCH_CORRECT
|
| 44 |
+
IER_CONFUSION = MATCH_CONFUSION
|
| 45 |
+
IER_FALSE_ALARM = MATCH_FALSE_ALARM
|
| 46 |
+
IER_MISS = MATCH_MISSED_DETECTION
|
| 47 |
+
IER_NAME = 'identification error rate'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class IdentificationErrorRate(UEMSupportMixin, BaseMetric):
|
| 51 |
+
"""Identification error rate
|
| 52 |
+
|
| 53 |
+
``ier = (wc x confusion + wf x false_alarm + wm x miss) / total``
|
| 54 |
+
|
| 55 |
+
where
|
| 56 |
+
- `confusion` is the total confusion duration in seconds
|
| 57 |
+
- `false_alarm` is the total hypothesis duration where there are
|
| 58 |
+
- `miss` is
|
| 59 |
+
- `total` is the total duration of all tracks
|
| 60 |
+
- wc, wf and wm are optional weights (default to 1)
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
collar : float, optional
|
| 65 |
+
Duration (in seconds) of collars removed from evaluation around
|
| 66 |
+
boundaries of reference segments.
|
| 67 |
+
skip_overlap : bool, optional
|
| 68 |
+
Set to True to not evaluate overlap regions.
|
| 69 |
+
Defaults to False (i.e. keep overlap regions).
|
| 70 |
+
confusion, miss, false_alarm: float, optional
|
| 71 |
+
Optional weights for confusion, miss and false alarm respectively.
|
| 72 |
+
Default to 1. (no weight)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def metric_name(cls) -> str:
|
| 77 |
+
return IER_NAME
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def metric_components(cls) -> MetricComponents:
|
| 81 |
+
return [
|
| 82 |
+
IER_TOTAL,
|
| 83 |
+
IER_CORRECT,
|
| 84 |
+
IER_FALSE_ALARM, IER_MISS,
|
| 85 |
+
IER_CONFUSION]
|
| 86 |
+
|
| 87 |
+
def __init__(self,
|
| 88 |
+
confusion: float = 1.,
|
| 89 |
+
miss: float = 1.,
|
| 90 |
+
false_alarm: float = 1.,
|
| 91 |
+
collar: float = 0.,
|
| 92 |
+
skip_overlap: bool = False,
|
| 93 |
+
**kwargs):
|
| 94 |
+
|
| 95 |
+
super().__init__(**kwargs)
|
| 96 |
+
self.matcher_ = LabelMatcher()
|
| 97 |
+
self.confusion = confusion
|
| 98 |
+
self.miss = miss
|
| 99 |
+
self.false_alarm = false_alarm
|
| 100 |
+
self.collar = collar
|
| 101 |
+
self.skip_overlap = skip_overlap
|
| 102 |
+
|
| 103 |
+
def compute_components(self,
|
| 104 |
+
reference: Annotation,
|
| 105 |
+
hypothesis: Annotation,
|
| 106 |
+
uem: Optional[Timeline] = None,
|
| 107 |
+
collar: Optional[float] = None,
|
| 108 |
+
skip_overlap: Optional[float] = None,
|
| 109 |
+
**kwargs) -> Details:
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
collar : float, optional
|
| 115 |
+
Override self.collar
|
| 116 |
+
skip_overlap : bool, optional
|
| 117 |
+
Override self.skip_overlap
|
| 118 |
+
|
| 119 |
+
See also
|
| 120 |
+
--------
|
| 121 |
+
:class:`pyannote_audio_utils.metric.diarization.DiarizationErrorRate` uses these
|
| 122 |
+
two options in its `compute_components` method.
|
| 123 |
+
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
detail = self.init_components()
|
| 127 |
+
|
| 128 |
+
if collar is None:
|
| 129 |
+
collar = self.collar
|
| 130 |
+
if skip_overlap is None:
|
| 131 |
+
skip_overlap = self.skip_overlap
|
| 132 |
+
|
| 133 |
+
R, H, common_timeline = self.uemify(
|
| 134 |
+
reference, hypothesis, uem=uem,
|
| 135 |
+
collar=collar, skip_overlap=skip_overlap,
|
| 136 |
+
returns_timeline=True)
|
| 137 |
+
|
| 138 |
+
# loop on all segments
|
| 139 |
+
for segment in common_timeline:
|
| 140 |
+
# segment duration
|
| 141 |
+
duration = segment.duration
|
| 142 |
+
|
| 143 |
+
# list of IDs in reference segment
|
| 144 |
+
r = R.get_labels(segment, unique=False)
|
| 145 |
+
|
| 146 |
+
# list of IDs in hypothesis segment
|
| 147 |
+
h = H.get_labels(segment, unique=False)
|
| 148 |
+
|
| 149 |
+
counts, _ = self.matcher_(r, h)
|
| 150 |
+
|
| 151 |
+
detail[IER_TOTAL] += duration * counts[IER_TOTAL]
|
| 152 |
+
detail[IER_CORRECT] += duration * counts[IER_CORRECT]
|
| 153 |
+
detail[IER_CONFUSION] += duration * counts[IER_CONFUSION]
|
| 154 |
+
detail[IER_MISS] += duration * counts[IER_MISS]
|
| 155 |
+
detail[IER_FALSE_ALARM] += duration * counts[IER_FALSE_ALARM]
|
| 156 |
+
|
| 157 |
+
return detail
|
| 158 |
+
|
| 159 |
+
def compute_metric(self, detail: Details) -> float:
|
| 160 |
+
|
| 161 |
+
numerator = 1. * (
|
| 162 |
+
self.confusion * detail[IER_CONFUSION] +
|
| 163 |
+
self.false_alarm * detail[IER_FALSE_ALARM] +
|
| 164 |
+
self.miss * detail[IER_MISS]
|
| 165 |
+
)
|
| 166 |
+
denominator = 1. * detail[IER_TOTAL]
|
| 167 |
+
if denominator == 0.:
|
| 168 |
+
if numerator == 0:
|
| 169 |
+
return 0.
|
| 170 |
+
else:
|
| 171 |
+
return 1.
|
| 172 |
+
else:
|
| 173 |
+
return numerator / denominator
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class IdentificationPrecision(UEMSupportMixin, Precision):
|
| 177 |
+
"""Identification Precision
|
| 178 |
+
|
| 179 |
+
Parameters
|
| 180 |
+
----------
|
| 181 |
+
collar : float, optional
|
| 182 |
+
Duration (in seconds) of collars removed from evaluation around
|
| 183 |
+
boundaries of reference segments.
|
| 184 |
+
skip_overlap : bool, optional
|
| 185 |
+
Set to True to not evaluate overlap regions.
|
| 186 |
+
Defaults to False (i.e. keep overlap regions).
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs):
|
| 190 |
+
super().__init__(**kwargs)
|
| 191 |
+
self.collar = collar
|
| 192 |
+
self.skip_overlap = skip_overlap
|
| 193 |
+
self.matcher_ = LabelMatcher()
|
| 194 |
+
|
| 195 |
+
def compute_components(self,
|
| 196 |
+
reference: Annotation,
|
| 197 |
+
hypothesis: Annotation,
|
| 198 |
+
uem: Optional[Timeline] = None,
|
| 199 |
+
**kwargs) -> Details:
|
| 200 |
+
detail = self.init_components()
|
| 201 |
+
|
| 202 |
+
R, H, common_timeline = self.uemify(
|
| 203 |
+
reference, hypothesis, uem=uem,
|
| 204 |
+
collar=self.collar, skip_overlap=self.skip_overlap,
|
| 205 |
+
returns_timeline=True)
|
| 206 |
+
|
| 207 |
+
# loop on all segments
|
| 208 |
+
for segment in common_timeline:
|
| 209 |
+
# segment duration
|
| 210 |
+
duration = segment.duration
|
| 211 |
+
|
| 212 |
+
# list of IDs in reference segment
|
| 213 |
+
r = R.get_labels(segment, unique=False)
|
| 214 |
+
|
| 215 |
+
# list of IDs in hypothesis segment
|
| 216 |
+
h = H.get_labels(segment, unique=False)
|
| 217 |
+
|
| 218 |
+
counts, _ = self.matcher_(r, h)
|
| 219 |
+
|
| 220 |
+
detail[PRECISION_RETRIEVED] += duration * len(h)
|
| 221 |
+
detail[PRECISION_RELEVANT_RETRIEVED] += \
|
| 222 |
+
duration * counts[IER_CORRECT]
|
| 223 |
+
|
| 224 |
+
return detail
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class IdentificationRecall(UEMSupportMixin, Recall):
|
| 228 |
+
"""Identification Recall
|
| 229 |
+
|
| 230 |
+
Parameters
|
| 231 |
+
----------
|
| 232 |
+
collar : float, optional
|
| 233 |
+
Duration (in seconds) of collars removed from evaluation around
|
| 234 |
+
boundaries of reference segments.
|
| 235 |
+
skip_overlap : bool, optional
|
| 236 |
+
Set to True to not evaluate overlap regions.
|
| 237 |
+
Defaults to False (i.e. keep overlap regions).
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs):
|
| 241 |
+
super().__init__(**kwargs)
|
| 242 |
+
self.collar = collar
|
| 243 |
+
self.skip_overlap = skip_overlap
|
| 244 |
+
self.matcher_ = LabelMatcher()
|
| 245 |
+
|
| 246 |
+
def compute_components(self,
|
| 247 |
+
reference: Annotation,
|
| 248 |
+
hypothesis: Annotation,
|
| 249 |
+
uem: Optional[Timeline] = None,
|
| 250 |
+
**kwargs) -> Details:
|
| 251 |
+
detail = self.init_components()
|
| 252 |
+
|
| 253 |
+
R, H, common_timeline = self.uemify(
|
| 254 |
+
reference, hypothesis, uem=uem,
|
| 255 |
+
collar=self.collar, skip_overlap=self.skip_overlap,
|
| 256 |
+
returns_timeline=True)
|
| 257 |
+
|
| 258 |
+
# loop on all segments
|
| 259 |
+
for segment in common_timeline:
|
| 260 |
+
# segment duration
|
| 261 |
+
duration = segment.duration
|
| 262 |
+
|
| 263 |
+
# list of IDs in reference segment
|
| 264 |
+
r = R.get_labels(segment, unique=False)
|
| 265 |
+
|
| 266 |
+
# list of IDs in hypothesis segment
|
| 267 |
+
h = H.get_labels(segment, unique=False)
|
| 268 |
+
|
| 269 |
+
counts, _ = self.matcher_(r, h)
|
| 270 |
+
|
| 271 |
+
detail[RECALL_RELEVANT] += duration * counts[IER_TOTAL]
|
| 272 |
+
detail[RECALL_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT]
|
| 273 |
+
|
| 274 |
+
return detail
|
ailia-models/code/pyannote_audio_utils/metrics/matcher.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2012-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
from typing import Dict, Tuple, Iterable, List, TYPE_CHECKING
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
from pyannote_audio_utils.core import Annotation
|
| 32 |
+
from scipy.optimize import linear_sum_assignment
|
| 33 |
+
|
| 34 |
+
if TYPE_CHECKING:
|
| 35 |
+
from pyannote_audio_utils.core.utils.types import Label
|
| 36 |
+
|
| 37 |
+
MATCH_CORRECT = 'correct'
|
| 38 |
+
MATCH_CONFUSION = 'confusion'
|
| 39 |
+
MATCH_MISSED_DETECTION = 'missed detection'
|
| 40 |
+
MATCH_FALSE_ALARM = 'false alarm'
|
| 41 |
+
MATCH_TOTAL = 'total'
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LabelMatcher:
|
| 45 |
+
"""
|
| 46 |
+
ID matcher base class mixin.
|
| 47 |
+
|
| 48 |
+
All ID matcher classes must inherit from this class and implement
|
| 49 |
+
.match() -- ie return True if two IDs match and False
|
| 50 |
+
otherwise.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def match(self, rlabel: 'Label', hlabel: 'Label') -> bool:
|
| 54 |
+
"""
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
rlabel :
|
| 58 |
+
Reference label
|
| 59 |
+
hlabel :
|
| 60 |
+
Hypothesis label
|
| 61 |
+
|
| 62 |
+
Returns
|
| 63 |
+
-------
|
| 64 |
+
match : bool
|
| 65 |
+
True if labels match, False otherwise.
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
# Two IDs match if they are equal to each other
|
| 69 |
+
return rlabel == hlabel
|
| 70 |
+
|
| 71 |
+
def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \
|
| 72 |
+
-> Tuple[Dict[str, int],
|
| 73 |
+
Dict[str, List['Label']]]:
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
rlabels, hlabels : iterable
|
| 79 |
+
Reference and hypothesis labels
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
counts : dict
|
| 84 |
+
details : dict
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
# counts and details
|
| 89 |
+
counts = {
|
| 90 |
+
MATCH_CORRECT: 0,
|
| 91 |
+
MATCH_CONFUSION: 0,
|
| 92 |
+
MATCH_MISSED_DETECTION: 0,
|
| 93 |
+
MATCH_FALSE_ALARM: 0,
|
| 94 |
+
MATCH_TOTAL: 0
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
details = {
|
| 98 |
+
MATCH_CORRECT: [],
|
| 99 |
+
MATCH_CONFUSION: [],
|
| 100 |
+
MATCH_MISSED_DETECTION: [],
|
| 101 |
+
MATCH_FALSE_ALARM: []
|
| 102 |
+
}
|
| 103 |
+
# this is to make sure rlabels and hlabels are lists
|
| 104 |
+
# as we will access them later by index
|
| 105 |
+
rlabels = list(rlabels)
|
| 106 |
+
hlabels = list(hlabels)
|
| 107 |
+
|
| 108 |
+
NR = len(rlabels)
|
| 109 |
+
NH = len(hlabels)
|
| 110 |
+
N = max(NR, NH)
|
| 111 |
+
|
| 112 |
+
# corner case
|
| 113 |
+
if N == 0:
|
| 114 |
+
return counts, details
|
| 115 |
+
|
| 116 |
+
# initialize match matrix
|
| 117 |
+
# with True if labels match and False otherwise
|
| 118 |
+
match = np.zeros((N, N), dtype=bool)
|
| 119 |
+
for r, rlabel in enumerate(rlabels):
|
| 120 |
+
for h, hlabel in enumerate(hlabels):
|
| 121 |
+
match[r, h] = self.match(rlabel, hlabel)
|
| 122 |
+
|
| 123 |
+
# find one-to-one mapping that maximize total number of matches
|
| 124 |
+
# using the Hungarian algorithm and computes error accordingly
|
| 125 |
+
for r, h in zip(*linear_sum_assignment(~match)):
|
| 126 |
+
|
| 127 |
+
# hypothesis label is matched with unexisting reference label
|
| 128 |
+
# ==> this is a false alarm
|
| 129 |
+
if r >= NR:
|
| 130 |
+
counts[MATCH_FALSE_ALARM] += 1
|
| 131 |
+
details[MATCH_FALSE_ALARM].append(hlabels[h])
|
| 132 |
+
|
| 133 |
+
# reference label is matched with unexisting hypothesis label
|
| 134 |
+
# ==> this is a missed detection
|
| 135 |
+
elif h >= NH:
|
| 136 |
+
counts[MATCH_MISSED_DETECTION] += 1
|
| 137 |
+
details[MATCH_MISSED_DETECTION].append(rlabels[r])
|
| 138 |
+
|
| 139 |
+
# reference and hypothesis labels match
|
| 140 |
+
# ==> this is a correct detection
|
| 141 |
+
elif match[r, h]:
|
| 142 |
+
counts[MATCH_CORRECT] += 1
|
| 143 |
+
details[MATCH_CORRECT].append((rlabels[r], hlabels[h]))
|
| 144 |
+
|
| 145 |
+
# reference and hypothesis do not match
|
| 146 |
+
# ==> this is a confusion
|
| 147 |
+
else:
|
| 148 |
+
counts[MATCH_CONFUSION] += 1
|
| 149 |
+
details[MATCH_CONFUSION].append((rlabels[r], hlabels[h]))
|
| 150 |
+
|
| 151 |
+
counts[MATCH_TOTAL] += NR
|
| 152 |
+
|
| 153 |
+
# returns counts and details
|
| 154 |
+
return counts, details
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class HungarianMapper:
|
| 158 |
+
|
| 159 |
+
def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']:
|
| 160 |
+
mapping = {}
|
| 161 |
+
|
| 162 |
+
cooccurrence = A * B
|
| 163 |
+
a_labels, b_labels = A.labels(), B.labels()
|
| 164 |
+
|
| 165 |
+
for a, b in zip(*linear_sum_assignment(-cooccurrence)):
|
| 166 |
+
if cooccurrence[a, b] > 0:
|
| 167 |
+
mapping[a_labels[a]] = b_labels[b]
|
| 168 |
+
|
| 169 |
+
return mapping
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class GreedyMapper:
|
| 173 |
+
|
| 174 |
+
def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']:
|
| 175 |
+
mapping = {}
|
| 176 |
+
|
| 177 |
+
cooccurrence = A * B
|
| 178 |
+
Na, Nb = cooccurrence.shape
|
| 179 |
+
a_labels, b_labels = A.labels(), B.labels()
|
| 180 |
+
|
| 181 |
+
for i in range(min(Na, Nb)):
|
| 182 |
+
a, b = np.unravel_index(np.argmax(cooccurrence), (Na, Nb))
|
| 183 |
+
|
| 184 |
+
if cooccurrence[a, b] > 0:
|
| 185 |
+
mapping[a_labels[a]] = b_labels[b]
|
| 186 |
+
cooccurrence[a, :] = 0.
|
| 187 |
+
cooccurrence[:, b] = 0.
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
return mapping
|
ailia-models/code/pyannote_audio_utils/metrics/types.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Literal
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
MetricComponent = str
|
| 5 |
+
CalibrationMethod = Literal["isotonic", "sigmoid"]
|
| 6 |
+
MetricComponents = List[MetricComponent]
|
| 7 |
+
Details = Dict[MetricComponent, float]
|
ailia-models/code/pyannote_audio_utils/metrics/utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2012-2019 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
import warnings
|
| 30 |
+
from typing import Optional, Tuple, Union
|
| 31 |
+
|
| 32 |
+
from pyannote_audio_utils.core import Timeline, Segment, Annotation
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class UEMSupportMixin:
|
| 36 |
+
"""Provides 'uemify' method with optional (à la NIST) collar"""
|
| 37 |
+
|
| 38 |
+
def extrude(self,
|
| 39 |
+
uem: Timeline,
|
| 40 |
+
reference: Annotation,
|
| 41 |
+
collar: float = 0.0,
|
| 42 |
+
skip_overlap: bool = False) -> Timeline:
|
| 43 |
+
"""Extrude reference boundary collars from uem
|
| 44 |
+
|
| 45 |
+
reference |----| |--------------| |-------------|
|
| 46 |
+
uem |---------------------| |-------------------------------|
|
| 47 |
+
extruded |--| |--| |---| |-----| |-| |-----| |-----------| |-----|
|
| 48 |
+
|
| 49 |
+
Parameters
|
| 50 |
+
----------
|
| 51 |
+
uem : Timeline
|
| 52 |
+
Evaluation map.
|
| 53 |
+
reference : Annotation
|
| 54 |
+
Reference annotation.
|
| 55 |
+
collar : float, optional
|
| 56 |
+
When provided, set the duration of collars centered around
|
| 57 |
+
reference segment boundaries that are extruded from both reference
|
| 58 |
+
and hypothesis. Defaults to 0. (i.e. no collar).
|
| 59 |
+
skip_overlap : bool, optional
|
| 60 |
+
Set to True to not evaluate overlap regions.
|
| 61 |
+
Defaults to False (i.e. keep overlap regions).
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
extruded_uem : Timeline
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
if collar == 0. and not skip_overlap:
|
| 69 |
+
return uem
|
| 70 |
+
|
| 71 |
+
collars, overlap_regions = [], []
|
| 72 |
+
|
| 73 |
+
# build list of collars if needed
|
| 74 |
+
if collar > 0.:
|
| 75 |
+
# iterate over all segments in reference
|
| 76 |
+
for segment in reference.itersegments():
|
| 77 |
+
# add collar centered on start time
|
| 78 |
+
t = segment.start
|
| 79 |
+
collars.append(Segment(t - .5 * collar, t + .5 * collar))
|
| 80 |
+
|
| 81 |
+
# add collar centered on end time
|
| 82 |
+
t = segment.end
|
| 83 |
+
collars.append(Segment(t - .5 * collar, t + .5 * collar))
|
| 84 |
+
|
| 85 |
+
# build list of overlap regions if needed
|
| 86 |
+
if skip_overlap:
|
| 87 |
+
# iterate over pair of intersecting segments
|
| 88 |
+
for (segment1, track1), (segment2, track2) in reference.co_iter(reference):
|
| 89 |
+
if segment1 == segment2 and track1 == track2:
|
| 90 |
+
continue
|
| 91 |
+
# add their intersection
|
| 92 |
+
overlap_regions.append(segment1 & segment2)
|
| 93 |
+
|
| 94 |
+
segments = collars + overlap_regions
|
| 95 |
+
|
| 96 |
+
return Timeline(segments=segments).support().gaps(support=uem)
|
| 97 |
+
|
| 98 |
+
def common_timeline(self, reference: Annotation, hypothesis: Annotation) \
|
| 99 |
+
-> Timeline:
|
| 100 |
+
"""Return timeline common to both reference and hypothesis
|
| 101 |
+
|
| 102 |
+
reference |--------| |------------| |---------| |----|
|
| 103 |
+
hypothesis |--------------| |------| |----------------|
|
| 104 |
+
timeline |--|-----|----|---|-|------| |-|---------|----| |----|
|
| 105 |
+
|
| 106 |
+
Parameters
|
| 107 |
+
----------
|
| 108 |
+
reference : Annotation
|
| 109 |
+
hypothesis : Annotation
|
| 110 |
+
|
| 111 |
+
Returns
|
| 112 |
+
-------
|
| 113 |
+
timeline : Timeline
|
| 114 |
+
"""
|
| 115 |
+
timeline = reference.get_timeline(copy=True)
|
| 116 |
+
timeline.update(hypothesis.get_timeline(copy=False))
|
| 117 |
+
return timeline.segmentation()
|
| 118 |
+
|
| 119 |
+
def project(self, annotation: Annotation, timeline: Timeline) -> Annotation:
|
| 120 |
+
"""Project annotation onto timeline segments
|
| 121 |
+
|
| 122 |
+
reference |__A__| |__B__|
|
| 123 |
+
|____C____|
|
| 124 |
+
|
| 125 |
+
timeline |---|---|---| |---|
|
| 126 |
+
|
| 127 |
+
projection |_A_|_A_|_C_| |_B_|
|
| 128 |
+
|_C_|
|
| 129 |
+
|
| 130 |
+
Parameters
|
| 131 |
+
----------
|
| 132 |
+
annotation : Annotation
|
| 133 |
+
timeline : Timeline
|
| 134 |
+
|
| 135 |
+
Returns
|
| 136 |
+
-------
|
| 137 |
+
projection : Annotation
|
| 138 |
+
"""
|
| 139 |
+
projection = annotation.empty()
|
| 140 |
+
timeline_ = annotation.get_timeline(copy=False)
|
| 141 |
+
for segment_, segment in timeline_.co_iter(timeline):
|
| 142 |
+
for track_ in annotation.get_tracks(segment_):
|
| 143 |
+
track = projection.new_track(segment, candidate=track_)
|
| 144 |
+
projection[segment, track] = annotation[segment_, track_]
|
| 145 |
+
return projection
|
| 146 |
+
|
| 147 |
+
def uemify(self,
|
| 148 |
+
reference: Annotation,
|
| 149 |
+
hypothesis: Annotation,
|
| 150 |
+
uem: Optional[Timeline] = None,
|
| 151 |
+
collar: float = 0.,
|
| 152 |
+
skip_overlap: bool = False,
|
| 153 |
+
returns_uem: bool = False,
|
| 154 |
+
returns_timeline: bool = False) \
|
| 155 |
+
-> Union[
|
| 156 |
+
Tuple[Annotation, Annotation],
|
| 157 |
+
Tuple[Annotation, Annotation, Timeline],
|
| 158 |
+
Tuple[Annotation, Annotation, Timeline, Timeline],
|
| 159 |
+
]:
|
| 160 |
+
"""Crop 'reference' and 'hypothesis' to 'uem' support
|
| 161 |
+
|
| 162 |
+
Parameters
|
| 163 |
+
----------
|
| 164 |
+
reference, hypothesis : Annotation
|
| 165 |
+
Reference and hypothesis annotations.
|
| 166 |
+
uem : Timeline, optional
|
| 167 |
+
Evaluation map.
|
| 168 |
+
collar : float, optional
|
| 169 |
+
When provided, set the duration of collars centered around
|
| 170 |
+
reference segment boundaries that are extruded from both reference
|
| 171 |
+
and hypothesis. Defaults to 0. (i.e. no collar).
|
| 172 |
+
skip_overlap : bool, optional
|
| 173 |
+
Set to True to not evaluate overlap regions.
|
| 174 |
+
Defaults to False (i.e. keep overlap regions).
|
| 175 |
+
returns_uem : bool, optional
|
| 176 |
+
Set to True to return extruded uem as well.
|
| 177 |
+
Defaults to False (i.e. only return reference and hypothesis)
|
| 178 |
+
returns_timeline : bool, optional
|
| 179 |
+
Set to True to oversegment reference and hypothesis so that they
|
| 180 |
+
share the same internal timeline.
|
| 181 |
+
|
| 182 |
+
Returns
|
| 183 |
+
-------
|
| 184 |
+
reference, hypothesis : Annotation
|
| 185 |
+
Extruded reference and hypothesis annotations
|
| 186 |
+
uem : Timeline
|
| 187 |
+
Extruded uem (returned only when 'returns_uem' is True)
|
| 188 |
+
timeline : Timeline:
|
| 189 |
+
Common timeline (returned only when 'returns_timeline' is True)
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
# when uem is not provided, use the union of reference and hypothesis
|
| 193 |
+
# extents -- and warn the user about that.
|
| 194 |
+
if uem is None:
|
| 195 |
+
r_extent = reference.get_timeline().extent()
|
| 196 |
+
h_extent = hypothesis.get_timeline().extent()
|
| 197 |
+
extent = r_extent | h_extent
|
| 198 |
+
uem = Timeline(segments=[extent] if extent else [],
|
| 199 |
+
uri=reference.uri)
|
| 200 |
+
warnings.warn(
|
| 201 |
+
"'uem' was approximated by the union of 'reference' "
|
| 202 |
+
"and 'hypothesis' extents.")
|
| 203 |
+
|
| 204 |
+
# extrude collars (and overlap regions) from uem
|
| 205 |
+
uem = self.extrude(uem, reference, collar=collar,
|
| 206 |
+
skip_overlap=skip_overlap)
|
| 207 |
+
|
| 208 |
+
# extrude regions outside of uem
|
| 209 |
+
reference = reference.crop(uem, mode='intersection')
|
| 210 |
+
hypothesis = hypothesis.crop(uem, mode='intersection')
|
| 211 |
+
|
| 212 |
+
# project reference and hypothesis on common timeline
|
| 213 |
+
if returns_timeline:
|
| 214 |
+
timeline = self.common_timeline(reference, hypothesis)
|
| 215 |
+
reference = self.project(reference, timeline)
|
| 216 |
+
hypothesis = self.project(hypothesis, timeline)
|
| 217 |
+
|
| 218 |
+
result = (reference, hypothesis)
|
| 219 |
+
if returns_uem:
|
| 220 |
+
result += (uem,)
|
| 221 |
+
|
| 222 |
+
if returns_timeline:
|
| 223 |
+
result += (timeline,)
|
| 224 |
+
|
| 225 |
+
return result
|
ailia-models/code/pyannote_audio_utils/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2018-2020 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# from ._version import get_versions
|
| 31 |
+
|
| 32 |
+
# __version__ = get_versions()["version"]
|
| 33 |
+
# del get_versions
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from .pipeline import Pipeline
|
| 37 |
+
# from .optimizer import Optimizer
|
ailia-models/code/pyannote_audio_utils/pipeline/parameter.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2018-2020 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
# Hadrien TITEUX - https://github.com/hadware
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
from typing import Iterable, Any
|
| 32 |
+
|
| 33 |
+
from .pipeline import Pipeline
|
| 34 |
+
from collections.abc import Mapping
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Parameter:
|
| 38 |
+
"""Base hyper-parameter"""
|
| 39 |
+
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Categorical(Parameter):
|
| 44 |
+
"""Categorical hyper-parameter
|
| 45 |
+
|
| 46 |
+
The value is sampled from `choices`.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
choices : iterable
|
| 51 |
+
Candidates of hyper-parameter value.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, choices: Iterable):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.choices = list(choices)
|
| 57 |
+
|
| 58 |
+
# def __call__(self, name: str, trial: Trial):
|
| 59 |
+
# return trial.suggest_categorical(name, self.choices)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class DiscreteUniform(Parameter):
|
| 63 |
+
"""Discrete uniform hyper-parameter
|
| 64 |
+
|
| 65 |
+
The value is sampled from the range [low, high],
|
| 66 |
+
and the step of discretization is `q`.
|
| 67 |
+
|
| 68 |
+
Parameters
|
| 69 |
+
----------
|
| 70 |
+
low : `float`
|
| 71 |
+
Lower endpoint of the range of suggested values.
|
| 72 |
+
`low` is included in the range.
|
| 73 |
+
high : `float`
|
| 74 |
+
Upper endpoint of the range of suggested values.
|
| 75 |
+
`high` is included in the range.
|
| 76 |
+
q : `float`
|
| 77 |
+
A step of discretization.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, low: float, high: float, q: float):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.low = float(low)
|
| 83 |
+
self.high = float(high)
|
| 84 |
+
self.q = float(q)
|
| 85 |
+
|
| 86 |
+
# def __call__(self, name: str, trial: Trial):
|
| 87 |
+
# return trial.suggest_discrete_uniform(name, self.low, self.high, self.q)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Integer(Parameter):
|
| 91 |
+
"""Integer hyper-parameter
|
| 92 |
+
|
| 93 |
+
The value is sampled from the integers in [low, high].
|
| 94 |
+
|
| 95 |
+
Parameters
|
| 96 |
+
----------
|
| 97 |
+
low : `int`
|
| 98 |
+
Lower endpoint of the range of suggested values.
|
| 99 |
+
`low` is included in the range.
|
| 100 |
+
high : `int`
|
| 101 |
+
Upper endpoint of the range of suggested values.
|
| 102 |
+
`high` is included in the range.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, low: int, high: int):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.low = int(low)
|
| 108 |
+
self.high = int(high)
|
| 109 |
+
|
| 110 |
+
# def __call__(self, name: str, trial: Trial):
|
| 111 |
+
# return trial.suggest_int(name, self.low, self.high)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class LogUniform(Parameter):
|
| 115 |
+
"""Log-uniform hyper-parameter
|
| 116 |
+
|
| 117 |
+
The value is sampled from the range [low, high) in the log domain.
|
| 118 |
+
|
| 119 |
+
Parameters
|
| 120 |
+
----------
|
| 121 |
+
low : `float`
|
| 122 |
+
Lower endpoint of the range of suggested values.
|
| 123 |
+
`low` is included in the range.
|
| 124 |
+
high : `float`
|
| 125 |
+
Upper endpoint of the range of suggested values.
|
| 126 |
+
`high` is excluded from the range.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, low: float, high: float):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.low = float(low)
|
| 132 |
+
self.high = float(high)
|
| 133 |
+
|
| 134 |
+
# def __call__(self, name: str, trial: Trial):
|
| 135 |
+
# return trial.suggest_loguniform(name, self.low, self.high)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Uniform(Parameter):
|
| 139 |
+
"""Uniform hyper-parameter
|
| 140 |
+
|
| 141 |
+
The value is sampled from the range [low, high) in the linear domain.
|
| 142 |
+
|
| 143 |
+
Parameters
|
| 144 |
+
----------
|
| 145 |
+
low : `float`
|
| 146 |
+
Lower endpoint of the range of suggested values.
|
| 147 |
+
`low` is included in the range.
|
| 148 |
+
high : `float`
|
| 149 |
+
Upper endpoint of the range of suggested values.
|
| 150 |
+
`high` is excluded from the range.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, low: float, high: float):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.low = float(low)
|
| 156 |
+
self.high = float(high)
|
| 157 |
+
|
| 158 |
+
# def __call__(self, name: str, trial: Trial):
|
| 159 |
+
# return trial.suggest_uniform(name, self.low, self.high)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Frozen(Parameter):
|
| 163 |
+
"""Frozen hyper-parameter
|
| 164 |
+
|
| 165 |
+
The value is fixed a priori
|
| 166 |
+
|
| 167 |
+
Parameters
|
| 168 |
+
----------
|
| 169 |
+
value :
|
| 170 |
+
Fixed value.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self, value: Any):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.value = value
|
| 176 |
+
|
| 177 |
+
# def __call__(self, name: str, trial: Trial):
|
| 178 |
+
# return self.value
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ParamDict(Pipeline, Mapping):
|
| 182 |
+
"""Dict-like structured hyper-parameter
|
| 183 |
+
|
| 184 |
+
Usage
|
| 185 |
+
-----
|
| 186 |
+
>>> params = ParamDict(param1=Uniform(0.0, 1.0), param2=Uniform(-1.0, 1.0))
|
| 187 |
+
>>> params = ParamDict(**{"param1": Uniform(0.0, 1.0), "param2": Uniform(-1.0, 1.0)})
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(self, **params):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.__params = params
|
| 193 |
+
for param_name, param_value in params.items():
|
| 194 |
+
setattr(self, param_name, param_value)
|
| 195 |
+
|
| 196 |
+
def __len__(self):
|
| 197 |
+
return len(self.__params)
|
| 198 |
+
|
| 199 |
+
def __iter__(self):
|
| 200 |
+
return iter(self.__params)
|
| 201 |
+
|
| 202 |
+
def __getitem__(self, param_name):
|
| 203 |
+
return getattr(self, param_name)
|
ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
# The MIT License (MIT)
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2018-2022 CNRS
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in
|
| 16 |
+
# all copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
# AUTHORS
|
| 27 |
+
# Hervé BREDIN - http://herve.niderb.fr
|
| 28 |
+
|
| 29 |
+
from typing import Optional, TextIO, Union, Dict, Any
|
| 30 |
+
|
| 31 |
+
from collections import OrderedDict
|
| 32 |
+
from .typing import PipelineInput
|
| 33 |
+
from .typing import PipelineOutput
|
| 34 |
+
import warnings
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Pipeline:
|
| 38 |
+
"""Base tunable pipeline"""
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
|
| 42 |
+
# un-instantiated parameters (= `Parameter` instances)
|
| 43 |
+
self._parameters: Dict[str, Parameter] = OrderedDict()
|
| 44 |
+
|
| 45 |
+
# instantiated parameters
|
| 46 |
+
self._instantiated: Dict[str, Any] = OrderedDict()
|
| 47 |
+
|
| 48 |
+
# sub-pipelines
|
| 49 |
+
self._pipelines: Dict[str, Pipeline] = OrderedDict()
|
| 50 |
+
|
| 51 |
+
# whether pipeline is currently being optimized
|
| 52 |
+
self.training = False
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def training(self):
|
| 56 |
+
return self._training
|
| 57 |
+
|
| 58 |
+
@training.setter
|
| 59 |
+
def training(self, training):
|
| 60 |
+
self._training = training
|
| 61 |
+
# recursively set sub-pipeline training attribute
|
| 62 |
+
for _, pipeline in self._pipelines.items():
|
| 63 |
+
pipeline.training = training
|
| 64 |
+
|
| 65 |
+
def __hash__(self):
|
| 66 |
+
# FIXME -- also keep track of (sub)pipeline attributes
|
| 67 |
+
frozen = self.parameters(frozen=True)
|
| 68 |
+
return hash(tuple(sorted(self._flatten(frozen).items())))
|
| 69 |
+
|
| 70 |
+
def __getattr__(self, name):
|
| 71 |
+
"""(Advanced) attribute getter"""
|
| 72 |
+
|
| 73 |
+
# in case `name` corresponds to an instantiated parameter value, returns it
|
| 74 |
+
if "_instantiated" in self.__dict__:
|
| 75 |
+
_instantiated = self.__dict__["_instantiated"]
|
| 76 |
+
if name in _instantiated:
|
| 77 |
+
return _instantiated[name]
|
| 78 |
+
|
| 79 |
+
# in case `name` corresponds to a parameter, returns it
|
| 80 |
+
if "_parameters" in self.__dict__:
|
| 81 |
+
_parameters = self.__dict__["_parameters"]
|
| 82 |
+
if name in _parameters:
|
| 83 |
+
return _parameters[name]
|
| 84 |
+
|
| 85 |
+
# in case `name` corresponds to a sub-pipeline, returns it
|
| 86 |
+
if "_pipelines" in self.__dict__:
|
| 87 |
+
_pipelines = self.__dict__["_pipelines"]
|
| 88 |
+
if name in _pipelines:
|
| 89 |
+
return _pipelines[name]
|
| 90 |
+
|
| 91 |
+
msg = "'{}' object has no attribute '{}'".format(type(self).__name__, name)
|
| 92 |
+
raise AttributeError(msg)
|
| 93 |
+
|
| 94 |
+
def __setattr__(self, name, value):
|
| 95 |
+
"""(Advanced) attribute setter
|
| 96 |
+
|
| 97 |
+
If `value` is an instance of `Parameter`, store it in `_parameters`.
|
| 98 |
+
elif `value` is an instance of `Pipeline`, store it in `_pipelines`.
|
| 99 |
+
elif `value` isn't an instance of `Parameter` and `name` is in `_parameters`,
|
| 100 |
+
store `value` in `_instantiated`.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
# imported here to avoid circular import
|
| 104 |
+
from .parameter import Parameter
|
| 105 |
+
|
| 106 |
+
def remove_from(*dicts):
|
| 107 |
+
for d in dicts:
|
| 108 |
+
if name in d:
|
| 109 |
+
del d[name]
|
| 110 |
+
|
| 111 |
+
_parameters = self.__dict__.get("_parameters")
|
| 112 |
+
_instantiated = self.__dict__.get("_instantiated")
|
| 113 |
+
_pipelines = self.__dict__.get("_pipelines")
|
| 114 |
+
|
| 115 |
+
# if `value` is an instance of `Parameter`, store it in `_parameters`
|
| 116 |
+
|
| 117 |
+
if isinstance(value, Parameter):
|
| 118 |
+
if _parameters is None:
|
| 119 |
+
msg = (
|
| 120 |
+
"cannot assign hyper-parameters " "before Pipeline.__init__() call"
|
| 121 |
+
)
|
| 122 |
+
raise AttributeError(msg)
|
| 123 |
+
remove_from(self.__dict__, _instantiated, _pipelines)
|
| 124 |
+
_parameters[name] = value
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
# add/update one sub-pipeline
|
| 128 |
+
if isinstance(value, Pipeline):
|
| 129 |
+
if _pipelines is None:
|
| 130 |
+
msg = "cannot assign sub-pipelines " "before Pipeline.__init__() call"
|
| 131 |
+
raise AttributeError(msg)
|
| 132 |
+
remove_from(self.__dict__, _parameters, _instantiated)
|
| 133 |
+
_pipelines[name] = value
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
# store instantiated parameter value
|
| 137 |
+
if _parameters is not None and name in _parameters:
|
| 138 |
+
_instantiated[name] = value
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
object.__setattr__(self, name, value)
|
| 142 |
+
|
| 143 |
+
def __delattr__(self, name):
|
| 144 |
+
|
| 145 |
+
if name in self._parameters:
|
| 146 |
+
del self._parameters[name]
|
| 147 |
+
|
| 148 |
+
elif name in self._instantiated:
|
| 149 |
+
del self._instantiated[name]
|
| 150 |
+
|
| 151 |
+
elif name in self._pipelines:
|
| 152 |
+
del self._pipelines[name]
|
| 153 |
+
|
| 154 |
+
else:
|
| 155 |
+
object.__delattr__(self, name)
|
| 156 |
+
|
| 157 |
+
def _flattened_parameters(
|
| 158 |
+
self, frozen: Optional[bool] = False, instantiated: Optional[bool] = False
|
| 159 |
+
) -> dict:
|
| 160 |
+
"""Get flattened dictionary of parameters
|
| 161 |
+
|
| 162 |
+
Parameters
|
| 163 |
+
----------
|
| 164 |
+
frozen : `bool`, optional
|
| 165 |
+
Only return value of frozen parameters.
|
| 166 |
+
instantiated : `bool`, optional
|
| 167 |
+
Only return value of instantiated parameters.
|
| 168 |
+
|
| 169 |
+
Returns
|
| 170 |
+
-------
|
| 171 |
+
params : `dict`
|
| 172 |
+
Flattened dictionary of parameters.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
# imported here to avoid circular imports
|
| 176 |
+
from .parameter import Frozen
|
| 177 |
+
|
| 178 |
+
if frozen and instantiated:
|
| 179 |
+
msg = "one must choose between `frozen` and `instantiated`."
|
| 180 |
+
raise ValueError(msg)
|
| 181 |
+
|
| 182 |
+
# initialize dictionary with root parameters
|
| 183 |
+
if instantiated:
|
| 184 |
+
params = dict(self._instantiated)
|
| 185 |
+
|
| 186 |
+
elif frozen:
|
| 187 |
+
params = {
|
| 188 |
+
n: p.value for n, p in self._parameters.items() if isinstance(p, Frozen)
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
params = dict(self._parameters)
|
| 193 |
+
|
| 194 |
+
# recursively add sub-pipeline parameters
|
| 195 |
+
for pipeline_name, pipeline in self._pipelines.items():
|
| 196 |
+
pipeline_params = pipeline._flattened_parameters(
|
| 197 |
+
frozen=frozen, instantiated=instantiated
|
| 198 |
+
)
|
| 199 |
+
for name, value in pipeline_params.items():
|
| 200 |
+
params[f"{pipeline_name}>{name}"] = value
|
| 201 |
+
|
| 202 |
+
return params
|
| 203 |
+
|
| 204 |
+
def _flatten(self, nested_params: dict) -> dict:
|
| 205 |
+
"""Convert nested dictionary to flattened dictionary
|
| 206 |
+
|
| 207 |
+
For instance, a nested dictionary like this one:
|
| 208 |
+
|
| 209 |
+
~~~~~~~~~~~~~~~~~~~~~
|
| 210 |
+
param: value1
|
| 211 |
+
pipeline:
|
| 212 |
+
param: value2
|
| 213 |
+
subpipeline:
|
| 214 |
+
param: value3
|
| 215 |
+
~~~~~~~~~~~~~~~~~~~~~
|
| 216 |
+
|
| 217 |
+
becomes the following flattened dictionary:
|
| 218 |
+
|
| 219 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 220 |
+
param : value1
|
| 221 |
+
pipeline>param : value2
|
| 222 |
+
pipeline>subpipeline>param : value3
|
| 223 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 224 |
+
|
| 225 |
+
Parameter
|
| 226 |
+
---------
|
| 227 |
+
nested_params : `dict`
|
| 228 |
+
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
flattened_params : `dict`
|
| 232 |
+
"""
|
| 233 |
+
flattened_params = dict()
|
| 234 |
+
for name, value in nested_params.items():
|
| 235 |
+
if isinstance(value, dict):
|
| 236 |
+
for subname, subvalue in self._flatten(value).items():
|
| 237 |
+
flattened_params[f"{name}>{subname}"] = subvalue
|
| 238 |
+
else:
|
| 239 |
+
flattened_params[name] = value
|
| 240 |
+
return flattened_params
|
| 241 |
+
|
| 242 |
+
def _unflatten(self, flattened_params: dict) -> dict:
|
| 243 |
+
"""Convert flattened dictionary to nested dictionary
|
| 244 |
+
|
| 245 |
+
For instance, a flattened dictionary like this one:
|
| 246 |
+
|
| 247 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 248 |
+
param : value1
|
| 249 |
+
pipeline>param : value2
|
| 250 |
+
pipeline>subpipeline>param : value3
|
| 251 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 252 |
+
|
| 253 |
+
becomes the following nested dictionary:
|
| 254 |
+
|
| 255 |
+
~~~~~~~~~~~~~~~~~~~~~
|
| 256 |
+
param: value1
|
| 257 |
+
pipeline:
|
| 258 |
+
param: value2
|
| 259 |
+
subpipeline:
|
| 260 |
+
param: value3
|
| 261 |
+
~~~~~~~~~~~~~~~~~~~~~
|
| 262 |
+
|
| 263 |
+
Parameter
|
| 264 |
+
---------
|
| 265 |
+
flattened_params : `dict`
|
| 266 |
+
|
| 267 |
+
Returns
|
| 268 |
+
-------
|
| 269 |
+
nested_params : `dict`
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
nested_params = {}
|
| 273 |
+
|
| 274 |
+
pipeline_params = {name: {} for name in self._pipelines}
|
| 275 |
+
for name, value in flattened_params.items():
|
| 276 |
+
# if name contains has multipe ">"-separated tokens
|
| 277 |
+
# it means that it is a sub-pipeline parameter
|
| 278 |
+
tokens = name.split(">")
|
| 279 |
+
if len(tokens) > 1:
|
| 280 |
+
# read sub-pipeline name
|
| 281 |
+
pipeline_name = tokens[0]
|
| 282 |
+
# read parameter name
|
| 283 |
+
param_name = ">".join(tokens[1:])
|
| 284 |
+
# update sub-pipeline flattened dictionary
|
| 285 |
+
pipeline_params[pipeline_name][param_name] = value
|
| 286 |
+
|
| 287 |
+
# otherwise, it is an actual parameter of this pipeline
|
| 288 |
+
else:
|
| 289 |
+
# store it as such
|
| 290 |
+
nested_params[name] = value
|
| 291 |
+
|
| 292 |
+
# recursively unflatten sub-pipeline flattened dictionary
|
| 293 |
+
for name, pipeline in self._pipelines.items():
|
| 294 |
+
nested_params[name] = pipeline._unflatten(pipeline_params[name])
|
| 295 |
+
|
| 296 |
+
return nested_params
|
| 297 |
+
|
| 298 |
+
def parameters(
|
| 299 |
+
self,
|
| 300 |
+
trial = None,
|
| 301 |
+
frozen: Optional[bool] = False,
|
| 302 |
+
instantiated: Optional[bool] = False,
|
| 303 |
+
) -> dict:
|
| 304 |
+
"""Returns nested dictionary of (optionnaly instantiated) parameters.
|
| 305 |
+
|
| 306 |
+
For a pipeline with one `param`, one sub-pipeline with its own param
|
| 307 |
+
and its own sub-pipeline, it will returns something like:
|
| 308 |
+
|
| 309 |
+
~~~~~~~~~~~~~~~~~~~~~
|
| 310 |
+
param: value1
|
| 311 |
+
pipeline:
|
| 312 |
+
param: value2
|
| 313 |
+
subpipeline:
|
| 314 |
+
param: value3
|
| 315 |
+
~~~~~~~~~~~~~~~~~~~~~
|
| 316 |
+
|
| 317 |
+
Parameter
|
| 318 |
+
---------
|
| 319 |
+
trial : `Trial`, optional
|
| 320 |
+
When provided, use trial to suggest new parameter values
|
| 321 |
+
and return them.
|
| 322 |
+
frozen : `bool`, optional
|
| 323 |
+
Return frozen parameter value
|
| 324 |
+
instantiated : `bool`, optional
|
| 325 |
+
Return instantiated parameter values.
|
| 326 |
+
|
| 327 |
+
Returns
|
| 328 |
+
-------
|
| 329 |
+
params : `dict`
|
| 330 |
+
Nested dictionary of parameters. See above for the actual format.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
if (instantiated or frozen) and trial is not None:
|
| 334 |
+
msg = "One must choose between `trial`, `instantiated`, or `frozen`"
|
| 335 |
+
raise ValueError(msg)
|
| 336 |
+
|
| 337 |
+
# get flattened dictionary of uninstantiated parameters
|
| 338 |
+
params = self._flattened_parameters(frozen=frozen, instantiated=instantiated)
|
| 339 |
+
|
| 340 |
+
if trial is not None:
|
| 341 |
+
# use provided `trial` to suggest values for parameters
|
| 342 |
+
params = {name: param(name, trial) for name, param in params.items()}
|
| 343 |
+
|
| 344 |
+
# un-flatten flattened dictionary
|
| 345 |
+
return self._unflatten(params)
|
| 346 |
+
|
| 347 |
+
def initialize(self):
|
| 348 |
+
"""Instantiate root pipeline with current set of parameters"""
|
| 349 |
+
pass
|
| 350 |
+
|
| 351 |
+
# def freeze(self, params: dict) -> "Pipeline":
|
| 352 |
+
# """Recursively freeze pipeline parameters
|
| 353 |
+
|
| 354 |
+
# Parameters
|
| 355 |
+
# ----------
|
| 356 |
+
# params : `dict`
|
| 357 |
+
# Nested dictionary of parameters.
|
| 358 |
+
|
| 359 |
+
# Returns
|
| 360 |
+
# -------
|
| 361 |
+
# self : `Pipeline`
|
| 362 |
+
# Pipeline.
|
| 363 |
+
# """
|
| 364 |
+
|
| 365 |
+
# # imported here to avoid circular imports
|
| 366 |
+
# from .parameter import Frozen
|
| 367 |
+
|
| 368 |
+
# for name, value in params.items():
|
| 369 |
+
|
| 370 |
+
# # recursively freeze sub-pipelines parameters
|
| 371 |
+
# if name in self._pipelines:
|
| 372 |
+
# if not isinstance(value, dict):
|
| 373 |
+
# msg = (
|
| 374 |
+
# f"only parameters of '{name}' pipeline can "
|
| 375 |
+
# f"be frozen (not the whole pipeline)"
|
| 376 |
+
# )
|
| 377 |
+
# raise ValueError(msg)
|
| 378 |
+
# self._pipelines[name].freeze(value)
|
| 379 |
+
# continue
|
| 380 |
+
|
| 381 |
+
# # instantiate parameter value
|
| 382 |
+
# if name in self._parameters:
|
| 383 |
+
# setattr(self, name, Frozen(value))
|
| 384 |
+
# continue
|
| 385 |
+
|
| 386 |
+
# msg = f"parameter '{name}' does not exist"
|
| 387 |
+
# raise ValueError(msg)
|
| 388 |
+
|
| 389 |
+
# return self
|
| 390 |
+
|
| 391 |
+
def instantiate(self, params: dict) -> "Pipeline":
|
| 392 |
+
"""Recursively instantiate all pipelines
|
| 393 |
+
|
| 394 |
+
Parameters
|
| 395 |
+
----------
|
| 396 |
+
params : `dict`
|
| 397 |
+
Nested dictionary of parameters.
|
| 398 |
+
|
| 399 |
+
Returns
|
| 400 |
+
-------
|
| 401 |
+
self : `Pipeline`
|
| 402 |
+
Instantiated pipeline.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
# imported here to avoid circular imports
|
| 406 |
+
from .parameter import Frozen
|
| 407 |
+
|
| 408 |
+
for name, value in params.items():
|
| 409 |
+
|
| 410 |
+
# recursively call `instantiate` with sub-pipelines
|
| 411 |
+
if name in self._pipelines:
|
| 412 |
+
if not isinstance(value, dict):
|
| 413 |
+
msg = (
|
| 414 |
+
f"only parameters of '{name}' pipeline can "
|
| 415 |
+
f"be instantiated (not the whole pipeline)"
|
| 416 |
+
)
|
| 417 |
+
raise ValueError(msg)
|
| 418 |
+
self._pipelines[name].instantiate(value)
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
# instantiate parameter value
|
| 422 |
+
if name in self._parameters:
|
| 423 |
+
param = getattr(self, name)
|
| 424 |
+
# overwrite provided value of frozen parameters
|
| 425 |
+
if isinstance(param, Frozen) and param.value != value:
|
| 426 |
+
msg = (
|
| 427 |
+
f"Parameter '{name}' is frozen: using its frozen value "
|
| 428 |
+
f"({param.value}) instead of the one provided ({value})."
|
| 429 |
+
)
|
| 430 |
+
warnings.warn(msg)
|
| 431 |
+
value = param.value
|
| 432 |
+
setattr(self, name, value)
|
| 433 |
+
continue
|
| 434 |
+
|
| 435 |
+
msg = f"parameter '{name}' does not exist"
|
| 436 |
+
raise ValueError(msg)
|
| 437 |
+
|
| 438 |
+
self.initialize()
|
| 439 |
+
|
| 440 |
+
return self
|
| 441 |
+
|
| 442 |
+
@property
|
| 443 |
+
def instantiated(self):
|
| 444 |
+
"""Whether pipeline has been instantiated (and therefore can be applied)"""
|
| 445 |
+
parameters = set(self._flatten(self.parameters()))
|
| 446 |
+
instantiated = set(self._flatten(self.parameters(instantiated=True)))
|
| 447 |
+
return parameters == instantiated
|
| 448 |
+
|
| 449 |
+
# def dump_params(
|
| 450 |
+
# self,
|
| 451 |
+
# params_yml: Path,
|
| 452 |
+
# params: Optional[dict] = None,
|
| 453 |
+
# loss: Optional[float] = None,
|
| 454 |
+
# ) -> str:
|
| 455 |
+
# """Dump parameters to disk
|
| 456 |
+
|
| 457 |
+
# Parameters
|
| 458 |
+
# ----------
|
| 459 |
+
# params_yml : `Path`
|
| 460 |
+
# Path to YAML file.
|
| 461 |
+
# params : `dict`, optional
|
| 462 |
+
# Nested Parameters. Defaults to pipeline current parameters.
|
| 463 |
+
# loss : `float`, optional
|
| 464 |
+
# Loss value. Defaults to not write loss to file.
|
| 465 |
+
|
| 466 |
+
# Returns
|
| 467 |
+
# -------
|
| 468 |
+
# content : `str`
|
| 469 |
+
# Content written in `param_yml`.
|
| 470 |
+
# """
|
| 471 |
+
# # use instantiated parameters when `params` is not provided
|
| 472 |
+
# if params is None:
|
| 473 |
+
# params = self.parameters(instantiated=True)
|
| 474 |
+
|
| 475 |
+
# content = {"params": params}
|
| 476 |
+
# if loss is not None:
|
| 477 |
+
# content["loss"] = loss
|
| 478 |
+
|
| 479 |
+
# # format as valid YAML
|
| 480 |
+
# content_yml = yaml.dump(content, default_flow_style=False)
|
| 481 |
+
|
| 482 |
+
# # (safely) dump YAML content
|
| 483 |
+
# with FileLock(params_yml.with_suffix(".lock")):
|
| 484 |
+
# with open(params_yml, mode="w") as fp:
|
| 485 |
+
# fp.write(content_yml)
|
| 486 |
+
|
| 487 |
+
# return content_yml
|
| 488 |
+
|
| 489 |
+
# def load_params(self, params_yml: Path) -> "Pipeline":
|
| 490 |
+
# """Instantiate pipeline using parameters from disk
|
| 491 |
+
|
| 492 |
+
# Parameters
|
| 493 |
+
# ----------
|
| 494 |
+
# param_yml : `Path`
|
| 495 |
+
# Path to YAML file.
|
| 496 |
+
|
| 497 |
+
# Returns
|
| 498 |
+
# -------
|
| 499 |
+
# self : `Pipeline`
|
| 500 |
+
# Instantiated pipeline
|
| 501 |
+
|
| 502 |
+
# """
|
| 503 |
+
|
| 504 |
+
# with open(params_yml, mode="r") as fp:
|
| 505 |
+
# params = yaml.load(fp, Loader=yaml.SafeLoader)
|
| 506 |
+
# return self.instantiate(params["params"])
|
| 507 |
+
|
| 508 |
+
def __call__(self, input: PipelineInput) -> PipelineOutput:
|
| 509 |
+
"""Apply pipeline on input and return its output"""
|
| 510 |
+
raise NotImplementedError
|
| 511 |
+
|
| 512 |
+
# def get_metric(self) -> "pyannote.metrics.base.BaseMetric":
|
| 513 |
+
# """Return new metric (from pyannote.metrics)
|
| 514 |
+
|
| 515 |
+
# When this method is implemented, the returned metric is used as a
|
| 516 |
+
# replacement for the loss method below.
|
| 517 |
+
|
| 518 |
+
# Returns
|
| 519 |
+
# -------
|
| 520 |
+
# metric : `pyannote.metrics.base.BaseMetric`
|
| 521 |
+
# """
|
| 522 |
+
# raise NotImplementedError()
|
| 523 |
+
|
| 524 |
+
# def get_direction(self) -> Direction:
|
| 525 |
+
# return "minimize"
|
| 526 |
+
|
| 527 |
+
# def loss(self, input: PipelineInput, output: PipelineOutput) -> float:
|
| 528 |
+
# """Compute loss for given input/output pair
|
| 529 |
+
|
| 530 |
+
# Parameters
|
| 531 |
+
# ----------
|
| 532 |
+
# input : object
|
| 533 |
+
# Pipeline input.
|
| 534 |
+
# output : object
|
| 535 |
+
# Pipeline output
|
| 536 |
+
|
| 537 |
+
# Returns
|
| 538 |
+
# -------
|
| 539 |
+
# loss : `float`
|
| 540 |
+
# Loss value
|
| 541 |
+
# """
|
| 542 |
+
# raise NotImplementedError()
|
| 543 |
+
|
| 544 |
+
# @property
|
| 545 |
+
# def write_format(self):
|
| 546 |
+
# return "rttm"
|
| 547 |
+
|
| 548 |
+
# def write(self, file: TextIO, output: PipelineOutput):
|
| 549 |
+
# """Write pipeline output to file
|
| 550 |
+
|
| 551 |
+
# Parameters
|
| 552 |
+
# ----------
|
| 553 |
+
# file : file object
|
| 554 |
+
# output : object
|
| 555 |
+
# Pipeline output
|
| 556 |
+
# """
|
| 557 |
+
|
| 558 |
+
# return getattr(self, f"write_{self.write_format}")(file, output)
|
| 559 |
+
|
| 560 |
+
# def write_rttm(self, file: TextIO, output: Union[Timeline, Annotation]):
|
| 561 |
+
# """Write pipeline output to "rttm" file
|
| 562 |
+
|
| 563 |
+
# Parameters
|
| 564 |
+
# ----------
|
| 565 |
+
# file : file object
|
| 566 |
+
# output : `pyannote.core.Timeline` or `pyannote.core.Annotation`
|
| 567 |
+
# Pipeline output
|
| 568 |
+
# """
|
| 569 |
+
|
| 570 |
+
# if isinstance(output, Timeline):
|
| 571 |
+
# output = output.to_annotation(generator="string")
|
| 572 |
+
|
| 573 |
+
# if isinstance(output, Annotation):
|
| 574 |
+
# for s, t, l in output.itertracks(yield_label=True):
|
| 575 |
+
# line = (
|
| 576 |
+
# f"SPEAKER {output.uri} 1 {s.start:.3f} {s.duration:.3f} "
|
| 577 |
+
# f"<NA> <NA> {l} <NA> <NA>\n"
|
| 578 |
+
# )
|
| 579 |
+
# file.write(line)
|
| 580 |
+
# return
|
| 581 |
+
|
| 582 |
+
# msg = (
|
| 583 |
+
# f'Dumping {output.__class__.__name__} instances to "rttm" files '
|
| 584 |
+
# f"is not supported."
|
| 585 |
+
# )
|
| 586 |
+
# raise NotImplementedError(msg)
|
| 587 |
+
|
| 588 |
+
# def write_txt(self, file: TextIO, output: Union[Timeline, Annotation]):
|
| 589 |
+
# """Write pipeline output to "txt" file
|
| 590 |
+
|
| 591 |
+
# Parameters
|
| 592 |
+
# ----------
|
| 593 |
+
# file : file object
|
| 594 |
+
# output : `pyannote.core.Timeline` or `pyannote.core.Annotation`
|
| 595 |
+
# Pipeline output
|
| 596 |
+
# """
|
| 597 |
+
|
| 598 |
+
# if isinstance(output, Timeline):
|
| 599 |
+
# for s in output:
|
| 600 |
+
# line = f"{output.uri} {s.start:.3f} {s.end:.3f}\n"
|
| 601 |
+
# file.write(line)
|
| 602 |
+
# return
|
| 603 |
+
|
| 604 |
+
# if isinstance(output, Annotation):
|
| 605 |
+
# for s, t, l in output.itertracks(yield_label=True):
|
| 606 |
+
# line = f"{output.uri} {s.start:.3f} {s.end:.3f} {t} {l}\n"
|
| 607 |
+
# file.write(line)
|
| 608 |
+
# return
|
| 609 |
+
|
| 610 |
+
# msg = (
|
| 611 |
+
# f'Dumping {output.__class__.__name__} instances to "txt" files '
|
| 612 |
+
# f"is not supported."
|
| 613 |
+
# )
|
| 614 |
+
# raise NotImplementedError(msg)
|