| |
| |
| |
| |
|
|
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| class AudioEncoderConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of an Audio Encoder. It is used to instantiate an |
| an Audio Encoder according to the specified arguments, defining the model architecture. |
| The audio encoder can be a PANNs model or a HTSAT. |
| """ |
| model_type = "audio_encoder" |
| |
| def __init__(self, |
| model_arch: str = "cnn", |
| model_name: str = "Cnn10", |
| pretrained: bool = True, |
| freeze: bool = False, |
| spec_augment: bool = True, |
| audio_args: dict = None, |
| **kwargs): |
| super(AudioEncoderConfig, self).__init__(**kwargs) |
| if model_arch not in ["cnn", "transformer"]: |
| raise ValueError(f"Not implemented model type: {model_arch}.") |
| if model_name not in ["Cnn10", "Cnn14", "ResNet38", "htsat"]: |
| raise ValueError(f"Not implemented model: {model_name}.") |
|
|
| self.model_arch = model_arch |
| self.model_name = model_name |
| self.pretrained = pretrained |
| self.freeze = freeze |
| self.hidden_size = 1024 if model_arch == "cnn" else 768 |
| self.spec_augment = spec_augment |
| self.audio_args = audio_args |
| self.num_labels = 0 |
| |