Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| import warnings | |
| import torch.nn.functional as F | |
| warnings.filterwarnings('ignore') | |
| class CrossAttentionBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, dropout=0.2): | |
| super(CrossAttentionBlock, self).__init__() | |
| self.attention = nn.MultiheadAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True # (Batch, Seq_Len, Channels) | |
| ) | |
| self.layer_norm = nn.LayerNorm(embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, query, key, value): | |
| attn_output, _ = self.attention(query, key, value) | |
| output = query + self.dropout(attn_output) | |
| output = self.layer_norm(output) | |
| return output | |
| class CAFN(nn.Module): | |
| def __init__(self, input_dim=46, num_classes=4, hidden_size=128): | |
| super(CAFN, self).__init__() | |
| self.conv_layer11 = nn.Sequential( | |
| nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3), | |
| nn.ReLU(), | |
| nn.MaxPool1d(kernel_size=2) | |
| ) | |
| self.conv_layer12 = nn.Sequential( | |
| nn.Conv1d(in_channels=3, out_channels=32, kernel_size=5), | |
| nn.ReLU(), | |
| nn.MaxPool1d(kernel_size=2) | |
| ) | |
| self.Residual = FeatureExtractor() | |
| self.embed_dim = 64 | |
| self.num_heads = 8 | |
| self.cross_attn_1_to_2 = CrossAttentionBlock(self.embed_dim, self.num_heads) | |
| self.cross_attn_2_to_1 = CrossAttentionBlock(self.embed_dim, self.num_heads) | |
| self.hidden_size = 64 | |
| self.biGRU = nn.GRU( | |
| input_size=self.embed_dim* 2, | |
| hidden_size=self.hidden_size, | |
| num_layers=1, | |
| bidirectional=True, | |
| batch_first=True, | |
| ) | |
| mlp_input_dim = self.hidden_size * 2 | |
| mlp_hidden_dim = 64 | |
| self.mlp_head = nn.Sequential( | |
| nn.Linear(mlp_input_dim, mlp_hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.35), | |
| nn.Linear(mlp_hidden_dim, num_classes) | |
| ) | |
| self.apply(self.init_weights) | |
| def init_weights(self, m): | |
| if type(m) == nn.Conv1d or type(m) == nn.Linear: | |
| init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| init.constant_(m.bias, 0.0) | |
| elif type(m) == nn.GRU: | |
| for name, param in m.named_parameters(): | |
| if 'weight_ih' in name: | |
| init.xavier_uniform_(param.data) | |
| elif 'weight_hh' in name: | |
| init.orthogonal_(param.data) | |
| elif 'bias' in name: | |
| param.data.fill_(0) | |
| def forward(self, x1, x2): | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| x1 = x1.to(device).unsqueeze(1) | |
| x1_conv = self.conv_layer11(x1) | |
| _, w1 = self.Residual(x1_conv) # w1 shape: (B, 64, L1) | |
| x2 = x2.to(device).transpose(1, 2) | |
| x2_conv = self.conv_layer12(x2) | |
| _, w2 = self.Residual(x2_conv) # w2 shape: (B, 64, L2) | |
| w1_p = w1.permute(0, 2, 1) # Shape: (B, L, 64) | |
| w2_p = w2.permute(0, 2, 1) # Shape: (B, L, 64) | |
| fused_w1 = self.cross_attn_1_to_2(query=w1_p, key=w2_p, value=w2_p) # Shape: (B, L, 64) | |
| fused_w2 = self.cross_attn_2_to_1(query=w2_p, key=w1_p, value=w1_p) # Shape: (B, L, 64) | |
| x = torch.cat((fused_w1, fused_w2), dim=2) # Shape: (B, L, 128) | |
| self.biGRU.flatten_parameters() | |
| output, _ = self.biGRU(x) # output shape: (B, L, hidden_size * 2) | |
| forward_out = output[:, -1, :self.hidden_size] | |
| backward_out = output[:, 0, self.hidden_size:] | |
| x = torch.cat((forward_out, backward_out), dim=1) | |
| x = self.mlp_head(x) | |
| return x | |
| class FeatureExtractor(nn.Module): | |
| def __init__(self, input_dim=46, num_classes=4): | |
| super(FeatureExtractor, self).__init__() | |
| self.conv_layer1 = nn.Sequential( | |
| nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.MaxPool1d(kernel_size=2) | |
| ) | |
| self.conv_layer2 = nn.Sequential( | |
| nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.MaxPool1d(kernel_size=2) | |
| ) | |
| self.conv_layer3 = nn.Sequential( | |
| nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.MaxPool1d(kernel_size=2) | |
| ) | |
| self.apply(self.init_weights) | |
| def init_weights(self, m): | |
| if type(m) == nn.Conv1d or type(m) == nn.Linear: | |
| init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| init.constant_(m.bias, 0.0) | |
| def forward(self, x): | |
| x1 = self.conv_layer1(x) | |
| x2 = self.conv_layer2(x1) | |
| w1 = x2 | |
| x3 = self.conv_layer3(x2) | |
| return x3, w1 |