File size: 1,246 Bytes
85ba398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from fairseq.data.audio import (
    AudioTransform,
    CompositeAudioTransform,
    import_transforms,
    register_audio_transform,
)


class AudioDatasetTransform(AudioTransform):
    pass


AUDIO_DATASET_TRANSFORM_REGISTRY = {}
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()


def get_audio_dataset_transform(name):
    return AUDIO_DATASET_TRANSFORM_REGISTRY[name]


def register_audio_dataset_transform(name):
    return register_audio_transform(
        name,
        AudioDatasetTransform,
        AUDIO_DATASET_TRANSFORM_REGISTRY,
        AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
    )


import_transforms(os.path.dirname(__file__), "dataset")


class CompositeAudioDatasetTransform(CompositeAudioTransform):
    @classmethod
    def from_config_dict(cls, config=None):
        return super()._from_config_dict(
            cls,
            "dataset",
            get_audio_dataset_transform,
            CompositeAudioDatasetTransform,
            config,
            return_empty=True,
        )

    def get_transform(self, cls):
        for t in self.transforms:
            if isinstance(t, cls):
                return t
        return None

    def has_transform(self, cls):
        return self.get_transform(cls) is not None