File size: 3,867 Bytes
1bc2162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
# coding: utf-8
# @Author  : Xinhao Mei @CVSSP, University of Surrey
# @E-mail  : x.mei@surrey.ac.uk
from abc import ABC

import torch
import yaml
from transformers.modeling_outputs import BaseModelOutput
from models.cnns import Cnn10, Cnn14, ResNet38
from models.htsat import HTSAT_Swin_Transformer
from transformers import PreTrainedModel
from models.audio_encoder_config import AudioEncoderConfig


class AudioEncoderModel(PreTrainedModel):
    config_class = AudioEncoderConfig

    def __init__(self, config):
        super(AudioEncoderModel, self).__init__(config)

        if config.model_arch == "cnn":
            if config.model_name == 'ResNet38':
                self.audio_enc = ResNet38(config)
            elif config.model_name == 'Cnn14':
                self.audio_enc = Cnn14(config)
            elif config.model_name == 'Cnn10':
                self.audio_enc = Cnn10(config)

            if config.pretrained:
                # loading pretrained CNN weights
                pretrained_cnn = torch.load('pretrained_models/audio_encoder/{}.pth'.
                                            format(config.model_name), map_location='cpu', weights_only=False)['model']
                dict_new = self.audio_enc.state_dict().copy()
                trained_list = [i for i in pretrained_cnn.keys()
                                if not ('fc' in i or i.startswith('spec') or i.startswith('logmel'))]
                for i in range(len(trained_list)):
                    dict_new[trained_list[i]] = pretrained_cnn[trained_list[i]]
                self.audio_enc.load_state_dict(dict_new)
                # print("Weights loaded for audio encoder.")
            self.audio_width = 2048
        elif config.model_arch == "transformer":
            self.audio_enc = HTSAT_Swin_Transformer(
                spec_size=256,
                patch_size=4,
                patch_stride=(4, 4),
                num_classes=527,
                embed_dim=96,
                depths=[2, 2, 6, 2],
                num_heads=[4, 8, 16, 32],
                window_size=8,
                config=config
            )
            if config.pretrained:
                audio_ckpt = torch.load("pretrained_models/audio_encoder/HTSAT.ckpt", map_location="cpu", weights_only=False)["state_dict"]
                for key in list(audio_ckpt.keys()):
                    if key.startswith('sed_model') and ('spectrogram_extractor' not in key
                                                        and 'logmel_extractor' not in key):
                        v = audio_ckpt.pop(key)
                        audio_ckpt[key[10:]] = v
                self.audio_enc.load_state_dict(audio_ckpt, strict=False)
                # param_names = [n for n, p in self.audio_enc.named_parameters()]
                # for n in param_names:
                #     print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
            self.audio_width = 768
        else:
            raise NotImplementedError('No such audio encoder network.')

        if config.freeze:
            for name, param in self.audio_enc.named_parameters():
                if "fc1" not in name:
                    param.requires_grad = False

    def forward(self, input_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
                ):
        audio_embeds = self.audio_enc(input_ids)
        if not return_dict:
            return (audio_embeds, )
        return BaseModelOutput(audio_embeds, None, None)


if __name__ == '__main__':
    import os
    os.chdir("../")
    with open("settings/settings.yaml", "r") as f:
        config = yaml.safe_load(f)
    config = AudioEncoderConfig(**config["audio_encoder_args"], audio_args=config["audio_args"])
    model = AudioEncoderModel(config)
    print(model)