Spaces:
Sleeping
Sleeping
| import torch | |
| from collections import OrderedDict | |
| # class ConvEncoder(torch.nn.Module): | |
| # def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7): | |
| # super().__init__() | |
| # if output_dim is None: | |
| # output_dim = hidden_dim | |
| # self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size) | |
| # self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size) | |
| # self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size) | |
| # self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size) | |
| # self.dropout = torch.nn.Dropout(dropout) | |
| # def forward(self, feature): #(samples, 1, 2048) | |
| # feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042) | |
| # feature = feature.relu() | |
| # feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036) | |
| # feature = feature.relu() | |
| # feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030) | |
| # feature = feature.relu() | |
| # feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024) | |
| # return feature | |
| # class ConvDecoder(torch.nn.Module): | |
| # def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7): | |
| # super().__init__() | |
| # if output_dim is None: | |
| # output_dim = hidden_dim | |
| # self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size) | |
| # self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size) | |
| # self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size) | |
| # self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size) | |
| # def forward(self, feature): #(samples, 1, 2048) | |
| # feature = self.convTranspose1(feature) #(samples, 64, 2030) | |
| # feature = feature.relu() | |
| # feature = self.convTranspose2(feature) #(samples, 64, 2036) | |
| # feature = feature.relu() | |
| # feature = self.convTranspose3(feature) #(samples, 64, 2042) | |
| # feature = feature.relu() | |
| # feature = self.convTranspose4(feature) | |
| # return feature | |
| def get_activation(activation_name: str) -> torch.nn.Module: | |
| if activation_name == "relu": | |
| return torch.nn.ReLU() | |
| elif activation_name == "gelu": | |
| return torch.nn.GELU() | |
| elif activation_name == "leaky_relu": | |
| return torch.nn.LeakyReLU() | |
| elif activation_name == "tanh": | |
| return torch.nn.Tanh() | |
| elif activation_name == "sigmoid": | |
| return torch.nn.Sigmoid() | |
| else: | |
| raise ValueError(f"Unsupported activation function: {activation_name}") | |
| class ConvEncoder(torch.nn.Module): | |
| def __init__(self, hidden_dim=64, output_dim=None, input_dim=1, dropout=0, kernel_size=7, activation="relu", last_activation=True): | |
| super().__init__() | |
| if output_dim is None: | |
| output_dim = hidden_dim | |
| layers = [ | |
| torch.nn.Conv1d(input_dim, hidden_dim, kernel_size), | |
| get_activation(activation), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size), | |
| get_activation(activation), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size), | |
| get_activation(activation), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.Conv1d(hidden_dim, output_dim, kernel_size) | |
| ] | |
| if last_activation: | |
| layers.append( | |
| get_activation(activation) | |
| ) | |
| layers.append(torch.nn.Dropout(dropout)) | |
| self.net = torch.nn.Sequential(*layers) | |
| def forward(self, feature): | |
| return self.net(feature) | |
| class ConvDecoder(torch.nn.Module): | |
| def __init__(self, input_dim=None, hidden_dim=64, output_dim=1, dropout=0, kernel_size=7, activation="relu", last_bias=True, last_activation=False): | |
| super().__init__() | |
| if input_dim is None: | |
| input_dim = hidden_dim | |
| layers = [ | |
| torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size), | |
| get_activation(activation), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size), | |
| get_activation(activation), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size), | |
| get_activation(activation), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.ConvTranspose1d(hidden_dim, output_dim, kernel_size, bias=last_bias), | |
| ] | |
| if last_activation: | |
| layers.append(get_activation(activation)) | |
| layers.append(torch.nn.Dropout(dropout)) | |
| self.net = torch.nn.Sequential(*layers) | |
| def forward(self, feature): | |
| return self.net(feature) | |
| class ConvMLP(torch.nn.Module): | |
| def __init__(self, input_dim, output_dim, hidden_dims=[128, 64], activation="relu"): | |
| super().__init__() | |
| mlp_dims = [input_dim] + hidden_dims + [output_dim] | |
| mlp_layers = [torch.nn.Conv1d(mlp_dims[0], mlp_dims[1], kernel_size=1)] | |
| for dims_in, dims_out in zip(mlp_dims[1:-1], mlp_dims[2:]): | |
| mlp_layers.extend([ | |
| get_activation(activation), | |
| torch.nn.Conv1d(dims_in, dims_out, kernel_size=1) | |
| ]) | |
| self.mlp = torch.nn.Sequential(*mlp_layers) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class MLP(torch.nn.Module): | |
| def __init__(self, input_dim, output_dim, hidden_dims=[128, 64], activation="relu"): | |
| super().__init__() | |
| mlp_dims = [input_dim] + hidden_dims + [output_dim] | |
| mlp_layers = [torch.nn.Linear(mlp_dims[0], mlp_dims[1])] | |
| for dims_in, dims_out in zip(mlp_dims[1:-1], mlp_dims[2:]): | |
| mlp_layers.extend([ | |
| get_activation(activation), | |
| torch.nn.Linear(dims_in, dims_out) | |
| ]) | |
| self.mlp = torch.nn.Sequential(*mlp_layers) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class ResponseHead(torch.nn.Module): | |
| def __init__(self, input_dim, output_length, hidden_dims=[128]): | |
| super().__init__() | |
| response_head_dims = [input_dim]+hidden_dims + [output_length] | |
| response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])] | |
| for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]): | |
| response_head_layers.extend([ | |
| torch.nn.GELU(), | |
| torch.nn.Linear(dims_in, dims_out) | |
| ]) | |
| self.response_head = torch.nn.Sequential(*response_head_layers) | |
| def forward(self, feature): | |
| return self.response_head(feature) | |
| class ShimNetWithSCRF(torch.nn.Module): | |
| def __init__(self, | |
| encoder_hidden_dims=64, | |
| encoder_dropout=0, | |
| bottleneck_dim=64, | |
| rensponse_length=61, | |
| resnponse_head_dims=[128], | |
| decoder_hidden_dims=64 | |
| ): | |
| super().__init__() | |
| self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout, last_activation=False) | |
| self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim)) | |
| torch.nn.init.xavier_normal_(self.query) | |
| self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims, last_activation=False) | |
| self.rensponse_length = rensponse_length | |
| self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims) | |
| self.EncoderLegacyNameMapping = { | |
| "conv4": "net.0", | |
| "conv3": "net.3", | |
| "conv2": "net.6", | |
| "conv1": "net.9", | |
| } | |
| self.DecoderLegacyNameMapping = { | |
| "convTranspose1": "net.0", | |
| "convTranspose2": "net.3", | |
| "convTranspose3": "net.6", | |
| "convTranspose4": "net.9", | |
| } | |
| def load_state_dict(self, state_dict, strict=True): | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| k_splitted = k.split(".") | |
| if k_splitted[0] == "encoder": | |
| if k_splitted[1] in self.EncoderLegacyNameMapping: | |
| k_splitted[1] = self.EncoderLegacyNameMapping[k_splitted[1]] | |
| new_key = ".".join(k_splitted) | |
| else: | |
| new_key = k | |
| elif k_splitted[0] == "decoder": | |
| if k_splitted[1] in self.DecoderLegacyNameMapping: | |
| k_splitted[1] = self.DecoderLegacyNameMapping[k_splitted[1]] | |
| new_key = ".".join(k_splitted) | |
| else: | |
| new_key = k | |
| else: | |
| new_key = k | |
| new_state_dict[new_key] = v | |
| super().load_state_dict(new_state_dict, strict=strict) | |
| def forward(self, feature): #(samples, 1, 2048) | |
| feature = self.encoder(feature) #(samples, 64, 2042) | |
| energy = self.query @ feature #(samples, 1, 2024) | |
| weight = torch.nn.functional.softmax(energy, 2) #(samples, 1, 2024) | |
| global_features = feature @ weight.transpose(1, 2) #(samples, 64, 1) | |
| response = self.response_head(global_features.squeeze(-1)) | |
| feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples, 64, 2048) | |
| feature = torch.cat([feature, global_features], 1) #(samples, 128, 2024) | |
| denoised_spectrum = self.decoder(feature) #(samples, 1, 2048) | |
| return { | |
| 'denoised': denoised_spectrum, | |
| 'response': response, | |
| 'attention': weight.squeeze(1) | |
| } | |
| class KVAttention(torch.nn.Module): | |
| """attention with learnable query""" | |
| def __init__(self, | |
| kv_dim =64, | |
| num_heads=4, | |
| k_processor = None, | |
| v_processor = None, | |
| ): | |
| super().__init__() | |
| if k_processor is None: | |
| k_processor = torch.nn.Identity() | |
| if v_processor is None: | |
| v_processor = torch.nn.Identity() | |
| self.k_processor = k_processor | |
| self.v_processor = v_processor | |
| self.kv_dim = kv_dim | |
| self.num_heads = num_heads | |
| self.query = torch.nn.Parameter(torch.empty(1, num_heads, kv_dim)) | |
| torch.nn.init.xavier_normal_(self.query) | |
| def forward(self, feature): # (samples, input_dim, seq_len) | |
| batch_size = feature.shape[0] | |
| seq_len = feature.shape[-1] | |
| keys = self.k_processor(feature) | |
| values = feature | |
| # Reshape for multi-head attention | |
| keys = keys.view(batch_size, self.num_heads, self.kv_dim, seq_len) #(samples, num_heads, kv_dim, seq_len) | |
| # Multi-head attention computation | |
| queries = self.query.expand(batch_size, -1, -1) #(samples, num_heads, kv_dim) | |
| energy = torch.einsum('bhd,bhdl->bhl', queries, keys) #(samples, num_heads, seq_len) | |
| weight = torch.nn.functional.softmax(energy, dim=2) #(samples, num_heads, seq_len) | |
| # Apply attention weights | |
| global_features = torch.einsum('bhl,bhdl->bhd', weight, feature.view(batch_size, self.num_heads, -1, seq_len)) #(samples, (num_heads* head_dim)) | |
| global_features = global_features.reshape(batch_size, -1) #(samples, (num_heads* head_dim)) | |
| # process values if needed | |
| global_features = self.v_processor(global_features) # (samples, input_dim) | |
| # global_features = global_features.reshape(batch_size, -1, 1) | |
| return global_features, weight | |
| class ShimnetModular(torch.nn.Module): | |
| def __init__(self, | |
| encoder, | |
| decoder, | |
| response_head, | |
| attention_module, | |
| local_feature_processor, | |
| global_feature_processor | |
| ): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.attention_module = attention_module | |
| self.decoder = decoder | |
| self.response_head = response_head | |
| self.local_feature_processor = local_feature_processor | |
| self.global_feature_processor = global_feature_processor | |
| def forward(self, feature): #(samples, 1, seq_len_in) | |
| feature = self.encoder(feature) #(samples, encoder_features_dim, seq_len) # seq_len != seq_len_in | |
| local_features = self.local_feature_processor(feature) #(samples, local_features_dim, seq_len) | |
| global_features, weight = self.attention_module(feature) #(samples, global_features_hidden_dim, 1), (samples, num_heads, seq_len) | |
| response = self.response_head(global_features.squeeze(-1)) # (samples, response_length) | |
| global_features_for_decoding = self.global_feature_processor(global_features).unsqueeze(-1) #(samples, global_features_dim, 1) | |
| local_features, global_features_for_decoding = torch.broadcast_tensors(local_features, global_features_for_decoding) #(samples, local_features_dim, seq_len), (samples, global_features_dim, seq_len) | |
| feature = torch.cat([local_features, global_features_for_decoding], 1) #(samples, local_features_dim + global_features_dim, seq_len) | |
| denoised_spectrum = self.decoder(feature) #(samples, 1, seq_len_in) | |
| return { | |
| 'denoised': denoised_spectrum, | |
| 'response': response, | |
| 'attention': weight.sum(1) # (samples, seq_len) | |
| } | |
| class Predictor: | |
| def __init__(self, model=None, weights_file=None): | |
| self.model = model | |
| if weights_file is not None: | |
| self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True)) | |
| def __call__(self, nsf_frq): | |
| with torch.no_grad(): | |
| msf_frq = self.model(nsf_frq[None, None])["denoised"] | |
| return msf_frq[0, 0] | |
| if __name__ == "__main__": | |
| encoder_hidden_dims = 64 | |
| encoder_dropout = 0 | |
| encoder_features_dim = 128 | |
| local_features_dim = 64 | |
| attention_kv_dim = 32 | |
| attention_num_heads = 8 | |
| global_features_hidden_dim = 256 | |
| global_features_dim = 64 | |
| response_length = 81 | |
| encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=encoder_features_dim, dropout=encoder_dropout) | |
| local_feature_processor = ConvMLP(encoder_features_dim, local_features_dim, hidden_dims=[256, 128]) | |
| attention = KVAttention( | |
| kv_dim=attention_kv_dim, num_heads=attention_num_heads, | |
| k_processor = ConvMLP(encoder_features_dim, attention_kv_dim*attention_num_heads, hidden_dims=[512, 256]), | |
| v_processor = MLP(encoder_features_dim, global_features_hidden_dim, hidden_dims=[512, 256]), | |
| ) | |
| global_feature_processor = MLP(global_features_hidden_dim, global_features_dim, hidden_dims=[512, 256]) | |
| response_head = MLP(global_features_hidden_dim, response_length, hidden_dims=[512, 256]) | |
| decoder = ConvDecoder(input_dim=local_features_dim + global_features_dim, hidden_dim=64) | |
| ### step by step | |
| inputs = torch.randn(2, 1, 2048) | |
| feature = encoder(inputs) #(samples, encoder_features_dim, seq_len) # seq_len != seq_len_in | |
| print(f"Encoder output shape: {feature.shape}") | |
| local_features = local_feature_processor(feature) #(samples, local_features_dim, seq_len) | |
| print(f"Local features shape: {local_features.shape}") | |
| global_features, weight = attention(feature) #(samples, global_features_hidden_dim, 1), (samples, num_heads, seq_len) | |
| print(f"Global features shape: {global_features.shape}") | |
| print(f"Attention weights shape: {weight.shape}") | |
| response = response_head(global_features) # (samples, response_length) | |
| print(f"Response shape: {response.shape}") | |
| global_features_for_decoding = global_feature_processor(global_features).unsqueeze(-1) #(samples, global_features_dim, 1) | |
| local_features, global_features_for_decoding = torch.broadcast_tensors(local_features, global_features_for_decoding) #(samples, local_features_dim, seq_len), (samples, global_features_dim, seq_len) | |
| feature = torch.cat([local_features, global_features_for_decoding], 1) #(samples, local_features_dim + global_features_dim, seq_len) | |
| denoised_spectrum = decoder(feature) | |
| print("="*80) | |
| ### assemble model | |
| model = ShimnetModular( | |
| encoder=encoder, | |
| decoder=decoder, | |
| response_head=response_head, | |
| attention_module=attention, | |
| local_feature_processor=local_feature_processor, | |
| global_feature_processor=global_feature_processor | |
| ) | |
| for k, v in model(inputs).items(): | |
| print(f"{k}: {v.shape}") |