|
|
import os |
|
|
from fairseq.data.audio import ( |
|
|
AudioTransform, |
|
|
CompositeAudioTransform, |
|
|
import_transforms, |
|
|
register_audio_transform, |
|
|
) |
|
|
|
|
|
|
|
|
class AudioDatasetTransform(AudioTransform): |
|
|
pass |
|
|
|
|
|
|
|
|
AUDIO_DATASET_TRANSFORM_REGISTRY = {} |
|
|
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set() |
|
|
|
|
|
|
|
|
def get_audio_dataset_transform(name): |
|
|
return AUDIO_DATASET_TRANSFORM_REGISTRY[name] |
|
|
|
|
|
|
|
|
def register_audio_dataset_transform(name): |
|
|
return register_audio_transform( |
|
|
name, |
|
|
AudioDatasetTransform, |
|
|
AUDIO_DATASET_TRANSFORM_REGISTRY, |
|
|
AUDIO_DATASET_TRANSFORM_CLASS_NAMES, |
|
|
) |
|
|
|
|
|
|
|
|
import_transforms(os.path.dirname(__file__), "dataset") |
|
|
|
|
|
|
|
|
class CompositeAudioDatasetTransform(CompositeAudioTransform): |
|
|
@classmethod |
|
|
def from_config_dict(cls, config=None): |
|
|
return super()._from_config_dict( |
|
|
cls, |
|
|
"dataset", |
|
|
get_audio_dataset_transform, |
|
|
CompositeAudioDatasetTransform, |
|
|
config, |
|
|
return_empty=True, |
|
|
) |
|
|
|
|
|
def get_transform(self, cls): |
|
|
for t in self.transforms: |
|
|
if isinstance(t, cls): |
|
|
return t |
|
|
return None |
|
|
|
|
|
def has_transform(self, cls): |
|
|
return self.get_transform(cls) is not None |
|
|
|