Toun / models /audio_encoder_config.py
babaTEEpe's picture
Upload 11 files
1bc2162 verified
#!/usr/bin/env python3
# coding: utf-8
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : x.mei@surrey.ac.uk
from transformers import PretrainedConfig
class AudioEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an Audio Encoder. It is used to instantiate an
an Audio Encoder according to the specified arguments, defining the model architecture.
The audio encoder can be a PANNs model or a HTSAT.
"""
model_type = "audio_encoder"
def __init__(self,
model_arch: str = "cnn",
model_name: str = "Cnn10",
pretrained: bool = True,
freeze: bool = False,
spec_augment: bool = True,
audio_args: dict = None,
**kwargs):
super(AudioEncoderConfig, self).__init__(**kwargs)
if model_arch not in ["cnn", "transformer"]:
raise ValueError(f"Not implemented model type: {model_arch}.")
if model_name not in ["Cnn10", "Cnn14", "ResNet38", "htsat"]:
raise ValueError(f"Not implemented model: {model_name}.")
self.model_arch = model_arch
self.model_name = model_name
self.pretrained = pretrained
self.freeze = freeze
self.hidden_size = 1024 if model_arch == "cnn" else 768
self.spec_augment = spec_augment
self.audio_args = audio_args
self.num_labels = 0