niobures's picture
Pyannote (models, models_onnx)
8c838e7 verified
#!/usr/bin/env python
# encoding: utf-8
# The MIT License (MIT)
# Copyright (c) 2016-2020 CNRS
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr
"""
#########
Protocols
#########
"""
import warnings
import collections
import threading
import itertools
from typing import Union, Dict, Iterator, Callable, Any, Text, Optional
# try:
from typing import Literal
# except ImportError:
# from typing_extensions import Literal
Subset = Literal["train", "development", "test"]
LEGACY_SUBSET_MAPPING = {"train": "trn", "development": "dev", "test": "tst"}
Scope = Literal["file", "database", "global"]
Preprocessor = Callable[["ProtocolFile"], Any]
Preprocessors = Dict[Text, Preprocessor]
class ProtocolFile(collections.abc.MutableMapping):
"""Protocol file with lazy preprocessors
This is a dict-like data structure where some values may depend on other
values, and are only computed if/when requested. Once computed, they are
cached and never recomputed again.
Parameters
----------
precomputed : dict
Regular dictionary with precomputed values
lazy : dict, optional
Dictionary describing how lazy value needs to be computed.
Values are callable expecting a dictionary as input and returning the
computed value.
"""
def __init__(self, precomputed: Union[Dict, "ProtocolFile"], lazy: Dict = None):
if lazy is None:
lazy = dict()
if isinstance(precomputed, ProtocolFile):
# when 'precomputed' is a ProtocolFile, it may already contain lazy keys.
# we use 'precomputed' precomputed keys as precomputed keys
self._store: Dict = abs(precomputed)
# we handle the corner case where the intersection of 'precomputed' lazy keys
# and 'lazy' keys is not empty. this is currently achieved by "unlazying" the
# 'precomputed' one (which is probably not the most efficient solution).
for key in set(precomputed.lazy) & set(lazy):
self._store[key] = precomputed[key]
# we use the union of 'precomputed' lazy keys and provided 'lazy' keys as lazy keys
compound_lazy = dict(precomputed.lazy)
compound_lazy.update(lazy)
self.lazy: Dict = compound_lazy
else:
# when 'precomputed' is a Dict, we use it directly as precomputed keys
# and 'lazy' as lazy keys.
self._store = dict(precomputed)
self.lazy = dict(lazy)
# re-entrant lock used below to make ProtocolFile thread-safe
self.lock_ = threading.RLock()
# this is needed to avoid infinite recursion
# when a key is both in precomputed and lazy.
# keys with evaluating_ > 0 are currently being evaluated
# and therefore should be taken from precomputed
self.evaluating_ = collections.Counter()
# since RLock is not pickable, remove it before pickling...
def __getstate__(self):
d = dict(self.__dict__)
del d["lock_"]
return d
# ... and add it back when unpickling
def __setstate__(self, d):
self.__dict__.update(d)
self.lock_ = threading.RLock()
def __abs__(self):
with self.lock_:
return dict(self._store)
def __getitem__(self, key):
with self.lock_:
if key in self.lazy and self.evaluating_[key] == 0:
# mark lazy key as being evaluated
self.evaluating_.update([key])
# apply preprocessor once and remove it
value = self.lazy[key](self)
del self.lazy[key]
# warn the user when a precomputed key is modified
if key in self._store and value != self._store[key]:
msg = 'Existing precomputed key "{key}" has been modified by a preprocessor.'
warnings.warn(msg.format(key=key))
# store the output of the lazy computation
# so that it is available for future access
self._store[key] = value
# lazy evaluation is finished for key
self.evaluating_.subtract([key])
return self._store[key]
def __setitem__(self, key, value):
with self.lock_:
if key in self.lazy:
del self.lazy[key]
self._store[key] = value
def __delitem__(self, key):
with self.lock_:
if key in self.lazy:
del self.lazy[key]
del self._store[key]
def __iter__(self):
with self.lock_:
store_keys = list(self._store)
for key in store_keys:
yield key
lazy_keys = list(self.lazy)
for key in lazy_keys:
if key in self._store:
continue
yield key
def __len__(self):
with self.lock_:
return len(set(self._store) | set(self.lazy))
def files(self) -> Iterator["ProtocolFile"]:
"""Iterate over all files
When `current_file` refers to only one file,
yield it and return.
When `current_file` refers to a list of file (i.e. 'uri' is a list),
yield each file separately.
Examples
--------
>>> current_file = ProtocolFile({
... 'uri': 'my_uri',
... 'database': 'my_database'})
>>> for file in current_file.files():
... print(file['uri'], file['database'])
my_uri my_database
>>> current_file = {
... 'uri': ['my_uri1', 'my_uri2', 'my_uri3'],
... 'database': 'my_database'}
>>> for file in current_file.files():
... print(file['uri'], file['database'])
my_uri1 my_database
my_uri2 my_database
my_uri3 my_database
"""
uris = self["uri"]
if not isinstance(uris, list):
yield self
return
n_uris = len(uris)
# iterate over precomputed keys and make sure
precomputed = {"uri": uris}
for key, value in abs(self).items():
if key == "uri":
continue
if not isinstance(value, list):
precomputed[key] = itertools.repeat(value)
else:
if len(value) != n_uris:
msg = (
f'Mismatch between number of "uris" ({n_uris}) '
f'and number of "{key}" ({len(value)}).'
)
raise ValueError(msg)
precomputed[key] = value
keys = list(precomputed.keys())
for values in zip(*precomputed.values()):
precomputed_one = dict(zip(keys, values))
yield ProtocolFile(precomputed_one, self.lazy)
class Protocol:
"""Experimental protocol
An experimental protocol usually defines three subsets: a training subset,
a development subset, and a test subset.
An experimental protocol can be defined programmatically by creating a
class that inherits from Protocol and implements at least
one of `train_iter`, `development_iter` and `test_iter` methods:
>>> class MyProtocol(Protocol):
... def train_iter(self) -> Iterator[Dict]:
... yield {"uri": "filename1", "any_other_key": "..."}
... yield {"uri": "filename2", "any_other_key": "..."}
`{subset}_iter` should return an iterator of dictionnaries with
- "uri" key (mandatory) that provides a unique file identifier (usually
the filename),
- any other key that the protocol may provide.
It can then be used in Python like this:
>>> protocol = MyProtocol()
>>> for file in protocol.train():
... print(file["uri"])
filename1
filename2
An experimental protocol can also be defined using `pyannote_audio_utils.database`
configuration file, whose (configurable) path defaults to "~/database.yml".
~~~ Content of ~/database.yml ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Protocols:
MyDatabase:
Protocol:
MyProtocol:
train:
uri: /path/to/collection.lst
any_other_key: ... # see custom loader documentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
where "/path/to/collection.lst" contains the list of identifiers of the
files in the collection:
~~~ Content of "/path/to/collection.lst ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
filename1
filename2
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It can then be used in Python like this:
>>> from pyannote_audio_utils.database import registry
>>> protocol = registry.get_protocol('MyDatabase.Protocol.MyProtocol')
>>> for file in protocol.train():
... print(file["uri"])
filename1
filename2
This class is usually inherited from, but can be used directly.
Parameters
----------
preprocessors : dict
Preprocess protocol files so that `file[key] = preprocessors[key](file)`
for each key in `preprocessors`. In case `preprocessors[key]` is not
callable, it should be a string containing placeholders for `file` keys
(e.g. {'audio': '/path/to/{uri}.wav'})
"""
def __init__(self, preprocessors: Optional[Preprocessors] = None):
super().__init__()
if preprocessors is None:
preprocessors = dict()
self.preprocessors = dict()
for key, preprocessor in preprocessors.items():
if callable(preprocessor):
self.preprocessors[key] = preprocessor
# when `preprocessor` is not callable, it should be a string
# containing placeholder for item key (e.g. '/path/to/{uri}.wav')
elif isinstance(preprocessor, str):
preprocessor_copy = str(preprocessor)
def func(current_file):
return preprocessor_copy.format(**current_file)
self.preprocessors[key] = func
else:
msg = f'"{key}" preprocessor is neither a callable nor a string.'
raise ValueError(msg)
def preprocess(self, current_file: Union[Dict, ProtocolFile]) -> ProtocolFile:
return ProtocolFile(current_file, lazy=self.preprocessors)
def __str__(self):
return self.__doc__
def train_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
"""Iterate over files in the training subset"""
raise NotImplementedError()
def development_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
"""Iterate over files in the development subset"""
raise NotImplementedError()
def test_iter(self) -> Iterator[Union[Dict, ProtocolFile]]:
"""Iterate over files in the test subset"""
raise NotImplementedError()
def subset_helper(self, subset: Subset) -> Iterator[ProtocolFile]:
try:
files = getattr(self, f"{subset}_iter")()
except (AttributeError, NotImplementedError):
# previous pyannote_audio_utils.database versions used `trn_iter` instead of
# `train_iter`, `dev_iter` instead of `development_iter`, and
# `tst_iter` instead of `test_iter`. therefore, we use the legacy
# version when it is available (and the new one is not).
subset_legacy = LEGACY_SUBSET_MAPPING[subset]
try:
files = getattr(self, f"{subset_legacy}_iter")()
except AttributeError:
msg = f"Protocol does not implement a {subset} subset."
raise NotImplementedError(msg)
for file in files:
yield self.preprocess(file)
def train(self) -> Iterator[ProtocolFile]:
return self.subset_helper("train")
def development(self) -> Iterator[ProtocolFile]:
return self.subset_helper("development")
def test(self) -> Iterator[ProtocolFile]:
return self.subset_helper("test")
def files(self) -> Iterator[ProtocolFile]:
"""Iterate over all files in `protocol`"""
# imported here to avoid circular imports
from pyannote_audio_utils.database.util import get_unique_identifier
yielded_uris = set()
for method in [
"development",
"development_enrolment",
"development_trial",
"test",
"test_enrolment",
"test_trial",
"train",
"train_enrolment",
"train_trial",
]:
if not hasattr(self, method):
continue
def iterate():
try:
for file in getattr(self, method)():
yield file
except (AttributeError, NotImplementedError):
return
for current_file in iterate():
# skip "files" that do not contain a "uri" entry.
# this happens for speaker verification trials that contain
# two nested files "file1" and "file2"
# see https://github.com/pyannote_audio_utils/pyannote_audio_utils-db-voxceleb/issues/4
if "uri" not in current_file:
continue
for current_file_ in current_file.files():
# corner case when the same file is yielded several times
uri = get_unique_identifier(current_file_)
if uri in yielded_uris:
continue
yield current_file_
yielded_uris.add(uri)