Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import Dict, Optional | |
| import importlib | |
| import os | |
| import numpy as np | |
| class AudioTransform(ABC): | |
| def from_config_dict(cls, config: Optional[Dict] = None): | |
| pass | |
| class CompositeAudioTransform(AudioTransform): | |
| def _from_config_dict( | |
| cls, | |
| transform_type, | |
| get_audio_transform, | |
| composite_cls, | |
| config=None, | |
| return_empty=False, | |
| ): | |
| _config = {} if config is None else config | |
| _transforms = _config.get(f"{transform_type}_transforms") | |
| if _transforms is None: | |
| if return_empty: | |
| _transforms = [] | |
| else: | |
| return None | |
| transforms = [ | |
| get_audio_transform(_t).from_config_dict(_config.get(_t)) | |
| for _t in _transforms | |
| ] | |
| return composite_cls(transforms) | |
| def __init__(self, transforms): | |
| self.transforms = [t for t in transforms if t is not None] | |
| def __call__(self, x): | |
| for t in self.transforms: | |
| x = t(x) | |
| return x | |
| def __repr__(self): | |
| format_string = ( | |
| [self.__class__.__name__ + "("] | |
| + [f" {t.__repr__()}" for t in self.transforms] | |
| + [")"] | |
| ) | |
| return "\n".join(format_string) | |
| def register_audio_transform(name, cls_type, registry, class_names): | |
| def register_audio_transform_cls(cls): | |
| if name in registry: | |
| raise ValueError(f"Cannot register duplicate transform ({name})") | |
| if not issubclass(cls, cls_type): | |
| raise ValueError( | |
| f"Transform ({name}: {cls.__name__}) must extend " | |
| f"{cls_type.__name__}" | |
| ) | |
| if cls.__name__ in class_names: | |
| raise ValueError( | |
| f"Cannot register audio transform with duplicate " | |
| f"class name ({cls.__name__})" | |
| ) | |
| registry[name] = cls | |
| class_names.add(cls.__name__) | |
| return cls | |
| return register_audio_transform_cls | |
| def import_transforms(transforms_dir, transform_type): | |
| for file in os.listdir(transforms_dir): | |
| path = os.path.join(transforms_dir, file) | |
| if ( | |
| not file.startswith("_") | |
| and not file.startswith(".") | |
| and (file.endswith(".py") or os.path.isdir(path)) | |
| ): | |
| name = file[: file.find(".py")] if file.endswith(".py") else file | |
| importlib.import_module( | |
| f"fairseq.data.audio.{transform_type}_transforms." + name | |
| ) | |
| # Utility fn for uniform numbers in transforms | |
| def rand_uniform(a, b): | |
| return np.random.uniform() * (b - a) + a | |