#!/usr/bin/env python # encoding: utf-8 # The MIT License (MIT) # Copyright (c) 2016-2020 CNRS # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # AUTHORS # Hervé BREDIN - http://herve.niderb.fr """ ######### Protocols ######### """ import warnings import collections import threading import itertools from typing import Union, Dict, Iterator, Callable, Any, Text, Optional # try: from typing import Literal # except ImportError: # from typing_extensions import Literal Subset = Literal["train", "development", "test"] LEGACY_SUBSET_MAPPING = {"train": "trn", "development": "dev", "test": "tst"} Scope = Literal["file", "database", "global"] Preprocessor = Callable[["ProtocolFile"], Any] Preprocessors = Dict[Text, Preprocessor] class ProtocolFile(collections.abc.MutableMapping): """Protocol file with lazy preprocessors This is a dict-like data structure where some values may depend on other values, and are only computed if/when requested. Once computed, they are cached and never recomputed again. Parameters ---------- precomputed : dict Regular dictionary with precomputed values lazy : dict, optional Dictionary describing how lazy value needs to be computed. Values are callable expecting a dictionary as input and returning the computed value. """ def __init__(self, precomputed: Union[Dict, "ProtocolFile"], lazy: Dict = None): if lazy is None: lazy = dict() if isinstance(precomputed, ProtocolFile): # when 'precomputed' is a ProtocolFile, it may already contain lazy keys. # we use 'precomputed' precomputed keys as precomputed keys self._store: Dict = abs(precomputed) # we handle the corner case where the intersection of 'precomputed' lazy keys # and 'lazy' keys is not empty. this is currently achieved by "unlazying" the # 'precomputed' one (which is probably not the most efficient solution). for key in set(precomputed.lazy) & set(lazy): self._store[key] = precomputed[key] # we use the union of 'precomputed' lazy keys and provided 'lazy' keys as lazy keys compound_lazy = dict(precomputed.lazy) compound_lazy.update(lazy) self.lazy: Dict = compound_lazy else: # when 'precomputed' is a Dict, we use it directly as precomputed keys # and 'lazy' as lazy keys. self._store = dict(precomputed) self.lazy = dict(lazy) # re-entrant lock used below to make ProtocolFile thread-safe self.lock_ = threading.RLock() # this is needed to avoid infinite recursion # when a key is both in precomputed and lazy. # keys with evaluating_ > 0 are currently being evaluated # and therefore should be taken from precomputed self.evaluating_ = collections.Counter() # since RLock is not pickable, remove it before pickling... def __getstate__(self): d = dict(self.__dict__) del d["lock_"] return d # ... and add it back when unpickling def __setstate__(self, d): self.__dict__.update(d) self.lock_ = threading.RLock() def __abs__(self): with self.lock_: return dict(self._store) def __getitem__(self, key): with self.lock_: if key in self.lazy and self.evaluating_[key] == 0: # mark lazy key as being evaluated self.evaluating_.update([key]) # apply preprocessor once and remove it value = self.lazy[key](self) del self.lazy[key] # warn the user when a precomputed key is modified if key in self._store and value != self._store[key]: msg = 'Existing precomputed key "{key}" has been modified by a preprocessor.' warnings.warn(msg.format(key=key)) # store the output of the lazy computation # so that it is available for future access self._store[key] = value # lazy evaluation is finished for key self.evaluating_.subtract([key]) return self._store[key] def __setitem__(self, key, value): with self.lock_: if key in self.lazy: del self.lazy[key] self._store[key] = value def __delitem__(self, key): with self.lock_: if key in self.lazy: del self.lazy[key] del self._store[key] def __iter__(self): with self.lock_: store_keys = list(self._store) for key in store_keys: yield key lazy_keys = list(self.lazy) for key in lazy_keys: if key in self._store: continue yield key def __len__(self): with self.lock_: return len(set(self._store) | set(self.lazy)) def files(self) -> Iterator["ProtocolFile"]: """Iterate over all files When `current_file` refers to only one file, yield it and return. When `current_file` refers to a list of file (i.e. 'uri' is a list), yield each file separately. Examples -------- >>> current_file = ProtocolFile({ ... 'uri': 'my_uri', ... 'database': 'my_database'}) >>> for file in current_file.files(): ... print(file['uri'], file['database']) my_uri my_database >>> current_file = { ... 'uri': ['my_uri1', 'my_uri2', 'my_uri3'], ... 'database': 'my_database'} >>> for file in current_file.files(): ... print(file['uri'], file['database']) my_uri1 my_database my_uri2 my_database my_uri3 my_database """ uris = self["uri"] if not isinstance(uris, list): yield self return n_uris = len(uris) # iterate over precomputed keys and make sure precomputed = {"uri": uris} for key, value in abs(self).items(): if key == "uri": continue if not isinstance(value, list): precomputed[key] = itertools.repeat(value) else: if len(value) != n_uris: msg = ( f'Mismatch between number of "uris" ({n_uris}) ' f'and number of "{key}" ({len(value)}).' ) raise ValueError(msg) precomputed[key] = value keys = list(precomputed.keys()) for values in zip(*precomputed.values()): precomputed_one = dict(zip(keys, values)) yield ProtocolFile(precomputed_one, self.lazy) class Protocol: """Experimental protocol An experimental protocol usually defines three subsets: a training subset, a development subset, and a test subset. An experimental protocol can be defined programmatically by creating a class that inherits from Protocol and implements at least one of `train_iter`, `development_iter` and `test_iter` methods: >>> class MyProtocol(Protocol): ... def train_iter(self) -> Iterator[Dict]: ... yield {"uri": "filename1", "any_other_key": "..."} ... yield {"uri": "filename2", "any_other_key": "..."} `{subset}_iter` should return an iterator of dictionnaries with - "uri" key (mandatory) that provides a unique file identifier (usually the filename), - any other key that the protocol may provide. It can then be used in Python like this: >>> protocol = MyProtocol() >>> for file in protocol.train(): ... print(file["uri"]) filename1 filename2 An experimental protocol can also be defined using `pyannote_audio_utils.database` configuration file, whose (configurable) path defaults to "~/database.yml". ~~~ Content of ~/database.yml ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Protocols: MyDatabase: Protocol: MyProtocol: train: uri: /path/to/collection.lst any_other_key: ... # see custom loader documentation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ where "/path/to/collection.lst" contains the list of identifiers of the files in the collection: ~~~ Content of "/path/to/collection.lst ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ filename1 filename2 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It can then be used in Python like this: >>> from pyannote_audio_utils.database import registry >>> protocol = registry.get_protocol('MyDatabase.Protocol.MyProtocol') >>> for file in protocol.train(): ... print(file["uri"]) filename1 filename2 This class is usually inherited from, but can be used directly. Parameters ---------- preprocessors : dict Preprocess protocol files so that `file[key] = preprocessors[key](file)` for each key in `preprocessors`. In case `preprocessors[key]` is not callable, it should be a string containing placeholders for `file` keys (e.g. {'audio': '/path/to/{uri}.wav'}) """ def __init__(self, preprocessors: Optional[Preprocessors] = None): super().__init__() if preprocessors is None: preprocessors = dict() self.preprocessors = dict() for key, preprocessor in preprocessors.items(): if callable(preprocessor): self.preprocessors[key] = preprocessor # when `preprocessor` is not callable, it should be a string # containing placeholder for item key (e.g. '/path/to/{uri}.wav') elif isinstance(preprocessor, str): preprocessor_copy = str(preprocessor) def func(current_file): return preprocessor_copy.format(**current_file) self.preprocessors[key] = func else: msg = f'"{key}" preprocessor is neither a callable nor a string.' raise ValueError(msg) def preprocess(self, current_file: Union[Dict, ProtocolFile]) -> ProtocolFile: return ProtocolFile(current_file, lazy=self.preprocessors) def __str__(self): return self.__doc__ def train_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: """Iterate over files in the training subset""" raise NotImplementedError() def development_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: """Iterate over files in the development subset""" raise NotImplementedError() def test_iter(self) -> Iterator[Union[Dict, ProtocolFile]]: """Iterate over files in the test subset""" raise NotImplementedError() def subset_helper(self, subset: Subset) -> Iterator[ProtocolFile]: try: files = getattr(self, f"{subset}_iter")() except (AttributeError, NotImplementedError): # previous pyannote_audio_utils.database versions used `trn_iter` instead of # `train_iter`, `dev_iter` instead of `development_iter`, and # `tst_iter` instead of `test_iter`. therefore, we use the legacy # version when it is available (and the new one is not). subset_legacy = LEGACY_SUBSET_MAPPING[subset] try: files = getattr(self, f"{subset_legacy}_iter")() except AttributeError: msg = f"Protocol does not implement a {subset} subset." raise NotImplementedError(msg) for file in files: yield self.preprocess(file) def train(self) -> Iterator[ProtocolFile]: return self.subset_helper("train") def development(self) -> Iterator[ProtocolFile]: return self.subset_helper("development") def test(self) -> Iterator[ProtocolFile]: return self.subset_helper("test") def files(self) -> Iterator[ProtocolFile]: """Iterate over all files in `protocol`""" # imported here to avoid circular imports from pyannote_audio_utils.database.util import get_unique_identifier yielded_uris = set() for method in [ "development", "development_enrolment", "development_trial", "test", "test_enrolment", "test_trial", "train", "train_enrolment", "train_trial", ]: if not hasattr(self, method): continue def iterate(): try: for file in getattr(self, method)(): yield file except (AttributeError, NotImplementedError): return for current_file in iterate(): # skip "files" that do not contain a "uri" entry. # this happens for speaker verification trials that contain # two nested files "file1" and "file2" # see https://github.com/pyannote_audio_utils/pyannote_audio_utils-db-voxceleb/issues/4 if "uri" not in current_file: continue for current_file_ in current_file.files(): # corner case when the same file is yielded several times uri = get_unique_identifier(current_file_) if uri in yielded_uris: continue yield current_file_ yielded_uris.add(uri)