File size: 5,156 Bytes
64b4096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch

class ConvEncoder(torch.nn.Module):
    def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
        super().__init__()
        if output_dim is None:
            output_dim = hidden_dim
        self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
        self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
        self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
        self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, feature):                                        #(samples, 1, 2048)
        feature = self.dropout(self.conv4(feature))                    #(samples, 64, 2042)
        feature = feature.relu()
        feature = self.dropout(self.conv3(feature))                    #(samples, 64, 2036)
        feature = feature.relu()
        feature = self.dropout(self.conv2(feature))                    #(samples, 64, 2030)
        feature = feature.relu()
        feature = self.dropout(self.conv1(feature))                    #(samples, 64, 2024)
        return feature

class ConvDecoder(torch.nn.Module):
    def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
        super().__init__()
        if output_dim is None:
            output_dim = hidden_dim
        self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
        self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
        self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
        self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
    
    def forward(self, feature):                                        #(samples, 1, 2048)
        feature = self.convTranspose1(feature)                         #(samples,  64, 2030)
        feature = feature.relu()
        feature = self.convTranspose2(feature)                         #(samples,  64, 2036)
        feature = feature.relu()
        feature = self.convTranspose3(feature)                         #(samples,  64, 2042)
        feature = feature.relu()
        feature = self.convTranspose4(feature) 
        return feature

class ResponseHead(torch.nn.Module):
    def __init__(self, input_dim, output_length, hidden_dims=[128]):
        super().__init__()
        response_head_dims = [input_dim]+hidden_dims + [output_length]
        response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])]
        for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]):
            response_head_layers.extend([
                torch.nn.GELU(),
                torch.nn.Linear(dims_in, dims_out)
            ])
        self.response_head = torch.nn.Sequential(*response_head_layers)

    def forward(self, feature):
        return self.response_head(feature)

class ShimNetWithSCRF(torch.nn.Module):
    def __init__(self,
        encoder_hidden_dims=64,
        encoder_dropout=0,
        bottleneck_dim=64,
        rensponse_length=61,
        resnponse_head_dims=[128],
        decoder_hidden_dims=64
        ):
        super().__init__()
        self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
        self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
        torch.nn.init.xavier_normal_(self.query)

        self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
        
        self.rensponse_length = rensponse_length
        self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
        
    def forward(self, feature):                                        #(samples,   1, 2048)
        feature = self.encoder(feature)                                #(samples,  64, 2042)
        energy = self.query @ feature                                  #(samples,   1, 2024)
        weight = torch.nn.functional.softmax(energy, 2)                #(samples,   1, 2024)
        global_features = feature @ weight.transpose(1, 2)                    #(samples,  64,    1)
        
        response = self.response_head(global_features.squeeze(-1))
        
        feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples,  64, 2048)
        feature = torch.cat([feature, global_features], 1)                    #(samples, 128, 2024)
        denoised_spectrum = self.decoder(feature)                            #(samples, 1, 2048)
        
        return {
            'denoised': denoised_spectrum,
            'response': response,
            'attention': weight.squeeze(1)
        }

class Predictor:
    def __init__(self, model=None, weights_file=None):
        self.model = model
        if weights_file is not None:
            self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True))

    def __call__(self, nsf_frq):
        with torch.no_grad():
            msf_frq = self.model(nsf_frq[None, None])["denoised"]
        return msf_frq[0, 0]