Spaces:
Sleeping
Sleeping
File size: 5,156 Bytes
64b4096 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
class ConvEncoder(torch.nn.Module):
def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
super().__init__()
if output_dim is None:
output_dim = hidden_dim
self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, feature): #(samples, 1, 2048)
feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
feature = feature.relu()
feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
feature = feature.relu()
feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
feature = feature.relu()
feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
return feature
class ConvDecoder(torch.nn.Module):
def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
super().__init__()
if output_dim is None:
output_dim = hidden_dim
self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
def forward(self, feature): #(samples, 1, 2048)
feature = self.convTranspose1(feature) #(samples, 64, 2030)
feature = feature.relu()
feature = self.convTranspose2(feature) #(samples, 64, 2036)
feature = feature.relu()
feature = self.convTranspose3(feature) #(samples, 64, 2042)
feature = feature.relu()
feature = self.convTranspose4(feature)
return feature
class ResponseHead(torch.nn.Module):
def __init__(self, input_dim, output_length, hidden_dims=[128]):
super().__init__()
response_head_dims = [input_dim]+hidden_dims + [output_length]
response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])]
for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]):
response_head_layers.extend([
torch.nn.GELU(),
torch.nn.Linear(dims_in, dims_out)
])
self.response_head = torch.nn.Sequential(*response_head_layers)
def forward(self, feature):
return self.response_head(feature)
class ShimNetWithSCRF(torch.nn.Module):
def __init__(self,
encoder_hidden_dims=64,
encoder_dropout=0,
bottleneck_dim=64,
rensponse_length=61,
resnponse_head_dims=[128],
decoder_hidden_dims=64
):
super().__init__()
self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
torch.nn.init.xavier_normal_(self.query)
self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
self.rensponse_length = rensponse_length
self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
def forward(self, feature): #(samples, 1, 2048)
feature = self.encoder(feature) #(samples, 64, 2042)
energy = self.query @ feature #(samples, 1, 2024)
weight = torch.nn.functional.softmax(energy, 2) #(samples, 1, 2024)
global_features = feature @ weight.transpose(1, 2) #(samples, 64, 1)
response = self.response_head(global_features.squeeze(-1))
feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples, 64, 2048)
feature = torch.cat([feature, global_features], 1) #(samples, 128, 2024)
denoised_spectrum = self.decoder(feature) #(samples, 1, 2048)
return {
'denoised': denoised_spectrum,
'response': response,
'attention': weight.squeeze(1)
}
class Predictor:
def __init__(self, model=None, weights_file=None):
self.model = model
if weights_file is not None:
self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True))
def __call__(self, nsf_frq):
with torch.no_grad():
msf_frq = self.model(nsf_frq[None, None])["denoised"]
return msf_frq[0, 0]
|