File size: 1,479 Bytes
1bc2162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
# coding: utf-8
# @Author  : Xinhao Mei @CVSSP, University of Surrey
# @E-mail  : x.mei@surrey.ac.uk


from transformers import PretrainedConfig


class AudioEncoderConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of an Audio Encoder. It is used to instantiate an
    an Audio Encoder according to the specified arguments, defining the model architecture.
    The audio encoder can be a PANNs model or a HTSAT.
    """
    model_type = "audio_encoder"
    
    def __init__(self,
                 model_arch: str = "cnn",
                 model_name: str = "Cnn10",
                 pretrained: bool = True,
                 freeze: bool = False,
                 spec_augment: bool = True,
                 audio_args: dict = None,
                 **kwargs):
        super(AudioEncoderConfig, self).__init__(**kwargs)
        if model_arch not in ["cnn", "transformer"]:
            raise ValueError(f"Not implemented model type: {model_arch}.")
        if model_name not in ["Cnn10", "Cnn14", "ResNet38", "htsat"]:
            raise ValueError(f"Not implemented model: {model_name}.")

        self.model_arch = model_arch
        self.model_name = model_name
        self.pretrained = pretrained
        self.freeze = freeze
        self.hidden_size = 1024 if model_arch == "cnn" else 768
        self.spec_augment = spec_augment
        self.audio_args = audio_args
        self.num_labels = 0