niobures commited on
Commit
8c838e7
·
verified ·
1 Parent(s): 7e98d38

Pyannote (models, models_onnx)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. ailia-models/code/LICENSE +21 -0
  3. ailia-models/code/README.md +135 -0
  4. ailia-models/code/config.yaml +17 -0
  5. ailia-models/code/data/sample.rttm +10 -0
  6. ailia-models/code/data/sample.wav +3 -0
  7. ailia-models/code/output.png +0 -0
  8. ailia-models/code/output_ground.png +0 -0
  9. ailia-models/code/pyannote-audio.py +181 -0
  10. ailia-models/code/pyannote_audio_utils/__init__.py +23 -0
  11. ailia-models/code/pyannote_audio_utils/audio/__init__.py +33 -0
  12. ailia-models/code/pyannote_audio_utils/audio/core/inference.py +596 -0
  13. ailia-models/code/pyannote_audio_utils/audio/core/io.py +352 -0
  14. ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py +218 -0
  15. ailia-models/code/pyannote_audio_utils/audio/core/task.py +125 -0
  16. ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py +35 -0
  17. ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py +468 -0
  18. ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py +553 -0
  19. ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py +249 -0
  20. ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py +37 -0
  21. ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py +248 -0
  22. ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py +291 -0
  23. ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py +59 -0
  24. ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py +125 -0
  25. ailia-models/code/pyannote_audio_utils/audio/utils/signal.py +369 -0
  26. ailia-models/code/pyannote_audio_utils/audio/version.py +1 -0
  27. ailia-models/code/pyannote_audio_utils/core/__init__.py +48 -0
  28. ailia-models/code/pyannote_audio_utils/core/_version.py +20 -0
  29. ailia-models/code/pyannote_audio_utils/core/annotation.py +1551 -0
  30. ailia-models/code/pyannote_audio_utils/core/feature.py +329 -0
  31. ailia-models/code/pyannote_audio_utils/core/notebook.py +468 -0
  32. ailia-models/code/pyannote_audio_utils/core/segment.py +910 -0
  33. ailia-models/code/pyannote_audio_utils/core/timeline.py +1126 -0
  34. ailia-models/code/pyannote_audio_utils/core/utils/generators.py +89 -0
  35. ailia-models/code/pyannote_audio_utils/core/utils/types.py +13 -0
  36. ailia-models/code/pyannote_audio_utils/database/__init__.py +91 -0
  37. ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py +34 -0
  38. ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py +434 -0
  39. ailia-models/code/pyannote_audio_utils/database/util.py +400 -0
  40. ailia-models/code/pyannote_audio_utils/metrics/__init__.py +36 -0
  41. ailia-models/code/pyannote_audio_utils/metrics/_version.py +21 -0
  42. ailia-models/code/pyannote_audio_utils/metrics/base.py +419 -0
  43. ailia-models/code/pyannote_audio_utils/metrics/diarization.py +167 -0
  44. ailia-models/code/pyannote_audio_utils/metrics/identification.py +274 -0
  45. ailia-models/code/pyannote_audio_utils/metrics/matcher.py +192 -0
  46. ailia-models/code/pyannote_audio_utils/metrics/types.py +7 -0
  47. ailia-models/code/pyannote_audio_utils/metrics/utils.py +225 -0
  48. ailia-models/code/pyannote_audio_utils/pipeline/__init__.py +37 -0
  49. ailia-models/code/pyannote_audio_utils/pipeline/parameter.py +203 -0
  50. ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py +614 -0
.gitattributes CHANGED
@@ -37,3 +37,5 @@ brouhaha/brouhaha.gif filter=lfs diff=lfs merge=lfs -text
37
  separation-ami-1.0/model.png filter=lfs diff=lfs merge=lfs -text
38
  speaker-diarization/technical_report_2.1.pdf filter=lfs diff=lfs merge=lfs -text
39
  speech-separation-ami-1.0/pipeline.png filter=lfs diff=lfs merge=lfs -text
 
 
 
37
  separation-ami-1.0/model.png filter=lfs diff=lfs merge=lfs -text
38
  speaker-diarization/technical_report_2.1.pdf filter=lfs diff=lfs merge=lfs -text
39
  speech-separation-ami-1.0/pipeline.png filter=lfs diff=lfs merge=lfs -text
40
+ ailia-models/code/data/sample.wav filter=lfs diff=lfs merge=lfs -text
41
+ speaker-diarization-community-1/diarization.gif filter=lfs diff=lfs merge=lfs -text
ailia-models/code/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 CNRS
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ailia-models/code/README.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyannote-audio : Speaker Diarization
2
+
3
+ ## Input
4
+
5
+ Audio file (.wav format).
6
+ ```
7
+ Example
8
+ input: data/demo.wav
9
+ ```
10
+ (Wav file from https://github.com/pyannote/pyannote-audio/tree/develop/pyannote/audio/sample)
11
+
12
+ ## Output
13
+
14
+ When and who spoke.
15
+ ![Output](output.png)
16
+
17
+ ```
18
+ [ 00:00:06.714 --> 00:00:07.003] A speaker91
19
+ [ 00:00:07.003 --> 00:00:07.173] B speaker90
20
+ [ 00:00:07.580 --> 00:00:08.310] C speaker91
21
+ [ 00:00:08.310 --> 00:00:09.923] D speaker90
22
+ [ 00:00:09.923 --> 00:00:10.976] E speaker91
23
+ [ 00:00:10.466 --> 00:00:14.745] F speaker90
24
+ [ 00:00:14.303 --> 00:00:17.886] G speaker91
25
+ [ 00:00:18.022 --> 00:00:21.502] H speaker90
26
+ [ 00:00:18.157 --> 00:00:18.446] I speaker91
27
+ [ 00:00:21.774 --> 00:00:28.531] J speaker91
28
+ [ 00:00:27.886 --> 00:00:29.991] K speaker90
29
+ ```
30
+
31
+ ## Requirements
32
+
33
+ This model recommends additional module.
34
+ ```bash
35
+ $ pip3 install -r requirements.txt
36
+ ```
37
+
38
+ ## Usage
39
+
40
+ Automatically downloads the onnx and prototxt files on the first run.
41
+ It is necessary to be connected to the Internet while downloading.
42
+
43
+ For the sample
44
+ ```bash
45
+ $ python pyannote-audio.py -i ./data/sample.wav
46
+ ```
47
+
48
+ For the sample with plot
49
+ ```bash
50
+ $ python pyannote-audio.py -i ./data/sample.wav --plt
51
+ ```
52
+
53
+ For the sample with verification
54
+ ```bash
55
+ $ python pyannote-audio.py -i ./data/sample.wav -g ./data/sample.rttm
56
+ ```
57
+
58
+ If you want to specify the audio, put the file path after the `--i` or `-input` option.
59
+
60
+ ```bash
61
+ $ python pyannote-audio.py --i FILE_PATH
62
+ ```
63
+
64
+ If you want to specify the ground truth, put the file path after the `--ig` or `-input_ground` option.
65
+
66
+ ```bash
67
+ $ python pyannote-audio.py --ig FILE_PATH
68
+ ```
69
+
70
+ If you want to specify the output file, put the file path after the `--o` or `-output` option.
71
+
72
+ ```bash
73
+ $ python pyannote-audio.py --o FILE_PATH
74
+ ```
75
+
76
+ If you want to specify the output ground truth file, put the file path after the `--og` or `-output_ground` option.
77
+
78
+ ```bash
79
+ $ python pyannote-audio.py --og FILE_PATH
80
+ ```
81
+
82
+ If you know the number of speakers, put the numper `--num` or `-num_speaker` option.
83
+ ```bash
84
+ $ python pyannote-audio.py --num 2
85
+ ```
86
+
87
+ If you know the maxisimum number of speakers, put the numper `--max` or `-max_speaker` option.
88
+ ```bash
89
+ $ python pyannote-audio.py --max 4
90
+ ```
91
+
92
+ If you know the minimum number of speakers, put the numper `--min` or `-min_speaker` option.
93
+ ```bash
94
+ $ python pyannote-audio.py --min 2
95
+ ```
96
+
97
+ By giving the `--e` or `-error` option, you can get diarization error rate.
98
+ ```bash
99
+ $ python pyannote-audio.py --use_onnx
100
+ ```
101
+
102
+ By giving the `--plt` option, you can visualize results.
103
+ ```bash
104
+ $ python pyannote-audio.py --use_onnx
105
+ ```
106
+
107
+ By giving the `--use_onnx` option, you can use onnx.
108
+ ```bash
109
+ $ python pyannote-audio.py --use_onnx
110
+ ```
111
+
112
+ By giving the `--embed` option, you can get embedding vector in the input file.
113
+ ```bash
114
+ $ python pyannote-audio.py --embed
115
+ ```
116
+
117
+ ## Reference
118
+
119
+ - [Pyannote-audio](https://github.com/pyannote/pyannote-audio)
120
+ - [Hugging Face - pyannote in speaker-diariazation](https://huggingface.co/pyannote/speaker-diarization-3.1)
121
+ - [Hugging Face - hdbrain in wespeaker-voxceleb-resnet34-LM](https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main)
122
+ - [KaldiFeat](https://github.com/yuyq96/kaldifeat)
123
+
124
+ ## Framework
125
+
126
+ Pytorch
127
+
128
+ ## Model Format
129
+
130
+ ONNX opset=14,17
131
+
132
+ ## Netron
133
+
134
+ - [segmentation.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pyannote-audio/segmentation.onnx.prototxt)
135
+ - [speaker-embedding.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/pyannote-audio/speaker-embedding.onnx.prototxt)
ailia-models/code/config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ params:
2
+ clustering:
3
+ method: centroid
4
+ min_cluster_size: 12
5
+ threshold: 0.7045654963945799
6
+ segmentation:
7
+ min_duration_off: 0.0
8
+ pipeline:
9
+ name: pyannote.audio.pipelines.SpeakerDiarization
10
+ params:
11
+ clustering: AgglomerativeClustering
12
+ embedding: speaker-embedding.onnx
13
+ embedding_batch_size: 32
14
+ embedding_exclude_overlap: true
15
+ segmentation: segmentation.onnx
16
+ segmentation_batch_size: 32
17
+ version: 3.1.0
ailia-models/code/data/sample.rttm ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ SPEAKER sample 1 6.690 0.430 <NA> <NA> speaker90 <NA> <NA>
2
+ SPEAKER sample 1 7.550 0.800 <NA> <NA> speaker91 <NA> <NA>
3
+ SPEAKER sample 1 8.320 1.700 <NA> <NA> speaker90 <NA> <NA>
4
+ SPEAKER sample 1 9.920 1.110 <NA> <NA> speaker91 <NA> <NA>
5
+ SPEAKER sample 1 10.570 4.130 <NA> <NA> speaker90 <NA> <NA>
6
+ SPEAKER sample 1 14.490 3.430 <NA> <NA> speaker91 <NA> <NA>
7
+ SPEAKER sample 1 18.050 3.440 <NA> <NA> speaker90 <NA> <NA>
8
+ SPEAKER sample 1 18.150 0.440 <NA> <NA> speaker91 <NA> <NA>
9
+ SPEAKER sample 1 21.780 6.720 <NA> <NA> speaker91 <NA> <NA>
10
+ SPEAKER sample 1 27.850 2.150 <NA> <NA> speaker90 <NA> <NA>
ailia-models/code/data/sample.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c319b4abca767b124e41432d364fd7df006cb26bb79d09326c487d606a134e6e
3
+ size 960104
ailia-models/code/output.png ADDED
ailia-models/code/output_ground.png ADDED
ailia-models/code/pyannote-audio.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import sys
3
+ import matplotlib.pyplot as plt
4
+ import time
5
+
6
+ from pyannote_audio_utils.audio.pipelines.speaker_diarization import SpeakerDiarization
7
+ from pyannote_audio_utils.core import Segment, Annotation
8
+ from pyannote_audio_utils.core.notebook import Notebook
9
+ from pyannote_audio_utils.database.util import load_rttm
10
+ from pyannote_audio_utils.metrics.diarization import DiarizationErrorRate
11
+
12
+ sys.path.append('../../util')
13
+ from arg_utils import get_base_parser, update_parser # noqa: E402
14
+ from model_utils import check_and_download_models # noqa: E402
15
+ from logging import getLogger # noqa: E402
16
+ logger = getLogger(__name__)
17
+
18
+ WEIGHT_SEGMENTATION_PATH = 'segmentation.onnx'
19
+ MODEL_SEGMENTATION_PATH = 'segmentation.onnx.prototxt'
20
+ WEIGHT_EMBEDDING_PATH = 'speaker-embedding.onnx'
21
+ MODEL_EMBEDDING_PATH = 'speaker-embedding.onnx.prototxt'
22
+ REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/pyannote-audio/'
23
+ YAML_PATH = 'config.yaml'
24
+ OUT_PATH = 'output.png'
25
+
26
+ parser = get_base_parser(
27
+ 'Pyannote-audio', './data/sample.wav', None, input_ftype='audio'
28
+ )
29
+
30
+ parser.add_argument(
31
+ '--num', '-num_speaker', default=0, type=int,
32
+ help='If the number of speakers is fixed',
33
+ )
34
+ parser.add_argument(
35
+ '--max', '-max_speaker', default=0, type=int,
36
+ help='If the maximum number of speakers is fixed',
37
+ )
38
+ parser.add_argument(
39
+ '--min', '-min_speaker', default=0, type=int,
40
+ help='If the minimum number of speakers is fixed',
41
+ )
42
+ parser.add_argument(
43
+ '--ig', '-ground', default=None,
44
+ help='Specify a wav file as ground truth. If you need diarization error rate, you need this file'
45
+ )
46
+ parser.add_argument(
47
+ '--o', '-output', default='output.png',
48
+ help='Specify an output file'
49
+ )
50
+ parser.add_argument(
51
+ '--og', '-output_ground', default='output_ground.png',
52
+ help='Specify an output ground truth file'
53
+ )
54
+ parser.add_argument(
55
+ '--e', '-error',
56
+ action='store_true',
57
+ help='If you need diarization error rate'
58
+ )
59
+ parser.add_argument(
60
+ '--plt',
61
+ action='store_true',
62
+ help='If you want to visualize result'
63
+ )
64
+ parser.add_argument(
65
+ '--embed',
66
+ action='store_true',
67
+ help='If you need embedding vector',
68
+ )
69
+ parser.add_argument(
70
+ '--onnx',
71
+ action='store_true',
72
+ help='execute onnxruntime version'
73
+ )
74
+
75
+ args = update_parser(parser)
76
+
77
+ def repr_annotation(args, annotation: Annotation, notebook:Notebook, ground:bool = False):
78
+ """Get `png` data for `annotation`"""
79
+ figsize = plt.rcParams["figure.figsize"]
80
+ plt.rcParams["figure.figsize"] = (notebook.width, 2)
81
+ fig, ax = plt.subplots()
82
+ notebook.plot_annotation(annotation, ax=ax)
83
+ if ground:
84
+ plt.savefig(args.og)
85
+ else:
86
+ plt.savefig(args.o)
87
+ plt.close(fig)
88
+ plt.rcParams["figure.figsize"] = figsize
89
+ return
90
+
91
+ def main(args):
92
+ check_and_download_models(WEIGHT_SEGMENTATION_PATH, MODEL_SEGMENTATION_PATH, remote_path=REMOTE_PATH)
93
+ check_and_download_models(WEIGHT_EMBEDDING_PATH, MODEL_EMBEDDING_PATH, remote_path=REMOTE_PATH)
94
+
95
+ if args.benchmark:
96
+ start = int(round(time.time() * 1000))
97
+
98
+ with open(YAML_PATH, 'r') as yml:
99
+ config = yaml.safe_load(yml)
100
+
101
+ config["pipeline"]["params"]["segmentation"] = WEIGHT_SEGMENTATION_PATH
102
+ config["pipeline"]["params"]["embedding"] = WEIGHT_EMBEDDING_PATH
103
+ with open(YAML_PATH, 'w') as f:
104
+ yaml.dump(config, f)
105
+
106
+ audio_file = args.input[0]
107
+ checkpoint_path = YAML_PATH
108
+ config_yml = checkpoint_path
109
+
110
+ with open(config_yml, "r") as fp:
111
+ config = yaml.load(fp, Loader=yaml.SafeLoader)
112
+
113
+ params = config["pipeline"].get("params", {})
114
+ pipeline = SpeakerDiarization(
115
+ **params,
116
+ args=args,
117
+ seg_path=MODEL_SEGMENTATION_PATH,
118
+ emb_path=MODEL_EMBEDDING_PATH,
119
+ )
120
+
121
+ if "params" in config:
122
+ pipeline.instantiate(config["params"])
123
+
124
+ if args.embed:
125
+ if args.num > 0:
126
+ diarization, embeddings = pipeline(audio_file, return_embeddings=True, num_speakers=args.num)
127
+ for s, speaker in enumerate(diarization.labels()):
128
+ print(speaker, embeddings[s].shape)
129
+ elif args.max > 0 or args.min > 0:
130
+ diarization, embeddings = pipeline(audio_file, return_embeddings=True, min_speakers=args.min, max_speaker=args.max)
131
+ for s, speaker in enumerate(diarization.labels()):
132
+ print(speaker, embeddings[s].shape)
133
+ else:
134
+ diarization, embeddings = pipeline(audio_file, return_embeddings=True)
135
+ for s, speaker in enumerate(diarization.labels()):
136
+ print(speaker, embeddings[s].shape)
137
+ else:
138
+ if args.num > 0:
139
+ diarization = pipeline(audio_file, num_speakers=args.num)
140
+ elif args.max > 0 or args.min > 0:
141
+ diarization = pipeline(audio_file, min_speakers=args.min, max_speaker=args.max)
142
+ else:
143
+ diarization = pipeline(audio_file)
144
+
145
+ if args.benchmark:
146
+ end = int(round(time.time() * 1000))
147
+ print(f'\tailia processing time {end - start} ms')
148
+
149
+ if args.ig:
150
+ _, groundtruth = load_rttm(args.ig).popitem()
151
+ metric = DiarizationErrorRate()
152
+ result = metric(groundtruth, diarization, detailed=False)
153
+
154
+ mapping = metric.optimal_mapping(groundtruth, diarization)
155
+ diarization = diarization.rename_labels(mapping=mapping)
156
+
157
+ print(diarization)
158
+ if args.e:
159
+ print(f'diarization error rate = {100 * result:.1f}%')
160
+
161
+ if args.plt:
162
+ EXCERPT = Segment(0, 30)
163
+ notebook = Notebook()
164
+ notebook.crop = EXCERPT
165
+ repr_annotation(args, diarization, notebook)
166
+ repr_annotation(args, groundtruth, notebook, ground=True)
167
+ return
168
+
169
+ else:
170
+ print(diarization)
171
+
172
+ if args.plt:
173
+ EXCERPT = Segment(0, 30)
174
+ notebook = Notebook()
175
+ notebook.crop = EXCERPT
176
+ repr_annotation(args, diarization, notebook)
177
+ return
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main(args)
ailia-models/code/pyannote_audio_utils/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2020 CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __import__("pkg_resources").declare_namespace(__name__)
ailia-models/code/pyannote_audio_utils/audio/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2020-2021 CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ try:
24
+ from .version import __version__, git_version # noqa: F401
25
+ except ImportError:
26
+ pass
27
+
28
+
29
+ from .core.inference import Inference
30
+ from .core.io import Audio
31
+ from .core.pipeline import Pipeline
32
+
33
+ __all__ = ["Audio", "Inference", "Pipeline"]
ailia-models/code/pyannote_audio_utils/audio/core/inference.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2020- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+
24
+ import ailia
25
+ import warnings
26
+ from pathlib import Path
27
+ from typing import Callable, List, Optional, Text, Tuple, Union
28
+ from functools import cached_property
29
+ from dataclasses import dataclass
30
+ import numpy as np
31
+
32
+ from pyannote_audio_utils.core import Segment, SlidingWindow, SlidingWindowFeature
33
+ from pyannote_audio_utils.audio.core.io import AudioFile, Audio
34
+ from pyannote_audio_utils.audio.core.task import Resolution, Specifications, Problem
35
+ from pyannote_audio_utils.audio.utils.multi_task import map_with_specifications
36
+ from pyannote_audio_utils.audio.utils.powerset import Powerset
37
+
38
+
39
+ class BaseInference:
40
+ pass
41
+
42
+ @dataclass
43
+ class Output:
44
+ num_frames: int
45
+ dimension: int
46
+ frames: SlidingWindow
47
+
48
+ class Inference(BaseInference):
49
+ """Inference
50
+
51
+ Parameters
52
+ ----------
53
+ model : Model
54
+ Model. Will be automatically set to eval() mode and moved to `device` when provided.
55
+ window : {"sliding", "whole"}, optional
56
+ Use a "sliding" window and aggregate the corresponding outputs (default)
57
+ or just one (potentially long) window covering the "whole" file or chunk.
58
+ duration : float, optional
59
+ Chunk duration, in seconds. Defaults to duration used for training the model.
60
+ Has no effect when `window` is "whole".
61
+ step : float, optional
62
+ Step between consecutive chunks, in seconds. Defaults to warm-up duration when
63
+ greater than 0s, otherwise 10% of duration. Has no effect when `window` is "whole".
64
+ pre_aggregation_hook : callable, optional
65
+ When a callable is provided, it is applied to the model output, just before aggregation.
66
+ Takes a (num_chunks, num_frames, dimension) numpy array as input and returns a modified
67
+ (num_chunks, num_frames, other_dimension) numpy array passed to overlap-add aggregation.
68
+ skip_aggregation : bool, optional
69
+ Do not aggregate outputs when using "sliding" window. Defaults to False.
70
+ skip_conversion: bool, optional
71
+ In case a task has been trained with `powerset` mode, output is automatically
72
+ converted to `multi-label`, unless `skip_conversion` is set to True.
73
+ batch_size : int, optional
74
+ Batch size. Larger values (should) make inference faster. Defaults to 32.
75
+ device : torch.device, optional
76
+ Device used for inference. Defaults to `model.device`.
77
+ In case `device` and `model.device` are different, model is sent to device.
78
+ use_auth_token : str, optional
79
+ When loading a private huggingface.co model, set `use_auth_token`
80
+ to True or to a string containing your hugginface.co authentication
81
+ token that can be obtained by running `huggingface-cli login`
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ model: Union[Text, Path],
87
+ window: Text = "sliding",
88
+ duration: float = None,
89
+ step: float = None,
90
+ pre_aggregation_hook: Callable[[np.ndarray], np.ndarray] = None,
91
+ skip_aggregation: bool = False,
92
+ skip_conversion: bool = False,
93
+ batch_size: int = 32,
94
+ use_auth_token: Union[Text, None] = None,
95
+ args = None,
96
+ seg_path = None,
97
+ ):
98
+ # ~~~~ model ~~~~~
99
+
100
+
101
+ if args.onnx:
102
+ #print("use onnx runtime")
103
+ import onnxruntime
104
+ model = onnxruntime.InferenceSession(model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
105
+ else:
106
+ #print("use ailia")
107
+ model = ailia.Net(seg_path, weight=model, env_id=args.env_id)
108
+
109
+ self.model = model
110
+ self.args = args
111
+
112
+ specifications = Specifications(problem=Problem.MONO_LABEL_CLASSIFICATION, resolution = Resolution.FRAME, classes=['speaker#1', 'speaker#2', 'speaker#3'])
113
+
114
+ self.specifications = specifications
115
+ self.audio = Audio(sample_rate=16000, mono="downmix")
116
+ # ~~~~ sliding window ~~~~~
117
+
118
+ if window not in ["sliding", "whole"]:
119
+ raise ValueError('`window` must be "sliding" or "whole".')
120
+
121
+ if window == "whole" and any(
122
+ s.resolution == Resolution.FRAME for s in specifications
123
+ ):
124
+ warnings.warn(
125
+ 'Using "whole" `window` inference with a frame-based model might lead to bad results '
126
+ 'and huge memory consumption: it is recommended to set `window` to "sliding".'
127
+ )
128
+ self.window = window
129
+
130
+ training_duration = next(iter(specifications)).duration
131
+ duration = duration or training_duration
132
+ if training_duration != duration:
133
+ warnings.warn(
134
+ f"Model was trained with {training_duration:g}s chunks, and you requested "
135
+ f"{duration:g}s chunks for inference: this might lead to suboptimal results."
136
+ )
137
+ self.duration = duration
138
+
139
+ # ~~~~ powerset to multilabel conversion ~~~~
140
+
141
+ self.skip_conversion = skip_conversion
142
+
143
+ conversion = list()
144
+ for s in specifications:
145
+ if s.powerset and not skip_conversion:
146
+ c = Powerset(len(s.classes), s.powerset_max_classes)
147
+ conversion.append(c)
148
+
149
+
150
+ if isinstance(specifications, Specifications):
151
+ self.conversion = conversion[0]
152
+
153
+ # ~~~~ overlap-add aggregation ~~~~~
154
+
155
+ self.skip_aggregation = skip_aggregation
156
+ self.pre_aggregation_hook = pre_aggregation_hook
157
+
158
+
159
+ self.warm_up = next(iter(specifications)).warm_up
160
+ # Use that many seconds on the left- and rightmost parts of each chunk
161
+ # to warm up the model. While the model does process those left- and right-most
162
+ # parts, only the remaining central part of each chunk is used for aggregating
163
+ # scores during inference.
164
+
165
+ # step between consecutive chunks
166
+ step = step or (
167
+ 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0]
168
+ )
169
+ if step > self.duration:
170
+ raise ValueError(
171
+ f"Step between consecutive chunks is set to {step:g}s, while chunks are "
172
+ f"only {self.duration:g}s long, leading to gaps between consecutive chunks. "
173
+ f"Either decrease step or increase duration."
174
+ )
175
+ self.step = step
176
+
177
+ self.batch_size = batch_size
178
+
179
+ def infer(self, chunks: np.ndarray) -> Union[np.ndarray, Tuple[np.ndarray]]:
180
+ """Forward pass
181
+
182
+ Takes care of sending chunks to right device and outputs back to CPU
183
+
184
+ Parameters
185
+ ----------
186
+ chunks : (batch_size, num_channels, num_samples) torch.Tensor
187
+ Batch of audio chunks.
188
+
189
+ Returns
190
+ -------
191
+ outputs : (tuple of) (batch_size, ...) np.ndarray
192
+ Model output.
193
+ """
194
+ chunks = chunks.astype(np.float32)
195
+ if self.args.onnx:
196
+ outputs = self.model.run(None, {"input": chunks})[0]
197
+ else:
198
+ outputs = self.model.predict([chunks])[0]
199
+
200
+
201
+ def __convert(output: np.ndarray, conversion, **kwargs):
202
+ return conversion(output)
203
+
204
+ return map_with_specifications(self.specifications, __convert, outputs, self.conversion)
205
+
206
+ @cached_property
207
+ def example_output(self) -> Union[Output, Tuple[Output]]:
208
+ """Example output"""
209
+ example_input_array = np.random.randn(1, 1, self.audio.get_num_samples(self.specifications.duration)).astype(np.float32)
210
+
211
+ example_outputs = self.infer(example_input_array)
212
+
213
+ def __example_output(
214
+ example_output: np.ndarray,
215
+ specifications: Specifications = None,
216
+ ) -> Output:
217
+ _, num_frames, dimension = example_output.shape
218
+
219
+ if specifications.resolution == Resolution.FRAME:
220
+ frame_duration = specifications.duration / num_frames
221
+ frames = SlidingWindow(step=frame_duration, duration=frame_duration)
222
+ else:
223
+ frames = None
224
+
225
+ return Output(
226
+ num_frames=num_frames,
227
+ dimension=dimension,
228
+ frames=frames,
229
+ )
230
+
231
+ return map_with_specifications(
232
+ self.specifications, __example_output, example_outputs
233
+ )
234
+
235
+ def slide(
236
+ self,
237
+ waveform: np.ndarray,
238
+ sample_rate: int,
239
+ hook: Optional[Callable],
240
+ ) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]:
241
+ """Slide model on a waveform
242
+
243
+ Parameters
244
+ ----------
245
+ waveform: (num_channels, num_samples) torch.Tensor
246
+ Waveform.
247
+ sample_rate : int
248
+ Sample rate.
249
+ hook: Optional[Callable]
250
+ When a callable is provided, it is called everytime a batch is
251
+ processed with two keyword arguments:
252
+ - `completed`: the number of chunks that have been processed so far
253
+ - `total`: the total number of chunks
254
+
255
+ Returns
256
+ -------
257
+ output : (tuple of) SlidingWindowFeature
258
+ Model output. Shape is (num_chunks, dimension) for chunk-level tasks,
259
+ and (num_frames, dimension) for frame-level tasks.
260
+ """
261
+
262
+ window_size: int = self.audio.get_num_samples(self.duration)
263
+ step_size: int = round(self.step * sample_rate)
264
+
265
+ _, num_samples = waveform.shape
266
+
267
+ def __frames(
268
+ example_output, specifications: Optional[Specifications] = None
269
+ ) -> SlidingWindow:
270
+ if specifications.resolution == Resolution.CHUNK:
271
+ return SlidingWindow(start=0.0, duration=self.duration, step=self.step)
272
+
273
+ return example_output.frames
274
+
275
+ frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications(
276
+ self.specifications,
277
+ __frames,
278
+ self.example_output,
279
+ )
280
+
281
+
282
+ # prepare complete chunks
283
+ def unfold_numpy(waveform, window_size, step_size):
284
+ batch_size, waveform_size = waveform.shape
285
+ num_windows = (waveform_size - window_size) // step_size + 1
286
+ shape = (batch_size, num_windows, window_size)
287
+ strides = (
288
+ waveform.strides[0],
289
+ step_size * waveform.strides[1],
290
+ waveform.strides[1],
291
+ )
292
+
293
+ return np.lib.stride_tricks.as_strided(waveform, shape=shape, strides=strides)
294
+
295
+ if num_samples >= window_size:
296
+ chunks: np.ndarray = (unfold_numpy(waveform, window_size, step_size)).transpose(1, 0, 2)
297
+ num_chunks = chunks.shape[0]
298
+
299
+ else:
300
+ num_chunks = 0
301
+
302
+ # prepare last incomplete chunk
303
+
304
+ has_last_chunk = (num_samples < window_size) or (num_samples - window_size) % step_size > 0
305
+
306
+ if has_last_chunk:
307
+ # pad last chunk with zeros
308
+ last_chunk: np.ndarray = waveform[:, num_chunks * step_size :]
309
+ _, last_window_size = last_chunk.shape
310
+ last_pad = window_size - last_window_size
311
+ last_chunk = np.pad(last_chunk, ((0, 0), (0, last_pad)))
312
+
313
+ def __empty_list(**kwargs):
314
+ return list()
315
+
316
+ outputs: Union[
317
+ List[np.ndarray], Tuple[List[np.ndarray]]
318
+ ] = map_with_specifications(self.specifications, __empty_list)
319
+
320
+ if hook is not None:
321
+ hook(completed=0, total=num_chunks + has_last_chunk)
322
+
323
+ def __append_batch(output, batch_output, **kwargs) -> None:
324
+ output.append(batch_output)
325
+ return
326
+
327
+
328
+ # slide over audio chunks in batch
329
+ for c in np.arange(0, num_chunks, self.batch_size):
330
+ batch: np.ndarray = chunks[c : c + self.batch_size]
331
+ batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch)
332
+
333
+ _ = map_with_specifications(
334
+ self.specifications, __append_batch, outputs, batch_outputs
335
+ )
336
+
337
+ if hook is not None:
338
+ hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk)
339
+
340
+
341
+ # process orphan last chunk
342
+ if has_last_chunk:
343
+ last_outputs = self.infer(last_chunk[None])
344
+
345
+ _ = map_with_specifications(
346
+ self.specifications, __append_batch, outputs, last_outputs
347
+ )
348
+
349
+ if hook is not None:
350
+ hook(
351
+ completed=num_chunks + has_last_chunk,
352
+ total=num_chunks + has_last_chunk,
353
+ )
354
+
355
+ def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray:
356
+ return np.vstack(output)
357
+
358
+ outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications(
359
+ self.specifications, __vstack, outputs
360
+ )
361
+
362
+ def __aggregate(
363
+ outputs: np.ndarray,
364
+ frames: SlidingWindow,
365
+ specifications: Optional[Specifications] = None,
366
+ ) -> SlidingWindowFeature:
367
+ # skip aggregation when requested,
368
+ # or when model outputs just one vector per chunk
369
+ # or when model is permutation-invariant (and not post-processed)
370
+
371
+ if (
372
+ self.skip_aggregation
373
+ or specifications.resolution == Resolution.CHUNK
374
+ or (
375
+ specifications.permutation_invariant
376
+ and self.pre_aggregation_hook is None
377
+ )
378
+ ):
379
+ frames = SlidingWindow(
380
+ start=0.0, duration=self.duration, step=self.step
381
+ )
382
+
383
+ return SlidingWindowFeature(outputs, frames)
384
+
385
+ return map_with_specifications(
386
+ self.specifications, __aggregate, outputs, frames
387
+ )
388
+
389
+ def __call__(
390
+ self, file: AudioFile, hook: Optional[Callable] = None
391
+ ) -> Union[
392
+ Tuple[Union[SlidingWindowFeature, np.ndarray]],
393
+ Union[SlidingWindowFeature, np.ndarray],
394
+ ]:
395
+ """Run inference on a whole file
396
+
397
+ Parameters
398
+ ----------
399
+ file : AudioFile
400
+ Audio file.
401
+ hook : callable, optional
402
+ When a callable is provided, it is called everytime a batch is processed
403
+ with two keyword arguments:
404
+ - `completed`: the number of chunks that have been processed so far
405
+ - `total`: the total number of chunks
406
+
407
+ Returns
408
+ -------
409
+ output : (tuple of) SlidingWindowFeature or np.ndarray
410
+ Model output, as `SlidingWindowFeature` if `window` is set to "sliding"
411
+ and `np.ndarray` if is set to "whole".
412
+
413
+ """
414
+
415
+ waveform, sample_rate = self.audio(file)
416
+
417
+ if self.window == "sliding":
418
+ return self.slide(waveform, sample_rate, hook=hook)
419
+
420
+
421
+ @staticmethod
422
+ def aggregate(
423
+ scores: SlidingWindowFeature,
424
+ frames: SlidingWindow = None,
425
+ warm_up: Tuple[float, float] = (0.0, 0.0),
426
+ epsilon: float = 1e-12,
427
+ hamming: bool = False,
428
+ missing: float = np.NaN,
429
+ skip_average: bool = False,
430
+ ) -> SlidingWindowFeature:
431
+ """Aggregation
432
+
433
+ Parameters
434
+ ----------
435
+ scores : SlidingWindowFeature
436
+ Raw (unaggregated) scores. Shape is (num_chunks, num_frames_per_chunk, num_classes).
437
+ frames : SlidingWindow, optional
438
+ Frames resolution. Defaults to estimate it automatically based on `scores` shape
439
+ and chunk size. Providing the exact frame resolution (when known) leads to better
440
+ temporal precision.
441
+ warm_up : (float, float) tuple, optional
442
+ Left/right warm up duration (in seconds).
443
+ missing : float, optional
444
+ Value used to replace missing (ie all NaNs) values.
445
+ skip_average : bool, optional
446
+ Skip final averaging step.
447
+
448
+ Returns
449
+ -------
450
+ aggregated_scores : SlidingWindowFeature
451
+ Aggregated scores. Shape is (num_frames, num_classes)
452
+ """
453
+
454
+ num_chunks, num_frames_per_chunk, num_classes = scores.data.shape
455
+
456
+ chunks = scores.sliding_window
457
+ if frames is None:
458
+ duration = step = chunks.duration / num_frames_per_chunk
459
+ frames = SlidingWindow(start=chunks.start, duration=duration, step=step)
460
+ else:
461
+ frames = SlidingWindow(
462
+ start=chunks.start,
463
+ duration=frames.duration,
464
+ step=frames.step,
465
+ )
466
+
467
+ masks = 1 - np.isnan(scores)
468
+ scores.data = np.nan_to_num(scores.data, copy=True, nan=0.0)
469
+
470
+ # Hamming window used for overlap-add aggregation
471
+ hamming_window = (
472
+ np.hamming(num_frames_per_chunk).reshape(-1, 1)
473
+ if hamming
474
+ else np.ones((num_frames_per_chunk, 1))
475
+ )
476
+
477
+ # anything before warm_up_left (and after num_frames_per_chunk - warm_up_right)
478
+ # will not be used in the final aggregation
479
+
480
+ # warm-up windows used for overlap-add aggregation
481
+ warm_up_window = np.ones((num_frames_per_chunk, 1))
482
+ # anything before warm_up_left will not contribute to aggregation
483
+ warm_up_left = round(
484
+ warm_up[0] / scores.sliding_window.duration * num_frames_per_chunk
485
+ )
486
+ warm_up_window[:warm_up_left] = epsilon
487
+ # anything after num_frames_per_chunk - warm_up_right either
488
+ warm_up_right = round(
489
+ warm_up[1] / scores.sliding_window.duration * num_frames_per_chunk
490
+ )
491
+ warm_up_window[num_frames_per_chunk - warm_up_right :] = epsilon
492
+
493
+ # aggregated_output[i] will be used to store the sum of all predictions
494
+ # for frame #i
495
+ num_frames = (
496
+ frames.closest_frame(
497
+ scores.sliding_window.start
498
+ + scores.sliding_window.duration
499
+ + (num_chunks - 1) * scores.sliding_window.step
500
+ )
501
+ + 1
502
+ )
503
+ aggregated_output: np.ndarray = np.zeros(
504
+ (num_frames, num_classes), dtype=np.float32
505
+ )
506
+
507
+ # overlapping_chunk_count[i] will be used to store the number of chunks
508
+ # that contributed to frame #i
509
+ overlapping_chunk_count: np.ndarray = np.zeros(
510
+ (num_frames, num_classes), dtype=np.float32
511
+ )
512
+
513
+ # aggregated_mask[i] will be used to indicate whether
514
+ # at least one non-NAN frame contributed to frame #i
515
+ aggregated_mask: np.ndarray = np.zeros(
516
+ (num_frames, num_classes), dtype=np.float32
517
+ )
518
+
519
+ # loop on the scores of sliding chunks
520
+ for (chunk, score), (_, mask) in zip(scores, masks):
521
+ # chunk ~ Segment
522
+ # score ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray
523
+ # mask ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray
524
+
525
+ start_frame = frames.closest_frame(chunk.start)
526
+ aggregated_output[start_frame : start_frame + num_frames_per_chunk] += (
527
+ score * mask * hamming_window * warm_up_window
528
+ )
529
+
530
+ overlapping_chunk_count[
531
+ start_frame : start_frame + num_frames_per_chunk
532
+ ] += (mask * hamming_window * warm_up_window)
533
+
534
+ aggregated_mask[
535
+ start_frame : start_frame + num_frames_per_chunk
536
+ ] = np.maximum(
537
+ aggregated_mask[start_frame : start_frame + num_frames_per_chunk],
538
+ mask,
539
+ )
540
+
541
+ if skip_average:
542
+ average = aggregated_output
543
+ else:
544
+ average = aggregated_output / np.maximum(overlapping_chunk_count, epsilon)
545
+
546
+ average[aggregated_mask == 0.0] = missing
547
+
548
+ return SlidingWindowFeature(average, frames)
549
+
550
+ @staticmethod
551
+ def trim(
552
+ scores: SlidingWindowFeature,
553
+ warm_up: Tuple[float, float] = (0.1, 0.1),
554
+ ) -> SlidingWindowFeature:
555
+ """Trim left and right warm-up regions
556
+
557
+ Parameters
558
+ ----------
559
+ scores : SlidingWindowFeature
560
+ (num_chunks, num_frames, num_classes)-shaped scores.
561
+ warm_up : (float, float) tuple
562
+ Left/right warm up ratio of chunk duration.
563
+ Defaults to (0.1, 0.1), i.e. 10% on both sides.
564
+
565
+ Returns
566
+ -------
567
+ trimmed : SlidingWindowFeature
568
+ (num_chunks, trimmed_num_frames, num_speakers)-shaped scores
569
+ """
570
+
571
+
572
+ assert (
573
+ scores.data.ndim == 3
574
+ ), "Inference.trim expects (num_chunks, num_frames, num_classes)-shaped `scores`"
575
+ _, num_frames, _ = scores.data.shape
576
+
577
+ chunks = scores.sliding_window
578
+
579
+ num_frames_left = round(num_frames * warm_up[0])
580
+ num_frames_right = round(num_frames * warm_up[1])
581
+
582
+ num_frames_step = round(num_frames * chunks.step / chunks.duration)
583
+ if num_frames - num_frames_left - num_frames_right < num_frames_step:
584
+ warnings.warn(
585
+ f"Total `warm_up` is so large ({sum(warm_up) * 100:g}% of each chunk) "
586
+ f"that resulting trimmed scores does not cover a whole step ({chunks.step:g}s)"
587
+ )
588
+ new_data = scores.data[:, num_frames_left : num_frames - num_frames_right]
589
+
590
+ new_chunks = SlidingWindow(
591
+ start=chunks.start + warm_up[0] * chunks.duration,
592
+ step=chunks.step,
593
+ duration=(1 - warm_up[0] - warm_up[1]) * chunks.duration,
594
+ )
595
+
596
+ return SlidingWindowFeature(new_data, new_chunks)
ailia-models/code/pyannote_audio_utils/audio/core/io.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2020- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ """
24
+ # Audio IO
25
+
26
+ pyannote.audio relies on torchaudio for reading and resampling.
27
+
28
+ """
29
+
30
+ import math
31
+ import random
32
+ import warnings
33
+
34
+ from pathlib import Path
35
+ from typing import Mapping, Optional, Text, Tuple, Union
36
+ import numpy as np
37
+
38
+ from pyannote_audio_utils.core import Segment
39
+
40
+ import ailia.audio
41
+ import soundfile
42
+
43
+ AudioFile = Union[Text, Path, Mapping]
44
+
45
+ AudioFileDocString = """
46
+ Audio files can be provided to the Audio class using different types:
47
+ - a "str" or "Path" instance: "audio.wav" or Path("audio.wav")
48
+ - a "IOBase" instance with "read" and "seek" support: open("audio.wav", "rb")
49
+ - a "Mapping" with any of the above as "audio" key: {"audio": ...}
50
+ - a "Mapping" with both "waveform" and "sample_rate" key:
51
+ {"waveform": (channel, time) numpy.ndarray or torch.Tensor, "sample_rate": 44100}
52
+
53
+ For last two options, an additional "channel" key can be provided as a zero-indexed
54
+ integer to load a specific channel: {"audio": "stereo.wav", "channel": 0}
55
+ """
56
+
57
+
58
+ class Audio:
59
+ """Audio IO
60
+
61
+ Parameters
62
+ ----------
63
+ sample_rate: int, optional
64
+ Target sampling rate. Defaults to using native sampling rate.
65
+ mono : {'random', 'downmix'}, optional
66
+ In case of multi-channel audio, convert to single-channel audio
67
+ using one of the following strategies: select one channel at
68
+ 'random' or 'downmix' by averaging all channels.
69
+
70
+ Usage
71
+ -----
72
+ >>> audio = Audio(sample_rate=16000, mono='downmix')
73
+ >>> waveform, sample_rate = audio({"audio": "/path/to/audio.wav"})
74
+ >>> assert sample_rate == 16000
75
+ >>> sample_rate = 44100
76
+ >>> two_seconds_stereo = torch.rand(2, 2 * sample_rate)
77
+ >>> waveform, sample_rate = audio({"waveform": two_seconds_stereo, "sample_rate": sample_rate})
78
+ >>> assert sample_rate == 16000
79
+ >>> assert waveform.shape[0] == 1
80
+ """
81
+
82
+ PRECISION = 0.001
83
+
84
+
85
+ @staticmethod
86
+ def validate_file(file: AudioFile) -> Mapping:
87
+ """Validate file for use with the other Audio methods
88
+
89
+ Parameter
90
+ ---------
91
+ file: AudioFile
92
+
93
+ Returns
94
+ -------
95
+ validated_file : Mapping
96
+ {"audio": str, "uri": str, ...}
97
+ {"waveform": array or tensor, "sample_rate": int, "uri": str, ...}
98
+ {"audio": file, "uri": "stream"} if `file` is an IOBase instance
99
+
100
+ Raises
101
+ ------
102
+ ValueError if file format is not valid or file does not exist.
103
+
104
+ """
105
+
106
+
107
+ if isinstance(file, Mapping):
108
+ pass
109
+ elif isinstance(file, (str, Path)):
110
+ file = {"audio": str(file), "uri": Path(file).stem}
111
+
112
+ # elif isinstance(file, IOBase):
113
+ # return {"audio": file, "uri": "stream"}
114
+
115
+ # else:
116
+ # raise ValueError(AudioFileDocString)
117
+
118
+ if "waveform" in file:
119
+ waveform: np.ndarray = file["waveform"]
120
+ if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]:
121
+ raise ValueError(
122
+ "'waveform' must be provided as a (channel, time) torch Tensor."
123
+ )
124
+
125
+ sample_rate: int = file.get("sample_rate", None)
126
+ if sample_rate is None:
127
+ raise ValueError(
128
+ "'waveform' must be provided with their 'sample_rate'."
129
+ )
130
+
131
+ file.setdefault("uri", "waveform")
132
+
133
+ elif "audio" in file:
134
+ # if isinstance(file["audio"], IOBase):
135
+ # return file
136
+
137
+ path = Path(file["audio"])
138
+ if not path.is_file():
139
+ raise ValueError(f"File {path} does not exist")
140
+
141
+ file.setdefault("uri", path.stem)
142
+
143
+ else:
144
+ raise ValueError(
145
+ "Neither 'waveform' nor 'audio' is available for this file."
146
+ )
147
+
148
+ return file
149
+
150
+ def __init__(self, sample_rate=None, mono=None):
151
+ super().__init__()
152
+ self.sample_rate = sample_rate
153
+ self.mono = mono
154
+
155
+ def downmix_and_resample(self, waveform: np.ndarray, sample_rate: int) -> np.ndarray:
156
+ """Downmix and resample
157
+
158
+ Parameters
159
+ ----------
160
+ waveform : (channel, time) Tensor
161
+ Waveform.
162
+ sample_rate : int
163
+ Sample rate.
164
+
165
+ Returns
166
+ -------
167
+ waveform : (channel, time) Tensor
168
+ Remixed and resampled waveform
169
+ sample_rate : int
170
+ New sample rate
171
+ """
172
+
173
+ # downmix to mono
174
+
175
+ num_channels = waveform.shape[0]
176
+ if num_channels > 1:
177
+ if self.mono == "random":
178
+ channel = random.randint(0, num_channels - 1)
179
+ waveform = waveform[channel : channel + 1]
180
+ elif self.mono == "downmix":
181
+ waveform = np.mean(waveform, axis=0, keepdims=True)
182
+
183
+
184
+
185
+ ######## ここでずれる ##########
186
+ if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
187
+ waveform = ailia.audio.resample(
188
+ waveform, org_sr=sample_rate, target_sr=self.sample_rate)
189
+
190
+ sample_rate = self.sample_rate
191
+
192
+ return waveform, sample_rate
193
+
194
+
195
+ def get_num_samples(
196
+ self, duration: float, sample_rate: Optional[int] = None
197
+ ) -> int:
198
+ """Deterministic number of samples from duration and sample rate"""
199
+
200
+ sample_rate = sample_rate or self.sample_rate
201
+
202
+ if sample_rate is None:
203
+ raise ValueError(
204
+ "`sample_rate` must be provided to compute number of samples."
205
+ )
206
+
207
+ return math.floor(duration * sample_rate)
208
+
209
+ def __call__(self, file: AudioFile) -> Tuple[np.ndarray, int]:
210
+ """Obtain waveform
211
+
212
+ Parameters
213
+ ----------
214
+ file : AudioFile
215
+
216
+ Returns
217
+ -------
218
+ waveform : (channel, time) torch.Tensor
219
+ Waveform
220
+ sample_rate : int
221
+ Sample rate
222
+
223
+ See also
224
+ --------
225
+ AudioFile
226
+ """
227
+
228
+ file = self.validate_file(file)
229
+
230
+ if "waveform" in file:
231
+ waveform = file["waveform"]
232
+ sample_rate = file["sample_rate"]
233
+
234
+ waveform, sample_rate = soundfile.read(file["audio"])
235
+
236
+ if waveform.ndim == 1:
237
+ waveform = np.expand_dims(waveform,axis=0)
238
+ else:
239
+ waveform = waveform.T
240
+
241
+ channel = file.get("channel", None)
242
+
243
+ if channel is not None:
244
+ waveform = waveform[channel : channel + 1]
245
+
246
+ return self.downmix_and_resample(waveform, sample_rate)
247
+
248
+ def crop(
249
+ self,
250
+ file: AudioFile,
251
+ segment: Segment,
252
+ duration: Optional[float] = None,
253
+ mode="raise",
254
+ ) -> Tuple[np.ndarray, int]:
255
+ """Fast version of self(file).crop(segment, **kwargs)
256
+
257
+ Parameters
258
+ ----------
259
+ file : AudioFile
260
+ Audio file.
261
+ segment : `pyannote.core.Segment`
262
+ Temporal segment to load.
263
+ duration : float, optional
264
+ Overrides `Segment` 'focus' duration and ensures that the number of
265
+ returned frames is fixed (which might otherwise not be the case
266
+ because of rounding errors).
267
+ mode : {'raise', 'pad'}, optional
268
+ Specifies how out-of-bounds segments will behave.
269
+ * 'raise' -- raise an error (default)
270
+ * 'pad' -- zero pad
271
+
272
+ Returns
273
+ -------
274
+ waveform : (channel, time) torch.Tensor
275
+ Waveform
276
+ sample_rate : int
277
+ Sample rate
278
+
279
+ """
280
+
281
+ file = self.validate_file(file)
282
+
283
+ if "waveform" in file:
284
+ waveform = file["waveform"]
285
+ frames = waveform.shape[1]
286
+ sample_rate = file["sample_rate"]
287
+
288
+ elif "torchaudio.info" in file:
289
+ info = file["torchaudio.info"]
290
+ frames = info.num_frames
291
+ sample_rate = info.sample_rate
292
+
293
+ else:
294
+ info = soundfile.read(file["audio"])
295
+ frames = info[0].shape[0]
296
+ sample_rate = info[1]
297
+
298
+ channel = file.get("channel", None)
299
+
300
+ # infer which samples to load from sample rate and requested chunk
301
+ start_frame = math.floor(segment.start * sample_rate)
302
+
303
+ if duration:
304
+ num_frames = math.floor(duration * sample_rate)
305
+ end_frame = start_frame + num_frames
306
+
307
+ else:
308
+ end_frame = math.floor(segment.end * sample_rate)
309
+ num_frames = end_frame - start_frame
310
+
311
+ if mode == "pad":
312
+ pad_start = -min(0, start_frame)
313
+ pad_end = max(end_frame, frames) - frames
314
+ start_frame = max(0, start_frame)
315
+ end_frame = min(end_frame, frames)
316
+ num_frames = end_frame - start_frame
317
+
318
+ if "waveform" in file:
319
+ data = file["waveform"][:, start_frame:end_frame]
320
+
321
+ else:
322
+ try:
323
+ data, _ = soundfile.read(file["audio"], start=start_frame, frames=num_frames)
324
+ if data.ndim == 1:
325
+ data = np.expand_dims(data, axis=0)
326
+ else:
327
+ data = data.T
328
+
329
+ except RuntimeError:
330
+ msg = (
331
+ f"torchaudio failed to seek-and-read in {file['audio']}: "
332
+ f"loading the whole file instead."
333
+ )
334
+
335
+ warnings.warn(msg)
336
+ waveform, sample_rate = self.__call__(file)
337
+ data = waveform[:, start_frame:end_frame]
338
+
339
+ # storing waveform and sample_rate for next time
340
+ # as it is very likely that seek-and-read will
341
+ # fail again for this particular file
342
+ file["waveform"] = waveform
343
+ file["sample_rate"] = sample_rate
344
+
345
+ if channel is not None:
346
+ data = data[channel : channel + 1, :]
347
+
348
+ if mode == "pad":
349
+ data = np.pad(data, ((0, 0), (pad_start, pad_end)))
350
+
351
+
352
+ return self.downmix_and_resample(data, sample_rate)
ailia-models/code/pyannote_audio_utils/audio/core/pipeline.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2021 CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import os
24
+ import warnings
25
+ from collections import OrderedDict
26
+ from collections.abc import Iterator
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Callable, Dict, List, Optional, Text, Union
30
+
31
+
32
+ import yaml
33
+ from importlib import import_module
34
+ from pyannote_audio_utils.database import ProtocolFile
35
+ from pyannote_audio_utils.pipeline import Pipeline as _Pipeline
36
+
37
+ from pyannote_audio_utils.audio import Audio, __version__
38
+ from pyannote_audio_utils.audio.core.inference import BaseInference
39
+ from pyannote_audio_utils.audio.core.io import AudioFile
40
+
41
+ PIPELINE_PARAMS_NAME = "config.yaml"
42
+
43
+
44
+ class Pipeline(_Pipeline):
45
+ @classmethod
46
+ def from_pretrained(
47
+ cls,
48
+ checkpoint_path: Union[Text, Path],
49
+ hparams_file: Union[Text, Path] = None,
50
+ use_auth_token: Union[Text, None] = None,
51
+ ) -> "Pipeline":
52
+ """Load pretrained pipeline
53
+
54
+ Parameters
55
+ ----------
56
+ checkpoint_path : Path or str
57
+ Path to pipeline checkpoint, or a remote URL,
58
+ or a pipeline identifier from the huggingface.co model hub.
59
+ hparams_file: Path or str, optional
60
+ use_auth_token : str, optional
61
+ When loading a private huggingface.co pipeline, set `use_auth_token`
62
+ to True or to a string containing your hugginface.co authentication
63
+ token that can be obtained by running `huggingface-cli login`
64
+ cache_dir: Path or str, optional
65
+ Path to model cache directory. Defauorch/pyannote_audio_utils" when unset.
66
+ """
67
+
68
+ checkpoint_path = str(checkpoint_path)
69
+ config_yml = checkpoint_path
70
+
71
+ with open(config_yml, "r") as fp:
72
+ config = yaml.load(fp, Loader=yaml.SafeLoader)
73
+
74
+ # initialize pipeline
75
+ pipeline_name = config["pipeline"]["name"]
76
+ tokens = pipeline_name.split('.')
77
+ module_name = '.'.join(tokens[:-1])
78
+ class_name = tokens[-1]
79
+ Klass = getattr(import_module(module_name), class_name)
80
+ params = config["pipeline"].get("params", {})
81
+ pipeline = Klass(**params)
82
+
83
+ # freeze parameters
84
+ if "params" in config:
85
+ pipeline.instantiate(config["params"])
86
+
87
+ return pipeline
88
+
89
+ def __init__(self):
90
+ super().__init__()
91
+ self._models: Dict[str] = OrderedDict()
92
+ self._inferences: Dict[str, BaseInference] = OrderedDict()
93
+
94
+ def __getattr__(self, name):
95
+ """(Advanced) attribute getter
96
+
97
+ Adds support for Model and Inference attributes,
98
+ which are iterated over by Pipeline.to() method.
99
+
100
+ See pyannote_audio_utils.pipeline.Pipeline.__getattr__.
101
+ """
102
+
103
+ if "_models" in self.__dict__:
104
+ _models = self.__dict__["_models"]
105
+ if name in _models:
106
+ return _models[name]
107
+
108
+ if "_inferences" in self.__dict__:
109
+ _inferences = self.__dict__["_inferences"]
110
+ if name in _inferences:
111
+ return _inferences[name]
112
+
113
+ return super().__getattr__(name)
114
+
115
+ def __setattr__(self, name, value):
116
+ """(Advanced) attribute setter
117
+
118
+ Adds support for Model and Inference attributes,
119
+ which are iterated over by Pipeline.to() method.
120
+
121
+ See pyannote_audio_utils.pipeline.Pipeline.__setattr__.
122
+ """
123
+
124
+ def remove_from(*dicts):
125
+ for d in dicts:
126
+ if name in d:
127
+ del d[name]
128
+
129
+ _parameters = self.__dict__.get("_parameters")
130
+ _instantiated = self.__dict__.get("_instantiated")
131
+ _pipelines = self.__dict__.get("_pipelines")
132
+ _models = self.__dict__.get("_models")
133
+ _inferences = self.__dict__.get("_inferences")
134
+
135
+
136
+
137
+ if isinstance(value, BaseInference):
138
+ if _inferences is None:
139
+ msg = "cannot assign inferences before Pipeline.__init__() call"
140
+ raise AttributeError(msg)
141
+ remove_from(self.__dict__, _models, _parameters, _instantiated, _pipelines)
142
+ _inferences[name] = value
143
+ return
144
+
145
+ super().__setattr__(name, value)
146
+
147
+ def __delattr__(self, name):
148
+ if name in self._models:
149
+ del self._models[name]
150
+
151
+ elif name in self._inferences:
152
+ del self._inferences[name]
153
+
154
+ else:
155
+ super().__delattr__(name)
156
+
157
+ @staticmethod
158
+ def setup_hook(file: AudioFile, hook: Optional[Callable] = None) -> Callable:
159
+ def noop(*args, **kwargs):
160
+ return
161
+
162
+ return partial(hook or noop, file=file)
163
+
164
+ def default_parameters(self):
165
+ raise NotImplementedError()
166
+
167
+ def classes(self) -> Union[List, Iterator]:
168
+ """Classes returned by the pipeline
169
+
170
+ Returns
171
+ -------
172
+ classes : list of string or string iterator
173
+ Finite list of strings when classes are known in advance
174
+ (e.g. ["MALE", "FEMALE"] for gender classification), or
175
+ infinite string iterator when they depend on the file
176
+ (e.g. "SPEAKER_00", "SPEAKER_01", ... for speaker diarization)
177
+
178
+ Usage
179
+ -----
180
+ >>> from collections.abc import Iterator
181
+ >>> classes = pipeline.classes()
182
+ >>> if isinstance(classes, Iterator): # classes depend on the input file
183
+ >>> if isinstance(classes, list): # classes are known in advance
184
+
185
+ """
186
+ raise NotImplementedError()
187
+
188
+ def __call__(self, file: AudioFile, **kwargs):
189
+ # breakpoint()
190
+ # fix_reproducibility(getattr(self, "device", torch.device("cpu")))
191
+
192
+ if not self.instantiated:
193
+ # instantiate with default parameters when available
194
+ try:
195
+ default_parameters = self.default_parameters()
196
+ except NotImplementedError:
197
+ raise RuntimeError(
198
+ "A pipeline must be instantiated with `pipeline.instantiate(parameters)` before it can be applied."
199
+ )
200
+
201
+ try:
202
+ self.instantiate(default_parameters)
203
+ except ValueError:
204
+ raise RuntimeError(
205
+ "A pipeline must be instantiated with `pipeline.instantiate(paramaters)` before it can be applied. "
206
+ "Tried to use parameters provided by `pipeline.default_parameters()` but those are not compatible. "
207
+ )
208
+
209
+ warnings.warn(
210
+ f"The pipeline has been automatically instantiated with {default_parameters}."
211
+ )
212
+
213
+ file = Audio.validate_file(file)
214
+
215
+ if hasattr(self, "preprocessors"):
216
+ file = ProtocolFile(file, lazy=self.preprocessors)
217
+
218
+ return self.apply(file, **kwargs)
ailia-models/code/pyannote_audio_utils/audio/core/task.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2020- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+
24
+ from __future__ import annotations
25
+
26
+ from dataclasses import dataclass
27
+ from enum import Enum
28
+ from functools import cached_property, partial
29
+ from typing import Dict, List, Literal, Optional, Sequence, Text, Tuple, Union
30
+
31
+ import scipy.special
32
+
33
+ from pyannote_audio_utils.database.protocol.protocol import Scope, Subset
34
+
35
+
36
+ Subsets = list(Subset.__args__)
37
+ Scopes = list(Scope.__args__)
38
+
39
+
40
+ # Type of machine learning problem
41
+ class Problem(Enum):
42
+ BINARY_CLASSIFICATION = 0
43
+ MONO_LABEL_CLASSIFICATION = 1
44
+ MULTI_LABEL_CLASSIFICATION = 2
45
+ REPRESENTATION = 3
46
+ REGRESSION = 4
47
+ # any other we could think of?
48
+
49
+
50
+ # A task takes an audio chunk as input and returns
51
+ # either a temporal sequence of predictions
52
+ # or just one prediction for the whole audio chunk
53
+ class Resolution(Enum):
54
+ FRAME = 1 # model outputs a sequence of frames
55
+ CHUNK = 2 # model outputs just one vector for the whole chunk
56
+
57
+
58
+ class UnknownSpecificationsError(Exception):
59
+ pass
60
+
61
+
62
+ @dataclass
63
+ class Specifications:
64
+ problem: Problem
65
+ resolution: Resolution
66
+
67
+ # (maximum) chunk duration in seconds
68
+ # duration: float
69
+ duration: float = 10.0
70
+
71
+ # (for variable-duration tasks only) minimum chunk duration in seconds
72
+ min_duration: Optional[float] = None
73
+
74
+ # use that many seconds on the left- and rightmost parts of each chunk
75
+ # to warm up the model. This is mostly useful for segmentation tasks.
76
+ # While the model does process those left- and right-most parts, only
77
+ # the remaining central part of each chunk is used for computing the
78
+ # loss during training, and for aggregating scores during inference.
79
+ # Defaults to 0. (i.e. no warm-up).
80
+ warm_up: Optional[Tuple[float, float]] = (0.0, 0.0)
81
+
82
+ # (for classification tasks only) list of classes
83
+ classes: Optional[List[Text]] = None
84
+ # classes: Optional[List[Text]] = ['speaker#1', 'speaker#2', 'speaker#3']
85
+
86
+ # (for powerset only) max number of simultaneous classes
87
+ # (n choose k with k <= powerset_max_classes)
88
+ # powerset_max_classes: Optional[int] = None
89
+ powerset_max_classes: Optional[int] = 2
90
+
91
+ # whether classes are permutation-invariant (e.g. diarization)
92
+ # permutation_invariant: bool = False
93
+ permutation_invariant: bool = True
94
+
95
+
96
+ @cached_property
97
+ def powerset(self) -> bool:
98
+ if self.powerset_max_classes is None:
99
+ return False
100
+
101
+ if self.problem != Problem.MONO_LABEL_CLASSIFICATION:
102
+ raise ValueError(
103
+ "`powerset_max_classes` only makes sense with multi-class classification problems."
104
+ )
105
+
106
+ return True
107
+
108
+ @cached_property
109
+ def num_powerset_classes(self) -> int:
110
+ # compute number of subsets of size at most "powerset_max_classes"
111
+ # e.g. with len(classes) = 3 and powerset_max_classes = 2:
112
+ # {}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}
113
+ return int(
114
+ sum(
115
+ scipy.special.binom(len(self.classes), i)
116
+ for i in range(0, self.powerset_max_classes + 1)
117
+ )
118
+ )
119
+
120
+ def __len__(self):
121
+ return 1
122
+
123
+ def __iter__(self):
124
+ yield self
125
+
ailia-models/code/pyannote_audio_utils/audio/pipelines/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2020-2022 CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # from .multilabel import MultiLabelSegmentation
24
+ # from .overlapped_speech_detection import OverlappedSpeechDetection
25
+ # from .resegmentation import Resegmentation
26
+ from .speaker_diarization import SpeakerDiarization
27
+ # from .voice_activity_detection import VoiceActivityDetection
28
+
29
+ __all__ = [
30
+ # "VoiceActivityDetection",
31
+ # "OverlappedSpeechDetection",
32
+ "SpeakerDiarization",
33
+ # "Resegmentation",
34
+ # "MultiLabelSegmentation",
35
+ ]
ailia-models/code/pyannote_audio_utils/audio/pipelines/clustering.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ #
3
+ # Copyright (c) 2021- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ """Clustering pipelines"""
24
+
25
+
26
+ import random
27
+ from enum import Enum
28
+ from typing import Optional, Tuple
29
+
30
+ import numpy as np
31
+ from pyannote_audio_utils.core import SlidingWindow, SlidingWindowFeature
32
+ from pyannote_audio_utils.pipeline import Pipeline
33
+ from pyannote_audio_utils.pipeline.parameter import Categorical, Integer, Uniform
34
+ from scipy.cluster.hierarchy import fcluster, linkage
35
+ from scipy.optimize import linear_sum_assignment
36
+ from scipy.spatial.distance import cdist
37
+
38
+
39
+ class BaseClustering(Pipeline):
40
+ def __init__(
41
+ self,
42
+ metric: str = "cosine",
43
+ max_num_embeddings: int = 1000,
44
+ constrained_assignment: bool = False,
45
+ ):
46
+ super().__init__()
47
+ self.metric = metric
48
+ self.max_num_embeddings = max_num_embeddings
49
+ self.constrained_assignment = constrained_assignment
50
+
51
+ def set_num_clusters(
52
+ self,
53
+ num_embeddings: int,
54
+ num_clusters: Optional[int] = None,
55
+ min_clusters: Optional[int] = None,
56
+ max_clusters: Optional[int] = None,
57
+ ):
58
+ min_clusters = num_clusters or min_clusters or 1
59
+ min_clusters = max(1, min(num_embeddings, min_clusters))
60
+ max_clusters = num_clusters or max_clusters or num_embeddings
61
+ max_clusters = max(1, min(num_embeddings, max_clusters))
62
+
63
+ if min_clusters > max_clusters:
64
+ raise ValueError(
65
+ f"min_clusters must be smaller than (or equal to) max_clusters "
66
+ f"(here: min_clusters={min_clusters:g} and max_clusters={max_clusters:g})."
67
+ )
68
+
69
+ if min_clusters == max_clusters:
70
+ num_clusters = min_clusters
71
+
72
+ return num_clusters, min_clusters, max_clusters
73
+
74
+ def filter_embeddings(
75
+ self,
76
+ embeddings: np.ndarray,
77
+ segmentations: Optional[SlidingWindowFeature] = None,
78
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
79
+ """Filter NaN embeddings and downsample embeddings
80
+
81
+ Parameters
82
+ ----------
83
+ embeddings : (num_chunks, num_speakers, dimension) array
84
+ Sequence of embeddings.
85
+ segmentations : (num_chunks, num_frames, num_speakers) array
86
+ Binary segmentations.
87
+
88
+ Returns
89
+ -------
90
+ filtered_embeddings : (num_embeddings, dimension) array
91
+ chunk_idx : (num_embeddings, ) array
92
+ speaker_idx : (num_embeddings, ) array
93
+ """
94
+
95
+ # whether speaker is active
96
+ active = np.sum(segmentations.data, axis=1) > 0
97
+ # whether speaker embedding extraction went fine
98
+ valid = ~np.any(np.isnan(embeddings), axis=2)
99
+
100
+ # indices of embeddings that are both active and valid
101
+ chunk_idx, speaker_idx = np.where(active * valid)
102
+
103
+ # sample max_num_embeddings embeddings
104
+ num_embeddings = len(chunk_idx)
105
+ if num_embeddings > self.max_num_embeddings:
106
+ indices = list(range(num_embeddings))
107
+ random.shuffle(indices)
108
+ indices = sorted(indices[: self.max_num_embeddings])
109
+ chunk_idx = chunk_idx[indices]
110
+ speaker_idx = speaker_idx[indices]
111
+
112
+ return embeddings[chunk_idx, speaker_idx], chunk_idx, speaker_idx
113
+
114
+ def constrained_argmax(self, soft_clusters: np.ndarray) -> np.ndarray:
115
+ soft_clusters = np.nan_to_num(soft_clusters, nan=np.nanmin(soft_clusters))
116
+ num_chunks, num_speakers, num_clusters = soft_clusters.shape
117
+ # num_chunks, num_speakers, num_clusters
118
+
119
+ hard_clusters = -2 * np.ones((num_chunks, num_speakers), dtype=np.int8)
120
+
121
+ for c, cost in enumerate(soft_clusters):
122
+ speakers, clusters = linear_sum_assignment(cost, maximize=True)
123
+ for s, k in zip(speakers, clusters):
124
+ hard_clusters[c, s] = k
125
+
126
+ return hard_clusters
127
+
128
+ def assign_embeddings(
129
+ self,
130
+ embeddings: np.ndarray,
131
+ train_chunk_idx: np.ndarray,
132
+ train_speaker_idx: np.ndarray,
133
+ train_clusters: np.ndarray,
134
+ constrained: bool = False,
135
+ ):
136
+ """Assign embeddings to the closest centroid
137
+
138
+ Cluster centroids are computed as the average of the train embeddings
139
+ previously assigned to them.
140
+
141
+ Parameters
142
+ ----------
143
+ embeddings : (num_chunks, num_speakers, dimension)-shaped array
144
+ Complete set of embeddings.
145
+ train_chunk_idx : (num_embeddings,)-shaped array
146
+ train_speaker_idx : (num_embeddings,)-shaped array
147
+ Indices of subset of embeddings used for "training".
148
+ train_clusters : (num_embedding,)-shaped array
149
+ Clusters of the above subset
150
+ constrained : bool, optional
151
+ Use constrained_argmax, instead of (default) argmax.
152
+
153
+ Returns
154
+ -------
155
+ soft_clusters : (num_chunks, num_speakers, num_clusters)-shaped array
156
+ hard_clusters : (num_chunks, num_speakers)-shaped array
157
+ centroids : (num_clusters, dimension)-shaped array
158
+ Clusters centroids
159
+ """
160
+
161
+ # TODO: option to add a new (dummy) cluster in case num_clusters < max(frame_speaker_count)
162
+
163
+ num_clusters = np.max(train_clusters) + 1
164
+ num_chunks, num_speakers, dimension = embeddings.shape
165
+
166
+ train_embeddings = embeddings[train_chunk_idx, train_speaker_idx]
167
+
168
+ centroids = np.vstack(
169
+ [
170
+ np.mean(train_embeddings[train_clusters == k], axis=0)
171
+ for k in range(num_clusters)
172
+ ]
173
+ )
174
+
175
+ e2k_distance = cdist(
176
+ embeddings.reshape([-1, dimension]),
177
+ centroids,
178
+ metric=self.metric
179
+ ).reshape([num_chunks, num_speakers, -1])
180
+
181
+ soft_clusters = 2 - e2k_distance
182
+
183
+ # assign each embedding to the cluster with the most similar centroid
184
+ if constrained:
185
+ hard_clusters = self.constrained_argmax(soft_clusters)
186
+ else:
187
+ hard_clusters = np.argmax(soft_clusters, axis=2)
188
+
189
+ # NOTE: train_embeddings might be reassigned to a different cluster
190
+ # in the process. based on experiments, this seems to lead to better
191
+ # results than sticking to the original assignment.
192
+
193
+ return hard_clusters, soft_clusters, centroids
194
+
195
+ def __call__(
196
+ self,
197
+ embeddings: np.ndarray,
198
+ segmentations: Optional[SlidingWindowFeature] = None,
199
+ num_clusters: Optional[int] = None,
200
+ min_clusters: Optional[int] = None,
201
+ max_clusters: Optional[int] = None,
202
+ **kwargs,
203
+ ) -> np.ndarray:
204
+ """Apply clustering
205
+
206
+ Parameters
207
+ ----------
208
+ embeddings : (num_chunks, num_speakers, dimension) array
209
+ Sequence of embeddings.
210
+ segmentations : (num_chunks, num_frames, num_speakers) array
211
+ Binary segmentations.
212
+ num_clusters : int, optional
213
+ Number of clusters, when known. Default behavior is to use
214
+ internal threshold hyper-parameter to decide on the number
215
+ of clusters.
216
+ min_clusters : int, optional
217
+ Minimum number of clusters. Has no effect when `num_clusters` is provided.
218
+ max_clusters : int, optional
219
+ Maximum number of clusters. Has no effect when `num_clusters` is provided.
220
+
221
+ Returns
222
+ -------
223
+ hard_clusters : (num_chunks, num_speakers) array
224
+ Hard cluster assignment (hard_clusters[c, s] = k means that sth speaker
225
+ of cth chunk is assigned to kth cluster)
226
+ soft_clusters : (num_chunks, num_speakers, num_clusters) array
227
+ Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely
228
+ the sth speaker of cth chunk belongs to kth cluster)
229
+ centroids : (num_clusters, dimension) array
230
+ Centroid vectors of each cluster
231
+ """
232
+
233
+ train_embeddings, train_chunk_idx, train_speaker_idx = self.filter_embeddings(
234
+ embeddings,
235
+ segmentations=segmentations,
236
+ )
237
+
238
+ num_embeddings, _ = train_embeddings.shape
239
+
240
+ num_clusters, min_clusters, max_clusters = self.set_num_clusters(
241
+ num_embeddings,
242
+ num_clusters=num_clusters,
243
+ min_clusters=min_clusters,
244
+ max_clusters=max_clusters,
245
+ )
246
+
247
+ if max_clusters < 2:
248
+ # do NOT apply clustering when min_clusters = max_clusters = 1
249
+ num_chunks, num_speakers, _ = embeddings.shape
250
+ hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8)
251
+ soft_clusters = np.ones((num_chunks, num_speakers, 1))
252
+ centroids = np.mean(train_embeddings, axis=0, keepdims=True)
253
+ return hard_clusters, soft_clusters, centroids
254
+
255
+ train_clusters = self.cluster(
256
+ train_embeddings,
257
+ min_clusters,
258
+ max_clusters,
259
+ num_clusters=num_clusters,
260
+ )
261
+
262
+ hard_clusters, soft_clusters, centroids = self.assign_embeddings(
263
+ embeddings,
264
+ train_chunk_idx,
265
+ train_speaker_idx,
266
+ train_clusters,
267
+ constrained=self.constrained_assignment,
268
+ )
269
+
270
+ return hard_clusters, soft_clusters, centroids
271
+
272
+
273
+ class AgglomerativeClustering(BaseClustering):
274
+ """Agglomerative clustering
275
+
276
+ Parameters
277
+ ----------
278
+ metric : {"cosine", "euclidean", ...}, optional
279
+ Distance metric to use. Defaults to "cosine".
280
+
281
+ Hyper-parameters
282
+ ----------------
283
+ method : {"average", "centroid", "complete", "median", "single", "ward"}
284
+ Linkage method.
285
+ threshold : float in range [0.0, 2.0]
286
+ Clustering threshold.
287
+ min_cluster_size : int in range [1, 20]
288
+ Minimum cluster size
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ metric: str = "cosine",
294
+ max_num_embeddings: int = np.inf,
295
+ constrained_assignment: bool = False,
296
+ ):
297
+ super().__init__(
298
+ metric=metric,
299
+ max_num_embeddings=max_num_embeddings,
300
+ constrained_assignment=constrained_assignment,
301
+ )
302
+
303
+ self.threshold = Uniform(0.0, 2.0) # assume unit-normalized embeddings
304
+ self.method = Categorical(
305
+ ["average", "centroid", "complete", "median", "single", "ward", "weighted"]
306
+ )
307
+
308
+ # minimum cluster size
309
+ self.min_cluster_size = Integer(1, 20)
310
+
311
+
312
+ def cluster(
313
+ self,
314
+ embeddings: np.ndarray,
315
+ min_clusters: int,
316
+ max_clusters: int,
317
+ num_clusters: Optional[int] = None,
318
+ ):
319
+ """
320
+
321
+ Parameters
322
+ ----------
323
+ embeddings : (num_embeddings, dimension) array
324
+ Embeddings
325
+ min_clusters : int
326
+ Minimum number of clusters
327
+ max_clusters : int
328
+ Maximum number of clusters
329
+ num_clusters : int, optional
330
+ Actual number of clusters. Default behavior is to estimate it based
331
+ on values provided for `min_clusters`, `max_clusters`, and `threshold`.
332
+
333
+ Returns
334
+ -------
335
+ clusters : (num_embeddings, ) array
336
+ 0-indexed cluster indices.
337
+ """
338
+
339
+ num_embeddings, _ = embeddings.shape
340
+
341
+ # heuristic to reduce self.min_cluster_size when num_embeddings is very small
342
+ # (0.1 value is kind of arbitrary, though)
343
+ min_cluster_size = min(
344
+ self.min_cluster_size, max(1, round(0.1 * num_embeddings))
345
+ )
346
+
347
+
348
+ # linkage function will complain when there is just one embedding to cluster
349
+ if num_embeddings == 1:
350
+ return np.zeros((1,), dtype=np.uint8)
351
+
352
+ # centroid, median, and Ward method only support "euclidean" metric
353
+ # therefore we unit-normalize embeddings to somehow make them "euclidean"
354
+ if self.metric == "cosine" and self.method in ["centroid", "median", "ward"]:
355
+ with np.errstate(divide="ignore", invalid="ignore"):
356
+ embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)
357
+ dendrogram: np.ndarray = linkage(
358
+ embeddings, method=self.method, metric="euclidean"
359
+ )
360
+
361
+ # other methods work just fine with any metric
362
+ else:
363
+ dendrogram: np.ndarray = linkage(
364
+ embeddings, method=self.method, metric=self.metric
365
+ )
366
+
367
+ # apply the predefined threshold
368
+ clusters = fcluster(dendrogram, self.threshold, criterion="distance") - 1
369
+
370
+ # split clusters into two categories based on their number of items:
371
+ # large clusters vs. small clusters
372
+ cluster_unique, cluster_counts = np.unique(
373
+ clusters,
374
+ return_counts=True,
375
+ )
376
+ large_clusters = cluster_unique[cluster_counts >= min_cluster_size]
377
+ num_large_clusters = len(large_clusters)
378
+
379
+ # force num_clusters to min_clusters in case the actual number is too small
380
+ if num_large_clusters < min_clusters:
381
+ num_clusters = min_clusters
382
+
383
+ # force num_clusters to max_clusters in case the actual number is too large
384
+ elif num_large_clusters > max_clusters:
385
+ num_clusters = max_clusters
386
+
387
+ # look for perfect candidate if necessary
388
+ if num_clusters is not None and num_large_clusters != num_clusters:
389
+ # switch stopping criterion from "inter-cluster distance" stopping to "iteration index"
390
+ _dendrogram = np.copy(dendrogram)
391
+ _dendrogram[:, 2] = np.arange(num_embeddings - 1)
392
+
393
+ best_iteration = num_embeddings - 1
394
+ best_num_large_clusters = 1
395
+
396
+ # traverse the dendrogram by going further and further away
397
+ # from the "optimal" threshold
398
+
399
+ for iteration in np.argsort(np.abs(dendrogram[:, 2] - self.threshold)):
400
+ # only consider iterations that might have resulted
401
+ # in changing the number of (large) clusters
402
+ new_cluster_size = _dendrogram[iteration, 3]
403
+ if new_cluster_size < min_cluster_size:
404
+ continue
405
+
406
+ # estimate number of large clusters at considered iteration
407
+ clusters = fcluster(_dendrogram, iteration, criterion="distance") - 1
408
+ cluster_unique, cluster_counts = np.unique(clusters, return_counts=True)
409
+ large_clusters = cluster_unique[cluster_counts >= min_cluster_size]
410
+ num_large_clusters = len(large_clusters)
411
+
412
+ # keep track of iteration that leads to the number of large clusters
413
+ # as close as possible to the target number of clusters.
414
+ if abs(num_large_clusters - num_clusters) < abs(
415
+ best_num_large_clusters - num_clusters
416
+ ):
417
+ best_iteration = iteration
418
+ best_num_large_clusters = num_large_clusters
419
+
420
+ # stop traversing the dendrogram as soon as we found a good candidate
421
+ if num_large_clusters == num_clusters:
422
+ break
423
+
424
+ # re-apply best iteration in case we did not find a perfect candidate
425
+ if best_num_large_clusters != num_clusters:
426
+ clusters = (
427
+ fcluster(_dendrogram, best_iteration, criterion="distance") - 1
428
+ )
429
+ cluster_unique, cluster_counts = np.unique(clusters, return_counts=True)
430
+ large_clusters = cluster_unique[cluster_counts >= min_cluster_size]
431
+ num_large_clusters = len(large_clusters)
432
+ print(
433
+ f"Found only {num_large_clusters} clusters. Using a smaller value than {min_cluster_size} for `min_cluster_size` might help."
434
+ )
435
+
436
+ if num_large_clusters == 0:
437
+ clusters[:] = 0
438
+ return clusters
439
+
440
+ small_clusters = cluster_unique[cluster_counts < min_cluster_size]
441
+ if len(small_clusters) == 0:
442
+ return clusters
443
+
444
+ # re-assign each small cluster to the most similar large cluster based on their respective centroids
445
+ large_centroids = np.vstack(
446
+ [
447
+ np.mean(embeddings[clusters == large_k], axis=0)
448
+ for large_k in large_clusters
449
+ ]
450
+ )
451
+ small_centroids = np.vstack(
452
+ [
453
+ np.mean(embeddings[clusters == small_k], axis=0)
454
+ for small_k in small_clusters
455
+ ]
456
+ )
457
+ centroids_cdist = cdist(large_centroids, small_centroids, metric=self.metric)
458
+ for small_k, large_k in enumerate(np.argmin(centroids_cdist, axis=0)):
459
+ clusters[clusters == small_clusters[small_k]] = large_clusters[large_k]
460
+
461
+ # re-number clusters from 0 to num_large_clusters
462
+ _, clusters = np.unique(clusters, return_inverse=True)
463
+ return clusters
464
+
465
+
466
+ class Clustering(Enum):
467
+ AgglomerativeClustering = AgglomerativeClustering
468
+
ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_diarization.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ #
3
+ # Copyright (c) 2021- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ """Speaker diarization pipelines"""
24
+
25
+ import functools
26
+ import itertools
27
+ import math
28
+ import textwrap
29
+ import warnings
30
+ import numpy as np
31
+
32
+ from typing import Callable, Optional, Text, Union, Mapping
33
+ from pathlib import Path
34
+
35
+ from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature
36
+ from pyannote_audio_utils.pipeline.parameter import ParamDict, Uniform
37
+ from pyannote_audio_utils.audio import Audio, Inference, Pipeline
38
+ from pyannote_audio_utils.audio.core.io import AudioFile
39
+ from pyannote_audio_utils.audio.pipelines.clustering import Clustering
40
+ from pyannote_audio_utils.audio.pipelines.speaker_verification import ONNXWeSpeakerPretrainedSpeakerEmbedding
41
+ from pyannote_audio_utils.audio.pipelines.utils import SpeakerDiarizationMixin
42
+
43
+ AudioFile = Union[Text, Path, Mapping]
44
+ PipelineModel = Union[Text, Mapping]
45
+
46
+ def batchify(iterable, batch_size: int = 32, fillvalue=None):
47
+ """Batchify iterable"""
48
+ # batchify('ABCDEFG', 3) --> ['A', 'B', 'C'] ['D', 'E', 'F'] [G, ]
49
+ args = [iter(iterable)] * batch_size
50
+ return itertools.zip_longest(*args, fillvalue=fillvalue)
51
+
52
+
53
+ class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline):
54
+ """Speaker diarization pipeline
55
+
56
+ Parameters
57
+ ----------
58
+ segmentation : Model, str, or dict, optional
59
+ Pretrained segmentation model. Defaults to "pyannote_audio_utils/segmentation@2022.07".
60
+ See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format.
61
+ segmentation_step: float, optional
62
+ The segmentation model is applied on a window sliding over the whole audio file.
63
+ `segmentation_step` controls the step of this window, provided as a ratio of its
64
+ duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows).
65
+ embedding : Model, str, or dict, optional
66
+ Pretrained embedding model. Defaults to "pyannote_audio_utils/embedding@2022.07".
67
+ See pyannote_audio_utils.audio.pipelines.utils.get_model for supported format.
68
+ embedding_exclude_overlap : bool, optional
69
+ Exclude overlapping speech regions when extracting embeddings.
70
+ Defaults (False) to use the whole speech.
71
+ clustering : str, optional
72
+ Clustering algorithm. See pyannote_audio_utils.audio.pipelines.clustering.Clustering
73
+ for available options. Defaults to "AgglomerativeClustering".
74
+ segmentation_batch_size : int, optional
75
+ Batch size used for speaker segmentation. Defaults to 1.
76
+ embedding_batch_size : int, optional
77
+ Batch size used for speaker embedding. Defaults to 1.
78
+ der_variant : dict, optional
79
+ Optimize for a variant of diarization error rate.
80
+ Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric`
81
+ when instantiating the metric: GreedyDiarizationErrorRate(**der_variant).
82
+ use_auth_token : str, optional
83
+ When loading private huggingface.co models, set `use_auth_token`
84
+ to True or to a string containing your hugginface.co authentication
85
+ token that can be obtained by running `huggingface-cli login`
86
+
87
+ Usage
88
+ -----
89
+ # perform (unconstrained) diarization
90
+ >>> diarization = pipeline("/path/to/audio.wav")
91
+
92
+ # perform diarization, targetting exactly 4 speakers
93
+ >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4)
94
+
95
+ # perform diarization, with at least 2 speakers and at most 10 speakers
96
+ >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10)
97
+
98
+ # perform diarization and get one representative embedding per speaker
99
+ >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True)
100
+ >>> for s, speaker in enumerate(diarization.labels()):
101
+ ... # embeddings[s] is the embedding of speaker `speaker`
102
+
103
+ Hyper-parameters
104
+ ----------------
105
+ segmentation.threshold
106
+ segmentation.min_duration_off
107
+ clustering.???
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ segmentation: PipelineModel = "pyannote_audio_utils/segmentation@2022.07",
113
+ segmentation_step: float = 0.1,
114
+ embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e",
115
+ embedding_exclude_overlap: bool = False,
116
+ clustering: str = "AgglomerativeClustering",
117
+ embedding_batch_size: int = 1,
118
+ segmentation_batch_size: int = 1,
119
+ args = None,
120
+ seg_path = None,
121
+ emb_path = None,
122
+ der_variant: dict = None,
123
+ use_auth_token: Union[Text, None] = None,
124
+ ):
125
+ super().__init__()
126
+
127
+ model = segmentation
128
+ self.segmentation_step = segmentation_step
129
+ self.embedding = embedding
130
+ self.embedding_batch_size = embedding_batch_size
131
+ self.embedding_exclude_overlap = embedding_exclude_overlap
132
+ self.klustering = clustering
133
+ self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False}
134
+
135
+ segmentation_duration = 10.0
136
+
137
+ self._segmentation = Inference(
138
+ model,
139
+ duration=segmentation_duration,
140
+ step=self.segmentation_step * segmentation_duration,
141
+ skip_aggregation=True,
142
+ batch_size=segmentation_batch_size,
143
+ args=args,
144
+ seg_path=seg_path
145
+ )
146
+
147
+ self._frames: SlidingWindow = self._segmentation.example_output.frames
148
+
149
+ self.segmentation = ParamDict(
150
+ min_duration_off=Uniform(0.0, 1.0),
151
+ )
152
+
153
+ self._embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(
154
+ self.embedding,
155
+ args=args,
156
+ emb_path=emb_path
157
+ )
158
+ self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix")
159
+
160
+ metric = self._embedding.metric
161
+ Klustering = Clustering[clustering]
162
+
163
+ self.clustering = Klustering.value(metric=metric)
164
+
165
+
166
+ def get_segmentations(self, file, hook=None) -> SlidingWindowFeature:
167
+ """Apply segmentation model
168
+
169
+ Parameter
170
+ ---------
171
+ file : AudioFile
172
+ hook : Optional[Callable]
173
+
174
+ Returns
175
+ -------
176
+ segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
177
+ """
178
+
179
+ if hook is not None:
180
+ hook = functools.partial(hook, "segmentation", None)
181
+ segmentations: SlidingWindowFeature = self._segmentation(file, hook=hook)
182
+
183
+ return segmentations
184
+
185
+ def get_embeddings(
186
+ self,
187
+ file,
188
+ binary_segmentations: SlidingWindowFeature,
189
+ exclude_overlap: bool = False,
190
+ hook: Optional[Callable] = None,
191
+ ):
192
+ """Extract embeddings for each (chunk, speaker) pair
193
+
194
+ Parameters
195
+ ----------
196
+ file : AudioFile
197
+ binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
198
+ Binarized segmentation.
199
+ exclude_overlap : bool, optional
200
+ Exclude overlapping speech regions when extracting embeddings.
201
+ In case non-overlapping speech is too short, use the whole speech.
202
+ hook: Optional[Callable]
203
+ Called during embeddings after every batch to report the progress
204
+
205
+ Returns
206
+ -------
207
+ embeddings : (num_chunks, num_speakers, dimension) array
208
+ """
209
+
210
+ # when optimizing the hyper-parameters of this pipeline with frozen
211
+ # "segmentation.threshold", one can reuse the embeddings from the first trial,
212
+ # bringing a massive speed up to the optimization process (and hence allowing to use
213
+ # a larger search space).
214
+
215
+ duration = binary_segmentations.sliding_window.duration
216
+ num_chunks, num_frames, num_speakers = binary_segmentations.data.shape
217
+
218
+ if exclude_overlap:
219
+
220
+ # minimum number of samples needed to extract an embedding
221
+ # (a lower number of samples would result in an error)
222
+ min_num_samples = self._embedding.min_num_samples
223
+
224
+ # corresponding minimum number of frames
225
+ num_samples = duration * self._embedding.sample_rate
226
+ min_num_frames = math.ceil(num_frames * min_num_samples / num_samples)
227
+
228
+ # zero-out frames with overlapping speech
229
+ clean_frames = 1.0 * (
230
+ np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2
231
+ )
232
+ clean_segmentations = SlidingWindowFeature(
233
+ binary_segmentations.data * clean_frames,
234
+ binary_segmentations.sliding_window,
235
+ )
236
+
237
+ else:
238
+ min_num_frames = -1
239
+ clean_segmentations = SlidingWindowFeature(
240
+ binary_segmentations.data, binary_segmentations.sliding_window
241
+ )
242
+
243
+ def iter_waveform_and_mask():
244
+ for (chunk, masks), (_, clean_masks) in zip(binary_segmentations, clean_segmentations):
245
+ # chunk: Segment(t, t + duration)
246
+ # masks: (num_frames, local_num_speakers) np.ndarray
247
+
248
+ waveform, _ = self._audio.crop(
249
+ file,
250
+ chunk,
251
+ duration=duration,
252
+ mode="pad",
253
+ )
254
+ # waveform: (1, num_samples) torch.Tensor
255
+
256
+ # mask may contain NaN (in case of partial stitching)
257
+ masks = np.nan_to_num(masks, nan=0.0).astype(np.float32)
258
+ clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32)
259
+
260
+ for mask, clean_mask in zip(masks.T, clean_masks.T):
261
+ # mask: (num_frames, ) np.ndarray
262
+
263
+ if np.sum(clean_mask) > min_num_frames:
264
+ used_mask = clean_mask
265
+ else:
266
+ used_mask = mask
267
+
268
+ # yield waveform[None], torch.from_numpy(used_mask)[None]
269
+ yield waveform[None], used_mask[None]
270
+
271
+ # w: (1, 1, num_samples) torch.Tensor
272
+ # m: (1, num_frames) torch.Tensor
273
+
274
+ batches = batchify(
275
+ iter_waveform_and_mask(),
276
+ batch_size=self.embedding_batch_size,
277
+ fillvalue=(None, None),
278
+ )
279
+
280
+
281
+ batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size)
282
+
283
+ embedding_batches = []
284
+
285
+ if hook is not None:
286
+ hook("embeddings", None, total=batch_count, completed=0)
287
+
288
+ for i, batch in enumerate(batches, 1):
289
+ waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch))
290
+
291
+ waveform_batch = np.vstack(waveforms)
292
+ # (batch_size, 1, num_samples) torch.Tensor
293
+
294
+ mask_batch = np.vstack(masks)
295
+ # (batch_size, num_frames) torch.Tensor
296
+
297
+ embedding_batch: np.ndarray = self._embedding(
298
+ waveform_batch, masks=mask_batch
299
+ )
300
+ # (batch_size, dimension) np.ndarray
301
+
302
+ embedding_batches.append(embedding_batch)
303
+
304
+ if hook is not None:
305
+ hook("embeddings", embedding_batch, total=batch_count, completed=i)
306
+
307
+ embedding_batches = np.vstack(embedding_batches)
308
+ embeddings = embedding_batches.reshape([num_chunks, -1 , embedding_batches.shape[-1]])
309
+
310
+ return embeddings
311
+
312
+ def reconstruct(
313
+ self,
314
+ segmentations: SlidingWindowFeature,
315
+ hard_clusters: np.ndarray,
316
+ count: SlidingWindowFeature,
317
+ ) -> SlidingWindowFeature:
318
+ """Build final discrete diarization out of clustered segmentation
319
+
320
+ Parameters
321
+ ----------
322
+ segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
323
+ Raw speaker segmentation.
324
+ hard_clusters : (num_chunks, num_speakers) array
325
+ Output of clustering step.
326
+ count : (total_num_frames, 1) SlidingWindowFeature
327
+ Instantaneous number of active speakers.
328
+
329
+ Returns
330
+ -------
331
+ discrete_diarization : SlidingWindowFeature
332
+ Discrete (0s and 1s) diarization.
333
+ """
334
+
335
+ num_chunks, num_frames, local_num_speakers = segmentations.data.shape
336
+
337
+ num_clusters = np.max(hard_clusters) + 1
338
+ clustered_segmentations = np.NAN * np.zeros(
339
+ (num_chunks, num_frames, num_clusters)
340
+ )
341
+
342
+ for c, (cluster, (chunk, segmentation)) in enumerate(
343
+ zip(hard_clusters, segmentations)
344
+ ):
345
+ # cluster is (local_num_speakers, )-shaped
346
+ # segmentation is (num_frames, local_num_speakers)-shaped
347
+ for k in np.unique(cluster):
348
+ if k == -2:
349
+ continue
350
+
351
+ # TODO: can we do better than this max here?
352
+ clustered_segmentations[c, :, k] = np.max(
353
+ segmentation[:, cluster == k], axis=1
354
+ )
355
+
356
+ clustered_segmentations = SlidingWindowFeature(
357
+ clustered_segmentations, segmentations.sliding_window
358
+ )
359
+
360
+ return self.to_diarization(clustered_segmentations, count)
361
+
362
+ def apply(
363
+ self,
364
+ file: AudioFile,
365
+ num_speakers: int = None,
366
+ min_speakers: int = None,
367
+ max_speakers: int = None,
368
+ return_embeddings: bool = False,
369
+ hook: Optional[Callable] = None,
370
+ ) -> Annotation:
371
+ """Apply speaker diarization
372
+
373
+ Parameters
374
+ ----------
375
+ file : AudioFile
376
+ Processed file.
377
+ num_speakers : int, optional
378
+ Number of speakers, when known.
379
+ min_speakers : int, optional
380
+ Minimum number of speakers. Has no effect when `num_speakers` is provided.
381
+ max_speakers : int, optional
382
+ Maximum number of speakers. Has no effect when `num_speakers` is provided.
383
+ return_embeddings : bool, optional
384
+ Return representative speaker embeddings.
385
+ hook : callable, optional
386
+ Callback called after each major steps of the pipeline as follows:
387
+ hook(step_name, # human-readable name of current step
388
+ step_artefact, # artifact generated by current step
389
+ file=file) # file being processed
390
+ Time-consuming steps call `hook` multiple times with the same `step_name`
391
+ and additional `completed` and `total` keyword arguments usable to track
392
+ progress of current step.
393
+
394
+ Returns
395
+ -------
396
+ diarization : Annotation
397
+ Speaker diarization
398
+ embeddings : np.array, optional
399
+ Representative speaker embeddings such that `embeddings[i]` is the
400
+ speaker embedding for i-th speaker in diarization.labels().
401
+ Only returned when `return_embeddings` is True.
402
+ """
403
+
404
+ # setup hook (e.g. for debugging purposes)
405
+ hook = self.setup_hook(file, hook=hook)
406
+
407
+ num_speakers, min_speakers, max_speakers = self.set_num_speakers(
408
+ num_speakers=num_speakers,
409
+ min_speakers=min_speakers,
410
+ max_speakers=max_speakers,
411
+ )
412
+
413
+ segmentations = self.get_segmentations(file, hook=hook)
414
+ hook("segmentation", segmentations)
415
+ # shape: (num_chunks, num_frames, local_num_speakers)
416
+
417
+ # binarize segmentation
418
+
419
+ binarized_segmentations = segmentations
420
+ # estimate frame-level number of instantaneous speakers
421
+ count = self.speaker_count(
422
+ binarized_segmentations,
423
+ frames=self._frames,
424
+ warm_up=(0.0, 0.0),
425
+ )
426
+ hook("speaker_counting", count)
427
+ # shape: (num_frames, 1)
428
+ # dtype: int
429
+
430
+ # exit early when no speaker is ever active
431
+ if np.nanmax(count.data) == 0.0:
432
+ diarization = Annotation(uri=file["uri"])
433
+ if return_embeddings:
434
+ return diarization, np.zeros((0, self._embedding.dimension))
435
+
436
+ return diarization
437
+
438
+ embeddings = self.get_embeddings(
439
+ file,
440
+ binarized_segmentations,
441
+ exclude_overlap=self.embedding_exclude_overlap,
442
+ hook=hook,
443
+ )
444
+
445
+ hook("embeddings", embeddings)
446
+ # shape: (num_chunks, local_num_speakers, dimension)
447
+
448
+ hard_clusters, _, centroids = self.clustering(
449
+ embeddings=embeddings,
450
+ segmentations=binarized_segmentations,
451
+ num_clusters=num_speakers,
452
+ min_clusters=min_speakers,
453
+ max_clusters=max_speakers,
454
+ file=file, # <== for oracle clustering
455
+ frames=self._frames, # <== for oracle clustering
456
+ )
457
+ # hard_clusters: (num_chunks, num_speakers)
458
+ # centroids: (num_speakers, dimension)
459
+
460
+ # number of detected clusters is the number of different speakers
461
+ num_different_speakers = np.max(hard_clusters) + 1
462
+
463
+ # detected number of speakers can still be out of bounds
464
+ # (specifically, lower than `min_speakers`), since there could be too few embeddings
465
+ # to make enough clusters with a given minimum cluster size.
466
+ if num_different_speakers < min_speakers or num_different_speakers > max_speakers:
467
+ warnings.warn(textwrap.dedent(
468
+ f"""
469
+ The detected number of speakers ({num_different_speakers}) is outside
470
+ the given bounds [{min_speakers}, {max_speakers}]. This can happen if the
471
+ given audio file is too short to contain {min_speakers} or more speakers.
472
+ Try to lower the desired minimal number of speakers.
473
+ """
474
+ ))
475
+
476
+ # during counting, we could possibly overcount the number of instantaneous
477
+ # speakers due to segmentation errors, so we cap the maximum instantaneous number
478
+ # of speakers by the `max_speakers` value
479
+ count.data = np.minimum(count.data, max_speakers).astype(np.int8)
480
+
481
+ # reconstruct discrete diarization from raw hard clusters
482
+
483
+ # keep track of inactive speakers
484
+ inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0
485
+ # shape: (num_chunks, num_speakers)
486
+
487
+ hard_clusters[inactive_speakers] = -2
488
+ discrete_diarization = self.reconstruct(
489
+ segmentations,
490
+ hard_clusters,
491
+ count,
492
+ )
493
+ hook("discrete_diarization", discrete_diarization)
494
+
495
+ # convert to continuous diarization
496
+ diarization = self.to_annotation(
497
+ discrete_diarization,
498
+ min_duration_on=0.0,
499
+ min_duration_off=self.segmentation.min_duration_off,
500
+ )
501
+ diarization.uri = file["uri"]
502
+
503
+ # at this point, `diarization` speaker labels are integers
504
+ # from 0 to `num_speakers - 1`, aligned with `centroids` rows.
505
+
506
+ if "annotation" in file and file["annotation"]:
507
+ # when reference is available, use it to map hypothesized speakers
508
+ # to reference speakers (this makes later error analysis easier
509
+ # but does not modify the actual output of the diarization pipeline)
510
+ _, mapping = self.optimal_mapping(
511
+ file["annotation"], diarization, return_mapping=True
512
+ )
513
+
514
+ # in case there are more speakers in the hypothesis than in
515
+ # the reference, those extra speakers are missing from `mapping`.
516
+ # we add them back here
517
+ mapping = {key: mapping.get(key, key) for key in diarization.labels()}
518
+
519
+ else:
520
+ # when reference is not available, rename hypothesized speakers
521
+ # to human-readable SPEAKER_00, SPEAKER_01, ...
522
+ mapping = {
523
+ label: expected_label
524
+ for label, expected_label in zip(diarization.labels(), self.classes())
525
+ }
526
+
527
+ diarization = diarization.rename_labels(mapping=mapping)
528
+ # at this point, `diarization` speaker labels are strings (or mix of
529
+ # strings and integers when reference is available and some hypothesis
530
+ # speakers are not present in the reference)
531
+ if not return_embeddings:
532
+ return diarization
533
+
534
+ # this can happen when we use OracleClustering
535
+ if centroids is None:
536
+ return diarization, None
537
+
538
+ # The number of centroids may be smaller than the number of speakers
539
+ # in the annotation. This can happen if the number of active speakers
540
+ # obtained from `speaker_count` for some frames is larger than the number
541
+ # of clusters obtained from `clustering`. In this case, we append zero embeddings
542
+ # for extra speakers
543
+ if len(diarization.labels()) > centroids.shape[0]:
544
+ centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)))
545
+
546
+ # re-order centroids so that they match
547
+ # the order given by diarization.labels()
548
+ inverse_mapping = {label: index for index, label in mapping.items()}
549
+ centroids = centroids[
550
+ [inverse_mapping[label] for label in diarization.labels()]
551
+ ]
552
+
553
+ return diarization, centroids
ailia-models/code/pyannote_audio_utils/audio/pipelines/speaker_verification.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2021 CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+
24
+ from functools import cached_property
25
+ from typing import Optional, Text, Union, Mapping
26
+
27
+ import numpy as np
28
+ import ailia
29
+
30
+ from pyannote_audio_utils.audio.pipelines.utils.kaldifeat import compute_fbank_feats
31
+ from pyannote_audio_utils.audio.core.inference import BaseInference
32
+
33
+ PipelineModel = Union[Text, Mapping]
34
+
35
+ class ONNXWeSpeakerPretrainedSpeakerEmbedding(BaseInference):
36
+ """Pretrained WeSpeaker speaker embedding
37
+
38
+ Parameters
39
+ ----------
40
+ embedding : str
41
+ Path to WeSpeaker pretrained speaker embedding
42
+ device : torch.device, optional
43
+ Device
44
+
45
+ Usage
46
+ -----
47
+ >>> get_embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM")
48
+ >>> assert waveforms.ndim == 3
49
+ >>> batch_size, num_channels, num_samples = waveforms.shape
50
+ >>> assert num_channels == 1
51
+ >>> embeddings = get_embedding(waveforms)
52
+ >>> assert embeddings.ndim == 2
53
+ >>> assert embeddings.shape[0] == batch_size
54
+
55
+ >>> assert binary_masks.ndim == 1
56
+ >>> assert binary_masks.shape[0] == batch_size
57
+ >>> embeddings = get_embedding(waveforms, masks=binary_masks)
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ embedding: Text = "hbredin/wespeaker-voxceleb-resnet34-LM",
63
+ # device: Optional[torch.device] = None,
64
+ args = None,
65
+ emb_path = None
66
+ ):
67
+ # if not ONNX_IS_AVAILABLE:
68
+ # raise ImportError(
69
+ # f"'onnxruntime' must be installed to use '{embedding}' embeddings."
70
+ # )
71
+
72
+ super().__init__()
73
+
74
+ # if not Path(embedding).exists():
75
+ # try:
76
+ # embedding = hf_hub_download(
77
+ # repo_id=embedding,
78
+ # filename="speaker-embedding.onnx",
79
+ # )
80
+ # except RepositoryNotFoundError:
81
+ # raise ValueError(
82
+ # f"Could not find '{embedding}' on huggingface.co nor on local disk."
83
+ # )
84
+
85
+ self.embedding = embedding
86
+
87
+ if args.onnx:
88
+ import onnxruntime as ort
89
+ #print("use onnx runtime")
90
+ providers = ["CPUExecutionProvider", ("CUDAExecutionProvider",{"cudnn_conv_algo_search": "DEFAULT"})]
91
+
92
+ sess_options = ort.SessionOptions()
93
+ sess_options.inter_op_num_threads = 1
94
+ sess_options.intra_op_num_threads = 1
95
+ self.session_ = ort.InferenceSession(
96
+ embedding, sess_options=sess_options, providers=providers
97
+ )
98
+ else:
99
+ #print("use ailia")
100
+
101
+ self.session_ = ailia.Net(emb_path, weight=embedding, env_id=args.env_id)
102
+
103
+ self.args = args
104
+
105
+ @cached_property
106
+ def sample_rate(self) -> int:
107
+ return 16000
108
+
109
+ @cached_property
110
+ def dimension(self) -> int:
111
+ dummy_waveforms = np.random.rand(1, 1, 16000)
112
+ features = self.compute_fbank(dummy_waveforms)
113
+
114
+ if self.args.onnx:
115
+ embeddings = self.session_.run(output_names=["embs"], input_feed={"feats": features}
116
+ )[0]
117
+ else:
118
+ embeddings = self.session_.predict([features])[0]
119
+
120
+ _, dimension = embeddings.shape
121
+ return dimension
122
+
123
+ @cached_property
124
+ def metric(self) -> str:
125
+ return "cosine"
126
+
127
+ @cached_property
128
+ def min_num_samples(self) -> int:
129
+ lower, upper = 2, round(0.5 * self.sample_rate)
130
+ middle = (lower + upper) // 2
131
+ while lower + 1 < upper:
132
+ try:
133
+ features = self.compute_fbank(np.random.randn(1, 1, middle))
134
+
135
+ except AssertionError:
136
+ lower = middle
137
+ middle = (lower + upper) // 2
138
+ continue
139
+
140
+ if self.args.onnx:
141
+ embeddings = self.session_.run(output_names=["embs"], input_feed={"feats": features})[0]
142
+ else:
143
+ embeddings = self.session_.predict([features])[0]
144
+
145
+ if np.any(np.isnan(embeddings)):
146
+ lower = middle
147
+ else:
148
+ upper = middle
149
+ middle = (lower + upper) // 2
150
+
151
+ return upper
152
+
153
+ @cached_property
154
+ def min_num_frames(self) -> int:
155
+ return self.compute_fbank(np.random.randn(1, 1, self.min_num_samples)).shape[1]
156
+
157
+ def compute_fbank(
158
+ self,
159
+ waveforms: np.ndarray,
160
+ num_mel_bins: int = 80,
161
+ frame_length: int = 25,
162
+ frame_shift: int = 10,
163
+ dither: float = 0.0,
164
+ ) -> np.ndarray:
165
+ """Extract fbank features
166
+
167
+ Parameters
168
+ ----------
169
+ waveforms : (batch_size, num_channels, num_samples)
170
+
171
+ Returns
172
+ -------
173
+ fbank : (batch_size, num_frames, num_mel_bins)
174
+
175
+ Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50
176
+ """
177
+
178
+ waveforms = waveforms * (1 << 15)
179
+
180
+ ### ここで少しずれる ###
181
+ features_numpy = np.stack([compute_fbank_feats(
182
+ waveform=waveform[0],
183
+ num_mel_bins=num_mel_bins,
184
+ frame_length=frame_length,
185
+ frame_shift=frame_shift,
186
+ dither=dither,
187
+ sample_frequency=self.sample_rate,
188
+ window_type="hamming",
189
+ use_energy=False,
190
+ )for waveform in waveforms])
191
+ ### ここで少しずれる ###
192
+
193
+ features = features_numpy.astype(np.float32)
194
+
195
+ return features - np.mean(features, axis=1, keepdims=True)
196
+
197
+ def __call__(
198
+ self, waveforms: np.ndarray, masks: Optional[np.ndarray] = None
199
+ ) -> np.ndarray:
200
+ """
201
+
202
+ Parameters
203
+ ----------
204
+ waveforms : (batch_size, num_channels, num_samples)
205
+ Only num_channels == 1 is supported.
206
+ masks : (batch_size, num_samples), optional
207
+
208
+ Returns
209
+ -------
210
+ embeddings : (batch_size, dimension)
211
+
212
+ """
213
+
214
+ batch_size, num_channels, num_samples = waveforms.shape
215
+ assert num_channels == 1
216
+
217
+ features = self.compute_fbank(waveforms)
218
+ _, num_frames, _ = features.shape
219
+
220
+ batch_size_masks, _ = masks.shape
221
+ assert batch_size == batch_size_masks
222
+
223
+ def interpolate_numpy(input_array, size):
224
+ output_array = np.zeros((input_array.shape[0],size))
225
+
226
+ for i in range(output_array.shape[0]):
227
+ for j in range(output_array.shape[1]):
228
+ ii = int(np.floor(i * input_array.shape[0] / output_array.shape[0]))
229
+ jj = int(np.floor(j * input_array.shape[1] / output_array.shape[1]))
230
+ output_array[i, j] = input_array[ii, jj]
231
+ return output_array
232
+
233
+ imasks = interpolate_numpy(masks,size=num_frames)
234
+ imasks = imasks > 0.5
235
+
236
+ embeddings = np.NAN * np.zeros((batch_size, self.dimension))
237
+
238
+ for f, (feature, imask) in enumerate(zip(features, imasks)):
239
+ masked_feature = feature[imask]
240
+ if masked_feature.shape[0] < self.min_num_frames:
241
+ continue
242
+
243
+ if self.args.onnx:
244
+ embeddings[f] = self.session_.run(output_names=["embs"],input_feed={"feats": masked_feature[None]},)[0][0]
245
+ else:
246
+ embeddings[f] = self.session_.predict([masked_feature[None]])[0][0]
247
+
248
+ return embeddings
249
+
ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2022- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ from .diarization import SpeakerDiarizationMixin
24
+ # from .getter import (
25
+ # PipelineAugmentation,
26
+ # PipelineInference,
27
+ # PipelineModel,
28
+ # get_augmentation,
29
+ # get_devices,
30
+ # get_inference,
31
+ # get_model,
32
+ # )
33
+ # from .oracle import oracle_segmentation
34
+
35
+ __all__ = [
36
+ "SpeakerDiarizationMixin",
37
+ ]
ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/diarization.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2022- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ from typing import Dict, Mapping, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ from pyannote_audio_utils.core import Annotation, SlidingWindow, SlidingWindowFeature
27
+ from pyannote_audio_utils.core.utils.types import Label
28
+ from pyannote_audio_utils.metrics.diarization import DiarizationErrorRate
29
+
30
+ from pyannote_audio_utils.audio.core.inference import Inference
31
+ from pyannote_audio_utils.audio.utils.signal import Binarize
32
+
33
+
34
+ # TODO: move to dedicated module
35
+ class SpeakerDiarizationMixin:
36
+ """Defines a bunch of methods common to speaker diarization pipelines"""
37
+
38
+ @staticmethod
39
+ def set_num_speakers(
40
+ num_speakers: Optional[int] = None,
41
+ min_speakers: Optional[int] = None,
42
+ max_speakers: Optional[int] = None,
43
+ ):
44
+ """Validate number of speakers
45
+
46
+ Parameters
47
+ ----------
48
+ num_speakers : int, optional
49
+ Number of speakers.
50
+ min_speakers : int, optional
51
+ Minimum number of speakers.
52
+ max_speakers : int, optional
53
+ Maximum number of speakers.
54
+
55
+ Returns
56
+ -------
57
+ num_speakers : int or None
58
+ min_speakers : int
59
+ max_speakers : int or np.inf
60
+ """
61
+
62
+ # override {min|max}_num_speakers by num_speakers when available
63
+ min_speakers = num_speakers or min_speakers or 1
64
+ max_speakers = num_speakers or max_speakers or np.inf
65
+
66
+ if min_speakers > max_speakers:
67
+ raise ValueError(
68
+ f"min_speakers must be smaller than (or equal to) max_speakers "
69
+ f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})."
70
+ )
71
+ if min_speakers == max_speakers:
72
+ num_speakers = min_speakers
73
+
74
+ return num_speakers, min_speakers, max_speakers
75
+
76
+ @staticmethod
77
+ def optimal_mapping(
78
+ reference: Union[Mapping, Annotation],
79
+ hypothesis: Annotation,
80
+ return_mapping: bool = False,
81
+ ) -> Union[Annotation, Tuple[Annotation, Dict[Label, Label]]]:
82
+ """Find the optimal bijective mapping between reference and hypothesis labels
83
+
84
+ Parameters
85
+ ----------
86
+ reference : Annotation or Mapping
87
+ Reference annotation. Can be an Annotation instance or
88
+ a mapping with an "annotation" key.
89
+ hypothesis : Annotation
90
+ Hypothesized annotation.
91
+ return_mapping : bool, optional
92
+ Return the label mapping itself along with the mapped annotation. Defaults to False.
93
+
94
+ Returns
95
+ -------
96
+ mapped : Annotation
97
+ Hypothesis mapped to reference speakers.
98
+ mapping : dict, optional
99
+ Mapping between hypothesis (key) and reference (value) labels
100
+ Only returned if `return_mapping` is True.
101
+ """
102
+
103
+ if isinstance(reference, Mapping):
104
+ reference = reference["annotation"]
105
+ annotated = reference["annotated"] if "annotated" in reference else None
106
+ else:
107
+ annotated = None
108
+
109
+ mapping = DiarizationErrorRate().optimal_mapping(
110
+ reference, hypothesis, uem=annotated
111
+ )
112
+ mapped_hypothesis = hypothesis.rename_labels(mapping=mapping)
113
+
114
+ if return_mapping:
115
+ return mapped_hypothesis, mapping
116
+
117
+ else:
118
+ return mapped_hypothesis
119
+
120
+ # TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count)
121
+ @staticmethod
122
+ def speaker_count(
123
+ binarized_segmentations: SlidingWindowFeature,
124
+ frames: SlidingWindow,
125
+ warm_up: Tuple[float, float] = (0.1, 0.1),
126
+ ) -> SlidingWindowFeature:
127
+ """Estimate frame-level number of instantaneous speakers
128
+
129
+ Parameters
130
+ ----------
131
+ binarized_segmentations : SlidingWindowFeature
132
+ (num_chunks, num_frames, num_classes)-shaped binarized scores.
133
+ warm_up : (float, float) tuple, optional
134
+ Left/right warm up ratio of chunk duration.
135
+ Defaults to (0.1, 0.1), i.e. 10% on both sides.
136
+ frames : SlidingWindow
137
+ Frames resolution. Defaults to estimate it automatically based on
138
+ `segmentations` shape and chunk size. Providing the exact frame
139
+ resolution (when known) leads to better temporal precision.
140
+
141
+ Returns
142
+ -------
143
+ count : SlidingWindowFeature
144
+ (num_frames, 1)-shaped instantaneous speaker count
145
+ """
146
+
147
+ trimmed = Inference.trim(binarized_segmentations, warm_up=warm_up)
148
+
149
+ count = Inference.aggregate(
150
+ np.sum(trimmed, axis=-1, keepdims=True),
151
+ frames,
152
+ hamming=False,
153
+ missing=0.0,
154
+ skip_average=False,
155
+ )
156
+
157
+ count.data = np.rint(count.data).astype(np.uint8)
158
+
159
+ return count
160
+
161
+ @staticmethod
162
+ def to_annotation(
163
+ discrete_diarization: SlidingWindowFeature,
164
+ min_duration_on: float = 0.0,
165
+ min_duration_off: float = 0.0,
166
+ ) -> Annotation:
167
+ """
168
+
169
+ Parameters
170
+ ----------
171
+ discrete_diarization : SlidingWindowFeature
172
+ (num_frames, num_speakers)-shaped discrete diarization
173
+ min_duration_on : float, optional
174
+ Defaults to 0.
175
+ min_duration_off : float, optional
176
+ Defaults to 0.
177
+
178
+ Returns
179
+ -------
180
+ continuous_diarization : Annotation
181
+ Continuous diarization, with speaker labels as integers,
182
+ corresponding to the speaker indices in the discrete diarization.
183
+ """
184
+
185
+ binarize = Binarize(
186
+ onset=0.5,
187
+ offset=0.5,
188
+ min_duration_on=min_duration_on,
189
+ min_duration_off=min_duration_off,
190
+ )
191
+
192
+ return binarize(discrete_diarization).rename_tracks(generator="string")
193
+
194
+ @staticmethod
195
+ def to_diarization(
196
+ segmentations: SlidingWindowFeature,
197
+ count: SlidingWindowFeature,
198
+ ) -> SlidingWindowFeature:
199
+ """Build diarization out of preprocessed segmentation and precomputed speaker count
200
+
201
+ Parameters
202
+ ----------
203
+ segmentations : SlidingWindowFeature
204
+ (num_chunks, num_frames, num_speakers)-shaped segmentations
205
+ count : SlidingWindow_feature
206
+ (num_frames, 1)-shaped speaker count
207
+
208
+ Returns
209
+ -------
210
+ discrete_diarization : SlidingWindowFeature
211
+ Discrete (0s and 1s) diarization.
212
+ """
213
+
214
+ # TODO: investigate alternative aggregation
215
+ activations = Inference.aggregate(
216
+ segmentations,
217
+ count.sliding_window,
218
+ hamming=False,
219
+ missing=0.0,
220
+ skip_average=True,
221
+ )
222
+ # shape is (num_frames, num_speakers)
223
+
224
+ _, num_speakers = activations.data.shape
225
+ max_speakers_per_frame = np.max(count.data)
226
+ if num_speakers < max_speakers_per_frame:
227
+ activations.data = np.pad(
228
+ activations.data, ((0, 0), (0, max_speakers_per_frame - num_speakers))
229
+ )
230
+
231
+ extent = activations.extent & count.extent
232
+ activations = activations.crop(extent, return_data=False)
233
+ count = count.crop(extent, return_data=False)
234
+
235
+ sorted_speakers = np.argsort(-activations, axis=-1)
236
+ binary = np.zeros_like(activations.data)
237
+
238
+ for t, ((_, c), speakers) in enumerate(zip(count, sorted_speakers)):
239
+ for i in range(c.item()):
240
+ binary[t, speakers[i]] = 1.0
241
+
242
+ return SlidingWindowFeature(binary, activations.sliding_window)
243
+
244
+ def classes(self):
245
+ speaker = 0
246
+ while True:
247
+ yield f"SPEAKER_{speaker:02d}"
248
+ speaker += 1
ailia-models/code/pyannote_audio_utils/audio/pipelines/utils/kaldifeat.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yuyq96/kaldifeat
2
+
3
+ import numpy as np
4
+
5
+
6
+ # ---------- feature-window ----------
7
+
8
+ def sliding_window(x, window_size, window_shift):
9
+ shape = x.shape[:-1] + (x.shape[-1] - window_size + 1, window_size)
10
+ strides = x.strides + (x.strides[-1],)
11
+ return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)[::window_shift]
12
+
13
+
14
+ def func_num_frames(num_samples, window_size, window_shift, snip_edges):
15
+ if snip_edges:
16
+ if num_samples < window_size:
17
+ return 0
18
+ else:
19
+ return 1 + ((num_samples - window_size) // window_shift)
20
+ else:
21
+ return (num_samples + (window_shift // 2)) // window_shift
22
+
23
+
24
+ def func_dither(waveform, dither_value):
25
+ if dither_value == 0.0:
26
+ return waveform
27
+ waveform += np.random.normal(size=waveform.shape).astype(waveform.dtype) * dither_value
28
+ return waveform
29
+
30
+
31
+ def func_remove_dc_offset(waveform):
32
+ return waveform - np.mean(waveform)
33
+
34
+
35
+ def func_log_energy(waveform):
36
+ return np.log(np.dot(waveform, waveform).clip(min=np.finfo(waveform.dtype).eps))
37
+
38
+
39
+ def func_preemphasis(waveform, preemph_coeff):
40
+ if preemph_coeff == 0.0:
41
+ return waveform
42
+ assert 0 < preemph_coeff <= 1
43
+ waveform[1:] -= preemph_coeff * waveform[:-1]
44
+ waveform[0] -= preemph_coeff * waveform[0]
45
+ return waveform
46
+
47
+
48
+ def sine(M):
49
+ if M < 1:
50
+ return np.array([])
51
+ if M == 1:
52
+ return np.ones(1, float)
53
+ n = np.arange(0, M)
54
+ return np.sin(np.pi * n / (M - 1))
55
+
56
+
57
+ def povey(M):
58
+ if M < 1:
59
+ return np.array([])
60
+ if M == 1:
61
+ return np.ones(1, float)
62
+ n = np.arange(0, M)
63
+ return (0.5 - 0.5 * np.cos(2.0 * np.pi * n / (M - 1))) ** 0.85
64
+
65
+
66
+ def feature_window_function(window_type, window_size, blackman_coeff):
67
+ assert window_size > 0
68
+ if window_type == 'hanning':
69
+ return np.hanning(window_size)
70
+ elif window_type == 'sine':
71
+ return sine(window_size)
72
+ elif window_type == 'hamming':
73
+ return np.hamming(window_size)
74
+ elif window_type == 'povey':
75
+ return povey(window_size)
76
+ elif window_type == 'rectangular':
77
+ return np.ones(window_size)
78
+ elif window_type == 'blackman':
79
+ window_func = np.blackman(window_size)
80
+ if blackman_coeff == 0.42:
81
+ return window_func
82
+ else:
83
+ return window_func - 0.42 + blackman_coeff
84
+ else:
85
+ raise ValueError('Invalid window type {}'.format(window_type))
86
+
87
+
88
+ def process_window(window, dither, remove_dc_offset, preemphasis_coefficient, window_function, raw_energy):
89
+ if dither != 0.0:
90
+ window = func_dither(window, dither)
91
+ if remove_dc_offset:
92
+ window = func_remove_dc_offset(window)
93
+ if raw_energy:
94
+ log_energy = func_log_energy(window)
95
+ if preemphasis_coefficient != 0.0:
96
+ window = func_preemphasis(window, preemphasis_coefficient)
97
+ window *= window_function
98
+ if not raw_energy:
99
+ log_energy = func_log_energy(window)
100
+ return window, log_energy
101
+
102
+
103
+ def extract_window(waveform, blackman_coeff, dither, window_size, window_shift,
104
+ preemphasis_coefficient, raw_energy, remove_dc_offset,
105
+ snip_edges, window_type, dtype):
106
+ num_samples = len(waveform)
107
+ num_frames = func_num_frames(num_samples, window_size, window_shift, snip_edges)
108
+ num_samples_ = (num_frames - 1) * window_shift + window_size
109
+ if snip_edges:
110
+ waveform = waveform[:num_samples_]
111
+ else:
112
+ offset = window_shift // 2 - window_size // 2
113
+ waveform = np.concatenate([
114
+ waveform[-offset - 1::-1],
115
+ waveform,
116
+ waveform[:-(offset + num_samples_ - num_samples + 1):-1]
117
+ ])
118
+ frames = sliding_window(waveform, window_size=window_size, window_shift=window_shift)
119
+ frames = frames.astype(dtype)
120
+ log_enery = np.empty(frames.shape[0], dtype=dtype)
121
+ for i in range(frames.shape[0]):
122
+ frames[i], log_enery[i] = process_window(
123
+ window=frames[i],
124
+ dither=dither,
125
+ remove_dc_offset=remove_dc_offset,
126
+ preemphasis_coefficient=preemphasis_coefficient,
127
+ window_function=feature_window_function(
128
+ window_type=window_type,
129
+ window_size=window_size,
130
+ blackman_coeff=blackman_coeff
131
+ ).astype(dtype),
132
+ raw_energy=raw_energy
133
+ )
134
+ return frames, log_enery
135
+
136
+
137
+ # ---------- feature-window ----------
138
+
139
+
140
+ # ---------- feature-functions ----------
141
+
142
+ def compute_spectrum(frames, n):
143
+ complex_spec = np.fft.rfft(frames, n)
144
+ return np.absolute(complex_spec)
145
+
146
+
147
+ def compute_power_spectrum(frames, n):
148
+ return np.square(compute_spectrum(frames, n))
149
+
150
+
151
+ # ---------- feature-functions ----------
152
+
153
+
154
+ # ---------- mel-computations ----------
155
+
156
+
157
+ def mel_scale(freq):
158
+ return 1127.0 * np.log(1.0 + freq / 700.0)
159
+
160
+
161
+ def compute_mel_banks(num_bins, sample_frequency, low_freq, high_freq, n):
162
+ """ Compute Mel banks.
163
+
164
+ :param num_bins: Number of triangular mel-frequency bins
165
+ :param sample_frequency: Waveform data sample frequency
166
+ :param low_freq: Low cutoff frequency for mel bins
167
+ :param high_freq: High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
168
+ :param n: Window size
169
+ :return: Mel banks.
170
+ """
171
+ assert num_bins >= 3, 'Must have at least 3 mel bins'
172
+ num_fft_bins = n // 2
173
+
174
+ nyquist = 0.5 * sample_frequency
175
+ if high_freq <= 0:
176
+ high_freq = nyquist + high_freq
177
+ assert 0 <= low_freq < high_freq <= nyquist
178
+
179
+ fft_bin_width = sample_frequency / n
180
+
181
+ mel_low_freq = mel_scale(low_freq)
182
+ mel_high_freq = mel_scale(high_freq)
183
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
184
+
185
+ mel_banks = np.zeros([num_bins, num_fft_bins + 1])
186
+ for i in range(num_bins):
187
+ left_mel = mel_low_freq + mel_freq_delta * i
188
+ center_mel = left_mel + mel_freq_delta
189
+ right_mel = center_mel + mel_freq_delta
190
+ for j in range(num_fft_bins):
191
+ mel = mel_scale(fft_bin_width * j)
192
+ if left_mel < mel < right_mel:
193
+ if mel <= center_mel:
194
+ mel_banks[i, j] = (mel - left_mel) / (center_mel - left_mel)
195
+ else:
196
+ mel_banks[i, j] = (right_mel - mel) / (right_mel - center_mel)
197
+ return mel_banks
198
+
199
+
200
+ # ---------- mel-computations ----------
201
+
202
+
203
+ # ---------- compute-fbank-feats ----------
204
+
205
+ def compute_fbank_feats(
206
+ waveform,
207
+ blackman_coeff=0.42,
208
+ dither=1.0,
209
+ energy_floor=0.0,
210
+ frame_length=25,
211
+ frame_shift=10,
212
+ high_freq=0,
213
+ low_freq=20,
214
+ num_mel_bins=23,
215
+ preemphasis_coefficient=0.97,
216
+ raw_energy=True,
217
+ remove_dc_offset=True,
218
+ round_to_power_of_two=True,
219
+ sample_frequency=16000,
220
+ snip_edges=True,
221
+ use_energy=False,
222
+ use_log_fbank=True,
223
+ use_power=True,
224
+ window_type='povey',
225
+ dtype=np.float32):
226
+ """ Compute (log) Mel filter bank energies
227
+
228
+ :param waveform: Input waveform.
229
+ :param blackman_coeff: Constant coefficient for generalized Blackman window. (float, default = 0.42)
230
+ :param dither: Dithering constant (0.0 means no dither). If you turn this off, you should set the --energy-floor option, e.g. to 1.0 or 0.1 (float, default = 1)
231
+ :param energy_floor: Floor on energy (absolute, not relative) in FBANK computation. Only makes a difference if --use-energy=true; only necessary if --dither=0.0. Suggested values: 0.1 or 1.0 (float, default = 0)
232
+ :param frame_length: Frame length in milliseconds (float, default = 25)
233
+ :param frame_shift: Frame shift in milliseconds (float, default = 10)
234
+ :param high_freq: High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (float, default = 0)
235
+ :param low_freq: Low cutoff frequency for mel bins (float, default = 20)
236
+ :param num_mel_bins: Number of triangular mel-frequency bins (int, default = 23)
237
+ :param preemphasis_coefficient: Coefficient for use in signal preemphasis (float, default = 0.97)
238
+ :param raw_energy: If true, compute energy before preemphasis and windowing (bool, default = true)
239
+ :param remove_dc_offset: Subtract mean from waveform on each frame (bool, default = true)
240
+ :param round_to_power_of_two: If true, round window size to power of two by zero-padding input to FFT. (bool, default = true)
241
+ :param sample_frequency: Waveform data sample frequency (must match the waveform file, if specified there) (float, default = 16000)
242
+ :param snip_edges: If true, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame-length. If false, the number of frames depends only on the frame-shift, and we reflect the data at the ends. (bool, default = true)
243
+ :param use_energy: Add an extra energy output. (bool, default = false)
244
+ :param use_log_fbank: If true, produce log-filterbank, else produce linear. (bool, default = true)
245
+ :param use_power: If true, use power, else use magnitude. (bool, default = true)
246
+ :param window_type: Type of window ("hamming"|"hanning"|"povey"|"rectangular"|"sine"|"blackmann") (string, default = "povey")
247
+ :param dtype: Type of array (np.float32|np.float64) (dtype or string, default=np.float32)
248
+ :return: (Log) Mel filter bank energies.
249
+ """
250
+ window_size = int(frame_length * sample_frequency * 0.001)
251
+ window_shift = int(frame_shift * sample_frequency * 0.001)
252
+ frames, log_energy = extract_window(
253
+ waveform=waveform,
254
+ blackman_coeff=blackman_coeff,
255
+ dither=dither,
256
+ window_size=window_size,
257
+ window_shift=window_shift,
258
+ preemphasis_coefficient=preemphasis_coefficient,
259
+ raw_energy=raw_energy,
260
+ remove_dc_offset=remove_dc_offset,
261
+ snip_edges=snip_edges,
262
+ window_type=window_type,
263
+ dtype=dtype
264
+ )
265
+ if round_to_power_of_two:
266
+ n = 1
267
+ while n < window_size:
268
+ n *= 2
269
+ else:
270
+ n = window_size
271
+ if use_power:
272
+ spectrum = compute_power_spectrum(frames, n)
273
+ else:
274
+ spectrum = compute_spectrum(frames, n)
275
+ mel_banks = compute_mel_banks(
276
+ num_bins=num_mel_bins,
277
+ sample_frequency=sample_frequency,
278
+ low_freq=low_freq,
279
+ high_freq=high_freq,
280
+ n=n
281
+ ).astype(dtype)
282
+ feat = np.dot(spectrum, mel_banks.T)
283
+ if use_log_fbank:
284
+ feat = np.log(feat.clip(min=np.finfo(dtype).eps))
285
+ if use_energy:
286
+ if energy_floor > 0.0:
287
+ log_energy.clip(min=np.math.log(energy_floor))
288
+ return feat, log_energy
289
+ return feat
290
+
291
+ # ---------- compute-fbank-feats ----------
ailia-models/code/pyannote_audio_utils/audio/utils/multi_task.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2023- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+
24
+ from typing import Any, Callable, Tuple, Union
25
+
26
+ from pyannote_audio_utils.audio.core.task import Specifications
27
+
28
+
29
+ def map_with_specifications(
30
+ specifications: Union[Specifications, Tuple[Specifications]],
31
+ func: Callable,
32
+ *iterables,
33
+ ) -> Union[Any, Tuple[Any]]:
34
+ """Compute the function using arguments from each of the iterables
35
+
36
+ Returns a tuple if provided `specifications` is a tuple,
37
+ otherwise returns the function return value.
38
+
39
+ Parameters
40
+ ----------
41
+ specifications : (tuple of) Specifications
42
+ Specifications or tuple of specifications
43
+ func : callable
44
+ Function called for each specification with
45
+ `func(*iterables[i], specifications=specifications[i])`
46
+ *iterables :
47
+ List of iterables with same length as `specifications`.
48
+
49
+ Returns
50
+ -------
51
+ output : (tuple of) `func` return value(s)
52
+ """
53
+
54
+ if isinstance(specifications, Specifications):
55
+ return func(*iterables, specifications=specifications)
56
+
57
+ return tuple(
58
+ func(*i, specifications=s) for s, *i in zip(specifications, *iterables)
59
+ )
ailia-models/code/pyannote_audio_utils/audio/utils/powerset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2023- CNRS
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # AUTHORS
24
+ # Hervé BREDIN - https://herve.niderb.fr
25
+ # Alexis PLAQUET
26
+
27
+ from functools import cached_property
28
+ from itertools import combinations
29
+
30
+ import scipy.special
31
+ import numpy as np
32
+
33
+ class Powerset():
34
+ """Powerset to multilabel conversion, and back.
35
+
36
+ Parameters
37
+ ----------
38
+ num_classes : int
39
+ Number of regular classes.
40
+ max_set_size : int
41
+ Maximum number of classes in each set.
42
+ """
43
+
44
+ def __init__(self, num_classes: int, max_set_size: int):
45
+ super().__init__()
46
+ self.num_classes = num_classes
47
+ self.max_set_size = max_set_size
48
+ self.mapping = self.build_mapping()
49
+ self.cardinality = self.build_cardinality()
50
+
51
+
52
+
53
+ @cached_property
54
+ def num_powerset_classes(self) -> int:
55
+ # compute number of subsets of size at most "max_set_size"
56
+ # e.g. with num_classes = 3 and max_set_size = 2:
57
+ # {}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}
58
+ return int(
59
+ sum(
60
+ scipy.special.binom(self.num_classes, i)
61
+ for i in range(0, self.max_set_size + 1)
62
+ )
63
+ )
64
+
65
+ def build_mapping(self) -> np.ndarray:
66
+ """Compute powerset to regular mapping
67
+
68
+ Returns
69
+ -------
70
+ mapping : (num_powerset_classes, num_classes) torch.Tensor
71
+ mapping[i, j] == 1 if jth regular class is a member of ith powerset class
72
+ mapping[i, j] == 0 otherwise
73
+
74
+ Example
75
+ -------
76
+ With num_classes == 3 and max_set_size == 2, returns
77
+
78
+ [0, 0, 0] # none
79
+ [1, 0, 0] # class #1
80
+ [0, 1, 0] # class #2
81
+ [0, 0, 1] # class #3
82
+ [1, 1, 0] # classes #1 and #2
83
+ [1, 0, 1] # classes #1 and #3
84
+ [0, 1, 1] # classes #2 and #3
85
+
86
+ """
87
+ mapping = np.zeros((self.num_powerset_classes, self.num_classes))
88
+
89
+ powerset_k = 0
90
+ for set_size in range(0, self.max_set_size + 1):
91
+ for current_set in combinations(range(self.num_classes), set_size):
92
+ mapping[powerset_k, current_set] = 1
93
+ powerset_k += 1
94
+
95
+ return mapping
96
+
97
+ def build_cardinality(self) -> np.ndarray:
98
+ """Compute size of each powerset class"""
99
+ return np.sum(self.mapping, axis=1)
100
+
101
+ def to_multilabel(self, powerset: np.ndarray, soft: bool = False) -> np.ndarray:
102
+ """Convert predictions from powerset to multi-label
103
+
104
+ Parameter
105
+ ---------
106
+ powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor
107
+ Soft predictions in "powerset" space.
108
+ soft : bool, optional
109
+ Return soft multi-label predictions. Defaults to False (i.e. hard predictions)
110
+ Assumes that `powerset` are "logits" (not "probabilities").
111
+
112
+ Returns
113
+ -------
114
+ multi_label : (batch_size, num_frames, num_classes) torch.Tensor
115
+ Predictions in "multi-label" space.
116
+ """
117
+
118
+ powerset_probs = np.identity(self.num_powerset_classes)[np.argmax(powerset, axis=-1)]
119
+ return np.matmul(powerset_probs, self.mapping)
120
+
121
+
122
+ def __call__(self, powerset: np.ndarray, soft: bool = False) -> np.ndarray:
123
+ """Alias for `to_multilabel`"""
124
+
125
+ return self.to_multilabel(powerset, soft=soft)
ailia-models/code/pyannote_audio_utils/audio/utils/signal.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ #
4
+ # The MIT License (MIT)
5
+ #
6
+ # Copyright (c) 2016-2021 CNRS
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ """
30
+ # Signal processing
31
+ """
32
+
33
+ from functools import singledispatch
34
+ from itertools import zip_longest
35
+ from typing import Optional, Union
36
+
37
+ import numpy as np
38
+ import scipy.signal
39
+ from pyannote_audio_utils.core import Annotation, Segment, SlidingWindowFeature, Timeline
40
+ from pyannote_audio_utils.core.utils.generators import pairwise
41
+
42
+
43
+ @singledispatch
44
+ def binarize(
45
+ scores,
46
+ onset: float = 0.5,
47
+ offset: Optional[float] = None,
48
+ initial_state: Optional[Union[bool, np.ndarray]] = None,
49
+ ):
50
+ """(Batch) hysteresis thresholding
51
+
52
+ Parameters
53
+ ----------
54
+ scores : numpy.ndarray or SlidingWindowFeature
55
+ (num_chunks, num_frames, num_classes)- or (num_frames, num_classes)-shaped scores.
56
+ onset : float, optional
57
+ Onset threshold. Defaults to 0.5.
58
+ offset : float, optional
59
+ Offset threshold. Defaults to `onset`.
60
+ initial_state : np.ndarray or bool, optional
61
+ Initial state.
62
+
63
+ Returns
64
+ -------
65
+ binarized : same as scores
66
+ Binarized scores with same shape and type as scores.
67
+
68
+ Reference
69
+ ---------
70
+ https://stackoverflow.com/questions/23289976/how-to-find-zero-crossings-with-hysteresis
71
+ """
72
+ raise NotImplementedError(
73
+ "scores must be of type numpy.ndarray or SlidingWindowFeatures"
74
+ )
75
+
76
+
77
+ @binarize.register
78
+ def binarize_ndarray(
79
+ scores: np.ndarray,
80
+ onset: float = 0.5,
81
+ offset: Optional[float] = None,
82
+ initial_state: Optional[Union[bool, np.ndarray]] = None,
83
+ ):
84
+ """(Batch) hysteresis thresholding
85
+
86
+ Parameters
87
+ ----------
88
+ scores : numpy.ndarray
89
+ (num_frames, num_classes)-shaped scores.
90
+ onset : float, optional
91
+ Onset threshold. Defaults to 0.5.
92
+ offset : float, optional
93
+ Offset threshold. Defaults to `onset`.
94
+ initial_state : np.ndarray or bool, optional
95
+ Initial state.
96
+
97
+ Returns
98
+ -------
99
+ binarized : same as scores
100
+ Binarized scores with same shape and type as scores.
101
+ """
102
+
103
+ offset = offset or onset
104
+
105
+ batch_size, num_frames = scores.shape
106
+
107
+ scores = np.nan_to_num(scores)
108
+
109
+ if initial_state is None:
110
+ initial_state = scores[:, 0] >= 0.5 * (onset + offset)
111
+
112
+ elif isinstance(initial_state, bool):
113
+ initial_state = initial_state * np.ones((batch_size,), dtype=bool)
114
+
115
+ elif isinstance(initial_state, np.ndarray):
116
+ assert initial_state.shape == (batch_size,)
117
+ assert initial_state.dtype == bool
118
+
119
+ initial_state = np.tile(initial_state, (num_frames, 1)).T
120
+
121
+ on = scores > onset
122
+ off_or_on = (scores < offset) | on
123
+
124
+ # indices of frames for which the on/off state is well-defined
125
+ well_defined_idx = np.array(
126
+ list(zip_longest(*[np.nonzero(oon)[0] for oon in off_or_on], fillvalue=-1))
127
+ ).T
128
+
129
+ # corner case where well_defined_idx is empty
130
+ if not well_defined_idx.size:
131
+ return np.zeros_like(scores, dtype=bool) | initial_state
132
+
133
+ # points to the index of the previous well-defined frame
134
+ same_as = np.cumsum(off_or_on, axis=1)
135
+
136
+ samples = np.tile(np.arange(batch_size), (num_frames, 1)).T
137
+
138
+ return np.where(
139
+ same_as, on[samples, well_defined_idx[samples, same_as - 1]], initial_state
140
+ )
141
+
142
+
143
+ @binarize.register
144
+ def binarize_swf(
145
+ scores: SlidingWindowFeature,
146
+ onset: float = 0.5,
147
+ offset: Optional[float] = None,
148
+ initial_state: Optional[bool] = None,
149
+ ):
150
+ """(Batch) hysteresis thresholding
151
+
152
+ Parameters
153
+ ----------
154
+ scores : SlidingWindowFeature
155
+ (num_chunks, num_frames, num_classes)- or (num_frames, num_classes)-shaped scores.
156
+ onset : float, optional
157
+ Onset threshold. Defaults to 0.5.
158
+ offset : float, optional
159
+ Offset threshold. Defaults to `onset`.
160
+ initial_state : np.ndarray or bool, optional
161
+ Initial state.
162
+
163
+ Returns
164
+ -------
165
+ binarized : same as scores
166
+ Binarized scores with same shape and type as scores.
167
+
168
+ """
169
+
170
+ offset = offset or onset
171
+
172
+ if scores.data.ndim == 2:
173
+ num_frames, num_classes = scores.data.shape
174
+ data = scores.data.transpose()
175
+ binarized = binarize(
176
+ data, onset=onset, offset=offset, initial_state=initial_state
177
+ )
178
+ return SlidingWindowFeature(
179
+ 1.0
180
+ * binarized.transpose(),
181
+ scores.sliding_window,
182
+ )
183
+
184
+ elif scores.data.ndim == 3:
185
+ num_chunks, num_frames, num_classes = scores.data.shape
186
+ data = scores.data.reshape([-1, num_classes])
187
+ binarized = binarize(
188
+ data, onset=onset, offset=offset, initial_state=initial_state
189
+ )
190
+ return SlidingWindowFeature(
191
+ 1.0
192
+ * binarized.reshape([num_chunks, num_frames, num_classes]),
193
+ scores.sliding_window,
194
+ )
195
+
196
+ else:
197
+ raise ValueError(
198
+ "Shape of scores must be (num_chunks, num_frames, num_classes) or (num_frames, num_classes)."
199
+ )
200
+
201
+
202
+ class Binarize:
203
+ """Binarize detection scores using hysteresis thresholding
204
+
205
+ Parameters
206
+ ----------
207
+ onset : float, optional
208
+ Onset threshold. Defaults to 0.5.
209
+ offset : float, optional
210
+ Offset threshold. Defaults to `onset`.
211
+ min_duration_on : float, optional
212
+ Remove active regions shorter than that many seconds. Defaults to 0s.
213
+ min_duration_off : float, optional
214
+ Fill inactive regions shorter than that many seconds. Defaults to 0s.
215
+ pad_onset : float, optional
216
+ Extend active regions by moving their start time by that many seconds.
217
+ Defaults to 0s.
218
+ pad_offset : float, optional
219
+ Extend active regions by moving their end time by that many seconds.
220
+ Defaults to 0s.
221
+
222
+ Reference
223
+ ---------
224
+ Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
225
+ RNN-based Voice Activity Detection", InterSpeech 2015.
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ onset: float = 0.5,
231
+ offset: Optional[float] = None,
232
+ min_duration_on: float = 0.0,
233
+ min_duration_off: float = 0.0,
234
+ pad_onset: float = 0.0,
235
+ pad_offset: float = 0.0,
236
+ ):
237
+
238
+ super().__init__()
239
+
240
+ self.onset = onset
241
+ self.offset = offset or onset
242
+
243
+ self.pad_onset = pad_onset
244
+ self.pad_offset = pad_offset
245
+
246
+ self.min_duration_on = min_duration_on
247
+ self.min_duration_off = min_duration_off
248
+
249
+ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
250
+ """Binarize detection scores
251
+
252
+ Parameters
253
+ ----------
254
+ scores : SlidingWindowFeature
255
+ Detection scores.
256
+
257
+ Returns
258
+ -------
259
+ active : Annotation
260
+ Binarized scores.
261
+ """
262
+
263
+ num_frames, num_classes = scores.data.shape
264
+ frames = scores.sliding_window
265
+ timestamps = [frames[i].middle for i in range(num_frames)]
266
+
267
+ # annotation meant to store 'active' regions
268
+ active = Annotation()
269
+
270
+ for k, k_scores in enumerate(scores.data.T):
271
+
272
+ label = k if scores.labels is None else scores.labels[k]
273
+
274
+ # initial state
275
+ start = timestamps[0]
276
+ is_active = k_scores[0] > self.onset
277
+
278
+ for t, y in zip(timestamps[1:], k_scores[1:]):
279
+
280
+ # currently active
281
+ if is_active:
282
+ # switching from active to inactive
283
+ if y < self.offset:
284
+ region = Segment(start - self.pad_onset, t + self.pad_offset)
285
+ active[region, k] = label
286
+ start = t
287
+ is_active = False
288
+
289
+ # currently inactive
290
+ else:
291
+ # switching from inactive to active
292
+ if y > self.onset:
293
+ start = t
294
+ is_active = True
295
+
296
+ # if active at the end, add final region
297
+ if is_active:
298
+ region = Segment(start - self.pad_onset, t + self.pad_offset)
299
+ active[region, k] = label
300
+
301
+ # because of padding, some active regions might be overlapping: merge them.
302
+ # also: fill same speaker gaps shorter than min_duration_off
303
+ if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
304
+ active = active.support(collar=self.min_duration_off)
305
+
306
+ # remove tracks shorter than min_duration_on
307
+ if self.min_duration_on > 0:
308
+ for segment, track in list(active.itertracks()):
309
+ if segment.duration < self.min_duration_on:
310
+ del active[segment, track]
311
+
312
+ return active
313
+
314
+
315
+ class Peak:
316
+ """Peak detection
317
+
318
+ Parameters
319
+ ----------
320
+ alpha : float, optional
321
+ Peak threshold. Defaults to 0.5
322
+ min_duration : float, optional
323
+ Minimum elapsed time between two consecutive peaks. Defaults to 1 second.
324
+ """
325
+
326
+ def __init__(
327
+ self,
328
+ alpha: float = 0.5,
329
+ min_duration: float = 1.0,
330
+ ):
331
+ super(Peak, self).__init__()
332
+ self.alpha = alpha
333
+ self.min_duration = min_duration
334
+
335
+ def __call__(self, scores: SlidingWindowFeature):
336
+ """Peak detection
337
+
338
+ Parameter
339
+ ---------
340
+ scores : SlidingWindowFeature
341
+ Detection scores.
342
+
343
+ Returns
344
+ -------
345
+ segmentation : Timeline
346
+ Partition.
347
+ """
348
+
349
+ if scores.dimension != 1:
350
+ raise ValueError("Peak expects one-dimensional scores.")
351
+
352
+ num_frames = len(scores)
353
+ frames = scores.sliding_window
354
+
355
+ precision = frames.step
356
+ order = max(1, int(np.rint(self.min_duration / precision)))
357
+ indices = scipy.signal.argrelmax(scores[:], order=order)[0]
358
+
359
+ peak_time = np.array(
360
+ [frames[i].middle for i in indices if scores[i] > self.alpha]
361
+ )
362
+ boundaries = np.hstack([[frames[0].start], peak_time, [frames[num_frames].end]])
363
+
364
+ segmentation = Timeline()
365
+ for i, (start, end) in enumerate(pairwise(boundaries)):
366
+ segment = Segment(start, end)
367
+ segmentation.add(segment)
368
+
369
+ return segmentation
ailia-models/code/pyannote_audio_utils/audio/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '3.1.1'
ailia-models/code/pyannote_audio_utils/core/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+
30
+ from ._version import get_versions
31
+ __version__ = get_versions()['version']
32
+ del get_versions
33
+
34
+ PYANNOTE_URI = 'uri'
35
+ PYANNOTE_MODALITY = 'modality'
36
+ PYANNOTE_SEGMENT = 'segment'
37
+ PYANNOTE_TRACK = 'track'
38
+ PYANNOTE_LABEL = 'label'
39
+ PYANNOTE_SCORE = 'score'
40
+ PYANNOTE_IDENTITY = 'identity'
41
+
42
+ from .segment import Segment, SlidingWindow
43
+ from .timeline import Timeline
44
+ from .annotation import Annotation
45
+ from .feature import SlidingWindowFeature
46
+
47
+ Segment.set_precision()
48
+
ailia-models/code/pyannote_audio_utils/core/_version.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # This file was generated by 'versioneer.py' (0.15) from
3
+ # revision-control system data, or from the parent directory name of an
4
+ # unpacked source archive. Distribution tarballs contain a pre-generated copy
5
+ # of this file.
6
+
7
+ import json
8
+
9
+ version_json = '''
10
+ {
11
+ "dirty": false,
12
+ "error": null,
13
+ "full-revisionid": "4b0fd5302d8fa3ba249b42d3ab7b4cb51ee61ba2",
14
+ "version": "5.0.0"
15
+ }
16
+ ''' # END VERSION_JSON
17
+
18
+
19
+ def get_versions():
20
+ return json.loads(version_json)
ailia-models/code/pyannote_audio_utils/core/annotation.py ADDED
@@ -0,0 +1,1551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2021 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ # Paul LERNER
29
+
30
+ """
31
+ ##########
32
+ Annotation
33
+ ##########
34
+
35
+ .. plot:: pyplots/annotation.py
36
+
37
+ :class:`pyannote.core.Annotation` instances are ordered sets of non-empty
38
+ tracks:
39
+
40
+ - ordered, because segments are sorted by start time (and end time in case of tie)
41
+ - set, because one cannot add twice the same track
42
+ - non-empty, because one cannot add empty track
43
+
44
+ A track is a (support, name) pair where `support` is a Segment instance,
45
+ and `name` is an additional identifier so that it is possible to add multiple
46
+ tracks with the same support.
47
+
48
+ To define the annotation depicted above:
49
+
50
+ .. code-block:: ipython
51
+
52
+ In [1]: from pyannote.core import Annotation, Segment
53
+
54
+ In [6]: annotation = Annotation()
55
+ ...: annotation[Segment(1, 5)] = 'Carol'
56
+ ...: annotation[Segment(6, 8)] = 'Bob'
57
+ ...: annotation[Segment(12, 18)] = 'Carol'
58
+ ...: annotation[Segment(7, 20)] = 'Alice'
59
+ ...:
60
+
61
+ which is actually a shortcut for
62
+
63
+ .. code-block:: ipython
64
+
65
+ In [6]: annotation = Annotation()
66
+ ...: annotation[Segment(1, 5), '_'] = 'Carol'
67
+ ...: annotation[Segment(6, 8), '_'] = 'Bob'
68
+ ...: annotation[Segment(12, 18), '_'] = 'Carol'
69
+ ...: annotation[Segment(7, 20), '_'] = 'Alice'
70
+ ...:
71
+
72
+ where all tracks share the same (default) name ``'_'``.
73
+
74
+ In case two tracks share the same support, use a different track name:
75
+
76
+ .. code-block:: ipython
77
+
78
+ In [6]: annotation = Annotation(uri='my_video_file', modality='speaker')
79
+ ...: annotation[Segment(1, 5), 1] = 'Carol' # track name = 1
80
+ ...: annotation[Segment(1, 5), 2] = 'Bob' # track name = 2
81
+ ...: annotation[Segment(12, 18)] = 'Carol'
82
+ ...:
83
+
84
+ The track name does not have to be unique over the whole set of tracks.
85
+
86
+ .. note::
87
+
88
+ The optional *uri* and *modality* keywords argument can be used to remember
89
+ which document and modality (e.g. speaker or face) it describes.
90
+
91
+ Several convenient methods are available. Here are a few examples:
92
+
93
+ .. code-block:: ipython
94
+
95
+ In [9]: annotation.labels() # sorted list of labels
96
+ Out[9]: ['Bob', 'Carol']
97
+
98
+ In [10]: annotation.chart() # label duration chart
99
+ Out[10]: [('Carol', 10), ('Bob', 4)]
100
+
101
+ In [11]: list(annotation.itertracks())
102
+ Out[11]: [(<Segment(1, 5)>, 1), (<Segment(1, 5)>, 2), (<Segment(12, 18)>, u'_')]
103
+
104
+ In [12]: annotation.label_timeline('Carol')
105
+ Out[12]: <Timeline(uri=my_video_file, segments=[<Segment(1, 5)>, <Segment(12, 18)>])>
106
+
107
+ See :class:`pyannote.core.Annotation` for the complete reference.
108
+ """
109
+ import itertools
110
+ import warnings
111
+ from collections import defaultdict
112
+ from typing import (
113
+ Hashable,
114
+ Optional,
115
+ Dict,
116
+ Union,
117
+ Iterable,
118
+ List,
119
+ Set,
120
+ TextIO,
121
+ Tuple,
122
+ Iterator,
123
+ Text,
124
+ TYPE_CHECKING,
125
+ )
126
+
127
+ import numpy as np
128
+ from sortedcontainers import SortedDict
129
+
130
+ from . import (
131
+ PYANNOTE_SEGMENT,
132
+ PYANNOTE_TRACK,
133
+ PYANNOTE_LABEL,
134
+ )
135
+ from .segment import Segment, SlidingWindow
136
+ from .timeline import Timeline
137
+ from .feature import SlidingWindowFeature
138
+ from .utils.generators import string_generator, int_generator
139
+ from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode
140
+
141
+ if TYPE_CHECKING:
142
+ import pandas as pd
143
+
144
+
145
+ class Annotation:
146
+ """Annotation
147
+
148
+ Parameters
149
+ ----------
150
+ uri : string, optional
151
+ name of annotated resource (e.g. audio or video file)
152
+ modality : string, optional
153
+ name of annotated modality
154
+
155
+ Returns
156
+ -------
157
+ annotation : Annotation
158
+ New annotation
159
+
160
+ """
161
+
162
+ @classmethod
163
+ def from_df(
164
+ cls,
165
+ df: "pd.DataFrame",
166
+ uri: Optional[str] = None,
167
+ modality: Optional[str] = None,
168
+ ) -> "Annotation":
169
+
170
+ df = df[[PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL]]
171
+ return Annotation.from_records(df.itertuples(index=False), uri, modality)
172
+
173
+ def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None):
174
+
175
+ self._uri: Optional[str] = uri
176
+ self.modality: Optional[str] = modality
177
+
178
+ # sorted dictionary
179
+ # keys: annotated segments
180
+ # values: {track: label} dictionary
181
+ self._tracks: Dict[Segment, Dict[TrackName, Label]] = SortedDict()
182
+
183
+ # dictionary
184
+ # key: label
185
+ # value: timeline
186
+ self._labels: Dict[Label, Timeline] = {}
187
+ self._labelNeedsUpdate: Dict[Label, bool] = {}
188
+
189
+ # timeline meant to store all annotated segments
190
+ self._timeline: Timeline = None
191
+ self._timelineNeedsUpdate: bool = True
192
+
193
+ @property
194
+ def uri(self):
195
+ return self._uri
196
+
197
+ @uri.setter
198
+ def uri(self, uri: str):
199
+ # update uri for all internal timelines
200
+ for label in self.labels():
201
+ timeline = self.label_timeline(label, copy=False)
202
+ timeline.uri = uri
203
+ timeline = self.get_timeline(copy=False)
204
+ timeline.uri = uri
205
+ self._uri = uri
206
+
207
+ def _updateLabels(self):
208
+
209
+ # list of labels that needs to be updated
210
+ update = set(
211
+ label for label, update in self._labelNeedsUpdate.items() if update
212
+ )
213
+
214
+ # accumulate segments for updated labels
215
+ _segments = {label: [] for label in update}
216
+ for segment, track, label in self.itertracks(yield_label=True):
217
+ if label in update:
218
+ _segments[label].append(segment)
219
+
220
+ # create timeline with accumulated segments for updated labels
221
+ for label in update:
222
+ if _segments[label]:
223
+ self._labels[label] = Timeline(segments=_segments[label], uri=self.uri)
224
+ self._labelNeedsUpdate[label] = False
225
+ else:
226
+ self._labels.pop(label, None)
227
+ self._labelNeedsUpdate.pop(label, None)
228
+
229
+ def __len__(self):
230
+ """Number of segments
231
+
232
+ >>> len(annotation) # annotation contains three segments
233
+ 3
234
+ """
235
+ return len(self._tracks)
236
+
237
+ def __nonzero__(self):
238
+ return self.__bool__()
239
+
240
+ def __bool__(self):
241
+ """Emptiness
242
+
243
+ >>> if annotation:
244
+ ... # annotation is not empty
245
+ ... else:
246
+ ... # annotation is empty
247
+ """
248
+ return len(self._tracks) > 0
249
+
250
+ def itersegments(self):
251
+ """Iterate over segments (in chronological order)
252
+
253
+ >>> for segment in annotation.itersegments():
254
+ ... # do something with the segment
255
+
256
+ See also
257
+ --------
258
+ :class:`pyannote.core.Segment` describes how segments are sorted.
259
+ """
260
+ return iter(self._tracks)
261
+
262
+ def itertracks(
263
+ self, yield_label: bool = False
264
+ ) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]:
265
+ """Iterate over tracks (in chronological order)
266
+
267
+ Parameters
268
+ ----------
269
+ yield_label : bool, optional
270
+ When True, yield (segment, track, label) tuples, such that
271
+ annotation[segment, track] == label. Defaults to yielding
272
+ (segment, track) tuple.
273
+
274
+ Examples
275
+ --------
276
+
277
+ >>> for segment, track in annotation.itertracks():
278
+ ... # do something with the track
279
+
280
+ >>> for segment, track, label in annotation.itertracks(yield_label=True):
281
+ ... # do something with the track and its label
282
+ """
283
+
284
+ for segment, tracks in self._tracks.items():
285
+ for track, lbl in sorted(
286
+ tracks.items(), key=lambda tl: (str(tl[0]), str(tl[1]))
287
+ ):
288
+ if yield_label:
289
+ yield segment, track, lbl
290
+ else:
291
+ yield segment, track
292
+
293
+ def _updateTimeline(self):
294
+ self._timeline = Timeline(segments=self._tracks, uri=self.uri)
295
+ self._timelineNeedsUpdate = False
296
+
297
+ def get_timeline(self, copy: bool = True) -> Timeline:
298
+ """Get timeline made of all annotated segments
299
+
300
+ Parameters
301
+ ----------
302
+ copy : bool, optional
303
+ Defaults (True) to returning a copy of the internal timeline.
304
+ Set to False to return the actual internal timeline (faster).
305
+
306
+ Returns
307
+ -------
308
+ timeline : Timeline
309
+ Timeline made of all annotated segments.
310
+
311
+ Note
312
+ ----
313
+ In case copy is set to False, be careful **not** to modify the returned
314
+ timeline, as it may lead to weird subsequent behavior of the annotation
315
+ instance.
316
+
317
+ """
318
+ if self._timelineNeedsUpdate:
319
+ self._updateTimeline()
320
+ if copy:
321
+ return self._timeline.copy()
322
+ return self._timeline
323
+
324
+ def __eq__(self, other: "Annotation"):
325
+ """Equality
326
+
327
+ >>> annotation == other
328
+
329
+ Two annotations are equal if and only if their tracks and associated
330
+ labels are equal.
331
+ """
332
+ pairOfTracks = itertools.zip_longest(
333
+ self.itertracks(yield_label=True), other.itertracks(yield_label=True)
334
+ )
335
+ return all(t1 == t2 for t1, t2 in pairOfTracks)
336
+
337
+ def __ne__(self, other: "Annotation"):
338
+ """Inequality"""
339
+ pairOfTracks = itertools.zip_longest(
340
+ self.itertracks(yield_label=True), other.itertracks(yield_label=True)
341
+ )
342
+
343
+ return any(t1 != t2 for t1, t2 in pairOfTracks)
344
+
345
+ def __contains__(self, included: Union[Segment, Timeline]):
346
+ """Inclusion
347
+
348
+ Check whether every segment of `included` does exist in annotation.
349
+
350
+ Parameters
351
+ ----------
352
+ included : Segment or Timeline
353
+ Segment or timeline being checked for inclusion
354
+
355
+ Returns
356
+ -------
357
+ contains : bool
358
+ True if every segment in `included` exists in timeline,
359
+ False otherwise
360
+
361
+ """
362
+ return included in self.get_timeline(copy=False)
363
+
364
+ def _iter_rttm(self) -> Iterator[Text]:
365
+ """Generate lines for an RTTM file for this annotation
366
+
367
+ Returns
368
+ -------
369
+ iterator: Iterator[str]
370
+ An iterator over RTTM text lines
371
+ """
372
+ uri = self.uri if self.uri else "<NA>"
373
+ if isinstance(uri, Text) and " " in uri:
374
+ msg = (
375
+ f"Space-separated RTTM file format does not allow file URIs "
376
+ f'containing spaces (got: "{uri}").'
377
+ )
378
+ raise ValueError(msg)
379
+ for segment, _, label in self.itertracks(yield_label=True):
380
+ if isinstance(label, Text) and " " in label:
381
+ msg = (
382
+ f"Space-separated RTTM file format does not allow labels "
383
+ f'containing spaces (got: "{label}").'
384
+ )
385
+ raise ValueError(msg)
386
+ yield (
387
+ f"SPEAKER {uri} 1 {segment.start:.3f} {segment.duration:.3f} "
388
+ f"<NA> <NA> {label} <NA> <NA>\n"
389
+ )
390
+
391
+ def to_rttm(self) -> Text:
392
+ """Serialize annotation as a string using RTTM format
393
+
394
+ Returns
395
+ -------
396
+ serialized: str
397
+ RTTM string
398
+ """
399
+ return "".join([line for line in self._iter_rttm()])
400
+
401
+ def write_rttm(self, file: TextIO):
402
+ """Dump annotation to file using RTTM format
403
+
404
+ Parameters
405
+ ----------
406
+ file : file object
407
+
408
+ Usage
409
+ -----
410
+ >>> with open('file.rttm', 'w') as file:
411
+ ... annotation.write_rttm(file)
412
+ """
413
+ for line in self._iter_rttm():
414
+ file.write(line)
415
+
416
+ def _iter_lab(self) -> Iterator[Text]:
417
+ """Generate lines for a LAB file for this annotation
418
+
419
+ Returns
420
+ -------
421
+ iterator: Iterator[str]
422
+ An iterator over LAB text lines
423
+ """
424
+ for segment, _, label in self.itertracks(yield_label=True):
425
+ if isinstance(label, Text) and " " in label:
426
+ msg = (
427
+ f"Space-separated LAB file format does not allow labels "
428
+ f'containing spaces (got: "{label}").'
429
+ )
430
+ raise ValueError(msg)
431
+ yield f"{segment.start:.3f} {segment.start + segment.duration:.3f} {label}\n"
432
+
433
+ def to_lab(self) -> Text:
434
+ """Serialize annotation as a string using LAB format
435
+
436
+ Returns
437
+ -------
438
+ serialized: str
439
+ LAB string
440
+ """
441
+ return "".join([line for line in self._iter_lab()])
442
+
443
+ def write_lab(self, file: TextIO):
444
+ """Dump annotation to file using LAB format
445
+
446
+ Parameters
447
+ ----------
448
+ file : file object
449
+
450
+ Usage
451
+ -----
452
+ >>> with open('file.lab', 'w') as file:
453
+ ... annotation.write_lab(file)
454
+ """
455
+ for line in self._iter_lab():
456
+ file.write(line)
457
+
458
+ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation":
459
+ """Crop annotation to new support
460
+
461
+ Parameters
462
+ ----------
463
+ support : Segment or Timeline
464
+ If `support` is a `Timeline`, its support is used.
465
+ mode : {'strict', 'loose', 'intersection'}, optional
466
+ Controls how segments that are not fully included in `support` are
467
+ handled. 'strict' mode only keeps fully included segments. 'loose'
468
+ mode keeps any intersecting segment. 'intersection' mode keeps any
469
+ intersecting segment but replace them by their actual intersection.
470
+
471
+ Returns
472
+ -------
473
+ cropped : Annotation
474
+ Cropped annotation
475
+
476
+ Note
477
+ ----
478
+ In 'intersection' mode, the best is done to keep the track names
479
+ unchanged. However, in some cases where two original segments are
480
+ cropped into the same resulting segments, conflicting track names are
481
+ modified to make sure no track is lost.
482
+
483
+ """
484
+
485
+ # TODO speed things up by working directly with annotation internals
486
+
487
+ if isinstance(support, Segment):
488
+ support = Timeline(segments=[support], uri=self.uri)
489
+ return self.crop(support, mode=mode)
490
+
491
+ elif isinstance(support, Timeline):
492
+
493
+ # if 'support' is a `Timeline`, we use its support
494
+ support = support.support()
495
+ cropped = self.__class__(uri=self.uri, modality=self.modality)
496
+
497
+ if mode == "loose":
498
+
499
+ _tracks = {}
500
+ _labels = set([])
501
+
502
+ for segment, _ in self.get_timeline(copy=False).co_iter(support):
503
+ tracks = dict(self._tracks[segment])
504
+ _tracks[segment] = tracks
505
+ _labels.update(tracks.values())
506
+
507
+ cropped._tracks = SortedDict(_tracks)
508
+
509
+ cropped._labelNeedsUpdate = {label: True for label in _labels}
510
+ cropped._labels = {label: None for label in _labels}
511
+
512
+ cropped._timelineNeedsUpdate = True
513
+ cropped._timeline = None
514
+
515
+ return cropped
516
+
517
+ elif mode == "strict":
518
+
519
+ _tracks = {}
520
+ _labels = set([])
521
+
522
+ for segment, other_segment in self.get_timeline(copy=False).co_iter(
523
+ support
524
+ ):
525
+
526
+ if segment not in other_segment:
527
+ continue
528
+
529
+ tracks = dict(self._tracks[segment])
530
+ _tracks[segment] = tracks
531
+ _labels.update(tracks.values())
532
+
533
+ cropped._tracks = SortedDict(_tracks)
534
+
535
+ cropped._labelNeedsUpdate = {label: True for label in _labels}
536
+ cropped._labels = {label: None for label in _labels}
537
+
538
+ cropped._timelineNeedsUpdate = True
539
+ cropped._timeline = None
540
+
541
+ return cropped
542
+
543
+ elif mode == "intersection":
544
+
545
+ for segment, other_segment in self.get_timeline(copy=False).co_iter(
546
+ support
547
+ ):
548
+
549
+ intersection = segment & other_segment
550
+ for track, label in self._tracks[segment].items():
551
+ track = cropped.new_track(intersection, candidate=track)
552
+ cropped[intersection, track] = label
553
+
554
+ return cropped
555
+
556
+ else:
557
+ raise NotImplementedError("unsupported mode: '%s'" % mode)
558
+
559
+ def extrude(
560
+ self, removed: Support, mode: CropMode = "intersection"
561
+ ) -> "Annotation":
562
+ """Remove segments that overlap `removed` support.
563
+
564
+ A simple illustration:
565
+
566
+ annotation
567
+ A |------| |------|
568
+ B |----------|
569
+ C |--------------| |------|
570
+
571
+ removed `Timeline`
572
+ |-------| |-----------|
573
+
574
+ extruded Annotation with mode="intersection"
575
+ B |---|
576
+ C |--| |------|
577
+
578
+ extruded Annotation with mode="loose"
579
+ C |------|
580
+
581
+ extruded Annotation with mode="strict"
582
+ A |------|
583
+ B |----------|
584
+ C |--------------| |------|
585
+
586
+ Parameters
587
+ ----------
588
+ removed : Segment or Timeline
589
+ If `support` is a `Timeline`, its support is used.
590
+ mode : {'strict', 'loose', 'intersection'}, optional
591
+ Controls how segments that are not fully included in `removed` are
592
+ handled. 'strict' mode only removes fully included segments. 'loose'
593
+ mode removes any intersecting segment. 'intersection' mode removes
594
+ the overlapping part of any intersecting segment.
595
+
596
+ Returns
597
+ -------
598
+ extruded : Annotation
599
+ Extruded annotation
600
+
601
+ Note
602
+ ----
603
+ In 'intersection' mode, the best is done to keep the track names
604
+ unchanged. However, in some cases where two original segments are
605
+ cropped into the same resulting segments, conflicting track names are
606
+ modified to make sure no track is lost.
607
+
608
+ """
609
+ if isinstance(removed, Segment):
610
+ removed = Timeline([removed])
611
+
612
+ extent_tl = Timeline([self.get_timeline().extent()], uri=self.uri)
613
+ truncating_support = removed.gaps(support=extent_tl)
614
+ # loose for truncate means strict for crop and vice-versa
615
+ if mode == "loose":
616
+ mode = "strict"
617
+ elif mode == "strict":
618
+ mode = "loose"
619
+ return self.crop(truncating_support, mode=mode)
620
+
621
+ def get_overlap(self, labels: Optional[Iterable[Label]] = None) -> "Timeline":
622
+ """Get overlapping parts of the annotation.
623
+
624
+ A simple illustration:
625
+
626
+ annotation
627
+ A |------| |------| |----|
628
+ B |--| |-----| |----------|
629
+ C |--------------| |------|
630
+
631
+ annotation.get_overlap()
632
+ |------| |-----| |--------|
633
+
634
+ annotation.get_overlap(for_labels=["A", "B"])
635
+ |--| |--| |----|
636
+
637
+ Parameters
638
+ ----------
639
+ labels : optional list of labels
640
+ Labels for which to consider the overlap
641
+
642
+ Returns
643
+ -------
644
+ overlap : `pyannote.core.Timeline`
645
+ Timeline of the overlaps.
646
+ """
647
+ if labels:
648
+ annotation = self.subset(labels)
649
+ else:
650
+ annotation = self
651
+
652
+ overlaps_tl = Timeline(uri=annotation.uri)
653
+ for (s1, t1), (s2, t2) in annotation.co_iter(annotation):
654
+ # if labels are the same for the two segments, skipping
655
+ if self[s1, t1] == self[s2, t2]:
656
+ continue
657
+ overlaps_tl.add(s1 & s2)
658
+ return overlaps_tl.support()
659
+
660
+ def get_tracks(self, segment: Segment) -> Set[TrackName]:
661
+ """Query tracks by segment
662
+
663
+ Parameters
664
+ ----------
665
+ segment : Segment
666
+ Query
667
+
668
+ Returns
669
+ -------
670
+ tracks : set
671
+ Set of tracks
672
+
673
+ Note
674
+ ----
675
+ This will return an empty set if segment does not exist.
676
+ """
677
+ return set(self._tracks.get(segment, {}).keys())
678
+
679
+ def has_track(self, segment: Segment, track: TrackName) -> bool:
680
+ """Check whether a given track exists
681
+
682
+ Parameters
683
+ ----------
684
+ segment : Segment
685
+ Query segment
686
+ track :
687
+ Query track
688
+
689
+ Returns
690
+ -------
691
+ exists : bool
692
+ True if track exists for segment
693
+ """
694
+ return track in self._tracks.get(segment, {})
695
+
696
+ def copy(self) -> "Annotation":
697
+ """Get a copy of the annotation
698
+
699
+ Returns
700
+ -------
701
+ annotation : Annotation
702
+ Copy of the annotation
703
+ """
704
+
705
+ # create new empty annotation
706
+ copied = self.__class__(uri=self.uri, modality=self.modality)
707
+
708
+ # deep copy internal track dictionary
709
+ _tracks, _labels = [], set([])
710
+ for key, value in self._tracks.items():
711
+ _labels.update(value.values())
712
+ _tracks.append((key, dict(value)))
713
+
714
+ copied._tracks = SortedDict(_tracks)
715
+
716
+ copied._labels = {label: None for label in _labels}
717
+ copied._labelNeedsUpdate = {label: True for label in _labels}
718
+
719
+ copied._timeline = None
720
+ copied._timelineNeedsUpdate = True
721
+
722
+ return copied
723
+
724
+ def new_track(
725
+ self,
726
+ segment: Segment,
727
+ candidate: Optional[TrackName] = None,
728
+ prefix: Optional[str] = None,
729
+ ) -> TrackName:
730
+ """Generate a new track name for given segment
731
+
732
+ Ensures that the returned track name does not already
733
+ exist for the given segment.
734
+
735
+ Parameters
736
+ ----------
737
+ segment : Segment
738
+ Segment for which a new track name is generated.
739
+ candidate : any valid track name, optional
740
+ When provided, try this candidate name first.
741
+ prefix : str, optional
742
+ Track name prefix. Defaults to the empty string ''.
743
+
744
+ Returns
745
+ -------
746
+ name : str
747
+ New track name
748
+ """
749
+
750
+ # obtain list of existing tracks for segment
751
+ existing_tracks = set(self._tracks.get(segment, {}))
752
+
753
+ # if candidate is provided, check whether it already exists
754
+ # in case it does not, use it
755
+ if (candidate is not None) and (candidate not in existing_tracks):
756
+ return candidate
757
+
758
+ # no candidate was provided or the provided candidate already exists
759
+ # we need to create a brand new one
760
+
761
+ # by default (if prefix is not provided), use ''
762
+ if prefix is None:
763
+ prefix = ""
764
+
765
+ # find first non-existing track name for segment
766
+ # eg. if '0' exists, try '1', then '2', ...
767
+ count = 0
768
+ while ("%s%d" % (prefix, count)) in existing_tracks:
769
+ count += 1
770
+
771
+ # return first non-existing track name
772
+ return "%s%d" % (prefix, count)
773
+
774
+ def __str__(self):
775
+ """Human-friendly representation"""
776
+ # TODO: use pandas.DataFrame
777
+ return "\n".join(
778
+ ["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]
779
+ )
780
+
781
+ def __delitem__(self, key: Key):
782
+ """Delete one track
783
+
784
+ >>> del annotation[segment, track]
785
+
786
+ Delete all tracks of a segment
787
+
788
+ >>> del annotation[segment]
789
+ """
790
+
791
+ # del annotation[segment]
792
+ if isinstance(key, Segment):
793
+
794
+ # Pop segment out of dictionary
795
+ # and get corresponding tracks
796
+ # Raises KeyError if segment does not exist
797
+ tracks = self._tracks.pop(key)
798
+
799
+ # mark timeline as modified
800
+ self._timelineNeedsUpdate = True
801
+
802
+ # mark every label in tracks as modified
803
+ for track, label in tracks.items():
804
+ self._labelNeedsUpdate[label] = True
805
+
806
+ # del annotation[segment, track]
807
+ elif isinstance(key, tuple) and len(key) == 2:
808
+
809
+ # get segment tracks as dictionary
810
+ # if segment does not exist, get empty dictionary
811
+ # Raises KeyError if segment does not exist
812
+ tracks = self._tracks[key[0]]
813
+
814
+ # pop track out of tracks dictionary
815
+ # and get corresponding label
816
+ # Raises KeyError if track does not exist
817
+ label = tracks.pop(key[1])
818
+
819
+ # mark label as modified
820
+ self._labelNeedsUpdate[label] = True
821
+
822
+ # if tracks dictionary is now empty,
823
+ # remove segment as well
824
+ if not tracks:
825
+ self._tracks.pop(key[0])
826
+ self._timelineNeedsUpdate = True
827
+
828
+ else:
829
+ raise NotImplementedError(
830
+ "Deletion only works with Segment or (Segment, track) keys."
831
+ )
832
+
833
+ # label = annotation[segment, track]
834
+ def __getitem__(self, key: Key) -> Label:
835
+ """Get track label
836
+
837
+ >>> label = annotation[segment, track]
838
+
839
+ Note
840
+ ----
841
+ ``annotation[segment]`` is equivalent to ``annotation[segment, '_']``
842
+
843
+ """
844
+
845
+ if isinstance(key, Segment):
846
+ key = (key, "_")
847
+
848
+ return self._tracks[key[0]][key[1]]
849
+
850
+ # annotation[segment, track] = label
851
+ def __setitem__(self, key: Key, label: Label):
852
+ """Add new or update existing track
853
+
854
+ >>> annotation[segment, track] = label
855
+
856
+ If (segment, track) does not exist, it is added.
857
+ If (segment, track) already exists, it is updated.
858
+
859
+ Note
860
+ ----
861
+ ``annotation[segment] = label`` is equivalent to ``annotation[segment, '_'] = label``
862
+
863
+ Note
864
+ ----
865
+ If `segment` is empty, it does nothing.
866
+ """
867
+
868
+ if isinstance(key, Segment):
869
+ key = (key, "_")
870
+
871
+ segment, track = key
872
+
873
+ # do not add empty track
874
+ if not segment:
875
+ return
876
+
877
+ # in case we create a new segment
878
+ # mark timeline as modified
879
+ if segment not in self._tracks:
880
+ self._tracks[segment] = {}
881
+ self._timelineNeedsUpdate = True
882
+
883
+ # in case we modify an existing track
884
+ # mark old label as modified
885
+ if track in self._tracks[segment]:
886
+ old_label = self._tracks[segment][track]
887
+ self._labelNeedsUpdate[old_label] = True
888
+
889
+ # mark new label as modified
890
+ self._tracks[segment][track] = label
891
+ self._labelNeedsUpdate[label] = True
892
+
893
+ def empty(self) -> "Annotation":
894
+ """Return an empty copy
895
+
896
+ Returns
897
+ -------
898
+ empty : Annotation
899
+ Empty annotation using the same 'uri' and 'modality' attributes.
900
+
901
+ """
902
+ return self.__class__(uri=self.uri, modality=self.modality)
903
+
904
+ def labels(self) -> List[Label]:
905
+ """Get sorted list of labels
906
+
907
+ Returns
908
+ -------
909
+ labels : list
910
+ Sorted list of labels
911
+ """
912
+ if any([lnu for lnu in self._labelNeedsUpdate.values()]):
913
+ self._updateLabels()
914
+ return sorted(self._labels, key=str)
915
+
916
+ def get_labels(
917
+ self, segment: Segment, unique: bool = True
918
+ ) -> Union[Set[Label], List[Label]]:
919
+ """Query labels by segment
920
+
921
+ Parameters
922
+ ----------
923
+ segment : Segment
924
+ Query
925
+ unique : bool, optional
926
+ When False, return the list of (possibly repeated) labels.
927
+ Defaults to returning the set of labels.
928
+
929
+ Returns
930
+ -------
931
+ labels : set or list
932
+ Set (resp. list) of labels for `segment` if it exists, empty set (resp. list) otherwise
933
+ if unique (resp. if not unique).
934
+
935
+ Examples
936
+ --------
937
+ >>> annotation = Annotation()
938
+ >>> segment = Segment(0, 2)
939
+ >>> annotation[segment, 'speaker1'] = 'Bernard'
940
+ >>> annotation[segment, 'speaker2'] = 'John'
941
+ >>> print sorted(annotation.get_labels(segment))
942
+ set(['Bernard', 'John'])
943
+ >>> print annotation.get_labels(Segment(1, 2))
944
+ set([])
945
+
946
+ """
947
+
948
+ labels = self._tracks.get(segment, {}).values()
949
+
950
+ if unique:
951
+ return set(labels)
952
+
953
+ return list(labels)
954
+
955
+ def subset(self, labels: Iterable[Label], invert: bool = False) -> "Annotation":
956
+ """Filter annotation by labels
957
+
958
+ Parameters
959
+ ----------
960
+ labels : iterable
961
+ List of filtered labels
962
+ invert : bool, optional
963
+ If invert is True, extract all but requested labels
964
+
965
+ Returns
966
+ -------
967
+ filtered : Annotation
968
+ Filtered annotation
969
+ """
970
+
971
+ labels = set(labels)
972
+
973
+ if invert:
974
+ labels = set(self.labels()) - labels
975
+ else:
976
+ labels = labels & set(self.labels())
977
+
978
+ sub = self.__class__(uri=self.uri, modality=self.modality)
979
+
980
+ _tracks, _labels = {}, set([])
981
+ for segment, tracks in self._tracks.items():
982
+ sub_tracks = {
983
+ track: label for track, label in tracks.items() if label in labels
984
+ }
985
+ if sub_tracks:
986
+ _tracks[segment] = sub_tracks
987
+ _labels.update(sub_tracks.values())
988
+
989
+ sub._tracks = SortedDict(_tracks)
990
+
991
+ sub._labelNeedsUpdate = {label: True for label in _labels}
992
+ sub._labels = {label: None for label in _labels}
993
+
994
+ sub._timelineNeedsUpdate = True
995
+ sub._timeline = None
996
+
997
+ return sub
998
+
999
+ def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation":
1000
+ """Add every track of an existing annotation (in place)
1001
+
1002
+ Parameters
1003
+ ----------
1004
+ annotation : Annotation
1005
+ Annotation whose tracks are being added
1006
+ copy : bool, optional
1007
+ Return a copy of the annotation. Defaults to updating the
1008
+ annotation in-place.
1009
+
1010
+ Returns
1011
+ -------
1012
+ self : Annotation
1013
+ Updated annotation
1014
+
1015
+ Note
1016
+ ----
1017
+ Existing tracks are updated with the new label.
1018
+ """
1019
+
1020
+ result = self.copy() if copy else self
1021
+
1022
+ # TODO speed things up by working directly with annotation internals
1023
+ for segment, track, label in annotation.itertracks(yield_label=True):
1024
+ result[segment, track] = label
1025
+
1026
+ return result
1027
+
1028
+ def label_timeline(self, label: Label, copy: bool = True) -> Timeline:
1029
+ """Query segments by label
1030
+
1031
+ Parameters
1032
+ ----------
1033
+ label : object
1034
+ Query
1035
+ copy : bool, optional
1036
+ Defaults (True) to returning a copy of the internal timeline.
1037
+ Set to False to return the actual internal timeline (faster).
1038
+
1039
+ Returns
1040
+ -------
1041
+ timeline : Timeline
1042
+ Timeline made of all segments for which at least one track is
1043
+ annotated as label
1044
+
1045
+ Note
1046
+ ----
1047
+ If label does not exist, this will return an empty timeline.
1048
+
1049
+ Note
1050
+ ----
1051
+ In case copy is set to False, be careful **not** to modify the returned
1052
+ timeline, as it may lead to weird subsequent behavior of the annotation
1053
+ instance.
1054
+
1055
+ """
1056
+ if label not in self.labels():
1057
+ return Timeline(uri=self.uri)
1058
+
1059
+ if self._labelNeedsUpdate[label]:
1060
+ self._updateLabels()
1061
+
1062
+ if copy:
1063
+ return self._labels[label].copy()
1064
+
1065
+ return self._labels[label]
1066
+
1067
+ def label_support(self, label: Label) -> Timeline:
1068
+ """Label support
1069
+
1070
+ Equivalent to ``Annotation.label_timeline(label).support()``
1071
+
1072
+ Parameters
1073
+ ----------
1074
+ label : object
1075
+ Query
1076
+
1077
+ Returns
1078
+ -------
1079
+ support : Timeline
1080
+ Label support
1081
+
1082
+ See also
1083
+ --------
1084
+ :func:`~pyannote.core.Annotation.label_timeline`
1085
+ :func:`~pyannote.core.Timeline.support`
1086
+
1087
+ """
1088
+ return self.label_timeline(label, copy=False).support()
1089
+
1090
+ def label_duration(self, label: Label) -> float:
1091
+ """Label duration
1092
+
1093
+ Equivalent to ``Annotation.label_timeline(label).duration()``
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ label : object
1098
+ Query
1099
+
1100
+ Returns
1101
+ -------
1102
+ duration : float
1103
+ Duration, in seconds.
1104
+
1105
+ See also
1106
+ --------
1107
+ :func:`~pyannote.core.Annotation.label_timeline`
1108
+ :func:`~pyannote.core.Timeline.duration`
1109
+
1110
+ """
1111
+
1112
+ return self.label_timeline(label, copy=False).duration()
1113
+
1114
+ def chart(self, percent: bool = False) -> List[Tuple[Label, float]]:
1115
+ """Get labels chart (from longest to shortest duration)
1116
+
1117
+ Parameters
1118
+ ----------
1119
+ percent : bool, optional
1120
+ Return list of (label, percentage) tuples.
1121
+ Defaults to returning list of (label, duration) tuples.
1122
+
1123
+ Returns
1124
+ -------
1125
+ chart : list
1126
+ List of (label, duration), sorted by duration in decreasing order.
1127
+ """
1128
+
1129
+ chart = sorted(
1130
+ ((L, self.label_duration(L)) for L in self.labels()),
1131
+ key=lambda x: x[1],
1132
+ reverse=True,
1133
+ )
1134
+
1135
+ if percent:
1136
+ total = np.sum([duration for _, duration in chart])
1137
+ chart = [(label, duration / total) for (label, duration) in chart]
1138
+
1139
+ return chart
1140
+
1141
+ def argmax(self, support: Optional[Support] = None) -> Optional[Label]:
1142
+ """Get label with longest duration
1143
+
1144
+ Parameters
1145
+ ----------
1146
+ support : Segment or Timeline, optional
1147
+ Find label with longest duration within provided support.
1148
+ Defaults to whole extent.
1149
+
1150
+ Returns
1151
+ -------
1152
+ label : any existing label or None
1153
+ Label with longest intersection
1154
+
1155
+ Examples
1156
+ --------
1157
+ >>> annotation = Annotation(modality='speaker')
1158
+ >>> annotation[Segment(0, 10), 'speaker1'] = 'Alice'
1159
+ >>> annotation[Segment(8, 20), 'speaker1'] = 'Bob'
1160
+ >>> print "%s is such a talker!" % annotation.argmax()
1161
+ Bob is such a talker!
1162
+ >>> segment = Segment(22, 23)
1163
+ >>> if not annotation.argmax(support):
1164
+ ... print "No label intersecting %s" % segment
1165
+ No label intersection [22 --> 23]
1166
+
1167
+ """
1168
+
1169
+ cropped = self
1170
+ if support is not None:
1171
+ cropped = cropped.crop(support, mode="intersection")
1172
+
1173
+ if not cropped:
1174
+ return None
1175
+
1176
+ return max(
1177
+ ((_, cropped.label_duration(_)) for _ in cropped.labels()),
1178
+ key=lambda x: x[1],
1179
+ )[0]
1180
+
1181
+ def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
1182
+ """Rename all tracks
1183
+
1184
+ Parameters
1185
+ ----------
1186
+ generator : 'string', 'int', or iterable, optional
1187
+ If 'string' (default) rename tracks to 'A', 'B', 'C', etc.
1188
+ If 'int', rename tracks to 0, 1, 2, etc.
1189
+ If iterable, use it to generate track names.
1190
+
1191
+ Returns
1192
+ -------
1193
+ renamed : Annotation
1194
+ Copy of the original annotation where tracks are renamed.
1195
+
1196
+ Example
1197
+ -------
1198
+ >>> annotation = Annotation()
1199
+ >>> annotation[Segment(0, 1), 'a'] = 'a'
1200
+ >>> annotation[Segment(0, 1), 'b'] = 'b'
1201
+ >>> annotation[Segment(1, 2), 'a'] = 'a'
1202
+ >>> annotation[Segment(1, 3), 'c'] = 'c'
1203
+ >>> print(annotation)
1204
+ [ 00:00:00.000 --> 00:00:01.000] a a
1205
+ [ 00:00:00.000 --> 00:00:01.000] b b
1206
+ [ 00:00:01.000 --> 00:00:02.000] a a
1207
+ [ 00:00:01.000 --> 00:00:03.000] c c
1208
+ >>> print(annotation.rename_tracks(generator='int'))
1209
+ [ 00:00:00.000 --> 00:00:01.000] 0 a
1210
+ [ 00:00:00.000 --> 00:00:01.000] 1 b
1211
+ [ 00:00:01.000 --> 00:00:02.000] 2 a
1212
+ [ 00:00:01.000 --> 00:00:03.000] 3 c
1213
+ """
1214
+
1215
+ renamed = self.__class__(uri=self.uri, modality=self.modality)
1216
+
1217
+ if generator == "string":
1218
+ generator = string_generator()
1219
+ elif generator == "int":
1220
+ generator = int_generator()
1221
+
1222
+ # TODO speed things up by working directly with annotation internals
1223
+ for s, _, label in self.itertracks(yield_label=True):
1224
+ renamed[s, next(generator)] = label
1225
+ return renamed
1226
+
1227
+ def rename_labels(
1228
+ self,
1229
+ mapping: Optional[Dict] = None,
1230
+ generator: LabelGenerator = "string",
1231
+ copy: bool = True,
1232
+ ) -> "Annotation":
1233
+ """Rename labels
1234
+
1235
+ Parameters
1236
+ ----------
1237
+ mapping : dict, optional
1238
+ {old_name: new_name} mapping dictionary.
1239
+ generator : 'string', 'int' or iterable, optional
1240
+ If 'string' (default) rename label to 'A', 'B', 'C', ... If 'int',
1241
+ rename to 0, 1, 2, etc. If iterable, use it to generate labels.
1242
+ copy : bool, optional
1243
+ Set to True to return a copy of the annotation. Set to False to
1244
+ update the annotation in-place. Defaults to True.
1245
+
1246
+ Returns
1247
+ -------
1248
+ renamed : Annotation
1249
+ Annotation where labels have been renamed
1250
+
1251
+ Note
1252
+ ----
1253
+ Unmapped labels are kept unchanged.
1254
+
1255
+ Note
1256
+ ----
1257
+ Parameter `generator` has no effect when `mapping` is provided.
1258
+
1259
+ """
1260
+
1261
+ if mapping is None:
1262
+ if generator == "string":
1263
+ generator = string_generator()
1264
+ elif generator == "int":
1265
+ generator = int_generator()
1266
+ # generate mapping
1267
+ mapping = {label: next(generator) for label in self.labels()}
1268
+
1269
+ renamed = self.copy() if copy else self
1270
+
1271
+ for old_label, new_label in mapping.items():
1272
+ renamed._labelNeedsUpdate[old_label] = True
1273
+ renamed._labelNeedsUpdate[new_label] = True
1274
+
1275
+ for segment, tracks in self._tracks.items():
1276
+ new_tracks = {
1277
+ track: mapping.get(label, label) for track, label in tracks.items()
1278
+ }
1279
+ renamed._tracks[segment] = new_tracks
1280
+
1281
+ return renamed
1282
+
1283
+ def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
1284
+ """Relabel tracks
1285
+
1286
+ Create a new annotation where each track has a unique label.
1287
+
1288
+ Parameters
1289
+ ----------
1290
+ generator : 'string', 'int' or iterable, optional
1291
+ If 'string' (default) relabel tracks to 'A', 'B', 'C', ... If 'int'
1292
+ relabel to 0, 1, 2, ... If iterable, use it to generate labels.
1293
+
1294
+ Returns
1295
+ -------
1296
+ renamed : Annotation
1297
+ New annotation with relabeled tracks.
1298
+ """
1299
+
1300
+ if generator == "string":
1301
+ generator = string_generator()
1302
+ elif generator == "int":
1303
+ generator = int_generator()
1304
+
1305
+ relabeled = self.empty()
1306
+ for s, t, _ in self.itertracks(yield_label=True):
1307
+ relabeled[s, t] = next(generator)
1308
+
1309
+ return relabeled
1310
+
1311
+ def support(self, collar: float = 0.0) -> "Annotation":
1312
+ """Annotation support
1313
+
1314
+ The support of an annotation is an annotation where contiguous tracks
1315
+ with same label are merged into one unique covering track.
1316
+
1317
+ A picture is worth a thousand words::
1318
+
1319
+ collar
1320
+ |---|
1321
+
1322
+ annotation
1323
+ |--A--| |--A--| |-B-|
1324
+ |-B-| |--C--| |----B-----|
1325
+
1326
+ annotation.support(collar)
1327
+ |------A------| |------B------|
1328
+ |-B-| |--C--|
1329
+
1330
+ Parameters
1331
+ ----------
1332
+ collar : float, optional
1333
+ Merge tracks with same label and separated by less than `collar`
1334
+ seconds. This is why 'A' tracks are merged in above figure.
1335
+ Defaults to 0.
1336
+
1337
+ Returns
1338
+ -------
1339
+ support : Annotation
1340
+ Annotation support
1341
+
1342
+ Note
1343
+ ----
1344
+ Track names are lost in the process.
1345
+ """
1346
+
1347
+ generator = string_generator()
1348
+
1349
+ # initialize an empty annotation
1350
+ # with same uri and modality as original
1351
+ support = self.empty()
1352
+ for label in self.labels():
1353
+
1354
+ # get timeline for current label
1355
+ timeline = self.label_timeline(label, copy=True)
1356
+
1357
+ # fill the gaps shorter than collar
1358
+ timeline = timeline.support(collar)
1359
+
1360
+ # reconstruct annotation with merged tracks
1361
+ for segment in timeline.support():
1362
+ support[segment, next(generator)] = label
1363
+
1364
+ return support
1365
+
1366
+ def co_iter(
1367
+ self, other: "Annotation"
1368
+ ) -> Iterator[Tuple[Tuple[Segment, TrackName], Tuple[Segment, TrackName]]]:
1369
+ """Iterate over pairs of intersecting tracks
1370
+
1371
+ Parameters
1372
+ ----------
1373
+ other : Annotation
1374
+ Second annotation
1375
+
1376
+ Returns
1377
+ -------
1378
+ iterable : (Segment, object), (Segment, object) iterable
1379
+ Yields pairs of intersecting tracks, in chronological (then
1380
+ alphabetical) order.
1381
+
1382
+ See also
1383
+ --------
1384
+ :func:`~pyannote.core.Timeline.co_iter`
1385
+
1386
+ """
1387
+ timeline = self.get_timeline(copy=False)
1388
+ other_timeline = other.get_timeline(copy=False)
1389
+ for s, S in timeline.co_iter(other_timeline):
1390
+ tracks = sorted(self.get_tracks(s), key=str)
1391
+ other_tracks = sorted(other.get_tracks(S), key=str)
1392
+ for t, T in itertools.product(tracks, other_tracks):
1393
+ yield (s, t), (S, T)
1394
+
1395
+ def __mul__(self, other: "Annotation") -> np.ndarray:
1396
+ """Cooccurrence (or confusion) matrix
1397
+
1398
+ >>> matrix = annotation * other
1399
+
1400
+ Parameters
1401
+ ----------
1402
+ other : Annotation
1403
+ Second annotation
1404
+
1405
+ Returns
1406
+ -------
1407
+ cooccurrence : (n_self, n_other) np.ndarray
1408
+ Cooccurrence matrix where `n_self` (resp. `n_other`) is the number
1409
+ of labels in `self` (resp. `other`).
1410
+ """
1411
+
1412
+ if not isinstance(other, Annotation):
1413
+ raise TypeError(
1414
+ "computing cooccurrence matrix only works with Annotation " "instances."
1415
+ )
1416
+
1417
+ i_labels = self.labels()
1418
+ j_labels = other.labels()
1419
+
1420
+ I = {label: i for i, label in enumerate(i_labels)}
1421
+ J = {label: j for j, label in enumerate(j_labels)}
1422
+
1423
+ matrix = np.zeros((len(I), len(J)))
1424
+
1425
+ # iterate over intersecting tracks and accumulate durations
1426
+ for (segment, track), (other_segment, other_track) in self.co_iter(other):
1427
+ i = I[self[segment, track]]
1428
+ j = J[other[other_segment, other_track]]
1429
+ duration = (segment & other_segment).duration
1430
+ matrix[i, j] += duration
1431
+
1432
+ return matrix
1433
+
1434
+ def discretize(
1435
+ self,
1436
+ support: Optional[Segment] = None,
1437
+ resolution: Union[float, SlidingWindow] = 0.01,
1438
+ labels: Optional[List[Hashable]] = None,
1439
+ duration: Optional[float] = None,
1440
+ ):
1441
+ """Discretize
1442
+
1443
+ Parameters
1444
+ ----------
1445
+ support : Segment, optional
1446
+ Part of annotation to discretize.
1447
+ Defaults to annotation full extent.
1448
+ resolution : float or SlidingWindow, optional
1449
+ Defaults to 10ms frames.
1450
+ labels : list of labels, optional
1451
+ Defaults to self.labels()
1452
+ duration : float, optional
1453
+ Overrides support duration and ensures that the number of
1454
+ returned frames is fixed (which might otherwise not be the case
1455
+ because of rounding errors).
1456
+
1457
+ Returns
1458
+ -------
1459
+ discretized : SlidingWindowFeature
1460
+ (num_frames, num_labels)-shaped binary features.
1461
+ """
1462
+
1463
+ if support is None:
1464
+ support = self.get_timeline().extent()
1465
+ start_time, end_time = support
1466
+
1467
+ cropped = self.crop(support, mode="intersection")
1468
+
1469
+ if labels is None:
1470
+ labels = cropped.labels()
1471
+
1472
+ if isinstance(resolution, SlidingWindow):
1473
+ resolution = SlidingWindow(
1474
+ start=start_time, step=resolution.step, duration=resolution.duration
1475
+ )
1476
+ else:
1477
+ resolution = SlidingWindow(
1478
+ start=start_time, step=resolution, duration=resolution
1479
+ )
1480
+
1481
+ start_frame = resolution.closest_frame(start_time)
1482
+ if duration is None:
1483
+ end_frame = resolution.closest_frame(end_time)
1484
+ num_frames = end_frame - start_frame
1485
+ else:
1486
+ num_frames = int(round(duration / resolution.step))
1487
+
1488
+ data = np.zeros((num_frames, len(labels)), dtype=np.uint8)
1489
+ for k, label in enumerate(labels):
1490
+ segments = cropped.label_timeline(label)
1491
+ for start, stop in resolution.crop(
1492
+ segments, mode="center", return_ranges=True
1493
+ ):
1494
+ data[max(0, start) : min(stop, num_frames), k] += 1
1495
+ data = np.minimum(data, 1, out=data)
1496
+
1497
+ return SlidingWindowFeature(data, resolution, labels=labels)
1498
+
1499
+ @classmethod
1500
+ def from_records(
1501
+ cls,
1502
+ records: Iterator[Tuple[Segment, TrackName, Label]],
1503
+ uri: Optional[str] = None,
1504
+ modality: Optional[str] = None,
1505
+ ) -> "Annotation":
1506
+ """Annotation
1507
+
1508
+ Parameters
1509
+ ----------
1510
+ records : iterator of tuples
1511
+ (segment, track, label) tuples
1512
+ uri : string, optional
1513
+ name of annotated resource (e.g. audio or video file)
1514
+ modality : string, optional
1515
+ name of annotated modality
1516
+
1517
+ Returns
1518
+ -------
1519
+ annotation : Annotation
1520
+ New annotation
1521
+
1522
+ """
1523
+ annotation = cls(uri=uri, modality=modality)
1524
+ tracks = defaultdict(dict)
1525
+ labels = set()
1526
+ for segment, track, label in records:
1527
+ tracks[segment][track] = label
1528
+ labels.add(label)
1529
+ annotation._tracks = SortedDict(tracks)
1530
+ annotation._labels = {label: None for label in labels}
1531
+ annotation._labelNeedsUpdate = {label: True for label in annotation._labels}
1532
+ annotation._timeline = None
1533
+ annotation._timelineNeedsUpdate = True
1534
+
1535
+ return annotation
1536
+
1537
+ def _repr_png_(self):
1538
+ """IPython notebook support
1539
+
1540
+ See also
1541
+ --------
1542
+ :mod:`pyannote.core.notebook`
1543
+ """
1544
+ from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
1545
+
1546
+ if not MATPLOTLIB_IS_AVAILABLE:
1547
+ warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
1548
+ return None
1549
+
1550
+ from .notebook import repr_annotation
1551
+ return repr_annotation(self)
ailia-models/code/pyannote_audio_utils/core/feature.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+
30
+ """
31
+ ########
32
+ Features
33
+ ########
34
+
35
+ See :class:`pyannote_audio_utils.core.SlidingWindowFeature` for the complete reference.
36
+ """
37
+ import numbers
38
+ import warnings
39
+ from typing import Tuple, Optional, Union, Iterator, List, Text
40
+
41
+ import numpy as np
42
+
43
+ from pyannote_audio_utils.core.utils.types import Alignment
44
+ from .segment import Segment
45
+ from .segment import SlidingWindow
46
+ from .timeline import Timeline
47
+
48
+
49
+ class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin):
50
+ """Periodic feature vectors
51
+
52
+ Parameters
53
+ ----------
54
+ data : (n_frames, n_features) numpy array
55
+ sliding_window : SlidingWindow
56
+ labels : list, optional
57
+ Textual description of each dimension.
58
+ """
59
+
60
+ def __init__(
61
+ self, data: np.ndarray, sliding_window: SlidingWindow, labels: List[Text] = None
62
+ ):
63
+ self.sliding_window: SlidingWindow = sliding_window
64
+ self.data = data
65
+ self.labels = labels
66
+ self.__i: int = -1
67
+
68
+ def __len__(self):
69
+ """Number of feature vectors"""
70
+ return self.data.shape[0]
71
+
72
+ @property
73
+ def extent(self):
74
+ return self.sliding_window.range_to_segment(0, len(self))
75
+
76
+ @property
77
+ def dimension(self):
78
+ """Dimension of feature vectors"""
79
+ return self.data.shape[1]
80
+
81
+ def getNumber(self):
82
+ warnings.warn("This is deprecated in favor of `__len__`", DeprecationWarning)
83
+ return self.data.shape[0]
84
+
85
+ def getDimension(self):
86
+ warnings.warn(
87
+ "This is deprecated in favor of `dimension` property", DeprecationWarning
88
+ )
89
+ return self.dimension
90
+
91
+ def getExtent(self):
92
+ warnings.warn(
93
+ "This is deprecated in favor of `extent` property", DeprecationWarning
94
+ )
95
+ return self.extent
96
+
97
+ def __getitem__(self, i: int) -> np.ndarray:
98
+ """Get ith feature vector"""
99
+ return self.data[i]
100
+
101
+ def __iter__(self):
102
+ self.__i = -1
103
+ return self
104
+
105
+ def __next__(self) -> Tuple[Segment, np.ndarray]:
106
+ self.__i += 1
107
+ try:
108
+ return self.sliding_window[self.__i], self.data[self.__i]
109
+ except IndexError as e:
110
+ raise StopIteration()
111
+
112
+ def next(self):
113
+ return self.__next__()
114
+
115
+ def iterfeatures(
116
+ self, window: Optional[bool] = False
117
+ ) -> Iterator[Union[Tuple[np.ndarray, Segment], np.ndarray]]:
118
+ """Feature vector iterator
119
+
120
+ Parameters
121
+ ----------
122
+ window : bool, optional
123
+ When True, yield both feature vector and corresponding window.
124
+ Default is to only yield feature vector
125
+
126
+ """
127
+ n_samples = self.data.shape[0]
128
+ for i in range(n_samples):
129
+ if window:
130
+ yield self.data[i], self.sliding_window[i]
131
+ else:
132
+ yield self.data[i]
133
+
134
+ def crop(
135
+ self,
136
+ focus: Union[Segment, Timeline],
137
+ mode: Alignment = "loose",
138
+ fixed: Optional[float] = None,
139
+ return_data: bool = True,
140
+ ) -> Union[np.ndarray, "SlidingWindowFeature"]:
141
+ """Extract frames
142
+
143
+ Parameters
144
+ ----------
145
+ focus : Segment or Timeline
146
+ mode : {'loose', 'strict', 'center'}, optional
147
+ In 'strict' mode, only frames fully included in 'focus' support are
148
+ returned. In 'loose' mode, any intersecting frames are returned. In
149
+ 'center' mode, first and last frames are chosen to be the ones
150
+ whose centers are the closest to 'focus' start and end times.
151
+ Defaults to 'loose'.
152
+ fixed : float, optional
153
+ Overrides `Segment` 'focus' duration and ensures that the number of
154
+ returned frames is fixed (which might otherwise not be the case
155
+ because of rounding errors).
156
+ return_data : bool, optional
157
+ Return a numpy array (default). For `Segment` 'focus', setting it
158
+ to False will return a `SlidingWindowFeature` instance.
159
+
160
+ Returns
161
+ -------
162
+ data : `numpy.ndarray` or `SlidingWindowFeature`
163
+ Frame features.
164
+
165
+ See also
166
+ --------
167
+ SlidingWindow.crop
168
+
169
+ """
170
+
171
+ if (not return_data) and (not isinstance(focus, Segment)):
172
+ msg = (
173
+ '"focus" must be a "Segment" instance when "return_data"'
174
+ "is set to False."
175
+ )
176
+ raise ValueError(msg)
177
+
178
+ if (not return_data) and (fixed is not None):
179
+ msg = '"fixed" cannot be set when "return_data" is set to False.'
180
+ raise ValueError(msg)
181
+
182
+ ranges = self.sliding_window.crop(
183
+ focus, mode=mode, fixed=fixed, return_ranges=True
184
+ )
185
+
186
+ # total number of samples in features
187
+ n_samples = self.data.shape[0]
188
+
189
+ # 1 for vector features (e.g. MFCC in pyannote_audio_utils.audio)
190
+ # 2 for matrix features (e.g. grey-level frames in pyannote_audio_utils.video)
191
+ # 3 for 3rd order tensor (e.g. RBG frames in pyannote_audio_utils.video)
192
+ n_dimensions = len(self.data.shape) - 1
193
+
194
+ # clip ranges
195
+ clipped_ranges, repeat_first, repeat_last = [], 0, 0
196
+ for start, end in ranges:
197
+ # count number of requested samples before first sample
198
+ repeat_first += min(end, 0) - min(start, 0)
199
+ # count number of requested samples after last sample
200
+ repeat_last += max(end, n_samples) - max(start, n_samples)
201
+ # if all requested samples are out of bounds, skip
202
+ if end < 0 or start >= n_samples:
203
+ continue
204
+ else:
205
+ # keep track of non-empty clipped ranges
206
+ clipped_ranges += [[max(start, 0), min(end, n_samples)]]
207
+
208
+ if clipped_ranges:
209
+ data = np.vstack([self.data[start:end, :] for start, end in clipped_ranges])
210
+ else:
211
+ # if all ranges are out of bounds, just return empty data
212
+ shape = (0,) + self.data.shape[1:]
213
+ data = np.empty(shape)
214
+
215
+ # corner case when "fixed" duration cropping is requested:
216
+ # correct number of samples even with out-of-bounds indices
217
+ if fixed is not None:
218
+ data = np.vstack(
219
+ [
220
+ # repeat first sample as many times as needed
221
+ np.tile(self.data[0], (repeat_first,) + (1,) * n_dimensions),
222
+ data,
223
+ # repeat last sample as many times as needed
224
+ np.tile(
225
+ self.data[n_samples - 1], (repeat_last,) + (1,) * n_dimensions
226
+ ),
227
+ ]
228
+ )
229
+
230
+ # return data
231
+ if return_data:
232
+ return data
233
+
234
+ # wrap data in a SlidingWindowFeature and return
235
+ sliding_window = SlidingWindow(
236
+ start=self.sliding_window[clipped_ranges[0][0]].start,
237
+ duration=self.sliding_window.duration,
238
+ step=self.sliding_window.step,
239
+ )
240
+
241
+ return SlidingWindowFeature(data, sliding_window, labels=self.labels)
242
+
243
+ def _repr_png_(self):
244
+ from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
245
+
246
+ if not MATPLOTLIB_IS_AVAILABLE:
247
+ warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
248
+ return None
249
+
250
+ from .notebook import repr_feature
251
+
252
+ return repr_feature(self)
253
+
254
+ _HANDLED_TYPES = (np.ndarray, numbers.Number)
255
+
256
+ def __array__(self) -> np.ndarray:
257
+ return self.data
258
+
259
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
260
+ out = kwargs.get("out", ())
261
+ for x in inputs + out:
262
+ # Only support operations with instances of _HANDLED_TYPES.
263
+ # Use SlidingWindowFeature instead of type(self) for isinstance to
264
+ # allow subclasses that don't override __array_ufunc__ to
265
+ # handle SlidingWindowFeature objects.
266
+ if not isinstance(x, self._HANDLED_TYPES + (SlidingWindowFeature,)):
267
+ return NotImplemented
268
+
269
+ # Defer to the implementation of the ufunc on unwrapped values.
270
+ inputs = tuple(
271
+ x.data if isinstance(x, SlidingWindowFeature) else x for x in inputs
272
+ )
273
+ if out:
274
+ kwargs["out"] = tuple(
275
+ x.data if isinstance(x, SlidingWindowFeature) else x for x in out
276
+ )
277
+ data = getattr(ufunc, method)(*inputs, **kwargs)
278
+
279
+ if type(data) is tuple:
280
+ # multiple return values
281
+ return tuple(
282
+ type(self)(x, self.sliding_window, labels=self.labels) for x in data
283
+ )
284
+ elif method == "at":
285
+ # no return value
286
+ return None
287
+ else:
288
+ # one return value
289
+ return type(self)(data, self.sliding_window, labels=self.labels)
290
+
291
+ def align(self, to: "SlidingWindowFeature") -> "SlidingWindowFeature":
292
+ """Align features by linear temporal interpolation
293
+
294
+ Parameters
295
+ ----------
296
+ to : SlidingWindowFeature
297
+ Features to align with.
298
+
299
+ Returns
300
+ -------
301
+ aligned : SlidingWindowFeature
302
+ Aligned features
303
+ """
304
+
305
+ old_start = self.sliding_window.start
306
+ old_step = self.sliding_window.step
307
+ old_duration = self.sliding_window.duration
308
+ old_samples = len(self)
309
+ old_t = old_start + 0.5 * old_duration + np.arange(old_samples) * old_step
310
+
311
+ new_start = to.sliding_window.start
312
+ new_step = to.sliding_window.step
313
+ new_duration = to.sliding_window.duration
314
+ new_samples = len(to)
315
+ new_t = new_start + 0.5 * new_duration + np.arange(new_samples) * new_step
316
+
317
+ new_data = np.hstack(
318
+ [
319
+ np.interp(new_t, old_t, old_data)[:, np.newaxis]
320
+ for old_data in self.data.T
321
+ ]
322
+ )
323
+ return SlidingWindowFeature(new_data, to.sliding_window, labels=self.labels)
324
+
325
+
326
+ if __name__ == "__main__":
327
+ import doctest
328
+
329
+ doctest.testmod()
ailia-models/code/pyannote_audio_utils/core/notebook.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ """
30
+ #############
31
+ Visualization
32
+ #############
33
+
34
+ :class:`pyannote.core.Segment`, :class:`pyannote.core.Timeline`,
35
+ :class:`pyannote.core.Annotation` and :class:`pyannote.core.SlidingWindowFeature`
36
+ instances can be directly visualized in notebooks.
37
+
38
+ You will however need to install ``pytannote.core``'s additional dependencies
39
+ for notebook representations (namely, matplotlib):
40
+
41
+
42
+ .. code-block:: bash
43
+
44
+ pip install pyannote.core[notebook]
45
+
46
+
47
+ Segments
48
+ --------
49
+
50
+ .. code-block:: ipython
51
+
52
+ In [1]: from pyannote.core import Segment
53
+
54
+ In [2]: segment = Segment(start=5, end=15)
55
+ ....: segment
56
+
57
+ .. plot:: pyplots/segment.py
58
+
59
+
60
+ Timelines
61
+ ---------
62
+
63
+ .. code-block:: ipython
64
+
65
+ In [25]: from pyannote.core import Timeline, Segment
66
+
67
+ In [26]: timeline = Timeline()
68
+ ....: timeline.add(Segment(1, 5))
69
+ ....: timeline.add(Segment(6, 8))
70
+ ....: timeline.add(Segment(12, 18))
71
+ ....: timeline.add(Segment(7, 20))
72
+ ....: timeline
73
+
74
+ .. plot:: pyplots/timeline.py
75
+
76
+
77
+ Annotations
78
+ -----------
79
+
80
+
81
+ .. code-block:: ipython
82
+
83
+ In [1]: from pyannote.core import Annotation, Segment
84
+
85
+ In [6]: annotation = Annotation()
86
+ ...: annotation[Segment(1, 5)] = 'Carol'
87
+ ...: annotation[Segment(6, 8)] = 'Bob'
88
+ ...: annotation[Segment(12, 18)] = 'Carol'
89
+ ...: annotation[Segment(7, 20)] = 'Alice'
90
+ ...: annotation
91
+
92
+ .. plot:: pyplots/annotation.py
93
+
94
+ """
95
+ from typing import Iterable, Dict, Optional
96
+
97
+ from .utils.types import Label, LabelStyle, Resource
98
+
99
+ # try:
100
+ # from IPython.core.pylabtools import print_figure
101
+ # except Exception as e:
102
+ # pass
103
+ import numpy as np
104
+ from itertools import cycle, product, groupby
105
+ from .segment import Segment, SlidingWindow
106
+ from .timeline import Timeline
107
+ from .annotation import Annotation
108
+ from .feature import SlidingWindowFeature
109
+
110
+ try:
111
+ import matplotlib
112
+ except ImportError:
113
+ MATPLOTLIB_IS_AVAILABLE = False
114
+ else:
115
+ MATPLOTLIB_IS_AVAILABLE = True
116
+
117
+ MATPLOTLIB_WARNING = (
118
+ "Couldn't import matplotlib to render the vizualization "
119
+ "for object {klass}. To enable, install the required dependencies "
120
+ "with 'pip install pyannore.core[notebook]'"
121
+ )
122
+
123
+
124
+ class Notebook:
125
+ def __init__(self):
126
+ self.reset()
127
+
128
+ def reset(self):
129
+ from matplotlib.cm import get_cmap
130
+
131
+ linewidth = [3, 1]
132
+ linestyle = ["solid", "dashed", "dotted"]
133
+
134
+ cm = get_cmap("Set1")
135
+ colors = [cm(1.0 * i / 8) for i in range(9)]
136
+
137
+ self._style_generator = cycle(product(linestyle, linewidth, colors))
138
+ self._style: Dict[Optional[Label], LabelStyle] = {
139
+ None: ("solid", 1, (0.0, 0.0, 0.0))
140
+ }
141
+ del self.crop
142
+ del self.width
143
+
144
+ @property
145
+ def crop(self):
146
+ """The crop property."""
147
+ return self._crop
148
+
149
+ @crop.setter
150
+ def crop(self, segment: Segment):
151
+ self._crop = segment
152
+
153
+ @crop.deleter
154
+ def crop(self):
155
+ self._crop = None
156
+
157
+ @property
158
+ def width(self):
159
+ """The width property"""
160
+ return self._width
161
+
162
+ @width.setter
163
+ def width(self, value: int):
164
+ self._width = value
165
+
166
+ @width.deleter
167
+ def width(self):
168
+ self._width = 20
169
+
170
+ def __getitem__(self, label: Label) -> LabelStyle:
171
+ """Get line style for a given label"""
172
+ if label not in self._style:
173
+ self._style[label] = next(self._style_generator)
174
+ return self._style[label]
175
+
176
+ def setup(self, ax=None, ylim=(0, 1), yaxis=False, time=True):
177
+ import matplotlib.pyplot as plt
178
+
179
+ if ax is None:
180
+ ax = plt.gca()
181
+ ax.set_xlim(self.crop)
182
+ if time:
183
+ ax.set_xlabel("Time")
184
+ else:
185
+ ax.set_xticklabels([])
186
+ ax.set_ylim(ylim)
187
+ ax.axes.get_yaxis().set_visible(yaxis)
188
+ return ax
189
+
190
+ def draw_segment(self, ax, segment: Segment, y, label=None, boundaries=True):
191
+
192
+ # do nothing if segment is empty
193
+ if not segment:
194
+ return
195
+
196
+ linestyle, linewidth, color = self[label]
197
+
198
+ # draw segment
199
+ ax.hlines(
200
+ y,
201
+ segment.start,
202
+ segment.end,
203
+ color,
204
+ linewidth=linewidth,
205
+ linestyle=linestyle,
206
+ label=label,
207
+ )
208
+ if boundaries:
209
+ ax.vlines(
210
+ segment.start, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid"
211
+ )
212
+ ax.vlines(
213
+ segment.end, y + 0.05, y - 0.05, color, linewidth=1, linestyle="solid"
214
+ )
215
+
216
+ if label is None:
217
+ return
218
+
219
+ def get_y(self, segments: Iterable[Segment]) -> np.ndarray:
220
+ """
221
+
222
+ Parameters
223
+ ----------
224
+ segments : Iterable
225
+ `Segment` iterable (sorted)
226
+
227
+ Returns
228
+ -------
229
+ y : np.array
230
+ y coordinates of each segment
231
+
232
+ """
233
+
234
+ # up_to stores the largest end time
235
+ # displayed in each line (at the current iteration)
236
+ # (at the beginning, there is only one empty line)
237
+ up_to = [-np.inf]
238
+
239
+ # y[k] indicates on which line to display kth segment
240
+ y = []
241
+
242
+ for segment in segments:
243
+ # so far, we do not know which line to use
244
+ found = False
245
+ # try each line until we find one that is ok
246
+ for i, u in enumerate(up_to):
247
+ # if segment starts after the previous one
248
+ # on the same line, then we add it to the line
249
+ if segment.start >= u:
250
+ found = True
251
+ y.append(i)
252
+ up_to[i] = segment.end
253
+ break
254
+ # in case we went out of lines, create a new one
255
+ if not found:
256
+ y.append(len(up_to))
257
+ up_to.append(segment.end)
258
+
259
+ # from line numbers to actual y coordinates
260
+ y = 1.0 - 1.0 / (len(up_to) + 1) * (1 + np.array(y))
261
+
262
+ return y
263
+
264
+ def __call__(self, resource: Resource, time: bool = True, legend: bool = True):
265
+
266
+ if isinstance(resource, Segment):
267
+ self.plot_segment(resource, time=time)
268
+
269
+ elif isinstance(resource, Timeline):
270
+ self.plot_timeline(resource, time=time)
271
+
272
+ elif isinstance(resource, Annotation):
273
+ self.plot_annotation(resource, time=time, legend=legend)
274
+
275
+ elif isinstance(resource, SlidingWindowFeature):
276
+ self.plot_feature(resource, time=time)
277
+
278
+ def plot_segment(self, segment, ax=None, time=True):
279
+
280
+ if not self.crop:
281
+ self.crop = segment
282
+
283
+ ax = self.setup(ax=ax, time=time)
284
+ self.draw_segment(ax, segment, 0.5)
285
+
286
+ def plot_timeline(self, timeline: Timeline, ax=None, time=True):
287
+
288
+ if not self.crop and timeline:
289
+ self.crop = timeline.extent()
290
+
291
+ cropped = timeline.crop(self.crop, mode="loose")
292
+
293
+ ax = self.setup(ax=ax, time=time)
294
+
295
+ for segment, y in zip(cropped, self.get_y(cropped)):
296
+ self.draw_segment(ax, segment, y)
297
+
298
+ # ax.set_aspect(3. / self.crop.duration)
299
+
300
+ def plot_annotation(self, annotation: Annotation, ax=None, time=True, legend=True):
301
+
302
+ if not self.crop:
303
+ self.crop = annotation.get_timeline(copy=False).extent()
304
+
305
+ cropped = annotation.crop(self.crop, mode="intersection")
306
+ labels = cropped.labels()
307
+ segments = [s for s, _ in cropped.itertracks()]
308
+
309
+ ax = self.setup(ax=ax, time=time)
310
+
311
+ for (segment, track, label), y in zip(
312
+ cropped.itertracks(yield_label=True), self.get_y(segments)
313
+ ):
314
+ self.draw_segment(ax, segment, y, label=label)
315
+
316
+ if legend:
317
+ H, L = ax.get_legend_handles_labels()
318
+
319
+ # corner case when no segment is visible
320
+ if not H:
321
+ return
322
+
323
+ # this gets exactly one legend handle and one legend label per label
324
+ # (avoids repeated legends for repeated tracks with same label)
325
+ HL = groupby(
326
+ sorted(zip(H, L), key=lambda h_l: h_l[1]), key=lambda h_l: h_l[1]
327
+ )
328
+ H, L = zip(*list((next(h_l)[0], l) for l, h_l in HL))
329
+ ax.legend(
330
+ H,
331
+ L,
332
+ bbox_to_anchor=(0, 1),
333
+ loc=3,
334
+ ncol=5,
335
+ borderaxespad=0.0,
336
+ frameon=False,
337
+ )
338
+
339
+ def plot_feature(
340
+ self, feature: SlidingWindowFeature, ax=None, time=True, ylim=None
341
+ ):
342
+
343
+ if not self.crop:
344
+ self.crop = feature.getExtent()
345
+
346
+ window = feature.sliding_window
347
+ n, dimension = feature.data.shape
348
+ ((start, stop),) = window.crop(self.crop, mode="loose", return_ranges=True)
349
+ xlim = (window[start].middle, window[stop].middle)
350
+
351
+ start = max(0, start)
352
+ stop = min(stop, n)
353
+ t = window[0].middle + window.step * np.arange(start, stop)
354
+ data = feature[start:stop]
355
+
356
+ if ylim is None:
357
+ m = np.nanmin(data)
358
+ M = np.nanmax(data)
359
+ ylim = (m - 0.1 * (M - m), M + 0.1 * (M - m))
360
+
361
+ ax = self.setup(ax=ax, yaxis=False, ylim=ylim, time=time)
362
+ ax.plot(t, data)
363
+ ax.set_xlim(xlim)
364
+
365
+
366
+ notebook = Notebook()
367
+
368
+ def repr_segment(segment: Segment):
369
+ """Get `png` data for `segment`"""
370
+ import matplotlib.pyplot as plt
371
+
372
+ figsize = plt.rcParams["figure.figsize"]
373
+ plt.rcParams["figure.figsize"] = (notebook.width, 1)
374
+ fig, ax = plt.subplots()
375
+ notebook.plot_segment(segment, ax=ax)
376
+ # data = print_figure(fig, "png")
377
+ plt.savefig('./output')
378
+ plt.close(fig)
379
+ plt.rcParams["figure.figsize"] = figsize
380
+ return
381
+
382
+
383
+ def repr_timeline(timeline: Timeline):
384
+ """Get `png` data for `timeline`"""
385
+ import matplotlib.pyplot as plt
386
+ breakpoint()
387
+ figsize = plt.rcParams["figure.figsize"]
388
+ plt.rcParams["figure.figsize"] = (notebook.width, 1)
389
+ fig, ax = plt.subplots()
390
+ notebook.plot_timeline(timeline, ax=ax)
391
+ # data = print_figure(fig, "png")
392
+ plt.savefig('./output')
393
+ plt.cla(fig)
394
+ plt.rcParams["figure.figsize"] = figsize
395
+ return
396
+
397
+
398
+ def repr_annotation(annotation: Annotation):
399
+ """Get `png` data for `annotation`"""
400
+ import matplotlib.pyplot as plt
401
+
402
+ figsize = plt.rcParams["figure.figsize"]
403
+ plt.rcParams["figure.figsize"] = (notebook.width, 2)
404
+ fig, ax = plt.subplots()
405
+ notebook.plot_annotation(annotation, ax=ax)
406
+ # data = print_figure(fig, "png")
407
+ plt.savefig('./output')
408
+ plt.close(fig)
409
+ plt.rcParams["figure.figsize"] = figsize
410
+ return
411
+
412
+
413
+ def repr_feature(feature: SlidingWindowFeature):
414
+ """Get `png` data for `feature`"""
415
+ import matplotlib.pyplot as plt
416
+
417
+ figsize = plt.rcParams["figure.figsize"]
418
+
419
+ if feature.data.ndim == 2:
420
+
421
+ plt.rcParams["figure.figsize"] = (notebook.width, 2)
422
+ fig, ax = plt.subplots()
423
+ notebook.plot_feature(feature, ax=ax)
424
+ # data = print_figure(fig, "png")
425
+ plt.savefig('./output')
426
+ plt.close(fig)
427
+
428
+ elif feature.data.ndim == 3:
429
+
430
+ num_chunks = len(feature)
431
+
432
+ if notebook.crop is None:
433
+ notebook.crop = Segment(
434
+ start=feature.sliding_window.start,
435
+ end=feature.sliding_window[num_chunks - 1].end,
436
+ )
437
+ else:
438
+ feature = feature.crop(notebook.crop, mode="loose", return_data=False)
439
+
440
+ num_overlap = (
441
+ round(feature.sliding_window.duration // feature.sliding_window.step) + 1
442
+ )
443
+
444
+ num_overlap = min(num_chunks, num_overlap)
445
+
446
+ plt.rcParams["figure.figsize"] = (notebook.width, 1.5 * num_overlap)
447
+
448
+ fig, axes = plt.subplots(nrows=num_overlap, ncols=1,)
449
+ mini, maxi = np.nanmin(feature.data), np.nanmax(feature.data)
450
+ ylim = (mini - 0.2 * (maxi - mini), maxi + 0.2 * (maxi - mini))
451
+ for c, (window, data) in enumerate(feature):
452
+ ax = axes[c % num_overlap]
453
+ step = duration = window.duration / len(data)
454
+ frames = SlidingWindow(start=window.start, step=step, duration=duration)
455
+ window_feature = SlidingWindowFeature(data, frames, labels=feature.labels)
456
+ notebook.plot_feature(
457
+ window_feature,
458
+ ax=ax,
459
+ time=c % num_overlap == (num_overlap - 1),
460
+ ylim=ylim,
461
+ )
462
+ ax.set_prop_cycle(None)
463
+ # data = print_figure(fig, "png")
464
+ plt.savefig('./output')
465
+ plt.close(fig)
466
+
467
+ plt.rcParams["figure.figsize"] = figsize
468
+ return
ailia-models/code/pyannote_audio_utils/core/segment.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2021 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ """
30
+ #######
31
+ Segment
32
+ #######
33
+
34
+ .. plot:: pyplots/segment.py
35
+
36
+ :class:`pyannote.core.Segment` instances describe temporal fragments (*e.g.* of an audio file). The segment depicted above can be defined like that:
37
+
38
+ .. code-block:: ipython
39
+
40
+ In [1]: from pyannote.core import Segment
41
+
42
+ In [2]: segment = Segment(start=5, end=15)
43
+
44
+ In [3]: print(segment)
45
+
46
+ It is nothing more than 2-tuples augmented with several useful methods and properties:
47
+
48
+ .. code-block:: ipython
49
+
50
+ In [4]: start, end = segment
51
+
52
+ In [5]: start
53
+
54
+ In [6]: segment.end
55
+
56
+ In [7]: segment.duration # duration (read-only)
57
+
58
+ In [8]: segment.middle # middle (read-only)
59
+
60
+ In [9]: segment & Segment(3, 12) # intersection
61
+
62
+ In [10]: segment | Segment(3, 12) # union
63
+
64
+ In [11]: segment.overlaps(3) # does segment overlap time t=3?
65
+
66
+
67
+ Use `Segment.set_precision(ndigits)` to automatically round start and end timestamps to `ndigits` precision after the decimal point.
68
+ To ensure consistency between `Segment` instances, it is recommended to call this method only once, right after importing `pyannote.core.Segment`.
69
+
70
+ .. code-block:: ipython
71
+
72
+ In [12]: Segment(1/1000, 330/1000) == Segment(1/1000, 90/1000+240/1000)
73
+ Out[12]: False
74
+
75
+ In [13]: Segment.set_precision(ndigits=4)
76
+
77
+ In [14]: Segment(1/1000, 330/1000) == Segment(1/1000, 90/1000+240/1000)
78
+ Out[14]: True
79
+
80
+ See :class:`pyannote.core.Segment` for the complete reference.
81
+ """
82
+
83
+ import warnings
84
+ from typing import Union, Optional, Tuple, List, Iterator, Iterable
85
+
86
+ from .utils.types import Alignment
87
+
88
+ import numpy as np
89
+ from dataclasses import dataclass
90
+
91
+
92
+ # setting 'frozen' to True makes it hashable and immutable
93
+ @dataclass(frozen=True, order=True)
94
+ class Segment:
95
+ """
96
+ Time interval
97
+
98
+ Parameters
99
+ ----------
100
+ start : float
101
+ interval start time, in seconds.
102
+ end : float
103
+ interval end time, in seconds.
104
+
105
+
106
+ Segments can be compared and sorted using the standard operators:
107
+
108
+ >>> Segment(0, 1) == Segment(0, 1.)
109
+ True
110
+ >>> Segment(0, 1) != Segment(3, 4)
111
+ True
112
+ >>> Segment(0, 1) < Segment(2, 3)
113
+ True
114
+ >>> Segment(0, 1) < Segment(0, 2)
115
+ True
116
+ >>> Segment(1, 2) < Segment(0, 3)
117
+ False
118
+
119
+ Note
120
+ ----
121
+ A segment is smaller than another segment if one of these two conditions is verified:
122
+
123
+ - `segment.start < other_segment.start`
124
+ - `segment.start == other_segment.start` and `segment.end < other_segment.end`
125
+
126
+ """
127
+ start: float = 0.0
128
+ end: float = 0.0
129
+
130
+ @staticmethod
131
+ def set_precision(ndigits: Optional[int] = None):
132
+ """Automatically round start and end timestamps to `ndigits` precision after the decimal point
133
+
134
+ To ensure consistency between `Segment` instances, it is recommended to call this method only
135
+ once, right after importing `pyannote.core.Segment`.
136
+
137
+ Usage
138
+ -----
139
+ >>> from pyannote.core import Segment
140
+ >>> Segment.set_precision(2)
141
+ >>> Segment(1/3, 2/3)
142
+ <Segment(0.33, 0.67)>
143
+ """
144
+ global AUTO_ROUND_TIME
145
+ global SEGMENT_PRECISION
146
+
147
+ if ndigits is None:
148
+ # backward compatibility
149
+ AUTO_ROUND_TIME = False
150
+ # 1 μs (one microsecond)
151
+ SEGMENT_PRECISION = 1e-6
152
+ else:
153
+ AUTO_ROUND_TIME = True
154
+ SEGMENT_PRECISION = 10 ** (-ndigits)
155
+
156
+ def __bool__(self):
157
+ """Emptiness
158
+
159
+ >>> if segment:
160
+ ... # segment is not empty.
161
+ ... else:
162
+ ... # segment is empty.
163
+
164
+ Note
165
+ ----
166
+ A segment is considered empty if its end time is smaller than its
167
+ start time, or its duration is smaller than 1μs.
168
+ """
169
+ return bool((self.end - self.start) > SEGMENT_PRECISION)
170
+
171
+ def __post_init__(self):
172
+ """Round start and end up to SEGMENT_PRECISION precision (when required)"""
173
+ if AUTO_ROUND_TIME:
174
+ object.__setattr__(self, 'start', int(self.start / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
175
+ object.__setattr__(self, 'end', int(self.end / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
176
+
177
+ @property
178
+ def duration(self) -> float:
179
+ """Segment duration (read-only)"""
180
+ return self.end - self.start if self else 0.
181
+
182
+ @property
183
+ def middle(self) -> float:
184
+ """Segment mid-time (read-only)"""
185
+ return .5 * (self.start + self.end)
186
+
187
+ def __iter__(self) -> Iterator[float]:
188
+ """Unpack segment boundaries
189
+ >>> segment = Segment(start, end)
190
+ >>> start, end = segment
191
+ """
192
+ yield self.start
193
+ yield self.end
194
+
195
+ def copy(self) -> 'Segment':
196
+ """Get a copy of the segment
197
+
198
+ Returns
199
+ -------
200
+ copy : Segment
201
+ Copy of the segment.
202
+ """
203
+ return Segment(start=self.start, end=self.end)
204
+
205
+ # ------------------------------------------------------- #
206
+ # Inclusion (in), intersection (&), union (|) and gap (^) #
207
+ # ------------------------------------------------------- #
208
+
209
+ def __contains__(self, other: 'Segment'):
210
+ """Inclusion
211
+
212
+ >>> segment = Segment(start=0, end=10)
213
+ >>> Segment(start=3, end=10) in segment:
214
+ True
215
+ >>> Segment(start=5, end=15) in segment:
216
+ False
217
+ """
218
+ return (self.start <= other.start) and (self.end >= other.end)
219
+
220
+ def __and__(self, other):
221
+ """Intersection
222
+
223
+ >>> segment = Segment(0, 10)
224
+ >>> other_segment = Segment(5, 15)
225
+ >>> segment & other_segment
226
+ <Segment(5, 10)>
227
+
228
+ Note
229
+ ----
230
+ When the intersection is empty, an empty segment is returned:
231
+
232
+ >>> segment = Segment(0, 10)
233
+ >>> other_segment = Segment(15, 20)
234
+ >>> intersection = segment & other_segment
235
+ >>> if not intersection:
236
+ ... # intersection is empty.
237
+ """
238
+ start = max(self.start, other.start)
239
+ end = min(self.end, other.end)
240
+ return Segment(start=start, end=end)
241
+
242
+ def intersects(self, other: 'Segment') -> bool:
243
+ """Check whether two segments intersect each other
244
+
245
+ Parameters
246
+ ----------
247
+ other : Segment
248
+ Other segment
249
+
250
+ Returns
251
+ -------
252
+ intersect : bool
253
+ True if segments intersect, False otherwise
254
+ """
255
+
256
+ return (self.start < other.start and
257
+ other.start < self.end - SEGMENT_PRECISION) or \
258
+ (self.start > other.start and
259
+ self.start < other.end - SEGMENT_PRECISION) or \
260
+ (self.start == other.start)
261
+
262
+ def overlaps(self, t: float) -> bool:
263
+ """Check if segment overlaps a given time
264
+
265
+ Parameters
266
+ ----------
267
+ t : float
268
+ Time, in seconds.
269
+
270
+ Returns
271
+ -------
272
+ overlap: bool
273
+ True if segment overlaps time t, False otherwise.
274
+ """
275
+ return self.start <= t and self.end >= t
276
+
277
+ def __or__(self, other: 'Segment') -> 'Segment':
278
+ """Union
279
+
280
+ >>> segment = Segment(0, 10)
281
+ >>> other_segment = Segment(5, 15)
282
+ >>> segment | other_segment
283
+ <Segment(0, 15)>
284
+
285
+ Note
286
+ ----
287
+ When a gap exists between the segment, their union covers the gap as well:
288
+
289
+ >>> segment = Segment(0, 10)
290
+ >>> other_segment = Segment(15, 20)
291
+ >>> segment | other_segment
292
+ <Segment(0, 20)
293
+ """
294
+
295
+ # if segment is empty, union is the other one
296
+ if not self:
297
+ return other
298
+ # if other one is empty, union is self
299
+ if not other:
300
+ return self
301
+
302
+ # otherwise, do what's meant to be...
303
+ start = min(self.start, other.start)
304
+ end = max(self.end, other.end)
305
+ return Segment(start=start, end=end)
306
+
307
+ def __xor__(self, other: 'Segment') -> 'Segment':
308
+ """Gap
309
+
310
+ >>> segment = Segment(0, 10)
311
+ >>> other_segment = Segment(15, 20)
312
+ >>> segment ^ other_segment
313
+ <Segment(10, 15)
314
+
315
+ Note
316
+ ----
317
+ The gap between a segment and an empty segment is not defined.
318
+
319
+ >>> segment = Segment(0, 10)
320
+ >>> empty_segment = Segment(11, 11)
321
+ >>> segment ^ empty_segment
322
+ ValueError: The gap between a segment and an empty segment is not defined.
323
+ """
324
+
325
+ # if segment is empty, xor is not defined
326
+ if (not self) or (not other):
327
+ raise ValueError(
328
+ 'The gap between a segment and an empty segment '
329
+ 'is not defined.')
330
+
331
+ start = min(self.end, other.end)
332
+ end = max(self.start, other.start)
333
+ return Segment(start=start, end=end)
334
+
335
+ def _str_helper(self, seconds: float) -> str:
336
+ from datetime import timedelta
337
+ negative = seconds < 0
338
+ seconds = abs(seconds)
339
+ td = timedelta(seconds=seconds)
340
+ seconds = td.seconds + 86400 * td.days
341
+ microseconds = td.microseconds
342
+ hours, remainder = divmod(seconds, 3600)
343
+ minutes, seconds = divmod(remainder, 60)
344
+ return '%s%02d:%02d:%02d.%03d' % (
345
+ '-' if negative else ' ', hours, minutes,
346
+ seconds, microseconds / 1000)
347
+
348
+ def __str__(self):
349
+ """Human-readable representation
350
+
351
+ >>> print(Segment(1337, 1337 + 0.42))
352
+ [ 00:22:17.000 --> 00:22:17.420]
353
+
354
+ Note
355
+ ----
356
+ Empty segments are printed as "[]"
357
+ """
358
+ if self:
359
+ return '[%s --> %s]' % (self._str_helper(self.start),
360
+ self._str_helper(self.end))
361
+ return '[]'
362
+
363
+ def __repr__(self):
364
+ """Computer-readable representation
365
+
366
+ >>> Segment(1337, 1337 + 0.42)
367
+ <Segment(1337, 1337.42)>
368
+ """
369
+ return '<Segment(%g, %g)>' % (self.start, self.end)
370
+
371
+ def _repr_png_(self):
372
+ """IPython notebook support
373
+
374
+ See also
375
+ --------
376
+ :mod:`pyannote.core.notebook`
377
+ """
378
+ from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
379
+ if not MATPLOTLIB_IS_AVAILABLE:
380
+ warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
381
+ return None
382
+
383
+ from .notebook import repr_segment
384
+ try:
385
+ return repr_segment(self)
386
+ except ImportError:
387
+ warnings.warn(
388
+ f"Couldn't import matplotlib to render the vizualization for object {self}. To enable, install the required dependencies with 'pip install pyannore.core[notebook]'")
389
+ return None
390
+
391
+
392
+ class SlidingWindow:
393
+ """Sliding window
394
+
395
+ Parameters
396
+ ----------
397
+ duration : float > 0, optional
398
+ Window duration, in seconds. Default is 30 ms.
399
+ step : float > 0, optional
400
+ Step between two consecutive position, in seconds. Default is 10 ms.
401
+ start : float, optional
402
+ First start position of window, in seconds. Default is 0.
403
+ end : float > `start`, optional
404
+ Default is infinity (ie. window keeps sliding forever)
405
+
406
+ Examples
407
+ --------
408
+
409
+ >>> sw = SlidingWindow(duration, step, start)
410
+ >>> frame_range = (a, b)
411
+ >>> frame_range == sw.toFrameRange(sw.toSegment(*frame_range))
412
+ ... True
413
+
414
+ >>> segment = Segment(A, B)
415
+ >>> new_segment = sw.toSegment(*sw.toFrameRange(segment))
416
+ >>> abs(segment) - abs(segment & new_segment) < .5 * sw.step
417
+
418
+ >>> sw = SlidingWindow(end=0.1)
419
+ >>> print(next(sw))
420
+ [ 00:00:00.000 --> 00:00:00.030]
421
+ >>> print(next(sw))
422
+ [ 00:00:00.010 --> 00:00:00.040]
423
+ """
424
+
425
+ def __init__(self, duration=0.030, step=0.010, start=0.000, end=None):
426
+
427
+ # duration must be a float > 0
428
+ if duration <= 0:
429
+ raise ValueError("'duration' must be a float > 0.")
430
+ self.__duration = duration
431
+
432
+ # step must be a float > 0
433
+ if step <= 0:
434
+ raise ValueError("'step' must be a float > 0.")
435
+ self.__step: float = step
436
+
437
+ # start must be a float.
438
+ self.__start: float = start
439
+
440
+ # if end is not provided, set it to infinity
441
+ if end is None:
442
+ self.__end: float = np.inf
443
+ else:
444
+ # end must be greater than start
445
+ if end <= start:
446
+ raise ValueError("'end' must be greater than 'start'.")
447
+ self.__end: float = end
448
+
449
+ # current index of iterator
450
+ self.__i: int = -1
451
+
452
+ @property
453
+ def start(self) -> float:
454
+ """Sliding window start time in seconds."""
455
+ return self.__start
456
+
457
+ @property
458
+ def end(self) -> float:
459
+ """Sliding window end time in seconds."""
460
+ return self.__end
461
+
462
+ @property
463
+ def step(self) -> float:
464
+ """Sliding window step in seconds."""
465
+ return self.__step
466
+
467
+ @property
468
+ def duration(self) -> float:
469
+ """Sliding window duration in seconds."""
470
+ return self.__duration
471
+
472
+ def closest_frame(self, t: float) -> int:
473
+ """Closest frame to timestamp.
474
+
475
+ Parameters
476
+ ----------
477
+ t : float
478
+ Timestamp, in seconds.
479
+
480
+ Returns
481
+ -------
482
+ index : int
483
+ Index of frame whose middle is the closest to `timestamp`
484
+
485
+ """
486
+ return int(np.rint(
487
+ (t - self.__start - .5 * self.__duration) / self.__step
488
+ ))
489
+
490
+ def samples(self, from_duration: float, mode: Alignment = 'strict') -> int:
491
+ """Number of frames
492
+
493
+ Parameters
494
+ ----------
495
+ from_duration : float
496
+ Duration in seconds.
497
+ mode : {'strict', 'loose', 'center'}
498
+ In 'strict' mode, computes the maximum number of consecutive frames
499
+ that can be fitted into a segment with duration `from_duration`.
500
+ In 'loose' mode, computes the maximum number of consecutive frames
501
+ intersecting a segment with duration `from_duration`.
502
+ In 'center' mode, computes the average number of consecutive frames
503
+ where the first one is centered on the start time and the last one
504
+ is centered on the end time of a segment with duration
505
+ `from_duration`.
506
+
507
+ """
508
+ if mode == 'strict':
509
+ return int(np.floor((from_duration - self.duration) / self.step)) + 1
510
+
511
+ elif mode == 'loose':
512
+ return int(np.floor((from_duration + self.duration) / self.step))
513
+
514
+ elif mode == 'center':
515
+ return int(np.rint((from_duration / self.step)))
516
+
517
+ def crop(self, focus: Union[Segment, 'Timeline'],
518
+ mode: Alignment = 'loose',
519
+ fixed: Optional[float] = None,
520
+ return_ranges: Optional[bool] = False) -> \
521
+ Union[np.ndarray, List[List[int]]]:
522
+ """Crop sliding window
523
+
524
+ Parameters
525
+ ----------
526
+ focus : `Segment` or `Timeline`
527
+ mode : {'strict', 'loose', 'center'}, optional
528
+ In 'strict' mode, only indices of segments fully included in
529
+ 'focus' support are returned. In 'loose' mode, indices of any
530
+ intersecting segments are returned. In 'center' mode, first and
531
+ last positions are chosen to be the positions whose centers are the
532
+ closest to 'focus' start and end times. Defaults to 'loose'.
533
+ fixed : float, optional
534
+ Overrides `Segment` 'focus' duration and ensures that the number of
535
+ returned frames is fixed (which might otherwise not be the case
536
+ because of rounding erros).
537
+ return_ranges : bool, optional
538
+ Return as list of ranges. Defaults to indices numpy array.
539
+
540
+ Returns
541
+ -------
542
+ indices : np.array (or list of ranges)
543
+ Array of unique indices of matching segments
544
+ """
545
+
546
+ from .timeline import Timeline
547
+
548
+ if not isinstance(focus, (Segment, Timeline)):
549
+ msg = '"focus" must be a `Segment` or `Timeline` instance.'
550
+ raise TypeError(msg)
551
+
552
+ if isinstance(focus, Timeline):
553
+
554
+ if fixed is not None:
555
+ msg = "'fixed' is not supported with `Timeline` 'focus'."
556
+ raise ValueError(msg)
557
+
558
+ if return_ranges:
559
+ ranges = []
560
+
561
+ for i, s in enumerate(focus.support()):
562
+ rng = self.crop(s, mode=mode, fixed=fixed,
563
+ return_ranges=True)
564
+
565
+ # if first or disjoint segment, add it
566
+ if i == 0 or rng[0][0] > ranges[-1][1]:
567
+ ranges += rng
568
+
569
+ # if overlapping segment, update last range
570
+ else:
571
+ ranges[-1][1] = rng[0][1]
572
+
573
+ return ranges
574
+
575
+ # concatenate all indices
576
+ indices = np.hstack([
577
+ self.crop(s, mode=mode, fixed=fixed, return_ranges=False)
578
+ for s in focus.support()])
579
+
580
+ # remove duplicate indices
581
+ return np.unique(indices)
582
+
583
+ # 'focus' is a `Segment` instance
584
+
585
+ if mode == 'loose':
586
+
587
+ # find smallest integer i such that
588
+ # self.start + i x self.step + self.duration >= focus.start
589
+ i_ = (focus.start - self.duration - self.start) / self.step
590
+ i = int(np.ceil(i_))
591
+
592
+ if fixed is None:
593
+ # find largest integer j such that
594
+ # self.start + j x self.step <= focus.end
595
+ j_ = (focus.end - self.start) / self.step
596
+ j = int(np.floor(j_))
597
+ rng = (i, j + 1)
598
+
599
+ else:
600
+ n = self.samples(fixed, mode='loose')
601
+ rng = (i, i + n)
602
+
603
+ elif mode == 'strict':
604
+
605
+ # find smallest integer i such that
606
+ # self.start + i x self.step >= focus.start
607
+ i_ = (focus.start - self.start) / self.step
608
+ i = int(np.ceil(i_))
609
+
610
+ if fixed is None:
611
+
612
+ # find largest integer j such that
613
+ # self.start + j x self.step + self.duration <= focus.end
614
+ j_ = (focus.end - self.duration - self.start) / self.step
615
+ j = int(np.floor(j_))
616
+ rng = (i, j + 1)
617
+
618
+ else:
619
+ n = self.samples(fixed, mode='strict')
620
+ rng = (i, i + n)
621
+
622
+ elif mode == 'center':
623
+
624
+ # find window position whose center is the closest to focus.start
625
+ i = self.closest_frame(focus.start)
626
+
627
+ if fixed is None:
628
+ # find window position whose center is the closest to focus.end
629
+ j = self.closest_frame(focus.end)
630
+ rng = (i, j + 1)
631
+ else:
632
+ n = self.samples(fixed, mode='center')
633
+ rng = (i, i + n)
634
+
635
+ else:
636
+ msg = "'mode' must be one of {'loose', 'strict', 'center'}."
637
+ raise ValueError(msg)
638
+
639
+ if return_ranges:
640
+ return [list(rng)]
641
+
642
+ return np.array(range(*rng), dtype=np.int64)
643
+
644
+ def segmentToRange(self, segment: Segment) -> Tuple[int, int]:
645
+ warnings.warn("Deprecated in favor of `segment_to_range`",
646
+ DeprecationWarning)
647
+ return self.segment_to_range(segment)
648
+
649
+ def segment_to_range(self, segment: Segment) -> Tuple[int, int]:
650
+ """Convert segment to 0-indexed frame range
651
+
652
+ Parameters
653
+ ----------
654
+ segment : Segment
655
+
656
+ Returns
657
+ -------
658
+ i0 : int
659
+ Index of first frame
660
+ n : int
661
+ Number of frames
662
+
663
+ Examples
664
+ --------
665
+
666
+ >>> window = SlidingWindow()
667
+ >>> print window.segment_to_range(Segment(10, 15))
668
+ i0, n
669
+
670
+ """
671
+ # find closest frame to segment start
672
+ i0 = self.closest_frame(segment.start)
673
+
674
+ # number of steps to cover segment duration
675
+ n = int(segment.duration / self.step) + 1
676
+
677
+ return i0, n
678
+
679
+ def rangeToSegment(self, i0: int, n: int) -> Segment:
680
+ warnings.warn("This is deprecated in favor of `range_to_segment`",
681
+ DeprecationWarning)
682
+ return self.range_to_segment(i0, n)
683
+
684
+ def range_to_segment(self, i0: int, n: int) -> Segment:
685
+ """Convert 0-indexed frame range to segment
686
+
687
+ Each frame represents a unique segment of duration 'step', centered on
688
+ the middle of the frame.
689
+
690
+ The very first frame (i0 = 0) is the exception. It is extended to the
691
+ sliding window start time.
692
+
693
+ Parameters
694
+ ----------
695
+ i0 : int
696
+ Index of first frame
697
+ n : int
698
+ Number of frames
699
+
700
+ Returns
701
+ -------
702
+ segment : Segment
703
+
704
+ Examples
705
+ --------
706
+
707
+ >>> window = SlidingWindow()
708
+ >>> print window.range_to_segment(3, 2)
709
+ [ --> ]
710
+
711
+ """
712
+
713
+ # frame start time
714
+ # start = self.start + i0 * self.step
715
+ # frame middle time
716
+ # start += .5 * self.duration
717
+ # subframe start time
718
+ # start -= .5 * self.step
719
+ start = self.__start + (i0 - .5) * self.__step + .5 * self.__duration
720
+ duration = n * self.__step
721
+ end = start + duration
722
+
723
+ # extend segment to the beginning of the timeline
724
+ if i0 == 0:
725
+ start = self.start
726
+
727
+ return Segment(start, end)
728
+
729
+ def samplesToDuration(self, nSamples: int) -> float:
730
+ warnings.warn("This is deprecated in favor of `samples_to_duration`",
731
+ DeprecationWarning)
732
+ return self.samples_to_duration(nSamples)
733
+
734
+ def samples_to_duration(self, n_samples: int) -> float:
735
+ """Returns duration of samples"""
736
+ return self.range_to_segment(0, n_samples).duration
737
+
738
+ def durationToSamples(self, duration: float) -> int:
739
+ warnings.warn("This is deprecated in favor of `duration_to_samples`",
740
+ DeprecationWarning)
741
+ return self.duration_to_samples(duration)
742
+
743
+ def duration_to_samples(self, duration: float) -> int:
744
+ """Returns samples in duration"""
745
+ return self.segment_to_range(Segment(0, duration))[1]
746
+
747
+ def __getitem__(self, i: int) -> Segment:
748
+ """
749
+ Parameters
750
+ ----------
751
+ i : int
752
+ Index of sliding window position
753
+
754
+ Returns
755
+ -------
756
+ segment : :class:`Segment`
757
+ Sliding window at ith position
758
+
759
+ """
760
+
761
+ # window start time at ith position
762
+ start = self.__start + i * self.__step
763
+
764
+ # in case segment starts after the end,
765
+ # return an empty segment
766
+ if start >= self.__end:
767
+ return None
768
+
769
+ return Segment(start=start, end=start + self.__duration)
770
+
771
+ def next(self) -> Segment:
772
+ return self.__next__()
773
+
774
+ def __next__(self) -> Segment:
775
+ self.__i += 1
776
+ window = self[self.__i]
777
+
778
+ if window:
779
+ return window
780
+ else:
781
+ raise StopIteration()
782
+
783
+ def __iter__(self) -> 'SlidingWindow':
784
+ """Sliding window iterator
785
+
786
+ Use expression 'for segment in sliding_window'
787
+
788
+ Examples
789
+ --------
790
+
791
+ >>> window = SlidingWindow(end=0.1)
792
+ >>> for segment in window:
793
+ ... print(segment)
794
+ [ 00:00:00.000 --> 00:00:00.030]
795
+ [ 00:00:00.010 --> 00:00:00.040]
796
+ [ 00:00:00.020 --> 00:00:00.050]
797
+ [ 00:00:00.030 --> 00:00:00.060]
798
+ [ 00:00:00.040 --> 00:00:00.070]
799
+ [ 00:00:00.050 --> 00:00:00.080]
800
+ [ 00:00:00.060 --> 00:00:00.090]
801
+ [ 00:00:00.070 --> 00:00:00.100]
802
+ [ 00:00:00.080 --> 00:00:00.110]
803
+ [ 00:00:00.090 --> 00:00:00.120]
804
+ """
805
+
806
+ # reset iterator index
807
+ self.__i = -1
808
+ return self
809
+
810
+ def __len__(self) -> int:
811
+ """Number of positions
812
+
813
+ Equivalent to len([segment for segment in window])
814
+
815
+ Returns
816
+ -------
817
+ length : int
818
+ Number of positions taken by the sliding window
819
+ (from start times to end times)
820
+
821
+ """
822
+ if np.isinf(self.__end):
823
+ raise ValueError('infinite sliding window.')
824
+
825
+ # start looking for last position
826
+ # based on frame closest to the end
827
+ i = self.closest_frame(self.__end)
828
+
829
+ while (self[i]):
830
+ i += 1
831
+ length = i
832
+
833
+ return length
834
+
835
+ def copy(self) -> 'SlidingWindow':
836
+ """Duplicate sliding window"""
837
+ duration = self.duration
838
+ step = self.step
839
+ start = self.start
840
+ end = self.end
841
+ sliding_window = self.__class__(
842
+ duration=duration, step=step, start=start, end=end
843
+ )
844
+ return sliding_window
845
+
846
+ def __call__(self,
847
+ support: Union[Segment, 'Timeline'],
848
+ align_last: bool = False) -> Iterable[Segment]:
849
+ """Slide window over support
850
+
851
+ Parameter
852
+ ---------
853
+ support : Segment or Timeline
854
+ Support on which to slide the window.
855
+ align_last : bool, optional
856
+ Yield a final segment so that it aligns exactly with end of support.
857
+
858
+ Yields
859
+ ------
860
+ chunk : Segment
861
+
862
+ Example
863
+ -------
864
+ >>> window = SlidingWindow(duration=2., step=1.)
865
+ >>> for chunk in window(Segment(3, 7.5)):
866
+ ... print(tuple(chunk))
867
+ (3.0, 5.0)
868
+ (4.0, 6.0)
869
+ (5.0, 7.0)
870
+ >>> for chunk in window(Segment(3, 7.5), align_last=True):
871
+ ... print(tuple(chunk))
872
+ (3.0, 5.0)
873
+ (4.0, 6.0)
874
+ (5.0, 7.0)
875
+ (5.5, 7.5)
876
+ """
877
+
878
+ from pyannote.core import Timeline
879
+ if isinstance(support, Timeline):
880
+ segments = support
881
+
882
+ elif isinstance(support, Segment):
883
+ segments = Timeline(segments=[support])
884
+
885
+ else:
886
+ msg = (
887
+ f'"support" must be either a Segment or a Timeline '
888
+ f'instance (is {type(support)})'
889
+ )
890
+ raise TypeError(msg)
891
+
892
+ for segment in segments:
893
+
894
+ if segment.duration < self.duration:
895
+ continue
896
+
897
+ window = SlidingWindow(duration=self.duration,
898
+ step=self.step,
899
+ start=segment.start,
900
+ end=segment.end)
901
+
902
+ for s in window:
903
+ # ugly hack to account for floating point imprecision
904
+ if s in segment:
905
+ yield s
906
+ last = s
907
+
908
+ if align_last and last.end < segment.end:
909
+ yield Segment(start=segment.end - self.duration,
910
+ end=segment.end)
ailia-models/code/pyannote_audio_utils/core/timeline.py ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2020 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ # Grant JENKS - http://www.grantjenks.com/
29
+ # Paul LERNER
30
+
31
+ """
32
+ ########
33
+ Timeline
34
+ ########
35
+
36
+ .. plot:: pyplots/timeline.py
37
+
38
+ :class:`pyannote.core.Timeline` instances are ordered sets of non-empty
39
+ segments:
40
+
41
+ - ordered, because segments are sorted by start time (and end time in case of tie)
42
+ - set, because one cannot add twice the same segment
43
+ - non-empty, because one cannot add empty segments (*i.e.* start >= end)
44
+
45
+ There are two ways to define the timeline depicted above:
46
+
47
+ .. code-block:: ipython
48
+
49
+ In [25]: from pyannote.core import Timeline, Segment
50
+
51
+ In [26]: timeline = Timeline()
52
+ ....: timeline.add(Segment(1, 5))
53
+ ....: timeline.add(Segment(6, 8))
54
+ ....: timeline.add(Segment(12, 18))
55
+ ....: timeline.add(Segment(7, 20))
56
+ ....:
57
+
58
+ In [27]: segments = [Segment(1, 5), Segment(6, 8), Segment(12, 18), Segment(7, 20)]
59
+ ....: timeline = Timeline(segments=segments, uri='my_audio_file') # faster
60
+ ....:
61
+
62
+ In [9]: for segment in timeline:
63
+ ...: print(segment)
64
+ ...:
65
+ [ 00:00:01.000 --> 00:00:05.000]
66
+ [ 00:00:06.000 --> 00:00:08.000]
67
+ [ 00:00:07.000 --> 00:00:20.000]
68
+ [ 00:00:12.000 --> 00:00:18.000]
69
+
70
+
71
+ .. note::
72
+
73
+ The optional *uri* keyword argument can be used to remember which document it describes.
74
+
75
+ Several convenient methods are available. Here are a few examples:
76
+
77
+ .. code-block:: ipython
78
+
79
+ In [3]: timeline.extent() # extent
80
+ Out[3]: <Segment(1, 20)>
81
+
82
+ In [5]: timeline.support() # support
83
+ Out[5]: <Timeline(uri=my_audio_file, segments=[<Segment(1, 5)>, <Segment(6, 20)>])>
84
+
85
+ In [6]: timeline.duration() # support duration
86
+ Out[6]: 18
87
+
88
+
89
+ See :class:`pyannote.core.Timeline` for the complete reference.
90
+ """
91
+ import warnings
92
+ from typing import (Optional, Iterable, List, Union, Callable,
93
+ TextIO, Tuple, TYPE_CHECKING, Iterator, Dict, Text)
94
+
95
+ from sortedcontainers import SortedList
96
+
97
+ from . import PYANNOTE_SEGMENT
98
+ from .segment import Segment
99
+ from .utils.types import Support, Label, CropMode
100
+
101
+
102
+ # this is a moderately ugly way to import `Annotation` to the namespace
103
+ # without causing some circular imports :
104
+ # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
105
+ if TYPE_CHECKING:
106
+ from .annotation import Annotation
107
+ import pandas as pd
108
+
109
+
110
+ # =====================================================================
111
+ # Timeline class
112
+ # =====================================================================
113
+
114
+
115
+ class Timeline:
116
+ """
117
+ Ordered set of segments.
118
+
119
+ A timeline can be seen as an ordered set of non-empty segments (Segment).
120
+ Segments can overlap -- though adding an already exisiting segment to a
121
+ timeline does nothing.
122
+
123
+ Parameters
124
+ ----------
125
+ segments : Segment iterator, optional
126
+ initial set of (non-empty) segments
127
+ uri : string, optional
128
+ name of segmented resource
129
+
130
+ Returns
131
+ -------
132
+ timeline : Timeline
133
+ New timeline
134
+ """
135
+
136
+ @classmethod
137
+ def from_df(cls, df: 'pd.DataFrame', uri: Optional[str] = None) -> 'Timeline':
138
+ segments = list(df[PYANNOTE_SEGMENT])
139
+ timeline = cls(segments=segments, uri=uri)
140
+ return timeline
141
+
142
+ def __init__(self,
143
+ segments: Optional[Iterable[Segment]] = None,
144
+ uri: str = None):
145
+ if segments is None:
146
+ segments = ()
147
+
148
+ # set of segments (used for checking inclusion)
149
+ # Store only non-empty Segments.
150
+ segments_set = set([segment for segment in segments if segment])
151
+
152
+ self.segments_set_ = segments_set
153
+
154
+ # sorted list of segments (used for sorted iteration)
155
+ self.segments_list_ = SortedList(segments_set)
156
+
157
+ # sorted list of (possibly redundant) segment boundaries
158
+ boundaries = (boundary for segment in segments_set for boundary in segment)
159
+ self.segments_boundaries_ = SortedList(boundaries)
160
+
161
+ # path to (or any identifier of) segmented resource
162
+ self.uri: str = uri
163
+
164
+ def __len__(self):
165
+ """Number of segments
166
+
167
+ >>> len(timeline) # timeline contains three segments
168
+ 3
169
+ """
170
+ return len(self.segments_set_)
171
+
172
+ def __nonzero__(self):
173
+ return self.__bool__()
174
+
175
+ def __bool__(self):
176
+ """Emptiness
177
+
178
+ >>> if timeline:
179
+ ... # timeline is not empty
180
+ ... else:
181
+ ... # timeline is empty
182
+ """
183
+ return len(self.segments_set_) > 0
184
+
185
+ def __iter__(self) -> Iterable[Segment]:
186
+ """Iterate over segments (in chronological order)
187
+
188
+ >>> for segment in timeline:
189
+ ... # do something with the segment
190
+
191
+ See also
192
+ --------
193
+ :class:`pyannote.core.Segment` describes how segments are sorted.
194
+ """
195
+ return iter(self.segments_list_)
196
+
197
+ def __getitem__(self, k: int) -> Segment:
198
+ """Get segment by index (in chronological order)
199
+
200
+ >>> first_segment = timeline[0]
201
+ >>> penultimate_segment = timeline[-2]
202
+ """
203
+ return self.segments_list_[k]
204
+
205
+ def __eq__(self, other: 'Timeline'):
206
+ """Equality
207
+
208
+ Two timelines are equal if and only if their segments are equal.
209
+
210
+ >>> timeline1 = Timeline([Segment(0, 1), Segment(2, 3)])
211
+ >>> timeline2 = Timeline([Segment(2, 3), Segment(0, 1)])
212
+ >>> timeline3 = Timeline([Segment(2, 3)])
213
+ >>> timeline1 == timeline2
214
+ True
215
+ >>> timeline1 == timeline3
216
+ False
217
+ """
218
+ return self.segments_set_ == other.segments_set_
219
+
220
+ def __ne__(self, other: 'Timeline'):
221
+ """Inequality"""
222
+ return self.segments_set_ != other.segments_set_
223
+
224
+ def index(self, segment: Segment) -> int:
225
+ """Get index of (existing) segment
226
+
227
+ Parameters
228
+ ----------
229
+ segment : Segment
230
+ Segment that is being looked for.
231
+
232
+ Returns
233
+ -------
234
+ position : int
235
+ Index of `segment` in timeline
236
+
237
+ Raises
238
+ ------
239
+ ValueError if `segment` is not present.
240
+ """
241
+ return self.segments_list_.index(segment)
242
+
243
+ def add(self, segment: Segment) -> 'Timeline':
244
+ """Add a segment (in place)
245
+
246
+ Parameters
247
+ ----------
248
+ segment : Segment
249
+ Segment that is being added
250
+
251
+ Returns
252
+ -------
253
+ self : Timeline
254
+ Updated timeline.
255
+
256
+ Note
257
+ ----
258
+ If the timeline already contains this segment, it will not be added
259
+ again, as a timeline is meant to be a **set** of segments (not a list).
260
+
261
+ If the segment is empty, it will not be added either, as a timeline
262
+ only contains non-empty segments.
263
+ """
264
+
265
+ segments_set_ = self.segments_set_
266
+ if segment in segments_set_ or not segment:
267
+ return self
268
+
269
+ segments_set_.add(segment)
270
+
271
+ self.segments_list_.add(segment)
272
+
273
+ segments_boundaries_ = self.segments_boundaries_
274
+ segments_boundaries_.add(segment.start)
275
+ segments_boundaries_.add(segment.end)
276
+
277
+ return self
278
+
279
+ def remove(self, segment: Segment) -> 'Timeline':
280
+ """Remove a segment (in place)
281
+
282
+ Parameters
283
+ ----------
284
+ segment : Segment
285
+ Segment that is being removed
286
+
287
+ Returns
288
+ -------
289
+ self : Timeline
290
+ Updated timeline.
291
+
292
+ Note
293
+ ----
294
+ If the timeline does not contain this segment, this does nothing
295
+ """
296
+
297
+ segments_set_ = self.segments_set_
298
+ if segment not in segments_set_:
299
+ return self
300
+
301
+ segments_set_.remove(segment)
302
+
303
+ self.segments_list_.remove(segment)
304
+
305
+ segments_boundaries_ = self.segments_boundaries_
306
+ segments_boundaries_.remove(segment.start)
307
+ segments_boundaries_.remove(segment.end)
308
+
309
+ return self
310
+
311
+ def discard(self, segment: Segment) -> 'Timeline':
312
+ """Same as `remove`
313
+
314
+ See also
315
+ --------
316
+ :func:`pyannote.core.Timeline.remove`
317
+ """
318
+ return self.remove(segment)
319
+
320
+ def __ior__(self, timeline: 'Timeline') -> 'Timeline':
321
+ return self.update(timeline)
322
+
323
+ def update(self, timeline: Segment) -> 'Timeline':
324
+ """Add every segments of an existing timeline (in place)
325
+
326
+ Parameters
327
+ ----------
328
+ timeline : Timeline
329
+ Timeline whose segments are being added
330
+
331
+ Returns
332
+ -------
333
+ self : Timeline
334
+ Updated timeline
335
+
336
+ Note
337
+ ----
338
+ Only segments that do not already exist will be added, as a timeline is
339
+ meant to be a **set** of segments (not a list).
340
+
341
+ """
342
+
343
+ segments_set = self.segments_set_
344
+
345
+ segments_set |= timeline.segments_set_
346
+
347
+ # sorted list of segments (used for sorted iteration)
348
+ self.segments_list_ = SortedList(segments_set)
349
+
350
+ # sorted list of (possibly redundant) segment boundaries
351
+ boundaries = (boundary for segment in segments_set for boundary in segment)
352
+ self.segments_boundaries_ = SortedList(boundaries)
353
+
354
+ return self
355
+
356
+ def __or__(self, timeline: 'Timeline') -> 'Timeline':
357
+ return self.union(timeline)
358
+
359
+ def union(self, timeline: 'Timeline') -> 'Timeline':
360
+ """Create new timeline made of union of segments
361
+
362
+ Parameters
363
+ ----------
364
+ timeline : Timeline
365
+ Timeline whose segments are being added
366
+
367
+ Returns
368
+ -------
369
+ union : Timeline
370
+ New timeline containing the union of both timelines.
371
+
372
+ Note
373
+ ----
374
+ This does the same as timeline.update(...) except it returns a new
375
+ timeline, and the original one is not modified.
376
+ """
377
+ segments = self.segments_set_ | timeline.segments_set_
378
+ return Timeline(segments=segments, uri=self.uri)
379
+
380
+ def co_iter(self, other: 'Timeline') -> Iterator[Tuple[Segment, Segment]]:
381
+ """Iterate over pairs of intersecting segments
382
+
383
+ >>> timeline1 = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)])
384
+ >>> timeline2 = Timeline([Segment(1, 3), Segment(3, 5)])
385
+ >>> for segment1, segment2 in timeline1.co_iter(timeline2):
386
+ ... print(segment1, segment2)
387
+ (<Segment(0, 2)>, <Segment(1, 3)>)
388
+ (<Segment(1, 2)>, <Segment(1, 3)>)
389
+ (<Segment(3, 4)>, <Segment(3, 5)>)
390
+
391
+ Parameters
392
+ ----------
393
+ other : Timeline
394
+ Second timeline
395
+
396
+ Returns
397
+ -------
398
+ iterable : (Segment, Segment) iterable
399
+ Yields pairs of intersecting segments in chronological order.
400
+ """
401
+
402
+ for segment in self.segments_list_:
403
+
404
+ # iterate over segments that starts before 'segment' ends
405
+ temp = Segment(start=segment.end, end=segment.end)
406
+ for other_segment in other.segments_list_.irange(maximum=temp):
407
+ if segment.intersects(other_segment):
408
+ yield segment, other_segment
409
+
410
+ def crop_iter(self,
411
+ support: Support,
412
+ mode: CropMode = 'intersection',
413
+ returns_mapping: bool = False) \
414
+ -> Iterator[Union[Tuple[Segment, Segment], Segment]]:
415
+ """Like `crop` but returns a segment iterator instead
416
+
417
+ See also
418
+ --------
419
+ :func:`pyannote.core.Timeline.crop`
420
+ """
421
+
422
+ if mode not in {'loose', 'strict', 'intersection'}:
423
+ raise ValueError("Mode must be one of 'loose', 'strict', or "
424
+ "'intersection'.")
425
+
426
+ if not isinstance(support, (Segment, Timeline)):
427
+ raise TypeError("Support must be a Segment or a Timeline.")
428
+
429
+ if isinstance(support, Segment):
430
+ # corner case where "support" is empty
431
+ if support:
432
+ segments = [support]
433
+ else:
434
+ segments = []
435
+
436
+ support = Timeline(segments=segments, uri=self.uri)
437
+ for yielded in self.crop_iter(support, mode=mode,
438
+ returns_mapping=returns_mapping):
439
+ yield yielded
440
+ return
441
+
442
+ # if 'support' is a `Timeline`, we use its support
443
+ support = support.support()
444
+
445
+ # loose mode
446
+ if mode == 'loose':
447
+ for segment, _ in self.co_iter(support):
448
+ yield segment
449
+ return
450
+
451
+ # strict mode
452
+ if mode == 'strict':
453
+ for segment, other_segment in self.co_iter(support):
454
+ if segment in other_segment:
455
+ yield segment
456
+ return
457
+
458
+ # intersection mode
459
+ for segment, other_segment in self.co_iter(support):
460
+ mapped_to = segment & other_segment
461
+ if not mapped_to:
462
+ continue
463
+ if returns_mapping:
464
+ yield segment, mapped_to
465
+ else:
466
+ yield mapped_to
467
+
468
+ def crop(self,
469
+ support: Support,
470
+ mode: CropMode = 'intersection',
471
+ returns_mapping: bool = False) \
472
+ -> Union['Timeline', Tuple['Timeline', Dict[Segment, Segment]]]:
473
+ """Crop timeline to new support
474
+
475
+ Parameters
476
+ ----------
477
+ support : Segment or Timeline
478
+ If `support` is a `Timeline`, its support is used.
479
+ mode : {'strict', 'loose', 'intersection'}, optional
480
+ Controls how segments that are not fully included in `support` are
481
+ handled. 'strict' mode only keeps fully included segments. 'loose'
482
+ mode keeps any intersecting segment. 'intersection' mode keeps any
483
+ intersecting segment but replace them by their actual intersection.
484
+ returns_mapping : bool, optional
485
+ In 'intersection' mode, return a dictionary whose keys are segments
486
+ of the cropped timeline, and values are list of the original
487
+ segments that were cropped. Defaults to False.
488
+
489
+ Returns
490
+ -------
491
+ cropped : Timeline
492
+ Cropped timeline
493
+ mapping : dict
494
+ When 'returns_mapping' is True, dictionary whose keys are segments
495
+ of 'cropped', and values are lists of corresponding original
496
+ segments.
497
+
498
+ Examples
499
+ --------
500
+
501
+ >>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)])
502
+ >>> timeline.crop(Segment(1, 3))
503
+ <Timeline(uri=None, segments=[<Segment(1, 2)>])>
504
+
505
+ >>> timeline.crop(Segment(1, 3), mode='loose')
506
+ <Timeline(uri=None, segments=[<Segment(0, 2)>, <Segment(1, 2)>])>
507
+
508
+ >>> timeline.crop(Segment(1, 3), mode='strict')
509
+ <Timeline(uri=None, segments=[<Segment(1, 2)>])>
510
+
511
+ >>> cropped, mapping = timeline.crop(Segment(1, 3), returns_mapping=True)
512
+ >>> print(mapping)
513
+ {<Segment(1, 2)>: [<Segment(0, 2)>, <Segment(1, 2)>]}
514
+
515
+ """
516
+
517
+ if mode == 'intersection' and returns_mapping:
518
+ segments, mapping = [], {}
519
+ for segment, mapped_to in self.crop_iter(support,
520
+ mode='intersection',
521
+ returns_mapping=True):
522
+ segments.append(mapped_to)
523
+ mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment]
524
+ return Timeline(segments=segments, uri=self.uri), mapping
525
+
526
+ return Timeline(segments=self.crop_iter(support, mode=mode),
527
+ uri=self.uri)
528
+
529
+ def overlapping(self, t: float) -> List[Segment]:
530
+ """Get list of segments overlapping `t`
531
+
532
+ Parameters
533
+ ----------
534
+ t : float
535
+ Timestamp, in seconds.
536
+
537
+ Returns
538
+ -------
539
+ segments : list
540
+ List of all segments of timeline containing time t
541
+ """
542
+ return list(self.overlapping_iter(t))
543
+
544
+ def overlapping_iter(self, t: float) -> Iterator[Segment]:
545
+ """Like `overlapping` but returns a segment iterator instead
546
+
547
+ See also
548
+ --------
549
+ :func:`pyannote.core.Timeline.overlapping`
550
+ """
551
+ segment = Segment(start=t, end=t)
552
+ for segment in self.segments_list_.irange(maximum=segment):
553
+ if segment.overlaps(t):
554
+ yield segment
555
+
556
+ def get_overlap(self) -> 'Timeline':
557
+ """Get overlapping parts of the timeline.
558
+
559
+ A simple illustration:
560
+
561
+ timeline
562
+ |------| |------| |----|
563
+ |--| |-----| |----------|
564
+
565
+ timeline.get_overlap()
566
+ |--| |---| |----|
567
+
568
+
569
+ Returns
570
+ -------
571
+ overlap : `pyannote.core.Timeline`
572
+ Timeline of the overlaps.
573
+ """
574
+ overlaps_tl = Timeline(uri=self.uri)
575
+ for s1, s2 in self.co_iter(self):
576
+ if s1 == s2:
577
+ continue
578
+ overlaps_tl.add(s1 & s2)
579
+ return overlaps_tl.support()
580
+
581
+ def extrude(self,
582
+ removed: Support,
583
+ mode: CropMode = 'intersection') -> 'Timeline':
584
+ """Remove segments that overlap `removed` support.
585
+
586
+ Parameters
587
+ ----------
588
+ removed : Segment or Timeline
589
+ If `support` is a `Timeline`, its support is used.
590
+ mode : {'strict', 'loose', 'intersection'}, optional
591
+ Controls how segments that are not fully included in `removed` are
592
+ handled. 'strict' mode only removes fully included segments. 'loose'
593
+ mode removes any intersecting segment. 'intersection' mode removes
594
+ the overlapping part of any intersecting segment.
595
+
596
+ Returns
597
+ -------
598
+ extruded : Timeline
599
+ Extruded timeline
600
+
601
+ Examples
602
+ --------
603
+
604
+ >>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 5)])
605
+ >>> timeline.extrude(Segment(1, 2))
606
+ <Timeline(uri=None, segments=[<Segment(0, 1)>, <Segment(3, 5)>])>
607
+
608
+ >>> timeline.extrude(Segment(1, 3), mode='loose')
609
+ <Timeline(uri=None, segments=[<Segment(3, 5)>])>
610
+
611
+ >>> timeline.extrude(Segment(1, 3), mode='strict')
612
+ <Timeline(uri=None, segments=[<Segment(0, 2)>, <Segment(3, 5)>])>
613
+
614
+ """
615
+ if isinstance(removed, Segment):
616
+ removed = Timeline([removed])
617
+
618
+ extent_tl = Timeline([self.extent()], uri=self.uri)
619
+ truncating_support = removed.gaps(support=extent_tl)
620
+ # loose for truncate means strict for crop and vice-versa
621
+ if mode == "loose":
622
+ mode = "strict"
623
+ elif mode == "strict":
624
+ mode = "loose"
625
+ return self.crop(truncating_support, mode=mode)
626
+
627
+ def __str__(self):
628
+ """Human-readable representation
629
+
630
+ >>> timeline = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)])
631
+ >>> print(timeline)
632
+ [[ 00:00:00.000 --> 00:00:10.000]
633
+ [ 00:00:01.000 --> 00:00:13.370]]
634
+
635
+ """
636
+
637
+ n = len(self.segments_list_)
638
+ string = "["
639
+ for i, segment in enumerate(self.segments_list_):
640
+ string += str(segment)
641
+ string += "\n " if i + 1 < n else ""
642
+ string += "]"
643
+ return string
644
+
645
+ def __repr__(self):
646
+ """Computer-readable representation
647
+
648
+ >>> Timeline(segments=[Segment(0, 10), Segment(1, 13.37)])
649
+ <Timeline(uri=None, segments=[<Segment(0, 10)>, <Segment(1, 13.37)>])>
650
+
651
+ """
652
+
653
+ return "<Timeline(uri=%s, segments=%s)>" % (self.uri,
654
+ list(self.segments_list_))
655
+
656
+ def __contains__(self, included: Union[Segment, 'Timeline']):
657
+ """Inclusion
658
+
659
+ Check whether every segment of `included` does exist in timeline.
660
+
661
+ Parameters
662
+ ----------
663
+ included : Segment or Timeline
664
+ Segment or timeline being checked for inclusion
665
+
666
+ Returns
667
+ -------
668
+ contains : bool
669
+ True if every segment in `included` exists in timeline,
670
+ False otherwise
671
+
672
+ Examples
673
+ --------
674
+ >>> timeline1 = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)])
675
+ >>> timeline2 = Timeline(segments=[Segment(0, 10)])
676
+ >>> timeline1 in timeline2
677
+ False
678
+ >>> timeline2 in timeline1
679
+ >>> Segment(1, 13.37) in timeline1
680
+ True
681
+
682
+ """
683
+
684
+ if isinstance(included, Segment):
685
+ return included in self.segments_set_
686
+
687
+ elif isinstance(included, Timeline):
688
+ return self.segments_set_.issuperset(included.segments_set_)
689
+
690
+ else:
691
+ raise TypeError(
692
+ 'Checking for inclusion only supports Segment and '
693
+ 'Timeline instances')
694
+
695
+ def empty(self) -> 'Timeline':
696
+ """Return an empty copy
697
+
698
+ Returns
699
+ -------
700
+ empty : Timeline
701
+ Empty timeline using the same 'uri' attribute.
702
+
703
+ """
704
+ return Timeline(uri=self.uri)
705
+
706
+ def covers(self, other: 'Timeline') -> bool:
707
+ """Check whether other timeline is fully covered by the timeline
708
+
709
+ Parameter
710
+ ---------
711
+ other : Timeline
712
+ Second timeline
713
+
714
+ Returns
715
+ -------
716
+ covers : bool
717
+ True if timeline covers "other" timeline entirely. False if at least
718
+ one segment of "other" is not fully covered by timeline
719
+ """
720
+
721
+ # compute gaps within "other" extent
722
+ # this is where we should look for possible faulty segments
723
+ gaps = self.gaps(support=other.extent())
724
+
725
+ # if at least one gap intersects with a segment from "other",
726
+ # "self" does not cover "other" entirely --> return False
727
+ for _ in gaps.co_iter(other):
728
+ return False
729
+
730
+ # if no gap intersects with a segment from "other",
731
+ # "self" covers "other" entirely --> return True
732
+ return True
733
+
734
+ def copy(self, segment_func: Optional[Callable[[Segment], Segment]] = None) \
735
+ -> 'Timeline':
736
+ """Get a copy of the timeline
737
+
738
+ If `segment_func` is provided, it is applied to each segment first.
739
+
740
+ Parameters
741
+ ----------
742
+ segment_func : callable, optional
743
+ Callable that takes a segment as input, and returns a segment.
744
+ Defaults to identity function (segment_func(segment) = segment)
745
+
746
+ Returns
747
+ -------
748
+ timeline : Timeline
749
+ Copy of the timeline
750
+
751
+ """
752
+
753
+ # if segment_func is not provided
754
+ # just add every segment
755
+ if segment_func is None:
756
+ return Timeline(segments=self.segments_list_, uri=self.uri)
757
+
758
+ # if is provided
759
+ # apply it to each segment before adding them
760
+ return Timeline(segments=[segment_func(s) for s in self.segments_list_],
761
+ uri=self.uri)
762
+
763
+ def extent(self) -> Segment:
764
+ """Extent
765
+
766
+ The extent of a timeline is the segment of minimum duration that
767
+ contains every segments of the timeline. It is unique, by definition.
768
+ The extent of an empty timeline is an empty segment.
769
+
770
+ A picture is worth a thousand words::
771
+
772
+ timeline
773
+ |------| |------| |----|
774
+ |--| |-----| |----------|
775
+
776
+ timeline.extent()
777
+ |--------------------------------|
778
+
779
+ Returns
780
+ -------
781
+ extent : Segment
782
+ Timeline extent
783
+
784
+ Examples
785
+ --------
786
+ >>> timeline = Timeline(segments=[Segment(0, 1), Segment(9, 10)])
787
+ >>> timeline.extent()
788
+ <Segment(0, 10)>
789
+
790
+ """
791
+ if self.segments_set_:
792
+ segments_boundaries_ = self.segments_boundaries_
793
+ start = segments_boundaries_[0]
794
+ end = segments_boundaries_[-1]
795
+ return Segment(start=start, end=end)
796
+
797
+ return Segment(start=0.0, end=0.0)
798
+
799
+ def support_iter(self, collar: float = 0.0) -> Iterator[Segment]:
800
+ """Like `support` but returns a segment generator instead
801
+
802
+ See also
803
+ --------
804
+ :func:`pyannote.core.Timeline.support`
805
+ """
806
+
807
+ # The support of an empty timeline is an empty timeline.
808
+ if not self:
809
+ return
810
+
811
+ # Principle:
812
+ # * gather all segments with no gap between them
813
+ # * add one segment per resulting group (their union |)
814
+ # Note:
815
+ # Since segments are kept sorted internally,
816
+ # there is no need to perform an exhaustive segment clustering.
817
+ # We just have to consider them in their natural order.
818
+
819
+ # Initialize new support segment
820
+ # as very first segment of the timeline
821
+ new_segment = self.segments_list_[0]
822
+
823
+ for segment in self:
824
+
825
+ # If there is no gap between new support segment and next segment
826
+ # OR there is a gap with duration < collar seconds,
827
+ possible_gap = segment ^ new_segment
828
+ if not possible_gap or possible_gap.duration < collar:
829
+ # Extend new support segment using next segment
830
+ new_segment |= segment
831
+
832
+ # If there actually is a gap and the gap duration >= collar
833
+ # seconds,
834
+ else:
835
+ yield new_segment
836
+
837
+ # Initialize new support segment as next segment
838
+ # (right after the gap)
839
+ new_segment = segment
840
+
841
+ # Add new segment to the timeline support
842
+ yield new_segment
843
+
844
+ def support(self, collar: float = 0.) -> 'Timeline':
845
+ """Timeline support
846
+
847
+ The support of a timeline is the timeline with the minimum number of
848
+ segments with exactly the same time span as the original timeline. It
849
+ is (by definition) unique and does not contain any overlapping
850
+ segments.
851
+
852
+ A picture is worth a thousand words::
853
+
854
+ collar
855
+ |---|
856
+
857
+ timeline
858
+ |------| |------| |----|
859
+ |--| |-----| |----------|
860
+
861
+ timeline.support()
862
+ |------| |--------| |----------|
863
+
864
+ timeline.support(collar)
865
+ |------------------| |----------|
866
+
867
+ Parameters
868
+ ----------
869
+ collar : float, optional
870
+ Merge separated by less than `collar` seconds. This is why there
871
+ are only two segments in the final timeline in the above figure.
872
+ Defaults to 0.
873
+
874
+ Returns
875
+ -------
876
+ support : Timeline
877
+ Timeline support
878
+ """
879
+ return Timeline(segments=self.support_iter(collar), uri=self.uri)
880
+
881
+ def duration(self) -> float:
882
+ """Timeline duration
883
+
884
+ The timeline duration is the sum of the durations of the segments
885
+ in the timeline support.
886
+
887
+ Returns
888
+ -------
889
+ duration : float
890
+ Duration of timeline support, in seconds.
891
+ """
892
+
893
+ # The timeline duration is the sum of the durations
894
+ # of the segments in the timeline support.
895
+ return sum(s.duration for s in self.support_iter())
896
+
897
+ def gaps_iter(self, support: Optional[Support] = None) -> Iterator[Segment]:
898
+ """Like `gaps` but returns a segment generator instead
899
+
900
+ See also
901
+ --------
902
+ :func:`pyannote.core.Timeline.gaps`
903
+
904
+ """
905
+
906
+ if support is None:
907
+ support = self.extent()
908
+
909
+ if not isinstance(support, (Segment, Timeline)):
910
+ raise TypeError("unsupported operand type(s) for -':"
911
+ "%s and Timeline." % type(support).__name__)
912
+
913
+ # segment support
914
+ if isinstance(support, Segment):
915
+
916
+ # `end` is meant to store the end time of former segment
917
+ # initialize it with beginning of provided segment `support`
918
+ end = support.start
919
+
920
+ # support on the intersection of timeline and provided segment
921
+ for segment in self.crop(support, mode='intersection').support():
922
+
923
+ # add gap between each pair of consecutive segments
924
+ # if there is no gap, segment is empty, therefore not added
925
+ gap = Segment(start=end, end=segment.start)
926
+ if gap:
927
+ yield gap
928
+
929
+ # keep track of the end of former segment
930
+ end = segment.end
931
+
932
+ # add final gap (if not empty)
933
+ gap = Segment(start=end, end=support.end)
934
+ if gap:
935
+ yield gap
936
+
937
+ # timeline support
938
+ elif isinstance(support, Timeline):
939
+
940
+ # yield gaps for every segment in support of provided timeline
941
+ for segment in support.support():
942
+ for gap in self.gaps_iter(support=segment):
943
+ yield gap
944
+
945
+ def gaps(self, support: Optional[Support] = None) \
946
+ -> 'Timeline':
947
+ """Gaps
948
+
949
+ A picture is worth a thousand words::
950
+
951
+ timeline
952
+ |------| |------| |----|
953
+ |--| |-----| |----------|
954
+
955
+ timeline.gaps()
956
+ |--| |--|
957
+
958
+ Parameters
959
+ ----------
960
+ support : None, Segment or Timeline
961
+ Support in which gaps are looked for. Defaults to timeline extent
962
+
963
+ Returns
964
+ -------
965
+ gaps : Timeline
966
+ Timeline made of all gaps from original timeline, and delimited
967
+ by provided support
968
+
969
+ See also
970
+ --------
971
+ :func:`pyannote.core.Timeline.extent`
972
+
973
+ """
974
+ return Timeline(segments=self.gaps_iter(support=support),
975
+ uri=self.uri)
976
+
977
+ def segmentation(self) -> 'Timeline':
978
+ """Segmentation
979
+
980
+ Create the unique timeline with same support and same set of segment
981
+ boundaries as original timeline, but with no overlapping segments.
982
+
983
+ A picture is worth a thousand words::
984
+
985
+ timeline
986
+ |------| |------| |----|
987
+ |--| |-----| |----------|
988
+
989
+ timeline.segmentation()
990
+ |-|--|-| |-|---|--| |--|----|--|
991
+
992
+ Returns
993
+ -------
994
+ timeline : Timeline
995
+ (unique) timeline with same support and same set of segment
996
+ boundaries as original timeline, but with no overlapping segments.
997
+ """
998
+ # COMPLEXITY: O(n)
999
+ support = self.support()
1000
+
1001
+ # COMPLEXITY: O(n.log n)
1002
+ # get all boundaries (sorted)
1003
+ # |------| |------| |----|
1004
+ # |--| |-----| |----------|
1005
+ # becomes
1006
+ # | | | | | | | | | | | |
1007
+ timestamps = set([])
1008
+ for (start, end) in self:
1009
+ timestamps.add(start)
1010
+ timestamps.add(end)
1011
+ timestamps = sorted(timestamps)
1012
+
1013
+ # create new partition timeline
1014
+ # | | | | | | | | | | | |
1015
+ # becomes
1016
+ # |-|--|-| |-|---|--| |--|----|--|
1017
+
1018
+ # start with an empty copy
1019
+ timeline = Timeline(uri=self.uri)
1020
+
1021
+ if len(timestamps) == 0:
1022
+ return Timeline(uri=self.uri)
1023
+
1024
+ segments = []
1025
+ start = timestamps[0]
1026
+ for end in timestamps[1:]:
1027
+ # only add segments that are covered by original timeline
1028
+ segment = Segment(start=start, end=end)
1029
+ if segment and support.overlapping(segment.middle):
1030
+ segments.append(segment)
1031
+ # next segment...
1032
+ start = end
1033
+
1034
+ return Timeline(segments=segments, uri=self.uri)
1035
+
1036
+ def to_annotation(self,
1037
+ generator: Union[str, Iterable[Label], None, None] = 'string',
1038
+ modality: Optional[str] = None) \
1039
+ -> 'Annotation':
1040
+ """Turn timeline into an annotation
1041
+
1042
+ Each segment is labeled by a unique label.
1043
+
1044
+ Parameters
1045
+ ----------
1046
+ generator : 'string', 'int', or iterable, optional
1047
+ If 'string' (default) generate string labels. If 'int', generate
1048
+ integer labels. If iterable, use it to generate labels.
1049
+ modality : str, optional
1050
+
1051
+ Returns
1052
+ -------
1053
+ annotation : Annotation
1054
+ Annotation
1055
+ """
1056
+
1057
+ from .annotation import Annotation
1058
+ annotation = Annotation(uri=self.uri, modality=modality)
1059
+ if generator == 'string':
1060
+ from .utils.generators import string_generator
1061
+ generator = string_generator()
1062
+ elif generator == 'int':
1063
+ from .utils.generators import int_generator
1064
+ generator = int_generator()
1065
+
1066
+ for segment in self:
1067
+ annotation[segment] = next(generator)
1068
+
1069
+ return annotation
1070
+
1071
+ def _iter_uem(self) -> Iterator[Text]:
1072
+ """Generate lines for a UEM file for this timeline
1073
+
1074
+ Returns
1075
+ -------
1076
+ iterator: Iterator[str]
1077
+ An iterator over UEM text lines
1078
+ """
1079
+ uri = self.uri if self.uri else "<NA>"
1080
+ if isinstance(uri, Text) and ' ' in uri:
1081
+ msg = (f'Space-separated UEM file format does not allow file URIs '
1082
+ f'containing spaces (got: "{uri}").')
1083
+ raise ValueError(msg)
1084
+ for segment in self:
1085
+ yield f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n"
1086
+
1087
+ def to_uem(self) -> Text:
1088
+ """Serialize timeline as a string using UEM format
1089
+
1090
+ Returns
1091
+ -------
1092
+ serialized: str
1093
+ UEM string
1094
+ """
1095
+ return "".join([line for line in self._iter_uem()])
1096
+
1097
+ def write_uem(self, file: TextIO):
1098
+ """Dump timeline to file using UEM format
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ file : file object
1103
+
1104
+ Usage
1105
+ -----
1106
+ >>> with open('file.uem', 'w') as file:
1107
+ ... timeline.write_uem(file)
1108
+ """
1109
+ for line in self._iter_uem():
1110
+ file.write(line)
1111
+
1112
+ def _repr_png_(self):
1113
+ """IPython notebook support
1114
+
1115
+ See also
1116
+ --------
1117
+ :mod:`pyannote.core.notebook`
1118
+ """
1119
+
1120
+ from .notebook import MATPLOTLIB_IS_AVAILABLE, MATPLOTLIB_WARNING
1121
+ if not MATPLOTLIB_IS_AVAILABLE:
1122
+ warnings.warn(MATPLOTLIB_WARNING.format(klass=self.__class__.__name__))
1123
+ return None
1124
+
1125
+ from .notebook import repr_timeline
1126
+ return repr_timeline(self)
ailia-models/code/pyannote_audio_utils/core/utils/generators.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2014-2018 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ import itertools
30
+ from string import ascii_uppercase
31
+ from typing import Iterable, Union, List, Set, Optional, Iterator
32
+
33
+
34
+ def pairwise(iterable: Iterable):
35
+ """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
36
+ a, b = itertools.tee(iterable)
37
+ next(b, None)
38
+ return zip(a, b)
39
+
40
+
41
+ def string_generator(skip: Optional[Union[List, Set]] = None) \
42
+ -> Iterator[str]:
43
+ """Label generator
44
+
45
+ Parameters
46
+ ----------
47
+ skip : list or set
48
+ List of labels that must be skipped.
49
+ This option is useful in case you want to make sure generated labels
50
+ are different from a pre-existing set of labels.
51
+
52
+ Usage
53
+ -----
54
+ t = string_generator()
55
+ next(t) -> 'A' # start with 1-letter labels
56
+ ... # from A to Z
57
+ next(t) -> 'Z'
58
+ next(t) -> 'AA' # then 2-letters labels
59
+ next(t) -> 'AB' # from AA to ZZ
60
+ ...
61
+ next(t) -> 'ZY'
62
+ next(t) -> 'ZZ'
63
+ next(t) -> 'AAA' # then 3-letters labels
64
+ ... # (you get the idea)
65
+ """
66
+ if skip is None:
67
+ skip = list()
68
+
69
+ # label length
70
+ r = 1
71
+
72
+ # infinite loop
73
+ while True:
74
+
75
+ # generate labels with current length
76
+ for c in itertools.product(ascii_uppercase, repeat=r):
77
+ if c in skip:
78
+ continue
79
+ yield ''.join(c)
80
+
81
+ # increment label length when all possibilities are exhausted
82
+ r = r + 1
83
+
84
+
85
+ def int_generator() -> Iterator[int]:
86
+ i = 0
87
+ while True:
88
+ yield i
89
+ i = i + 1
ailia-models/code/pyannote_audio_utils/core/utils/types.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Hashable, Union, Tuple, Iterator, Literal
2
+
3
+ Label = Hashable
4
+ Support = Union['Segment', 'Timeline']
5
+ LabelGeneratorMode = Literal['int', 'string']
6
+ LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]]
7
+ TrackName = Union[str, int]
8
+ Key = Union['Segment', Tuple['Segment', TrackName]]
9
+ Resource = Union['Segment', 'Timeline', 'SlidingWindowFeature',
10
+ 'Annotation']
11
+ CropMode = Literal['intersection', 'loose', 'strict']
12
+ Alignment = Literal['center', 'loose', 'strict']
13
+ LabelStyle = Tuple[str, int, Tuple[float, float, float]]
ailia-models/code/pyannote_audio_utils/database/__init__.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2016- CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ # Alexis PLAQUET
29
+
30
+ """pyannote.database"""
31
+
32
+
33
+ from typing import Optional
34
+ import warnings
35
+
36
+ # from .registry import registry, LoadingMode
37
+
38
+ # from .database import Database
39
+
40
+ from .protocol.protocol import Protocol
41
+ from .protocol.protocol import ProtocolFile
42
+ from .protocol.protocol import Subset
43
+ from .protocol.protocol import Preprocessors
44
+
45
+ # from .file_finder import FileFinder
46
+ # from .util import get_annotated
47
+ # from .util import get_unique_identifier
48
+ # from .util import get_label_identifier
49
+
50
+ # from ._version import get_versions
51
+ #
52
+
53
+ # __version__ = get_versions()["version"]
54
+ # del get_versions
55
+
56
+
57
+ def get_protocol(name, preprocessors: Optional[Preprocessors] = None) -> Protocol:
58
+ """Get protocol by full name
59
+
60
+ name : str
61
+ Protocol full name (e.g. "Etape.SpeakerDiarization.TV")
62
+ preprocessors : dict or (key, preprocessor) iterable
63
+ When provided, each protocol item (dictionary) are preprocessed, such
64
+ that item[key] = preprocessor(item). In case 'preprocessor' is not
65
+ callable, it should be a string containing placeholder for item keys
66
+ (e.g. {'audio': '/path/to/{uri}.wav'})
67
+
68
+ Returns
69
+ -------
70
+ protocol : Protocol
71
+ Protocol instance
72
+ """
73
+ warnings.warn(
74
+ "`get_protocol` has been deprecated in favor of `pyannote.database.registry.get_protocol`.",
75
+ DeprecationWarning)
76
+ return registry.get_protocol(name, preprocessors=preprocessors)
77
+
78
+
79
+ __all__ = [
80
+ "registry",
81
+ "get_protocol",
82
+ "LoadingMode",
83
+ "Database",
84
+ "Protocol",
85
+ "ProtocolFile",
86
+ "Subset",
87
+ "FileFinder",
88
+ "get_annotated",
89
+ "get_unique_identifier",
90
+ "get_label_identifier",
91
+ ]
ailia-models/code/pyannote_audio_utils/database/protocol/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2016- CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ from .protocol import Protocol
30
+
31
+ __all__ = [
32
+ "Protocol",
33
+ ]
34
+
ailia-models/code/pyannote_audio_utils/database/protocol/protocol.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2016-2020 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ """
30
+ #########
31
+ Protocols
32
+ #########
33
+
34
+ """
35
+
36
+ import warnings
37
+ import collections
38
+ import threading
39
+ import itertools
40
+ from typing import Union, Dict, Iterator, Callable, Any, Text, Optional
41
+
42
+ # try:
43
+ from typing import Literal
44
+ # except ImportError:
45
+ # from typing_extensions import Literal
46
+
47
+ Subset = Literal["train", "development", "test"]
48
+ LEGACY_SUBSET_MAPPING = {"train": "trn", "development": "dev", "test": "tst"}
49
+ Scope = Literal["file", "database", "global"]
50
+
51
+ Preprocessor = Callable[["ProtocolFile"], Any]
52
+ Preprocessors = Dict[Text, Preprocessor]
53
+
54
+
55
+ class ProtocolFile(collections.abc.MutableMapping):
56
+ """Protocol file with lazy preprocessors
57
+
58
+ This is a dict-like data structure where some values may depend on other
59
+ values, and are only computed if/when requested. Once computed, they are
60
+ cached and never recomputed again.
61
+
62
+ Parameters
63
+ ----------
64
+ precomputed : dict
65
+ Regular dictionary with precomputed values
66
+ lazy : dict, optional
67
+ Dictionary describing how lazy value needs to be computed.
68
+ Values are callable expecting a dictionary as input and returning the
69
+ computed value.
70
+
71
+ """
72
+
73
+ def __init__(self, precomputed: Union[Dict, "ProtocolFile"], lazy: Dict = None):
74
+
75
+ if lazy is None:
76
+ lazy = dict()
77
+
78
+ if isinstance(precomputed, ProtocolFile):
79
+ # when 'precomputed' is a ProtocolFile, it may already contain lazy keys.
80
+
81
+ # we use 'precomputed' precomputed keys as precomputed keys
82
+ self._store: Dict = abs(precomputed)
83
+
84
+ # we handle the corner case where the intersection of 'precomputed' lazy keys
85
+ # and 'lazy' keys is not empty. this is currently achieved by "unlazying" the
86
+ # 'precomputed' one (which is probably not the most efficient solution).
87
+ for key in set(precomputed.lazy) & set(lazy):
88
+ self._store[key] = precomputed[key]
89
+
90
+ # we use the union of 'precomputed' lazy keys and provided 'lazy' keys as lazy keys
91
+ compound_lazy = dict(precomputed.lazy)
92
+ compound_lazy.update(lazy)
93
+ self.lazy: Dict = compound_lazy
94
+
95
+ else:
96
+ # when 'precomputed' is a Dict, we use it directly as precomputed keys
97
+ # and 'lazy' as lazy keys.
98
+ self._store = dict(precomputed)
99
+ self.lazy = dict(lazy)
100
+
101
+ # re-entrant lock used below to make ProtocolFile thread-safe
102
+ self.lock_ = threading.RLock()
103
+
104
+ # this is needed to avoid infinite recursion
105
+ # when a key is both in precomputed and lazy.
106
+ # keys with evaluating_ > 0 are currently being evaluated
107
+ # and therefore should be taken from precomputed
108
+ self.evaluating_ = collections.Counter()
109
+
110
+ # since RLock is not pickable, remove it before pickling...
111
+ def __getstate__(self):
112
+ d = dict(self.__dict__)
113
+ del d["lock_"]
114
+ return d
115
+
116
+ # ... and add it back when unpickling
117
+ def __setstate__(self, d):
118
+ self.__dict__.update(d)
119
+ self.lock_ = threading.RLock()
120
+
121
+ def __abs__(self):
122
+ with self.lock_:
123
+ return dict(self._store)
124
+
125
+ def __getitem__(self, key):
126
+ with self.lock_:
127
+
128
+ if key in self.lazy and self.evaluating_[key] == 0:
129
+
130
+ # mark lazy key as being evaluated
131
+ self.evaluating_.update([key])
132
+
133
+ # apply preprocessor once and remove it
134
+ value = self.lazy[key](self)
135
+ del self.lazy[key]
136
+
137
+ # warn the user when a precomputed key is modified
138
+ if key in self._store and value != self._store[key]:
139
+ msg = 'Existing precomputed key "{key}" has been modified by a preprocessor.'
140
+ warnings.warn(msg.format(key=key))
141
+
142
+ # store the output of the lazy computation
143
+ # so that it is available for future access
144
+ self._store[key] = value
145
+
146
+ # lazy evaluation is finished for key
147
+ self.evaluating_.subtract([key])
148
+
149
+ return self._store[key]
150
+
151
+ def __setitem__(self, key, value):
152
+ with self.lock_:
153
+
154
+ if key in self.lazy:
155
+ del self.lazy[key]
156
+
157
+ self._store[key] = value
158
+
159
+ def __delitem__(self, key):
160
+ with self.lock_:
161
+
162
+ if key in self.lazy:
163
+ del self.lazy[key]
164
+
165
+ del self._store[key]
166
+
167
+ def __iter__(self):
168
+ with self.lock_:
169
+
170
+ store_keys = list(self._store)
171
+ for key in store_keys:
172
+ yield key
173
+
174
+ lazy_keys = list(self.lazy)
175
+ for key in lazy_keys:
176
+ if key in self._store:
177
+ continue
178
+ yield key
179
+
180
+ def __len__(self):
181
+ with self.lock_:
182
+ return len(set(self._store) | set(self.lazy))
183
+
184
+ def files(self) -> Iterator["ProtocolFile"]:
185
+ """Iterate over all files
186
+
187
+ When `current_file` refers to only one file,
188
+ yield it and return.
189
+ When `current_file` refers to a list of file (i.e. 'uri' is a list),
190
+ yield each file separately.
191
+
192
+ Examples
193
+ --------
194
+ >>> current_file = ProtocolFile({
195
+ ... 'uri': 'my_uri',
196
+ ... 'database': 'my_database'})
197
+ >>> for file in current_file.files():
198
+ ... print(file['uri'], file['database'])
199
+ my_uri my_database
200
+
201
+ >>> current_file = {
202
+ ... 'uri': ['my_uri1', 'my_uri2', 'my_uri3'],
203
+ ... 'database': 'my_database'}
204
+ >>> for file in current_file.files():
205
+ ... print(file['uri'], file['database'])
206
+ my_uri1 my_database
207
+ my_uri2 my_database
208
+ my_uri3 my_database
209
+
210
+ """
211
+
212
+ uris = self["uri"]
213
+ if not isinstance(uris, list):
214
+ yield self
215
+ return
216
+
217
+ n_uris = len(uris)
218
+
219
+ # iterate over precomputed keys and make sure
220
+
221
+ precomputed = {"uri": uris}
222
+ for key, value in abs(self).items():
223
+
224
+ if key == "uri":
225
+ continue
226
+
227
+ if not isinstance(value, list):
228
+ precomputed[key] = itertools.repeat(value)
229
+
230
+ else:
231
+ if len(value) != n_uris:
232
+ msg = (
233
+ f'Mismatch between number of "uris" ({n_uris}) '
234
+ f'and number of "{key}" ({len(value)}).'
235
+ )
236
+ raise ValueError(msg)
237
+ precomputed[key] = value
238
+
239
+ keys = list(precomputed.keys())
240
+ for values in zip(*precomputed.values()):
241
+ precomputed_one = dict(zip(keys, values))
242
+ yield ProtocolFile(precomputed_one, self.lazy)
243
+
244
+
245
+ class Protocol:
246
+ """Experimental protocol
247
+
248
+ An experimental protocol usually defines three subsets: a training subset,
249
+ a development subset, and a test subset.
250
+
251
+ An experimental protocol can be defined programmatically by creating a
252
+ class that inherits from Protocol and implements at least
253
+ one of `train_iter`, `development_iter` and `test_iter` methods:
254
+
255
+ >>> class MyProtocol(Protocol):
256
+ ... def train_iter(self) -> Iterator[Dict]:
257
+ ... yield {"uri": "filename1", "any_other_key": "..."}
258
+ ... yield {"uri": "filename2", "any_other_key": "..."}
259
+
260
+ `{subset}_iter` should return an iterator of dictionnaries with
261
+ - "uri" key (mandatory) that provides a unique file identifier (usually
262
+ the filename),
263
+ - any other key that the protocol may provide.
264
+
265
+ It can then be used in Python like this:
266
+
267
+ >>> protocol = MyProtocol()
268
+ >>> for file in protocol.train():
269
+ ... print(file["uri"])
270
+ filename1
271
+ filename2
272
+
273
+ An experimental protocol can also be defined using `pyannote_audio_utils.database`
274
+ configuration file, whose (configurable) path defaults to "~/database.yml".
275
+
276
+ ~~~ Content of ~/database.yml ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
277
+ Protocols:
278
+ MyDatabase:
279
+ Protocol:
280
+ MyProtocol:
281
+ train:
282
+ uri: /path/to/collection.lst
283
+ any_other_key: ... # see custom loader documentation
284
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+
286
+ where "/path/to/collection.lst" contains the list of identifiers of the
287
+ files in the collection:
288
+
289
+ ~~~ Content of "/path/to/collection.lst ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
290
+ filename1
291
+ filename2
292
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
293
+
294
+ It can then be used in Python like this:
295
+
296
+ >>> from pyannote_audio_utils.database import registry
297
+ >>> protocol = registry.get_protocol('MyDatabase.Protocol.MyProtocol')
298
+ >>> for file in protocol.train():
299
+ ... print(file["uri"])
300
+ filename1
301
+ filename2
302
+
303
+ This class is usually inherited from, but can be used directly.
304
+
305
+ Parameters
306
+ ----------
307
+ preprocessors : dict
308
+ Preprocess protocol files so that `file[key] = preprocessors[key](file)`
309
+ for each key in `preprocessors`. In case `preprocessors[key]` is not
310
+ callable, it should be a string containing placeholders for `file` keys
311
+ (e.g. {'audio': '/path/to/{uri}.wav'})
312
+ """
313
+
314
+ def __init__(self, preprocessors: Optional[Preprocessors] = None):
315
+ super().__init__()
316
+
317
+ if preprocessors is None:
318
+ preprocessors = dict()
319
+
320
+ self.preprocessors = dict()
321
+ for key, preprocessor in preprocessors.items():
322
+
323
+ if callable(preprocessor):
324
+ self.preprocessors[key] = preprocessor
325
+
326
+ # when `preprocessor` is not callable, it should be a string
327
+ # containing placeholder for item key (e.g. '/path/to/{uri}.wav')
328
+ elif isinstance(preprocessor, str):
329
+ preprocessor_copy = str(preprocessor)
330
+
331
+ def func(current_file):
332
+ return preprocessor_copy.format(**current_file)
333
+
334
+ self.preprocessors[key] = func
335
+
336
+ else:
337
+ msg = f'"{key}" preprocessor is neither a callable nor a string.'
338
+ raise ValueError(msg)
339
+
340
+ def preprocess(self, current_file: Union[Dict, ProtocolFile]) -> ProtocolFile:
341
+ return ProtocolFile(current_file, lazy=self.preprocessors)
342
+
343
+ def __str__(self):
344
+ return self.__doc__
345
+
346
+ def train_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
347
+ """Iterate over files in the training subset"""
348
+ raise NotImplementedError()
349
+
350
+ def development_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
351
+ """Iterate over files in the development subset"""
352
+ raise NotImplementedError()
353
+
354
+ def test_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
355
+ """Iterate over files in the test subset"""
356
+ raise NotImplementedError()
357
+
358
+ def subset_helper(self, subset: Subset) -> Iterator[ProtocolFile]:
359
+
360
+ try:
361
+ files = getattr(self, f"{subset}_iter")()
362
+ except (AttributeError, NotImplementedError):
363
+ # previous pyannote_audio_utils.database versions used `trn_iter` instead of
364
+ # `train_iter`, `dev_iter` instead of `development_iter`, and
365
+ # `tst_iter` instead of `test_iter`. therefore, we use the legacy
366
+ # version when it is available (and the new one is not).
367
+ subset_legacy = LEGACY_SUBSET_MAPPING[subset]
368
+ try:
369
+ files = getattr(self, f"{subset_legacy}_iter")()
370
+ except AttributeError:
371
+ msg = f"Protocol does not implement a {subset} subset."
372
+ raise NotImplementedError(msg)
373
+
374
+ for file in files:
375
+ yield self.preprocess(file)
376
+
377
+ def train(self) -> Iterator[ProtocolFile]:
378
+ return self.subset_helper("train")
379
+
380
+ def development(self) -> Iterator[ProtocolFile]:
381
+ return self.subset_helper("development")
382
+
383
+ def test(self) -> Iterator[ProtocolFile]:
384
+ return self.subset_helper("test")
385
+
386
+ def files(self) -> Iterator[ProtocolFile]:
387
+ """Iterate over all files in `protocol`"""
388
+
389
+ # imported here to avoid circular imports
390
+ from pyannote_audio_utils.database.util import get_unique_identifier
391
+
392
+ yielded_uris = set()
393
+
394
+ for method in [
395
+ "development",
396
+ "development_enrolment",
397
+ "development_trial",
398
+ "test",
399
+ "test_enrolment",
400
+ "test_trial",
401
+ "train",
402
+ "train_enrolment",
403
+ "train_trial",
404
+ ]:
405
+
406
+ if not hasattr(self, method):
407
+ continue
408
+
409
+ def iterate():
410
+ try:
411
+ for file in getattr(self, method)():
412
+ yield file
413
+ except (AttributeError, NotImplementedError):
414
+ return
415
+
416
+ for current_file in iterate():
417
+
418
+ # skip "files" that do not contain a "uri" entry.
419
+ # this happens for speaker verification trials that contain
420
+ # two nested files "file1" and "file2"
421
+ # see https://github.com/pyannote_audio_utils/pyannote_audio_utils-db-voxceleb/issues/4
422
+ if "uri" not in current_file:
423
+ continue
424
+
425
+ for current_file_ in current_file.files():
426
+
427
+ # corner case when the same file is yielded several times
428
+ uri = get_unique_identifier(current_file_)
429
+ if uri in yielded_uris:
430
+ continue
431
+
432
+ yield current_file_
433
+
434
+ yielded_uris.add(uri)
ailia-models/code/pyannote_audio_utils/database/util.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2016-2020 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+
30
+ import warnings
31
+ import pandas as pd
32
+ from pyannote_audio_utils.core import Segment, Timeline, Annotation
33
+
34
+ from typing import Text
35
+
36
+ DatabaseName = Text
37
+ PathTemplate = Text
38
+
39
+
40
+ def get_unique_identifier(item):
41
+ """Return unique item identifier
42
+
43
+ The complete format is {database}/{uri}_{channel}:
44
+ * prefixed by "{database}/" only when `item` has a 'database' key.
45
+ * suffixed by "_{channel}" only when `item` has a 'channel' key.
46
+
47
+ Parameters
48
+ ----------
49
+ item : dict
50
+ Item as yielded by pyannote_audio_utils.database protocols
51
+
52
+ Returns
53
+ -------
54
+ identifier : str
55
+ Unique item identifier
56
+ """
57
+
58
+ IDENTIFIER = ""
59
+
60
+ # {database}/{uri}_{channel}
61
+ database = item.get("database", None)
62
+ if database is not None:
63
+ IDENTIFIER += f"{database}/"
64
+ IDENTIFIER += item["uri"]
65
+ channel = item.get("channel", None)
66
+ if channel is not None:
67
+ IDENTIFIER += f"_{channel:d}"
68
+
69
+ return IDENTIFIER
70
+
71
+
72
+ # This function is used in custom.py
73
+ def get_annotated(current_file):
74
+ """Get part of the file that is annotated.
75
+
76
+ Parameters
77
+ ----------
78
+ current_file : `dict`
79
+ File generated by a `pyannote_audio_utils.database` protocol.
80
+
81
+ Returns
82
+ -------
83
+ annotated : `pyannote_audio_utils.core.Timeline`
84
+ Part of the file that is annotated. Defaults to
85
+ `current_file["annotated"]`. When it does not exist, try to use the
86
+ full audio extent. When that fails, use "annotation" extent.
87
+ """
88
+
89
+ # if protocol provides 'annotated' key, use it
90
+ if "annotated" in current_file:
91
+ annotated = current_file["annotated"]
92
+ return annotated
93
+
94
+ # if it does not, but does provide 'audio' key
95
+ # try and use wav duration
96
+
97
+ if "duration" in current_file:
98
+ try:
99
+ duration = current_file["duration"]
100
+ except ImportError:
101
+ pass
102
+ else:
103
+ annotated = Timeline([Segment(0, duration)])
104
+ msg = '"annotated" was approximated by [0, audio duration].'
105
+ warnings.warn(msg)
106
+ return annotated
107
+
108
+ extent = current_file["annotation"].get_timeline().extent()
109
+ annotated = Timeline([extent])
110
+
111
+ msg = (
112
+ '"annotated" was approximated by "annotation" extent. '
113
+ 'Please provide "annotated" directly, or at the very '
114
+ 'least, use a "duration" preprocessor.'
115
+ )
116
+ warnings.warn(msg)
117
+
118
+ return annotated
119
+
120
+
121
+ def get_label_identifier(label, current_file):
122
+ """Return unique label identifier
123
+
124
+ Parameters
125
+ ----------
126
+ label : str
127
+ Database-internal label
128
+ current_file
129
+ Yielded by pyannote_audio_utils.database protocols
130
+
131
+ Returns
132
+ -------
133
+ unique_label : str
134
+ Global label
135
+ """
136
+
137
+ # TODO. when the "true" name of a person is used,
138
+ # do not preprend database name.
139
+ database = current_file["database"]
140
+ return database + "|" + label
141
+
142
+
143
+ def load_rttm(file_rttm, keep_type="SPEAKER"):
144
+ """Load RTTM file
145
+
146
+ Parameter
147
+ ---------
148
+ file_rttm : `str`
149
+ Path to RTTM file.
150
+ keep_type : str, optional
151
+ Only keep lines with this type (field #1 in RTTM specs).
152
+ Defaults to "SPEAKER".
153
+
154
+ Returns
155
+ -------
156
+ annotations : `dict`
157
+ Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary.
158
+ """
159
+
160
+ names = [
161
+ "type",
162
+ "uri",
163
+ "NA2",
164
+ "start",
165
+ "duration",
166
+ "NA3",
167
+ "NA4",
168
+ "speaker",
169
+ "NA5",
170
+ "NA6",
171
+ ]
172
+ dtype = {"uri": str, "start": float, "duration": float, "speaker": str}
173
+ data = pd.read_csv(
174
+ file_rttm,
175
+ names=names,
176
+ dtype=dtype,
177
+ # delim_whitespace=True,
178
+ sep='\s+',
179
+ keep_default_na=True,
180
+ )
181
+
182
+ annotations = dict()
183
+ for uri, turns in data.groupby("uri"):
184
+ annotation = Annotation(uri=uri)
185
+ for i, turn in turns.iterrows():
186
+ if turn.type != keep_type:
187
+ continue
188
+ segment = Segment(turn.start, turn.start + turn.duration)
189
+ annotation[segment, i] = turn.speaker
190
+ annotations[uri] = annotation
191
+
192
+ return annotations
193
+
194
+
195
+ def load_stm(file_stm):
196
+ """Load STM file (speaker-info only)
197
+
198
+ Parameter
199
+ ---------
200
+ file_stm : str
201
+ Path to STM file
202
+
203
+ Returns
204
+ -------
205
+ annotations : `dict`
206
+ Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary.
207
+ """
208
+
209
+ dtype = {"uri": str, "speaker": str, "start": float, "end": float}
210
+ data = pd.read_csv(
211
+ file_stm,
212
+ # delim_whitespace=True,
213
+ sep='\s+',
214
+ usecols=[0, 2, 3, 4],
215
+ dtype=dtype,
216
+ names=list(dtype),
217
+ )
218
+
219
+ annotations = dict()
220
+ for uri, turns in data.groupby("uri"):
221
+ annotation = Annotation(uri=uri)
222
+ for i, turn in turns.iterrows():
223
+ segment = Segment(turn.start, turn.end)
224
+ annotation[segment, i] = turn.speaker
225
+ annotations[uri] = annotation
226
+
227
+ return annotations
228
+
229
+
230
+ def load_mdtm(file_mdtm):
231
+ """Load MDTM file
232
+
233
+ Parameter
234
+ ---------
235
+ file_mdtm : `str`
236
+ Path to MDTM file.
237
+
238
+ Returns
239
+ -------
240
+ annotations : `dict`
241
+ Speaker diarization as a {uri: pyannote_audio_utils.core.Annotation} dictionary.
242
+ """
243
+
244
+ names = ["uri", "NA1", "start", "duration", "NA2", "NA3", "NA4", "speaker"]
245
+ dtype = {"uri": str, "start": float, "duration": float, "speaker": str}
246
+ data = pd.read_csv(
247
+ file_mdtm,
248
+ names=names,
249
+ dtype=dtype,
250
+ # delim_whitespace=True,
251
+ sep='\s+',
252
+ keep_default_na=False,
253
+ )
254
+
255
+ annotations = dict()
256
+ for uri, turns in data.groupby("uri"):
257
+ annotation = Annotation(uri=uri)
258
+ for i, turn in turns.iterrows():
259
+ segment = Segment(turn.start, turn.start + turn.duration)
260
+ annotation[segment, i] = turn.speaker
261
+ annotations[uri] = annotation
262
+
263
+ return annotations
264
+
265
+
266
+ def load_uem(file_uem):
267
+ """Load UEM file
268
+
269
+ Parameter
270
+ ---------
271
+ file_uem : `str`
272
+ Path to UEM file.
273
+
274
+ Returns
275
+ -------
276
+ timelines : `dict`
277
+ Evaluation map as a {uri: pyannote_audio_utils.core.Timeline} dictionary.
278
+ """
279
+
280
+ names = ["uri", "NA1", "start", "end"]
281
+ dtype = {"uri": str, "start": float, "end": float}
282
+ data = pd.read_csv(file_uem, names=names, dtype=dtype, sep='\s+',)
283
+
284
+ timelines = dict()
285
+ for uri, parts in data.groupby("uri"):
286
+ segments = [Segment(part.start, part.end) for i, part in parts.iterrows()]
287
+ timelines[uri] = Timeline(segments=segments, uri=uri)
288
+
289
+ return timelines
290
+
291
+
292
+ def load_lab(path, uri: str = None) -> Annotation:
293
+ """Load LAB file
294
+
295
+ Parameter
296
+ ---------
297
+ file_lab : `str`
298
+ Path to LAB file
299
+
300
+ Returns
301
+ -------
302
+ data : `pyannote_audio_utils.core.Annotation`
303
+ """
304
+
305
+ names = ["start", "end", "label"]
306
+ dtype = {"start": float, "end": float, "label": str}
307
+ data = pd.read_csv(path, names=names, dtype=dtype, sep='\s+',)
308
+
309
+ annotation = Annotation(uri=uri)
310
+ for i, turn in data.iterrows():
311
+ segment = Segment(turn.start, turn.end)
312
+ annotation[segment, i] = turn.label
313
+
314
+ return annotation
315
+
316
+
317
+ def load_lst(file_lst):
318
+ """Load LST file
319
+
320
+ LST files provide a list of URIs (one line per URI)
321
+
322
+ Parameter
323
+ ---------
324
+ file_lst : `str`
325
+ Path to LST file.
326
+
327
+ Returns
328
+ -------
329
+ uris : `list`
330
+ List or uris
331
+ """
332
+
333
+ with open(file_lst, mode="r") as fp:
334
+ lines = fp.readlines()
335
+ return [line.strip() for line in lines]
336
+
337
+
338
+ def load_mapping(mapping_txt):
339
+ """Load mapping file
340
+
341
+ Parameter
342
+ ---------
343
+ mapping_txt : `str`
344
+ Path to mapping file
345
+
346
+ Returns
347
+ -------
348
+ mapping : `dict`
349
+ {1st field: 2nd field} dictionary
350
+ """
351
+
352
+ with open(mapping_txt, mode="r") as fp:
353
+ lines = fp.readlines()
354
+
355
+ mapping = dict()
356
+ for line in lines:
357
+ key, value, *left = line.strip().split()
358
+ mapping[key] = value
359
+
360
+ return mapping
361
+
362
+
363
+ class LabelMapper(object):
364
+ """Label mapper for use as pyannote_audio_utils.database preprocessor
365
+
366
+ Parameters
367
+ ----------
368
+ mapping : `dict`
369
+ Mapping dictionary as used in `Annotation.rename_labels()`.
370
+ keep_missing : `bool`, optional
371
+ In case a label has no mapping, a `ValueError` will be raised.
372
+ Set "keep_missing" to True to keep those labels unchanged instead.
373
+
374
+ Usage
375
+ -----
376
+ >>> mapping = {'Hadrien': 'MAL', 'Marvin': 'MAL',
377
+ ... 'Wassim': 'CHI', 'Herve': 'GOD'}
378
+ >>> preprocessors = {'annotation': LabelMapper(mapping=mapping)}
379
+ >>> protocol = registry.get_protocol('AMI.SpeakerDiarization.MixHeadset',
380
+ preprocessors=preprocessors)
381
+
382
+ """
383
+
384
+ def __init__(self, mapping, keep_missing=False):
385
+ self.mapping = mapping
386
+ self.keep_missing = keep_missing
387
+
388
+ def __call__(self, current_file):
389
+
390
+ if not self.keep_missing:
391
+ missing = set(current_file["annotation"].labels()) - set(self.mapping)
392
+ if missing and not self.keep_missing:
393
+ label = missing.pop()
394
+ msg = (
395
+ f'No mapping found for label "{label}". Set "keep_missing" '
396
+ f"to True to keep labels with no mapping."
397
+ )
398
+ raise ValueError(msg)
399
+
400
+ return current_file["annotation"].rename_labels(mapping=self.mapping)
ailia-models/code/pyannote_audio_utils/metrics/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2012-2021 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ from ._version import get_versions
30
+ from .base import f_measure
31
+
32
+ __version__ = get_versions()["version"]
33
+ del get_versions
34
+
35
+
36
+ __all__ = ["f_measure"]
ailia-models/code/pyannote_audio_utils/metrics/_version.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # This file was generated by 'versioneer.py' (0.15) from
3
+ # revision-control system data, or from the parent directory name of an
4
+ # unpacked source archive. Distribution tarballs contain a pre-generated copy
5
+ # of this file.
6
+
7
+ import json
8
+ import sys
9
+
10
+ version_json = '''
11
+ {
12
+ "dirty": false,
13
+ "error": null,
14
+ "full-revisionid": "babbd1c68adc50c0e2199676c7ae741194c520da",
15
+ "version": "3.2.1"
16
+ }
17
+ ''' # END VERSION_JSON
18
+
19
+
20
+ def get_versions():
21
+ return json.loads(version_json)
ailia-models/code/pyannote_audio_utils/metrics/base.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2012- CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ from typing import List, Union, Optional, Set, Tuple
29
+
30
+ import warnings
31
+ import numpy as np
32
+ import pandas as pd
33
+ import scipy.stats
34
+ from pyannote_audio_utils.core import Annotation, Timeline
35
+
36
+ from pyannote_audio_utils.metrics.types import Details, MetricComponents
37
+
38
+
39
+ class BaseMetric:
40
+ """
41
+ :class:`BaseMetric` is the base class for most pyannote_audio_utils evaluation metrics.
42
+
43
+ Attributes
44
+ ----------
45
+ name : str
46
+ Human-readable name of the metric (eg. 'diarization error rate')
47
+ """
48
+
49
+ @classmethod
50
+ def metric_name(cls) -> str:
51
+ raise NotImplementedError(
52
+ cls.__name__ + " is missing a 'metric_name' class method. "
53
+ "It should return the name of the metric as string."
54
+ )
55
+
56
+ @classmethod
57
+ def metric_components(cls) -> MetricComponents:
58
+ raise NotImplementedError(
59
+ cls.__name__ + " is missing a 'metric_components' class method. "
60
+ "It should return the list of names of metric components."
61
+ )
62
+
63
+ def __init__(self, **kwargs):
64
+ super(BaseMetric, self).__init__()
65
+ self.metric_name_ = self.__class__.metric_name()
66
+ self.components_: Set[str] = set(self.__class__.metric_components())
67
+ self.reset()
68
+
69
+ def init_components(self):
70
+ return {value: 0.0 for value in self.components_}
71
+
72
+ def reset(self):
73
+ """Reset accumulated components and metric values"""
74
+ self.accumulated_: Details = dict()
75
+ self.results_: List = list()
76
+ for value in self.components_:
77
+ self.accumulated_[value] = 0.0
78
+
79
+ @property
80
+ def name(self):
81
+ """Metric name."""
82
+ return self.metric_name()
83
+
84
+ # TODO: use joblib/locky to allow parallel processing?
85
+ # TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...)
86
+
87
+ def __call__(self, reference: Union[Timeline, Annotation],
88
+ hypothesis: Union[Timeline, Annotation],
89
+ detailed: bool = False, uri: Optional[str] = None, **kwargs):
90
+ """Compute metric value and accumulate components
91
+
92
+ Parameters
93
+ ----------
94
+ reference : type depends on the metric
95
+ Manual `reference`
96
+ hypothesis : type depends on the metric
97
+ Evaluated `hypothesis`
98
+ uri : optional
99
+ Override uri.
100
+ detailed : bool, optional
101
+ By default (False), return metric value only.
102
+ Set `detailed` to True to return dictionary where keys are
103
+ components names and values are component values
104
+
105
+ Returns
106
+ -------
107
+ value : float (if `detailed` is False)
108
+ Metric value
109
+ components : dict (if `detailed` is True)
110
+ `components` updated with metric value
111
+ """
112
+
113
+ # compute metric components
114
+ components = self.compute_components(reference, hypothesis, **kwargs)
115
+
116
+ # compute rate based on components
117
+ components[self.metric_name_] = self.compute_metric(components)
118
+
119
+ # keep track of this computation
120
+ uri = uri or getattr(reference, "uri", "NA")
121
+ self.results_.append((uri, components))
122
+
123
+ # accumulate components
124
+ for name in self.components_:
125
+ self.accumulated_[name] += components[name]
126
+
127
+ if detailed:
128
+ return components
129
+
130
+ return components[self.metric_name_]
131
+
132
+ def report(self, display: bool = False) -> pd.DataFrame:
133
+ """Evaluation report
134
+
135
+ Parameters
136
+ ----------
137
+ display : bool, optional
138
+ Set to True to print the report to stdout.
139
+
140
+ Returns
141
+ -------
142
+ report : pandas.DataFrame
143
+ Dataframe with one column per metric component, one row per
144
+ evaluated item, and one final row for accumulated results.
145
+ """
146
+
147
+ report = []
148
+ uris = []
149
+
150
+ percent = "total" in self.metric_components()
151
+
152
+ for uri, components in self.results_:
153
+ row = {}
154
+ if percent:
155
+ total = components["total"]
156
+ for key, value in components.items():
157
+ if key == self.name:
158
+ row[key, "%"] = 100 * value
159
+ elif key == "total":
160
+ row[key, ""] = value
161
+ else:
162
+ row[key, ""] = value
163
+ if percent:
164
+ if total > 0:
165
+ row[key, "%"] = 100 * value / total
166
+ else:
167
+ row[key, "%"] = np.NaN
168
+
169
+ report.append(row)
170
+ uris.append(uri)
171
+
172
+ row = {}
173
+ components = self.accumulated_
174
+
175
+ if percent:
176
+ total = components["total"]
177
+
178
+ for key, value in components.items():
179
+ if key == self.name:
180
+ row[key, "%"] = 100 * value
181
+ elif key == "total":
182
+ row[key, ""] = value
183
+ else:
184
+ row[key, ""] = value
185
+ if percent:
186
+ if total > 0:
187
+ row[key, "%"] = 100 * value / total
188
+ else:
189
+ row[key, "%"] = np.NaN
190
+
191
+ row[self.name, "%"] = 100 * abs(self)
192
+ report.append(row)
193
+ uris.append("TOTAL")
194
+
195
+ df = pd.DataFrame(report)
196
+
197
+ df["item"] = uris
198
+ df = df.set_index("item")
199
+
200
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
201
+
202
+ df = df[[self.name] + self.metric_components()]
203
+
204
+ if display:
205
+ print(
206
+ df.to_string(
207
+ index=True,
208
+ sparsify=False,
209
+ justify="right",
210
+ float_format=lambda f: "{0:.2f}".format(f),
211
+ )
212
+ )
213
+
214
+ return df
215
+
216
+ def __str__(self):
217
+ report = self.report(display=False)
218
+ return report.to_string(
219
+ sparsify=False, float_format=lambda f: "{0:.2f}".format(f)
220
+ )
221
+
222
+ def __abs__(self):
223
+ """Compute metric value from accumulated components"""
224
+ return self.compute_metric(self.accumulated_)
225
+
226
+ def __getitem__(self, component: str) -> Union[float, Details]:
227
+ """Get value of accumulated `component`.
228
+
229
+ Parameters
230
+ ----------
231
+ component : str
232
+ Name of `component`
233
+
234
+ Returns
235
+ -------
236
+ value : type depends on the metric
237
+ Value of accumulated `component`
238
+
239
+ """
240
+ if component == slice(None, None, None):
241
+ return dict(self.accumulated_)
242
+ else:
243
+ return self.accumulated_[component]
244
+
245
+ def __iter__(self):
246
+ """Iterator over the accumulated (uri, value)"""
247
+ for uri, component in self.results_:
248
+ yield uri, component
249
+
250
+ def compute_components(self,
251
+ reference: Union[Timeline, Annotation],
252
+ hypothesis: Union[Timeline, Annotation],
253
+ **kwargs) -> Details:
254
+ """Compute metric components
255
+
256
+ Parameters
257
+ ----------
258
+ reference : type depends on the metric
259
+ Manual `reference`
260
+ hypothesis : same as `reference`
261
+ Evaluated `hypothesis`
262
+
263
+ Returns
264
+ -------
265
+ components : dict
266
+ Dictionary where keys are component names and values are component
267
+ values
268
+
269
+ """
270
+ raise NotImplementedError(
271
+ self.__class__.__name__ + " is missing a 'compute_components' method."
272
+ "It should return a dictionary where keys are component names "
273
+ "and values are component values."
274
+ )
275
+
276
+ def compute_metric(self, components: Details):
277
+ """Compute metric value from computed `components`
278
+
279
+ Parameters
280
+ ----------
281
+ components : dict
282
+ Dictionary where keys are components names and values are component
283
+ values
284
+
285
+ Returns
286
+ -------
287
+ value : type depends on the metric
288
+ Metric value
289
+ """
290
+ raise NotImplementedError(
291
+ self.__class__.__name__ + " is missing a 'compute_metric' method. "
292
+ "It should return the actual value of the metric based "
293
+ "on the precomputed component dictionary given as input."
294
+ )
295
+
296
+ def confidence_interval(self, alpha: float = 0.9) \
297
+ -> Tuple[float, Tuple[float, float]]:
298
+ """Compute confidence interval on accumulated metric values
299
+
300
+ Parameters
301
+ ----------
302
+ alpha : float, optional
303
+ Probability that the returned confidence interval contains
304
+ the true metric value.
305
+
306
+ Returns
307
+ -------
308
+ (center, (lower, upper))
309
+ with center the mean of the conditional pdf of the metric value
310
+ and (lower, upper) is a confidence interval centered on the median,
311
+ containing the estimate to a probability alpha.
312
+
313
+ See Also:
314
+ ---------
315
+ scipy.stats.bayes_mvs
316
+
317
+ """
318
+
319
+ values = [r[self.metric_name_] for _, r in self.results_]
320
+
321
+ if len(values) == 0:
322
+ raise ValueError("Please evaluate a bunch of files before computing confidence interval.")
323
+
324
+ elif len(values) == 1:
325
+ warnings.warn("Cannot compute a reliable confidence interval out of just one file.")
326
+ center = lower = upper = values[0]
327
+ return center, (lower, upper)
328
+
329
+ else:
330
+ return scipy.stats.bayes_mvs(values, alpha=alpha)[0]
331
+
332
+
333
+ PRECISION_NAME = "precision"
334
+ PRECISION_RETRIEVED = "# retrieved"
335
+ PRECISION_RELEVANT_RETRIEVED = "# relevant retrieved"
336
+
337
+
338
+ class Precision(BaseMetric):
339
+ """
340
+ :class:`Precision` is a base class for precision-like evaluation metrics.
341
+
342
+ It defines two components '# retrieved' and '# relevant retrieved' and the
343
+ compute_metric() method to compute the actual precision:
344
+
345
+ Precision = # retrieved / # relevant retrieved
346
+
347
+ Inheriting classes must implement compute_components().
348
+ """
349
+
350
+ @classmethod
351
+ def metric_name(cls):
352
+ return PRECISION_NAME
353
+
354
+ @classmethod
355
+ def metric_components(cls) -> MetricComponents:
356
+ return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED]
357
+
358
+ def compute_metric(self, components: Details) -> float:
359
+ """Compute precision from `components`"""
360
+ numerator = components[PRECISION_RELEVANT_RETRIEVED]
361
+ denominator = components[PRECISION_RETRIEVED]
362
+ if denominator == 0.0:
363
+ if numerator == 0:
364
+ return 1.0
365
+ else:
366
+ raise ValueError("")
367
+ else:
368
+ return numerator / denominator
369
+
370
+
371
+ RECALL_NAME = "recall"
372
+ RECALL_RELEVANT = "# relevant"
373
+ RECALL_RELEVANT_RETRIEVED = "# relevant retrieved"
374
+
375
+
376
+ class Recall(BaseMetric):
377
+ """
378
+ :class:`Recall` is a base class for recall-like evaluation metrics.
379
+
380
+ It defines two components '# relevant' and '# relevant retrieved' and the
381
+ compute_metric() method to compute the actual recall:
382
+
383
+ Recall = # relevant retrieved / # relevant
384
+
385
+ Inheriting classes must implement compute_components().
386
+ """
387
+
388
+ @classmethod
389
+ def metric_name(cls):
390
+ return RECALL_NAME
391
+
392
+ @classmethod
393
+ def metric_components(cls) -> MetricComponents:
394
+ return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED]
395
+
396
+ def compute_metric(self, components: Details) -> float:
397
+ """Compute recall from `components`"""
398
+ numerator = components[RECALL_RELEVANT_RETRIEVED]
399
+ denominator = components[RECALL_RELEVANT]
400
+ if denominator == 0.0:
401
+ if numerator == 0:
402
+ return 1.0
403
+ else:
404
+ raise ValueError("")
405
+ else:
406
+ return numerator / denominator
407
+
408
+
409
+ def f_measure(precision: float, recall: float, beta=1.0) -> float:
410
+ """Compute f-measure
411
+
412
+ f-measure is defined as follows:
413
+ F(P, R, b) = (1+b²).P.R / (b².P + R)
414
+
415
+ where P is `precision`, R is `recall` and b is `beta`
416
+ """
417
+ if precision + recall == 0.0:
418
+ return 0
419
+ return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
ailia-models/code/pyannote_audio_utils/metrics/diarization.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2012-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ """Metrics for diarization"""
30
+ from typing import Optional, Dict, TYPE_CHECKING
31
+
32
+ from pyannote_audio_utils.core import Annotation, Timeline
33
+ from pyannote_audio_utils.core.utils.types import Label
34
+
35
+ from .identification import IdentificationErrorRate
36
+ from .matcher import HungarianMapper
37
+ from .types import Details, MetricComponents
38
+
39
+ if TYPE_CHECKING:
40
+ pass
41
+
42
+ # TODO: can't we put these as class attributes?
43
+ DER_NAME = 'diarization error rate'
44
+
45
+
46
+ class DiarizationErrorRate(IdentificationErrorRate):
47
+ """Diarization error rate
48
+
49
+ First, the optimal mapping between reference and hypothesis labels
50
+ is obtained using the Hungarian algorithm. Then, the actual diarization
51
+ error rate is computed as the identification error rate with each hypothesis
52
+ label translated into the corresponding reference label.
53
+
54
+ Parameters
55
+ ----------
56
+ collar : float, optional
57
+ Duration (in seconds) of collars removed from evaluation around
58
+ boundaries of reference segments.
59
+ skip_overlap : bool, optional
60
+ Set to True to not evaluate overlap regions.
61
+ Defaults to False (i.e. keep overlap regions).
62
+
63
+ Usage
64
+ -----
65
+
66
+ * Diarization error rate between `reference` and `hypothesis` annotations
67
+
68
+ >>> metric = DiarizationErrorRate()
69
+ >>> reference = Annotation(...) # doctest: +SKIP
70
+ >>> hypothesis = Annotation(...) # doctest: +SKIP
71
+ >>> value = metric(reference, hypothesis) # doctest: +SKIP
72
+
73
+ * Compute global diarization error rate and confidence interval
74
+ over multiple documents
75
+
76
+ >>> for reference, hypothesis in ... # doctest: +SKIP
77
+ ... metric(reference, hypothesis) # doctest: +SKIP
78
+ >>> global_value = abs(metric) # doctest: +SKIP
79
+ >>> mean, (lower, upper) = metric.confidence_interval() # doctest: +SKIP
80
+
81
+ * Get diarization error rate detailed components
82
+
83
+ >>> components = metric(reference, hypothesis, detailed=True) #doctest +SKIP
84
+
85
+ * Get accumulated components
86
+
87
+ >>> components = metric[:] # doctest: +SKIP
88
+ >>> metric['confusion'] # doctest: +SKIP
89
+
90
+ See Also
91
+ --------
92
+ :class:`pyannote_audio_utils.metric.base.BaseMetric`: details on accumulation
93
+ :class:`pyannote_audio_utils.metric.identification.IdentificationErrorRate`: identification error rate
94
+
95
+ """
96
+
97
+ @classmethod
98
+ def metric_name(cls) -> str:
99
+ return DER_NAME
100
+
101
+ def __init__(self, collar: float = 0.0, skip_overlap: bool = False,
102
+ **kwargs):
103
+ super().__init__(collar=collar, skip_overlap=skip_overlap, **kwargs)
104
+ self.mapper_ = HungarianMapper()
105
+
106
+ def optimal_mapping(self,
107
+ reference: Annotation,
108
+ hypothesis: Annotation,
109
+ uem: Optional[Timeline] = None) -> Dict[Label, Label]:
110
+ """Optimal label mapping
111
+
112
+ Parameters
113
+ ----------
114
+ reference : Annotation
115
+ hypothesis : Annotation
116
+ Reference and hypothesis diarization
117
+ uem : Timeline
118
+ Evaluation map
119
+
120
+ Returns
121
+ -------
122
+ mapping : dict
123
+ Mapping between hypothesis (key) and reference (value) labels
124
+ """
125
+
126
+ # NOTE that this 'uemification' will not be called when
127
+ # 'optimal_mapping' is called from 'compute_components' as it
128
+ # has already been done in 'compute_components'
129
+ if uem:
130
+ reference, hypothesis = self.uemify(reference, hypothesis, uem=uem)
131
+
132
+ # call hungarian mapper
133
+ return self.mapper_(hypothesis, reference)
134
+
135
+ def compute_components(self,
136
+ reference: Annotation,
137
+ hypothesis: Annotation,
138
+ uem: Optional[Timeline] = None,
139
+ **kwargs) -> Details:
140
+ # crop reference and hypothesis to evaluated regions (uem)
141
+ # remove collars around reference segment boundaries
142
+ # remove overlap regions (if requested)
143
+ reference, hypothesis, uem = self.uemify(
144
+ reference, hypothesis, uem=uem,
145
+ collar=self.collar, skip_overlap=self.skip_overlap,
146
+ returns_uem=True)
147
+ # NOTE that this 'uemification' must be done here because it
148
+ # might have an impact on the search for the optimal mapping.
149
+
150
+ # make sure reference only contains string labels ('A', 'B', ...)
151
+ reference = reference.rename_labels(generator='string')
152
+
153
+ # make sure hypothesis only contains integer labels (1, 2, ...)
154
+ hypothesis = hypothesis.rename_labels(generator='int')
155
+
156
+ # optimal (int --> str) mapping
157
+ mapping = self.optimal_mapping(reference, hypothesis)
158
+
159
+ # compute identification error rate based on mapped hypothesis
160
+ # NOTE that collar is set to 0.0 because 'uemify' has already
161
+ # been applied (same reason for setting skip_overlap to False)
162
+ mapped = hypothesis.rename_labels(mapping=mapping)
163
+ return super(DiarizationErrorRate, self) \
164
+ .compute_components(reference, mapped, uem=uem,
165
+ collar=0.0, skip_overlap=False,
166
+ **kwargs)
167
+
ailia-models/code/pyannote_audio_utils/metrics/identification.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2012-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ from typing import Optional
29
+
30
+ from pyannote_audio_utils.core import Annotation, Timeline
31
+
32
+ from .base import BaseMetric
33
+ from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED
34
+ from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED
35
+ from .matcher import LabelMatcher, \
36
+ MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \
37
+ MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM
38
+ from .types import MetricComponents, Details
39
+ from .utils import UEMSupportMixin
40
+
41
+ # TODO: can't we put these as class attributes?
42
+ IER_TOTAL = MATCH_TOTAL
43
+ IER_CORRECT = MATCH_CORRECT
44
+ IER_CONFUSION = MATCH_CONFUSION
45
+ IER_FALSE_ALARM = MATCH_FALSE_ALARM
46
+ IER_MISS = MATCH_MISSED_DETECTION
47
+ IER_NAME = 'identification error rate'
48
+
49
+
50
+ class IdentificationErrorRate(UEMSupportMixin, BaseMetric):
51
+ """Identification error rate
52
+
53
+ ``ier = (wc x confusion + wf x false_alarm + wm x miss) / total``
54
+
55
+ where
56
+ - `confusion` is the total confusion duration in seconds
57
+ - `false_alarm` is the total hypothesis duration where there are
58
+ - `miss` is
59
+ - `total` is the total duration of all tracks
60
+ - wc, wf and wm are optional weights (default to 1)
61
+
62
+ Parameters
63
+ ----------
64
+ collar : float, optional
65
+ Duration (in seconds) of collars removed from evaluation around
66
+ boundaries of reference segments.
67
+ skip_overlap : bool, optional
68
+ Set to True to not evaluate overlap regions.
69
+ Defaults to False (i.e. keep overlap regions).
70
+ confusion, miss, false_alarm: float, optional
71
+ Optional weights for confusion, miss and false alarm respectively.
72
+ Default to 1. (no weight)
73
+ """
74
+
75
+ @classmethod
76
+ def metric_name(cls) -> str:
77
+ return IER_NAME
78
+
79
+ @classmethod
80
+ def metric_components(cls) -> MetricComponents:
81
+ return [
82
+ IER_TOTAL,
83
+ IER_CORRECT,
84
+ IER_FALSE_ALARM, IER_MISS,
85
+ IER_CONFUSION]
86
+
87
+ def __init__(self,
88
+ confusion: float = 1.,
89
+ miss: float = 1.,
90
+ false_alarm: float = 1.,
91
+ collar: float = 0.,
92
+ skip_overlap: bool = False,
93
+ **kwargs):
94
+
95
+ super().__init__(**kwargs)
96
+ self.matcher_ = LabelMatcher()
97
+ self.confusion = confusion
98
+ self.miss = miss
99
+ self.false_alarm = false_alarm
100
+ self.collar = collar
101
+ self.skip_overlap = skip_overlap
102
+
103
+ def compute_components(self,
104
+ reference: Annotation,
105
+ hypothesis: Annotation,
106
+ uem: Optional[Timeline] = None,
107
+ collar: Optional[float] = None,
108
+ skip_overlap: Optional[float] = None,
109
+ **kwargs) -> Details:
110
+ """
111
+
112
+ Parameters
113
+ ----------
114
+ collar : float, optional
115
+ Override self.collar
116
+ skip_overlap : bool, optional
117
+ Override self.skip_overlap
118
+
119
+ See also
120
+ --------
121
+ :class:`pyannote_audio_utils.metric.diarization.DiarizationErrorRate` uses these
122
+ two options in its `compute_components` method.
123
+
124
+ """
125
+
126
+ detail = self.init_components()
127
+
128
+ if collar is None:
129
+ collar = self.collar
130
+ if skip_overlap is None:
131
+ skip_overlap = self.skip_overlap
132
+
133
+ R, H, common_timeline = self.uemify(
134
+ reference, hypothesis, uem=uem,
135
+ collar=collar, skip_overlap=skip_overlap,
136
+ returns_timeline=True)
137
+
138
+ # loop on all segments
139
+ for segment in common_timeline:
140
+ # segment duration
141
+ duration = segment.duration
142
+
143
+ # list of IDs in reference segment
144
+ r = R.get_labels(segment, unique=False)
145
+
146
+ # list of IDs in hypothesis segment
147
+ h = H.get_labels(segment, unique=False)
148
+
149
+ counts, _ = self.matcher_(r, h)
150
+
151
+ detail[IER_TOTAL] += duration * counts[IER_TOTAL]
152
+ detail[IER_CORRECT] += duration * counts[IER_CORRECT]
153
+ detail[IER_CONFUSION] += duration * counts[IER_CONFUSION]
154
+ detail[IER_MISS] += duration * counts[IER_MISS]
155
+ detail[IER_FALSE_ALARM] += duration * counts[IER_FALSE_ALARM]
156
+
157
+ return detail
158
+
159
+ def compute_metric(self, detail: Details) -> float:
160
+
161
+ numerator = 1. * (
162
+ self.confusion * detail[IER_CONFUSION] +
163
+ self.false_alarm * detail[IER_FALSE_ALARM] +
164
+ self.miss * detail[IER_MISS]
165
+ )
166
+ denominator = 1. * detail[IER_TOTAL]
167
+ if denominator == 0.:
168
+ if numerator == 0:
169
+ return 0.
170
+ else:
171
+ return 1.
172
+ else:
173
+ return numerator / denominator
174
+
175
+
176
+ class IdentificationPrecision(UEMSupportMixin, Precision):
177
+ """Identification Precision
178
+
179
+ Parameters
180
+ ----------
181
+ collar : float, optional
182
+ Duration (in seconds) of collars removed from evaluation around
183
+ boundaries of reference segments.
184
+ skip_overlap : bool, optional
185
+ Set to True to not evaluate overlap regions.
186
+ Defaults to False (i.e. keep overlap regions).
187
+ """
188
+
189
+ def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs):
190
+ super().__init__(**kwargs)
191
+ self.collar = collar
192
+ self.skip_overlap = skip_overlap
193
+ self.matcher_ = LabelMatcher()
194
+
195
+ def compute_components(self,
196
+ reference: Annotation,
197
+ hypothesis: Annotation,
198
+ uem: Optional[Timeline] = None,
199
+ **kwargs) -> Details:
200
+ detail = self.init_components()
201
+
202
+ R, H, common_timeline = self.uemify(
203
+ reference, hypothesis, uem=uem,
204
+ collar=self.collar, skip_overlap=self.skip_overlap,
205
+ returns_timeline=True)
206
+
207
+ # loop on all segments
208
+ for segment in common_timeline:
209
+ # segment duration
210
+ duration = segment.duration
211
+
212
+ # list of IDs in reference segment
213
+ r = R.get_labels(segment, unique=False)
214
+
215
+ # list of IDs in hypothesis segment
216
+ h = H.get_labels(segment, unique=False)
217
+
218
+ counts, _ = self.matcher_(r, h)
219
+
220
+ detail[PRECISION_RETRIEVED] += duration * len(h)
221
+ detail[PRECISION_RELEVANT_RETRIEVED] += \
222
+ duration * counts[IER_CORRECT]
223
+
224
+ return detail
225
+
226
+
227
+ class IdentificationRecall(UEMSupportMixin, Recall):
228
+ """Identification Recall
229
+
230
+ Parameters
231
+ ----------
232
+ collar : float, optional
233
+ Duration (in seconds) of collars removed from evaluation around
234
+ boundaries of reference segments.
235
+ skip_overlap : bool, optional
236
+ Set to True to not evaluate overlap regions.
237
+ Defaults to False (i.e. keep overlap regions).
238
+ """
239
+
240
+ def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs):
241
+ super().__init__(**kwargs)
242
+ self.collar = collar
243
+ self.skip_overlap = skip_overlap
244
+ self.matcher_ = LabelMatcher()
245
+
246
+ def compute_components(self,
247
+ reference: Annotation,
248
+ hypothesis: Annotation,
249
+ uem: Optional[Timeline] = None,
250
+ **kwargs) -> Details:
251
+ detail = self.init_components()
252
+
253
+ R, H, common_timeline = self.uemify(
254
+ reference, hypothesis, uem=uem,
255
+ collar=self.collar, skip_overlap=self.skip_overlap,
256
+ returns_timeline=True)
257
+
258
+ # loop on all segments
259
+ for segment in common_timeline:
260
+ # segment duration
261
+ duration = segment.duration
262
+
263
+ # list of IDs in reference segment
264
+ r = R.get_labels(segment, unique=False)
265
+
266
+ # list of IDs in hypothesis segment
267
+ h = H.get_labels(segment, unique=False)
268
+
269
+ counts, _ = self.matcher_(r, h)
270
+
271
+ detail[RECALL_RELEVANT] += duration * counts[IER_TOTAL]
272
+ detail[RECALL_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT]
273
+
274
+ return detail
ailia-models/code/pyannote_audio_utils/metrics/matcher.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2012-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ from typing import Dict, Tuple, Iterable, List, TYPE_CHECKING
29
+
30
+ import numpy as np
31
+ from pyannote_audio_utils.core import Annotation
32
+ from scipy.optimize import linear_sum_assignment
33
+
34
+ if TYPE_CHECKING:
35
+ from pyannote_audio_utils.core.utils.types import Label
36
+
37
+ MATCH_CORRECT = 'correct'
38
+ MATCH_CONFUSION = 'confusion'
39
+ MATCH_MISSED_DETECTION = 'missed detection'
40
+ MATCH_FALSE_ALARM = 'false alarm'
41
+ MATCH_TOTAL = 'total'
42
+
43
+
44
+ class LabelMatcher:
45
+ """
46
+ ID matcher base class mixin.
47
+
48
+ All ID matcher classes must inherit from this class and implement
49
+ .match() -- ie return True if two IDs match and False
50
+ otherwise.
51
+ """
52
+
53
+ def match(self, rlabel: 'Label', hlabel: 'Label') -> bool:
54
+ """
55
+ Parameters
56
+ ----------
57
+ rlabel :
58
+ Reference label
59
+ hlabel :
60
+ Hypothesis label
61
+
62
+ Returns
63
+ -------
64
+ match : bool
65
+ True if labels match, False otherwise.
66
+
67
+ """
68
+ # Two IDs match if they are equal to each other
69
+ return rlabel == hlabel
70
+
71
+ def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \
72
+ -> Tuple[Dict[str, int],
73
+ Dict[str, List['Label']]]:
74
+ """
75
+
76
+ Parameters
77
+ ----------
78
+ rlabels, hlabels : iterable
79
+ Reference and hypothesis labels
80
+
81
+ Returns
82
+ -------
83
+ counts : dict
84
+ details : dict
85
+
86
+ """
87
+
88
+ # counts and details
89
+ counts = {
90
+ MATCH_CORRECT: 0,
91
+ MATCH_CONFUSION: 0,
92
+ MATCH_MISSED_DETECTION: 0,
93
+ MATCH_FALSE_ALARM: 0,
94
+ MATCH_TOTAL: 0
95
+ }
96
+
97
+ details = {
98
+ MATCH_CORRECT: [],
99
+ MATCH_CONFUSION: [],
100
+ MATCH_MISSED_DETECTION: [],
101
+ MATCH_FALSE_ALARM: []
102
+ }
103
+ # this is to make sure rlabels and hlabels are lists
104
+ # as we will access them later by index
105
+ rlabels = list(rlabels)
106
+ hlabels = list(hlabels)
107
+
108
+ NR = len(rlabels)
109
+ NH = len(hlabels)
110
+ N = max(NR, NH)
111
+
112
+ # corner case
113
+ if N == 0:
114
+ return counts, details
115
+
116
+ # initialize match matrix
117
+ # with True if labels match and False otherwise
118
+ match = np.zeros((N, N), dtype=bool)
119
+ for r, rlabel in enumerate(rlabels):
120
+ for h, hlabel in enumerate(hlabels):
121
+ match[r, h] = self.match(rlabel, hlabel)
122
+
123
+ # find one-to-one mapping that maximize total number of matches
124
+ # using the Hungarian algorithm and computes error accordingly
125
+ for r, h in zip(*linear_sum_assignment(~match)):
126
+
127
+ # hypothesis label is matched with unexisting reference label
128
+ # ==> this is a false alarm
129
+ if r >= NR:
130
+ counts[MATCH_FALSE_ALARM] += 1
131
+ details[MATCH_FALSE_ALARM].append(hlabels[h])
132
+
133
+ # reference label is matched with unexisting hypothesis label
134
+ # ==> this is a missed detection
135
+ elif h >= NH:
136
+ counts[MATCH_MISSED_DETECTION] += 1
137
+ details[MATCH_MISSED_DETECTION].append(rlabels[r])
138
+
139
+ # reference and hypothesis labels match
140
+ # ==> this is a correct detection
141
+ elif match[r, h]:
142
+ counts[MATCH_CORRECT] += 1
143
+ details[MATCH_CORRECT].append((rlabels[r], hlabels[h]))
144
+
145
+ # reference and hypothesis do not match
146
+ # ==> this is a confusion
147
+ else:
148
+ counts[MATCH_CONFUSION] += 1
149
+ details[MATCH_CONFUSION].append((rlabels[r], hlabels[h]))
150
+
151
+ counts[MATCH_TOTAL] += NR
152
+
153
+ # returns counts and details
154
+ return counts, details
155
+
156
+
157
+ class HungarianMapper:
158
+
159
+ def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']:
160
+ mapping = {}
161
+
162
+ cooccurrence = A * B
163
+ a_labels, b_labels = A.labels(), B.labels()
164
+
165
+ for a, b in zip(*linear_sum_assignment(-cooccurrence)):
166
+ if cooccurrence[a, b] > 0:
167
+ mapping[a_labels[a]] = b_labels[b]
168
+
169
+ return mapping
170
+
171
+
172
+ class GreedyMapper:
173
+
174
+ def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']:
175
+ mapping = {}
176
+
177
+ cooccurrence = A * B
178
+ Na, Nb = cooccurrence.shape
179
+ a_labels, b_labels = A.labels(), B.labels()
180
+
181
+ for i in range(min(Na, Nb)):
182
+ a, b = np.unravel_index(np.argmax(cooccurrence), (Na, Nb))
183
+
184
+ if cooccurrence[a, b] > 0:
185
+ mapping[a_labels[a]] = b_labels[b]
186
+ cooccurrence[a, :] = 0.
187
+ cooccurrence[:, b] = 0.
188
+ continue
189
+
190
+ break
191
+
192
+ return mapping
ailia-models/code/pyannote_audio_utils/metrics/types.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Literal
2
+
3
+
4
+ MetricComponent = str
5
+ CalibrationMethod = Literal["isotonic", "sigmoid"]
6
+ MetricComponents = List[MetricComponent]
7
+ Details = Dict[MetricComponent, float]
ailia-models/code/pyannote_audio_utils/metrics/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2012-2019 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ import warnings
30
+ from typing import Optional, Tuple, Union
31
+
32
+ from pyannote_audio_utils.core import Timeline, Segment, Annotation
33
+
34
+
35
+ class UEMSupportMixin:
36
+ """Provides 'uemify' method with optional (à la NIST) collar"""
37
+
38
+ def extrude(self,
39
+ uem: Timeline,
40
+ reference: Annotation,
41
+ collar: float = 0.0,
42
+ skip_overlap: bool = False) -> Timeline:
43
+ """Extrude reference boundary collars from uem
44
+
45
+ reference |----| |--------------| |-------------|
46
+ uem |---------------------| |-------------------------------|
47
+ extruded |--| |--| |---| |-----| |-| |-----| |-----------| |-----|
48
+
49
+ Parameters
50
+ ----------
51
+ uem : Timeline
52
+ Evaluation map.
53
+ reference : Annotation
54
+ Reference annotation.
55
+ collar : float, optional
56
+ When provided, set the duration of collars centered around
57
+ reference segment boundaries that are extruded from both reference
58
+ and hypothesis. Defaults to 0. (i.e. no collar).
59
+ skip_overlap : bool, optional
60
+ Set to True to not evaluate overlap regions.
61
+ Defaults to False (i.e. keep overlap regions).
62
+
63
+ Returns
64
+ -------
65
+ extruded_uem : Timeline
66
+ """
67
+
68
+ if collar == 0. and not skip_overlap:
69
+ return uem
70
+
71
+ collars, overlap_regions = [], []
72
+
73
+ # build list of collars if needed
74
+ if collar > 0.:
75
+ # iterate over all segments in reference
76
+ for segment in reference.itersegments():
77
+ # add collar centered on start time
78
+ t = segment.start
79
+ collars.append(Segment(t - .5 * collar, t + .5 * collar))
80
+
81
+ # add collar centered on end time
82
+ t = segment.end
83
+ collars.append(Segment(t - .5 * collar, t + .5 * collar))
84
+
85
+ # build list of overlap regions if needed
86
+ if skip_overlap:
87
+ # iterate over pair of intersecting segments
88
+ for (segment1, track1), (segment2, track2) in reference.co_iter(reference):
89
+ if segment1 == segment2 and track1 == track2:
90
+ continue
91
+ # add their intersection
92
+ overlap_regions.append(segment1 & segment2)
93
+
94
+ segments = collars + overlap_regions
95
+
96
+ return Timeline(segments=segments).support().gaps(support=uem)
97
+
98
+ def common_timeline(self, reference: Annotation, hypothesis: Annotation) \
99
+ -> Timeline:
100
+ """Return timeline common to both reference and hypothesis
101
+
102
+ reference |--------| |------------| |---------| |----|
103
+ hypothesis |--------------| |------| |----------------|
104
+ timeline |--|-----|----|---|-|------| |-|---------|----| |----|
105
+
106
+ Parameters
107
+ ----------
108
+ reference : Annotation
109
+ hypothesis : Annotation
110
+
111
+ Returns
112
+ -------
113
+ timeline : Timeline
114
+ """
115
+ timeline = reference.get_timeline(copy=True)
116
+ timeline.update(hypothesis.get_timeline(copy=False))
117
+ return timeline.segmentation()
118
+
119
+ def project(self, annotation: Annotation, timeline: Timeline) -> Annotation:
120
+ """Project annotation onto timeline segments
121
+
122
+ reference |__A__| |__B__|
123
+ |____C____|
124
+
125
+ timeline |---|---|---| |---|
126
+
127
+ projection |_A_|_A_|_C_| |_B_|
128
+ |_C_|
129
+
130
+ Parameters
131
+ ----------
132
+ annotation : Annotation
133
+ timeline : Timeline
134
+
135
+ Returns
136
+ -------
137
+ projection : Annotation
138
+ """
139
+ projection = annotation.empty()
140
+ timeline_ = annotation.get_timeline(copy=False)
141
+ for segment_, segment in timeline_.co_iter(timeline):
142
+ for track_ in annotation.get_tracks(segment_):
143
+ track = projection.new_track(segment, candidate=track_)
144
+ projection[segment, track] = annotation[segment_, track_]
145
+ return projection
146
+
147
+ def uemify(self,
148
+ reference: Annotation,
149
+ hypothesis: Annotation,
150
+ uem: Optional[Timeline] = None,
151
+ collar: float = 0.,
152
+ skip_overlap: bool = False,
153
+ returns_uem: bool = False,
154
+ returns_timeline: bool = False) \
155
+ -> Union[
156
+ Tuple[Annotation, Annotation],
157
+ Tuple[Annotation, Annotation, Timeline],
158
+ Tuple[Annotation, Annotation, Timeline, Timeline],
159
+ ]:
160
+ """Crop 'reference' and 'hypothesis' to 'uem' support
161
+
162
+ Parameters
163
+ ----------
164
+ reference, hypothesis : Annotation
165
+ Reference and hypothesis annotations.
166
+ uem : Timeline, optional
167
+ Evaluation map.
168
+ collar : float, optional
169
+ When provided, set the duration of collars centered around
170
+ reference segment boundaries that are extruded from both reference
171
+ and hypothesis. Defaults to 0. (i.e. no collar).
172
+ skip_overlap : bool, optional
173
+ Set to True to not evaluate overlap regions.
174
+ Defaults to False (i.e. keep overlap regions).
175
+ returns_uem : bool, optional
176
+ Set to True to return extruded uem as well.
177
+ Defaults to False (i.e. only return reference and hypothesis)
178
+ returns_timeline : bool, optional
179
+ Set to True to oversegment reference and hypothesis so that they
180
+ share the same internal timeline.
181
+
182
+ Returns
183
+ -------
184
+ reference, hypothesis : Annotation
185
+ Extruded reference and hypothesis annotations
186
+ uem : Timeline
187
+ Extruded uem (returned only when 'returns_uem' is True)
188
+ timeline : Timeline:
189
+ Common timeline (returned only when 'returns_timeline' is True)
190
+ """
191
+
192
+ # when uem is not provided, use the union of reference and hypothesis
193
+ # extents -- and warn the user about that.
194
+ if uem is None:
195
+ r_extent = reference.get_timeline().extent()
196
+ h_extent = hypothesis.get_timeline().extent()
197
+ extent = r_extent | h_extent
198
+ uem = Timeline(segments=[extent] if extent else [],
199
+ uri=reference.uri)
200
+ warnings.warn(
201
+ "'uem' was approximated by the union of 'reference' "
202
+ "and 'hypothesis' extents.")
203
+
204
+ # extrude collars (and overlap regions) from uem
205
+ uem = self.extrude(uem, reference, collar=collar,
206
+ skip_overlap=skip_overlap)
207
+
208
+ # extrude regions outside of uem
209
+ reference = reference.crop(uem, mode='intersection')
210
+ hypothesis = hypothesis.crop(uem, mode='intersection')
211
+
212
+ # project reference and hypothesis on common timeline
213
+ if returns_timeline:
214
+ timeline = self.common_timeline(reference, hypothesis)
215
+ reference = self.project(reference, timeline)
216
+ hypothesis = self.project(hypothesis, timeline)
217
+
218
+ result = (reference, hypothesis)
219
+ if returns_uem:
220
+ result += (uem,)
221
+
222
+ if returns_timeline:
223
+ result += (timeline,)
224
+
225
+ return result
ailia-models/code/pyannote_audio_utils/pipeline/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2018-2020 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+
30
+ # from ._version import get_versions
31
+
32
+ # __version__ = get_versions()["version"]
33
+ # del get_versions
34
+
35
+
36
+ from .pipeline import Pipeline
37
+ # from .optimizer import Optimizer
ailia-models/code/pyannote_audio_utils/pipeline/parameter.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2018-2020 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+ # Hadrien TITEUX - https://github.com/hadware
29
+
30
+
31
+ from typing import Iterable, Any
32
+
33
+ from .pipeline import Pipeline
34
+ from collections.abc import Mapping
35
+
36
+
37
+ class Parameter:
38
+ """Base hyper-parameter"""
39
+
40
+ pass
41
+
42
+
43
+ class Categorical(Parameter):
44
+ """Categorical hyper-parameter
45
+
46
+ The value is sampled from `choices`.
47
+
48
+ Parameters
49
+ ----------
50
+ choices : iterable
51
+ Candidates of hyper-parameter value.
52
+ """
53
+
54
+ def __init__(self, choices: Iterable):
55
+ super().__init__()
56
+ self.choices = list(choices)
57
+
58
+ # def __call__(self, name: str, trial: Trial):
59
+ # return trial.suggest_categorical(name, self.choices)
60
+
61
+
62
+ class DiscreteUniform(Parameter):
63
+ """Discrete uniform hyper-parameter
64
+
65
+ The value is sampled from the range [low, high],
66
+ and the step of discretization is `q`.
67
+
68
+ Parameters
69
+ ----------
70
+ low : `float`
71
+ Lower endpoint of the range of suggested values.
72
+ `low` is included in the range.
73
+ high : `float`
74
+ Upper endpoint of the range of suggested values.
75
+ `high` is included in the range.
76
+ q : `float`
77
+ A step of discretization.
78
+ """
79
+
80
+ def __init__(self, low: float, high: float, q: float):
81
+ super().__init__()
82
+ self.low = float(low)
83
+ self.high = float(high)
84
+ self.q = float(q)
85
+
86
+ # def __call__(self, name: str, trial: Trial):
87
+ # return trial.suggest_discrete_uniform(name, self.low, self.high, self.q)
88
+
89
+
90
+ class Integer(Parameter):
91
+ """Integer hyper-parameter
92
+
93
+ The value is sampled from the integers in [low, high].
94
+
95
+ Parameters
96
+ ----------
97
+ low : `int`
98
+ Lower endpoint of the range of suggested values.
99
+ `low` is included in the range.
100
+ high : `int`
101
+ Upper endpoint of the range of suggested values.
102
+ `high` is included in the range.
103
+ """
104
+
105
+ def __init__(self, low: int, high: int):
106
+ super().__init__()
107
+ self.low = int(low)
108
+ self.high = int(high)
109
+
110
+ # def __call__(self, name: str, trial: Trial):
111
+ # return trial.suggest_int(name, self.low, self.high)
112
+
113
+
114
+ class LogUniform(Parameter):
115
+ """Log-uniform hyper-parameter
116
+
117
+ The value is sampled from the range [low, high) in the log domain.
118
+
119
+ Parameters
120
+ ----------
121
+ low : `float`
122
+ Lower endpoint of the range of suggested values.
123
+ `low` is included in the range.
124
+ high : `float`
125
+ Upper endpoint of the range of suggested values.
126
+ `high` is excluded from the range.
127
+ """
128
+
129
+ def __init__(self, low: float, high: float):
130
+ super().__init__()
131
+ self.low = float(low)
132
+ self.high = float(high)
133
+
134
+ # def __call__(self, name: str, trial: Trial):
135
+ # return trial.suggest_loguniform(name, self.low, self.high)
136
+
137
+
138
+ class Uniform(Parameter):
139
+ """Uniform hyper-parameter
140
+
141
+ The value is sampled from the range [low, high) in the linear domain.
142
+
143
+ Parameters
144
+ ----------
145
+ low : `float`
146
+ Lower endpoint of the range of suggested values.
147
+ `low` is included in the range.
148
+ high : `float`
149
+ Upper endpoint of the range of suggested values.
150
+ `high` is excluded from the range.
151
+ """
152
+
153
+ def __init__(self, low: float, high: float):
154
+ super().__init__()
155
+ self.low = float(low)
156
+ self.high = float(high)
157
+
158
+ # def __call__(self, name: str, trial: Trial):
159
+ # return trial.suggest_uniform(name, self.low, self.high)
160
+
161
+
162
+ class Frozen(Parameter):
163
+ """Frozen hyper-parameter
164
+
165
+ The value is fixed a priori
166
+
167
+ Parameters
168
+ ----------
169
+ value :
170
+ Fixed value.
171
+ """
172
+
173
+ def __init__(self, value: Any):
174
+ super().__init__()
175
+ self.value = value
176
+
177
+ # def __call__(self, name: str, trial: Trial):
178
+ # return self.value
179
+
180
+
181
+ class ParamDict(Pipeline, Mapping):
182
+ """Dict-like structured hyper-parameter
183
+
184
+ Usage
185
+ -----
186
+ >>> params = ParamDict(param1=Uniform(0.0, 1.0), param2=Uniform(-1.0, 1.0))
187
+ >>> params = ParamDict(**{"param1": Uniform(0.0, 1.0), "param2": Uniform(-1.0, 1.0)})
188
+ """
189
+
190
+ def __init__(self, **params):
191
+ super().__init__()
192
+ self.__params = params
193
+ for param_name, param_value in params.items():
194
+ setattr(self, param_name, param_value)
195
+
196
+ def __len__(self):
197
+ return len(self.__params)
198
+
199
+ def __iter__(self):
200
+ return iter(self.__params)
201
+
202
+ def __getitem__(self, param_name):
203
+ return getattr(self, param_name)
ailia-models/code/pyannote_audio_utils/pipeline/pipeline.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ # The MIT License (MIT)
5
+
6
+ # Copyright (c) 2018-2022 CNRS
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ # AUTHORS
27
+ # Hervé BREDIN - http://herve.niderb.fr
28
+
29
+ from typing import Optional, TextIO, Union, Dict, Any
30
+
31
+ from collections import OrderedDict
32
+ from .typing import PipelineInput
33
+ from .typing import PipelineOutput
34
+ import warnings
35
+
36
+
37
+ class Pipeline:
38
+ """Base tunable pipeline"""
39
+
40
+ def __init__(self):
41
+
42
+ # un-instantiated parameters (= `Parameter` instances)
43
+ self._parameters: Dict[str, Parameter] = OrderedDict()
44
+
45
+ # instantiated parameters
46
+ self._instantiated: Dict[str, Any] = OrderedDict()
47
+
48
+ # sub-pipelines
49
+ self._pipelines: Dict[str, Pipeline] = OrderedDict()
50
+
51
+ # whether pipeline is currently being optimized
52
+ self.training = False
53
+
54
+ @property
55
+ def training(self):
56
+ return self._training
57
+
58
+ @training.setter
59
+ def training(self, training):
60
+ self._training = training
61
+ # recursively set sub-pipeline training attribute
62
+ for _, pipeline in self._pipelines.items():
63
+ pipeline.training = training
64
+
65
+ def __hash__(self):
66
+ # FIXME -- also keep track of (sub)pipeline attributes
67
+ frozen = self.parameters(frozen=True)
68
+ return hash(tuple(sorted(self._flatten(frozen).items())))
69
+
70
+ def __getattr__(self, name):
71
+ """(Advanced) attribute getter"""
72
+
73
+ # in case `name` corresponds to an instantiated parameter value, returns it
74
+ if "_instantiated" in self.__dict__:
75
+ _instantiated = self.__dict__["_instantiated"]
76
+ if name in _instantiated:
77
+ return _instantiated[name]
78
+
79
+ # in case `name` corresponds to a parameter, returns it
80
+ if "_parameters" in self.__dict__:
81
+ _parameters = self.__dict__["_parameters"]
82
+ if name in _parameters:
83
+ return _parameters[name]
84
+
85
+ # in case `name` corresponds to a sub-pipeline, returns it
86
+ if "_pipelines" in self.__dict__:
87
+ _pipelines = self.__dict__["_pipelines"]
88
+ if name in _pipelines:
89
+ return _pipelines[name]
90
+
91
+ msg = "'{}' object has no attribute '{}'".format(type(self).__name__, name)
92
+ raise AttributeError(msg)
93
+
94
+ def __setattr__(self, name, value):
95
+ """(Advanced) attribute setter
96
+
97
+ If `value` is an instance of `Parameter`, store it in `_parameters`.
98
+ elif `value` is an instance of `Pipeline`, store it in `_pipelines`.
99
+ elif `value` isn't an instance of `Parameter` and `name` is in `_parameters`,
100
+ store `value` in `_instantiated`.
101
+ """
102
+
103
+ # imported here to avoid circular import
104
+ from .parameter import Parameter
105
+
106
+ def remove_from(*dicts):
107
+ for d in dicts:
108
+ if name in d:
109
+ del d[name]
110
+
111
+ _parameters = self.__dict__.get("_parameters")
112
+ _instantiated = self.__dict__.get("_instantiated")
113
+ _pipelines = self.__dict__.get("_pipelines")
114
+
115
+ # if `value` is an instance of `Parameter`, store it in `_parameters`
116
+
117
+ if isinstance(value, Parameter):
118
+ if _parameters is None:
119
+ msg = (
120
+ "cannot assign hyper-parameters " "before Pipeline.__init__() call"
121
+ )
122
+ raise AttributeError(msg)
123
+ remove_from(self.__dict__, _instantiated, _pipelines)
124
+ _parameters[name] = value
125
+ return
126
+
127
+ # add/update one sub-pipeline
128
+ if isinstance(value, Pipeline):
129
+ if _pipelines is None:
130
+ msg = "cannot assign sub-pipelines " "before Pipeline.__init__() call"
131
+ raise AttributeError(msg)
132
+ remove_from(self.__dict__, _parameters, _instantiated)
133
+ _pipelines[name] = value
134
+ return
135
+
136
+ # store instantiated parameter value
137
+ if _parameters is not None and name in _parameters:
138
+ _instantiated[name] = value
139
+ return
140
+
141
+ object.__setattr__(self, name, value)
142
+
143
+ def __delattr__(self, name):
144
+
145
+ if name in self._parameters:
146
+ del self._parameters[name]
147
+
148
+ elif name in self._instantiated:
149
+ del self._instantiated[name]
150
+
151
+ elif name in self._pipelines:
152
+ del self._pipelines[name]
153
+
154
+ else:
155
+ object.__delattr__(self, name)
156
+
157
+ def _flattened_parameters(
158
+ self, frozen: Optional[bool] = False, instantiated: Optional[bool] = False
159
+ ) -> dict:
160
+ """Get flattened dictionary of parameters
161
+
162
+ Parameters
163
+ ----------
164
+ frozen : `bool`, optional
165
+ Only return value of frozen parameters.
166
+ instantiated : `bool`, optional
167
+ Only return value of instantiated parameters.
168
+
169
+ Returns
170
+ -------
171
+ params : `dict`
172
+ Flattened dictionary of parameters.
173
+ """
174
+
175
+ # imported here to avoid circular imports
176
+ from .parameter import Frozen
177
+
178
+ if frozen and instantiated:
179
+ msg = "one must choose between `frozen` and `instantiated`."
180
+ raise ValueError(msg)
181
+
182
+ # initialize dictionary with root parameters
183
+ if instantiated:
184
+ params = dict(self._instantiated)
185
+
186
+ elif frozen:
187
+ params = {
188
+ n: p.value for n, p in self._parameters.items() if isinstance(p, Frozen)
189
+ }
190
+
191
+ else:
192
+ params = dict(self._parameters)
193
+
194
+ # recursively add sub-pipeline parameters
195
+ for pipeline_name, pipeline in self._pipelines.items():
196
+ pipeline_params = pipeline._flattened_parameters(
197
+ frozen=frozen, instantiated=instantiated
198
+ )
199
+ for name, value in pipeline_params.items():
200
+ params[f"{pipeline_name}>{name}"] = value
201
+
202
+ return params
203
+
204
+ def _flatten(self, nested_params: dict) -> dict:
205
+ """Convert nested dictionary to flattened dictionary
206
+
207
+ For instance, a nested dictionary like this one:
208
+
209
+ ~~~~~~~~~~~~~~~~~~~~~
210
+ param: value1
211
+ pipeline:
212
+ param: value2
213
+ subpipeline:
214
+ param: value3
215
+ ~~~~~~~~~~~~~~~~~~~~~
216
+
217
+ becomes the following flattened dictionary:
218
+
219
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
220
+ param : value1
221
+ pipeline>param : value2
222
+ pipeline>subpipeline>param : value3
223
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
224
+
225
+ Parameter
226
+ ---------
227
+ nested_params : `dict`
228
+
229
+ Returns
230
+ -------
231
+ flattened_params : `dict`
232
+ """
233
+ flattened_params = dict()
234
+ for name, value in nested_params.items():
235
+ if isinstance(value, dict):
236
+ for subname, subvalue in self._flatten(value).items():
237
+ flattened_params[f"{name}>{subname}"] = subvalue
238
+ else:
239
+ flattened_params[name] = value
240
+ return flattened_params
241
+
242
+ def _unflatten(self, flattened_params: dict) -> dict:
243
+ """Convert flattened dictionary to nested dictionary
244
+
245
+ For instance, a flattened dictionary like this one:
246
+
247
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
248
+ param : value1
249
+ pipeline>param : value2
250
+ pipeline>subpipeline>param : value3
251
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
252
+
253
+ becomes the following nested dictionary:
254
+
255
+ ~~~~~~~~~~~~~~~~~~~~~
256
+ param: value1
257
+ pipeline:
258
+ param: value2
259
+ subpipeline:
260
+ param: value3
261
+ ~~~~~~~~~~~~~~~~~~~~~
262
+
263
+ Parameter
264
+ ---------
265
+ flattened_params : `dict`
266
+
267
+ Returns
268
+ -------
269
+ nested_params : `dict`
270
+ """
271
+
272
+ nested_params = {}
273
+
274
+ pipeline_params = {name: {} for name in self._pipelines}
275
+ for name, value in flattened_params.items():
276
+ # if name contains has multipe ">"-separated tokens
277
+ # it means that it is a sub-pipeline parameter
278
+ tokens = name.split(">")
279
+ if len(tokens) > 1:
280
+ # read sub-pipeline name
281
+ pipeline_name = tokens[0]
282
+ # read parameter name
283
+ param_name = ">".join(tokens[1:])
284
+ # update sub-pipeline flattened dictionary
285
+ pipeline_params[pipeline_name][param_name] = value
286
+
287
+ # otherwise, it is an actual parameter of this pipeline
288
+ else:
289
+ # store it as such
290
+ nested_params[name] = value
291
+
292
+ # recursively unflatten sub-pipeline flattened dictionary
293
+ for name, pipeline in self._pipelines.items():
294
+ nested_params[name] = pipeline._unflatten(pipeline_params[name])
295
+
296
+ return nested_params
297
+
298
+ def parameters(
299
+ self,
300
+ trial = None,
301
+ frozen: Optional[bool] = False,
302
+ instantiated: Optional[bool] = False,
303
+ ) -> dict:
304
+ """Returns nested dictionary of (optionnaly instantiated) parameters.
305
+
306
+ For a pipeline with one `param`, one sub-pipeline with its own param
307
+ and its own sub-pipeline, it will returns something like:
308
+
309
+ ~~~~~~~~~~~~~~~~~~~~~
310
+ param: value1
311
+ pipeline:
312
+ param: value2
313
+ subpipeline:
314
+ param: value3
315
+ ~~~~~~~~~~~~~~~~~~~~~
316
+
317
+ Parameter
318
+ ---------
319
+ trial : `Trial`, optional
320
+ When provided, use trial to suggest new parameter values
321
+ and return them.
322
+ frozen : `bool`, optional
323
+ Return frozen parameter value
324
+ instantiated : `bool`, optional
325
+ Return instantiated parameter values.
326
+
327
+ Returns
328
+ -------
329
+ params : `dict`
330
+ Nested dictionary of parameters. See above for the actual format.
331
+ """
332
+
333
+ if (instantiated or frozen) and trial is not None:
334
+ msg = "One must choose between `trial`, `instantiated`, or `frozen`"
335
+ raise ValueError(msg)
336
+
337
+ # get flattened dictionary of uninstantiated parameters
338
+ params = self._flattened_parameters(frozen=frozen, instantiated=instantiated)
339
+
340
+ if trial is not None:
341
+ # use provided `trial` to suggest values for parameters
342
+ params = {name: param(name, trial) for name, param in params.items()}
343
+
344
+ # un-flatten flattened dictionary
345
+ return self._unflatten(params)
346
+
347
+ def initialize(self):
348
+ """Instantiate root pipeline with current set of parameters"""
349
+ pass
350
+
351
+ # def freeze(self, params: dict) -> "Pipeline":
352
+ # """Recursively freeze pipeline parameters
353
+
354
+ # Parameters
355
+ # ----------
356
+ # params : `dict`
357
+ # Nested dictionary of parameters.
358
+
359
+ # Returns
360
+ # -------
361
+ # self : `Pipeline`
362
+ # Pipeline.
363
+ # """
364
+
365
+ # # imported here to avoid circular imports
366
+ # from .parameter import Frozen
367
+
368
+ # for name, value in params.items():
369
+
370
+ # # recursively freeze sub-pipelines parameters
371
+ # if name in self._pipelines:
372
+ # if not isinstance(value, dict):
373
+ # msg = (
374
+ # f"only parameters of '{name}' pipeline can "
375
+ # f"be frozen (not the whole pipeline)"
376
+ # )
377
+ # raise ValueError(msg)
378
+ # self._pipelines[name].freeze(value)
379
+ # continue
380
+
381
+ # # instantiate parameter value
382
+ # if name in self._parameters:
383
+ # setattr(self, name, Frozen(value))
384
+ # continue
385
+
386
+ # msg = f"parameter '{name}' does not exist"
387
+ # raise ValueError(msg)
388
+
389
+ # return self
390
+
391
+ def instantiate(self, params: dict) -> "Pipeline":
392
+ """Recursively instantiate all pipelines
393
+
394
+ Parameters
395
+ ----------
396
+ params : `dict`
397
+ Nested dictionary of parameters.
398
+
399
+ Returns
400
+ -------
401
+ self : `Pipeline`
402
+ Instantiated pipeline.
403
+ """
404
+
405
+ # imported here to avoid circular imports
406
+ from .parameter import Frozen
407
+
408
+ for name, value in params.items():
409
+
410
+ # recursively call `instantiate` with sub-pipelines
411
+ if name in self._pipelines:
412
+ if not isinstance(value, dict):
413
+ msg = (
414
+ f"only parameters of '{name}' pipeline can "
415
+ f"be instantiated (not the whole pipeline)"
416
+ )
417
+ raise ValueError(msg)
418
+ self._pipelines[name].instantiate(value)
419
+ continue
420
+
421
+ # instantiate parameter value
422
+ if name in self._parameters:
423
+ param = getattr(self, name)
424
+ # overwrite provided value of frozen parameters
425
+ if isinstance(param, Frozen) and param.value != value:
426
+ msg = (
427
+ f"Parameter '{name}' is frozen: using its frozen value "
428
+ f"({param.value}) instead of the one provided ({value})."
429
+ )
430
+ warnings.warn(msg)
431
+ value = param.value
432
+ setattr(self, name, value)
433
+ continue
434
+
435
+ msg = f"parameter '{name}' does not exist"
436
+ raise ValueError(msg)
437
+
438
+ self.initialize()
439
+
440
+ return self
441
+
442
+ @property
443
+ def instantiated(self):
444
+ """Whether pipeline has been instantiated (and therefore can be applied)"""
445
+ parameters = set(self._flatten(self.parameters()))
446
+ instantiated = set(self._flatten(self.parameters(instantiated=True)))
447
+ return parameters == instantiated
448
+
449
+ # def dump_params(
450
+ # self,
451
+ # params_yml: Path,
452
+ # params: Optional[dict] = None,
453
+ # loss: Optional[float] = None,
454
+ # ) -> str:
455
+ # """Dump parameters to disk
456
+
457
+ # Parameters
458
+ # ----------
459
+ # params_yml : `Path`
460
+ # Path to YAML file.
461
+ # params : `dict`, optional
462
+ # Nested Parameters. Defaults to pipeline current parameters.
463
+ # loss : `float`, optional
464
+ # Loss value. Defaults to not write loss to file.
465
+
466
+ # Returns
467
+ # -------
468
+ # content : `str`
469
+ # Content written in `param_yml`.
470
+ # """
471
+ # # use instantiated parameters when `params` is not provided
472
+ # if params is None:
473
+ # params = self.parameters(instantiated=True)
474
+
475
+ # content = {"params": params}
476
+ # if loss is not None:
477
+ # content["loss"] = loss
478
+
479
+ # # format as valid YAML
480
+ # content_yml = yaml.dump(content, default_flow_style=False)
481
+
482
+ # # (safely) dump YAML content
483
+ # with FileLock(params_yml.with_suffix(".lock")):
484
+ # with open(params_yml, mode="w") as fp:
485
+ # fp.write(content_yml)
486
+
487
+ # return content_yml
488
+
489
+ # def load_params(self, params_yml: Path) -> "Pipeline":
490
+ # """Instantiate pipeline using parameters from disk
491
+
492
+ # Parameters
493
+ # ----------
494
+ # param_yml : `Path`
495
+ # Path to YAML file.
496
+
497
+ # Returns
498
+ # -------
499
+ # self : `Pipeline`
500
+ # Instantiated pipeline
501
+
502
+ # """
503
+
504
+ # with open(params_yml, mode="r") as fp:
505
+ # params = yaml.load(fp, Loader=yaml.SafeLoader)
506
+ # return self.instantiate(params["params"])
507
+
508
+ def __call__(self, input: PipelineInput) -> PipelineOutput:
509
+ """Apply pipeline on input and return its output"""
510
+ raise NotImplementedError
511
+
512
+ # def get_metric(self) -> "pyannote.metrics.base.BaseMetric":
513
+ # """Return new metric (from pyannote.metrics)
514
+
515
+ # When this method is implemented, the returned metric is used as a
516
+ # replacement for the loss method below.
517
+
518
+ # Returns
519
+ # -------
520
+ # metric : `pyannote.metrics.base.BaseMetric`
521
+ # """
522
+ # raise NotImplementedError()
523
+
524
+ # def get_direction(self) -> Direction:
525
+ # return "minimize"
526
+
527
+ # def loss(self, input: PipelineInput, output: PipelineOutput) -> float:
528
+ # """Compute loss for given input/output pair
529
+
530
+ # Parameters
531
+ # ----------
532
+ # input : object
533
+ # Pipeline input.
534
+ # output : object
535
+ # Pipeline output
536
+
537
+ # Returns
538
+ # -------
539
+ # loss : `float`
540
+ # Loss value
541
+ # """
542
+ # raise NotImplementedError()
543
+
544
+ # @property
545
+ # def write_format(self):
546
+ # return "rttm"
547
+
548
+ # def write(self, file: TextIO, output: PipelineOutput):
549
+ # """Write pipeline output to file
550
+
551
+ # Parameters
552
+ # ----------
553
+ # file : file object
554
+ # output : object
555
+ # Pipeline output
556
+ # """
557
+
558
+ # return getattr(self, f"write_{self.write_format}")(file, output)
559
+
560
+ # def write_rttm(self, file: TextIO, output: Union[Timeline, Annotation]):
561
+ # """Write pipeline output to "rttm" file
562
+
563
+ # Parameters
564
+ # ----------
565
+ # file : file object
566
+ # output : `pyannote.core.Timeline` or `pyannote.core.Annotation`
567
+ # Pipeline output
568
+ # """
569
+
570
+ # if isinstance(output, Timeline):
571
+ # output = output.to_annotation(generator="string")
572
+
573
+ # if isinstance(output, Annotation):
574
+ # for s, t, l in output.itertracks(yield_label=True):
575
+ # line = (
576
+ # f"SPEAKER {output.uri} 1 {s.start:.3f} {s.duration:.3f} "
577
+ # f"<NA> <NA> {l} <NA> <NA>\n"
578
+ # )
579
+ # file.write(line)
580
+ # return
581
+
582
+ # msg = (
583
+ # f'Dumping {output.__class__.__name__} instances to "rttm" files '
584
+ # f"is not supported."
585
+ # )
586
+ # raise NotImplementedError(msg)
587
+
588
+ # def write_txt(self, file: TextIO, output: Union[Timeline, Annotation]):
589
+ # """Write pipeline output to "txt" file
590
+
591
+ # Parameters
592
+ # ----------
593
+ # file : file object
594
+ # output : `pyannote.core.Timeline` or `pyannote.core.Annotation`
595
+ # Pipeline output
596
+ # """
597
+
598
+ # if isinstance(output, Timeline):
599
+ # for s in output:
600
+ # line = f"{output.uri} {s.start:.3f} {s.end:.3f}\n"
601
+ # file.write(line)
602
+ # return
603
+
604
+ # if isinstance(output, Annotation):
605
+ # for s, t, l in output.itertracks(yield_label=True):
606
+ # line = f"{output.uri} {s.start:.3f} {s.end:.3f} {t} {l}\n"
607
+ # file.write(line)
608
+ # return
609
+
610
+ # msg = (
611
+ # f'Dumping {output.__class__.__name__} instances to "txt" files '
612
+ # f"is not supported."
613
+ # )
614
+ # raise NotImplementedError(msg)