|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
######### |
|
|
Protocols |
|
|
######### |
|
|
|
|
|
""" |
|
|
|
|
|
import warnings |
|
|
import collections |
|
|
import threading |
|
|
import itertools |
|
|
from typing import Union, Dict, Iterator, Callable, Any, Text, Optional |
|
|
|
|
|
|
|
|
from typing import Literal |
|
|
|
|
|
|
|
|
|
|
|
Subset = Literal["train", "development", "test"] |
|
|
LEGACY_SUBSET_MAPPING = {"train": "trn", "development": "dev", "test": "tst"} |
|
|
Scope = Literal["file", "database", "global"] |
|
|
|
|
|
Preprocessor = Callable[["ProtocolFile"], Any] |
|
|
Preprocessors = Dict[Text, Preprocessor] |
|
|
|
|
|
|
|
|
class ProtocolFile(collections.abc.MutableMapping): |
|
|
"""Protocol file with lazy preprocessors |
|
|
|
|
|
This is a dict-like data structure where some values may depend on other |
|
|
values, and are only computed if/when requested. Once computed, they are |
|
|
cached and never recomputed again. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
precomputed : dict |
|
|
Regular dictionary with precomputed values |
|
|
lazy : dict, optional |
|
|
Dictionary describing how lazy value needs to be computed. |
|
|
Values are callable expecting a dictionary as input and returning the |
|
|
computed value. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, precomputed: Union[Dict, "ProtocolFile"], lazy: Dict = None): |
|
|
|
|
|
if lazy is None: |
|
|
lazy = dict() |
|
|
|
|
|
if isinstance(precomputed, ProtocolFile): |
|
|
|
|
|
|
|
|
|
|
|
self._store: Dict = abs(precomputed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for key in set(precomputed.lazy) & set(lazy): |
|
|
self._store[key] = precomputed[key] |
|
|
|
|
|
|
|
|
compound_lazy = dict(precomputed.lazy) |
|
|
compound_lazy.update(lazy) |
|
|
self.lazy: Dict = compound_lazy |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
self._store = dict(precomputed) |
|
|
self.lazy = dict(lazy) |
|
|
|
|
|
|
|
|
self.lock_ = threading.RLock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.evaluating_ = collections.Counter() |
|
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
d = dict(self.__dict__) |
|
|
del d["lock_"] |
|
|
return d |
|
|
|
|
|
|
|
|
def __setstate__(self, d): |
|
|
self.__dict__.update(d) |
|
|
self.lock_ = threading.RLock() |
|
|
|
|
|
def __abs__(self): |
|
|
with self.lock_: |
|
|
return dict(self._store) |
|
|
|
|
|
def __getitem__(self, key): |
|
|
with self.lock_: |
|
|
|
|
|
if key in self.lazy and self.evaluating_[key] == 0: |
|
|
|
|
|
|
|
|
self.evaluating_.update([key]) |
|
|
|
|
|
|
|
|
value = self.lazy[key](self) |
|
|
del self.lazy[key] |
|
|
|
|
|
|
|
|
if key in self._store and value != self._store[key]: |
|
|
msg = 'Existing precomputed key "{key}" has been modified by a preprocessor.' |
|
|
warnings.warn(msg.format(key=key)) |
|
|
|
|
|
|
|
|
|
|
|
self._store[key] = value |
|
|
|
|
|
|
|
|
self.evaluating_.subtract([key]) |
|
|
|
|
|
return self._store[key] |
|
|
|
|
|
def __setitem__(self, key, value): |
|
|
with self.lock_: |
|
|
|
|
|
if key in self.lazy: |
|
|
del self.lazy[key] |
|
|
|
|
|
self._store[key] = value |
|
|
|
|
|
def __delitem__(self, key): |
|
|
with self.lock_: |
|
|
|
|
|
if key in self.lazy: |
|
|
del self.lazy[key] |
|
|
|
|
|
del self._store[key] |
|
|
|
|
|
def __iter__(self): |
|
|
with self.lock_: |
|
|
|
|
|
store_keys = list(self._store) |
|
|
for key in store_keys: |
|
|
yield key |
|
|
|
|
|
lazy_keys = list(self.lazy) |
|
|
for key in lazy_keys: |
|
|
if key in self._store: |
|
|
continue |
|
|
yield key |
|
|
|
|
|
def __len__(self): |
|
|
with self.lock_: |
|
|
return len(set(self._store) | set(self.lazy)) |
|
|
|
|
|
def files(self) -> Iterator["ProtocolFile"]: |
|
|
"""Iterate over all files |
|
|
|
|
|
When `current_file` refers to only one file, |
|
|
yield it and return. |
|
|
When `current_file` refers to a list of file (i.e. 'uri' is a list), |
|
|
yield each file separately. |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
>>> current_file = ProtocolFile({ |
|
|
... 'uri': 'my_uri', |
|
|
... 'database': 'my_database'}) |
|
|
>>> for file in current_file.files(): |
|
|
... print(file['uri'], file['database']) |
|
|
my_uri my_database |
|
|
|
|
|
>>> current_file = { |
|
|
... 'uri': ['my_uri1', 'my_uri2', 'my_uri3'], |
|
|
... 'database': 'my_database'} |
|
|
>>> for file in current_file.files(): |
|
|
... print(file['uri'], file['database']) |
|
|
my_uri1 my_database |
|
|
my_uri2 my_database |
|
|
my_uri3 my_database |
|
|
|
|
|
""" |
|
|
|
|
|
uris = self["uri"] |
|
|
if not isinstance(uris, list): |
|
|
yield self |
|
|
return |
|
|
|
|
|
n_uris = len(uris) |
|
|
|
|
|
|
|
|
|
|
|
precomputed = {"uri": uris} |
|
|
for key, value in abs(self).items(): |
|
|
|
|
|
if key == "uri": |
|
|
continue |
|
|
|
|
|
if not isinstance(value, list): |
|
|
precomputed[key] = itertools.repeat(value) |
|
|
|
|
|
else: |
|
|
if len(value) != n_uris: |
|
|
msg = ( |
|
|
f'Mismatch between number of "uris" ({n_uris}) ' |
|
|
f'and number of "{key}" ({len(value)}).' |
|
|
) |
|
|
raise ValueError(msg) |
|
|
precomputed[key] = value |
|
|
|
|
|
keys = list(precomputed.keys()) |
|
|
for values in zip(*precomputed.values()): |
|
|
precomputed_one = dict(zip(keys, values)) |
|
|
yield ProtocolFile(precomputed_one, self.lazy) |
|
|
|
|
|
|
|
|
class Protocol: |
|
|
"""Experimental protocol |
|
|
|
|
|
An experimental protocol usually defines three subsets: a training subset, |
|
|
a development subset, and a test subset. |
|
|
|
|
|
An experimental protocol can be defined programmatically by creating a |
|
|
class that inherits from Protocol and implements at least |
|
|
one of `train_iter`, `development_iter` and `test_iter` methods: |
|
|
|
|
|
>>> class MyProtocol(Protocol): |
|
|
... def train_iter(self) -> Iterator[Dict]: |
|
|
... yield {"uri": "filename1", "any_other_key": "..."} |
|
|
... yield {"uri": "filename2", "any_other_key": "..."} |
|
|
|
|
|
`{subset}_iter` should return an iterator of dictionnaries with |
|
|
- "uri" key (mandatory) that provides a unique file identifier (usually |
|
|
the filename), |
|
|
- any other key that the protocol may provide. |
|
|
|
|
|
It can then be used in Python like this: |
|
|
|
|
|
>>> protocol = MyProtocol() |
|
|
>>> for file in protocol.train(): |
|
|
... print(file["uri"]) |
|
|
filename1 |
|
|
filename2 |
|
|
|
|
|
An experimental protocol can also be defined using `pyannote_audio_utils.database` |
|
|
configuration file, whose (configurable) path defaults to "~/database.yml". |
|
|
|
|
|
~~~ Content of ~/database.yml ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
|
Protocols: |
|
|
MyDatabase: |
|
|
Protocol: |
|
|
MyProtocol: |
|
|
train: |
|
|
uri: /path/to/collection.lst |
|
|
any_other_key: ... # see custom loader documentation |
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
|
|
|
|
where "/path/to/collection.lst" contains the list of identifiers of the |
|
|
files in the collection: |
|
|
|
|
|
~~~ Content of "/path/to/collection.lst ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
|
filename1 |
|
|
filename2 |
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
|
|
|
|
|
It can then be used in Python like this: |
|
|
|
|
|
>>> from pyannote_audio_utils.database import registry |
|
|
>>> protocol = registry.get_protocol('MyDatabase.Protocol.MyProtocol') |
|
|
>>> for file in protocol.train(): |
|
|
... print(file["uri"]) |
|
|
filename1 |
|
|
filename2 |
|
|
|
|
|
This class is usually inherited from, but can be used directly. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
preprocessors : dict |
|
|
Preprocess protocol files so that `file[key] = preprocessors[key](file)` |
|
|
for each key in `preprocessors`. In case `preprocessors[key]` is not |
|
|
callable, it should be a string containing placeholders for `file` keys |
|
|
(e.g. {'audio': '/path/to/{uri}.wav'}) |
|
|
""" |
|
|
|
|
|
def __init__(self, preprocessors: Optional[Preprocessors] = None): |
|
|
super().__init__() |
|
|
|
|
|
if preprocessors is None: |
|
|
preprocessors = dict() |
|
|
|
|
|
self.preprocessors = dict() |
|
|
for key, preprocessor in preprocessors.items(): |
|
|
|
|
|
if callable(preprocessor): |
|
|
self.preprocessors[key] = preprocessor |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(preprocessor, str): |
|
|
preprocessor_copy = str(preprocessor) |
|
|
|
|
|
def func(current_file): |
|
|
return preprocessor_copy.format(**current_file) |
|
|
|
|
|
self.preprocessors[key] = func |
|
|
|
|
|
else: |
|
|
msg = f'"{key}" preprocessor is neither a callable nor a string.' |
|
|
raise ValueError(msg) |
|
|
|
|
|
def preprocess(self, current_file: Union[Dict, ProtocolFile]) -> ProtocolFile: |
|
|
return ProtocolFile(current_file, lazy=self.preprocessors) |
|
|
|
|
|
def __str__(self): |
|
|
return self.__doc__ |
|
|
|
|
|
def train_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: |
|
|
"""Iterate over files in the training subset""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
def development_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: |
|
|
"""Iterate over files in the development subset""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
def test_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: |
|
|
"""Iterate over files in the test subset""" |
|
|
raise NotImplementedError() |
|
|
|
|
|
def subset_helper(self, subset: Subset) -> Iterator[ProtocolFile]: |
|
|
|
|
|
try: |
|
|
files = getattr(self, f"{subset}_iter")() |
|
|
except (AttributeError, NotImplementedError): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subset_legacy = LEGACY_SUBSET_MAPPING[subset] |
|
|
try: |
|
|
files = getattr(self, f"{subset_legacy}_iter")() |
|
|
except AttributeError: |
|
|
msg = f"Protocol does not implement a {subset} subset." |
|
|
raise NotImplementedError(msg) |
|
|
|
|
|
for file in files: |
|
|
yield self.preprocess(file) |
|
|
|
|
|
def train(self) -> Iterator[ProtocolFile]: |
|
|
return self.subset_helper("train") |
|
|
|
|
|
def development(self) -> Iterator[ProtocolFile]: |
|
|
return self.subset_helper("development") |
|
|
|
|
|
def test(self) -> Iterator[ProtocolFile]: |
|
|
return self.subset_helper("test") |
|
|
|
|
|
def files(self) -> Iterator[ProtocolFile]: |
|
|
"""Iterate over all files in `protocol`""" |
|
|
|
|
|
|
|
|
from pyannote_audio_utils.database.util import get_unique_identifier |
|
|
|
|
|
yielded_uris = set() |
|
|
|
|
|
for method in [ |
|
|
"development", |
|
|
"development_enrolment", |
|
|
"development_trial", |
|
|
"test", |
|
|
"test_enrolment", |
|
|
"test_trial", |
|
|
"train", |
|
|
"train_enrolment", |
|
|
"train_trial", |
|
|
]: |
|
|
|
|
|
if not hasattr(self, method): |
|
|
continue |
|
|
|
|
|
def iterate(): |
|
|
try: |
|
|
for file in getattr(self, method)(): |
|
|
yield file |
|
|
except (AttributeError, NotImplementedError): |
|
|
return |
|
|
|
|
|
for current_file in iterate(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "uri" not in current_file: |
|
|
continue |
|
|
|
|
|
for current_file_ in current_file.files(): |
|
|
|
|
|
|
|
|
uri = get_unique_identifier(current_file_) |
|
|
if uri in yielded_uris: |
|
|
continue |
|
|
|
|
|
yield current_file_ |
|
|
|
|
|
yielded_uris.add(uri) |
|
|
|