File size: 5,016 Bytes
0376b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import torch.nn.init as init
import warnings
import torch.nn.functional as F
warnings.filterwarnings('ignore')

class CrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.2):
        super(CrossAttentionBlock, self).__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True  # (Batch, Seq_Len, Channels)
        )
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        attn_output, _ = self.attention(query, key, value)
        output = query + self.dropout(attn_output)
        output = self.layer_norm(output)
        return output

class CAFN(nn.Module):
    def __init__(self, input_dim=46, num_classes=4, hidden_size=128):
        super(CAFN, self).__init__()
        self.conv_layer11 = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )
        self.conv_layer12 = nn.Sequential(
            nn.Conv1d(in_channels=3, out_channels=32, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )
        self.Residual = FeatureExtractor()
        self.embed_dim = 64
        self.num_heads = 8
        self.cross_attn_1_to_2 = CrossAttentionBlock(self.embed_dim, self.num_heads)
        self.cross_attn_2_to_1 = CrossAttentionBlock(self.embed_dim, self.num_heads)
        self.hidden_size = 64
        self.biGRU = nn.GRU(
            input_size=self.embed_dim* 2,
            hidden_size=self.hidden_size,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
        )
        mlp_input_dim = self.hidden_size * 2
        mlp_hidden_dim = 64
        self.mlp_head = nn.Sequential(
            nn.Linear(mlp_input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.35),
            nn.Linear(mlp_hidden_dim, num_classes)
        )
        self.apply(self.init_weights)

    def init_weights(self, m):
        if type(m) == nn.Conv1d or type(m) == nn.Linear:
            init.xavier_uniform_(m.weight)
            if m.bias is not None:
                init.constant_(m.bias, 0.0)
        elif type(m) == nn.GRU:
            for name, param in m.named_parameters():
                if 'weight_ih' in name:
                    init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    init.orthogonal_(param.data)
                elif 'bias' in name:
                    param.data.fill_(0)

    def forward(self, x1, x2):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        x1 = x1.to(device).unsqueeze(1)
        x1_conv = self.conv_layer11(x1)
        _, w1 = self.Residual(x1_conv)  # w1 shape: (B, 64, L1)
        x2 = x2.to(device).transpose(1, 2)
        x2_conv = self.conv_layer12(x2)
        _, w2 = self.Residual(x2_conv)  # w2 shape: (B, 64, L2)
        w1_p = w1.permute(0, 2, 1)  # Shape: (B, L, 64)
        w2_p = w2.permute(0, 2, 1)  # Shape: (B, L, 64)

        fused_w1 = self.cross_attn_1_to_2(query=w1_p, key=w2_p, value=w2_p)  # Shape: (B, L, 64)
        fused_w2 = self.cross_attn_2_to_1(query=w2_p, key=w1_p, value=w1_p)  # Shape: (B, L, 64)
        x = torch.cat((fused_w1, fused_w2), dim=2)  # Shape: (B, L, 128)

        self.biGRU.flatten_parameters()
        output, _ = self.biGRU(x)  # output shape: (B, L, hidden_size * 2)
        forward_out = output[:, -1, :self.hidden_size]
        backward_out = output[:, 0, self.hidden_size:]

        x = torch.cat((forward_out, backward_out), dim=1)
        x = self.mlp_head(x)
        return x

class FeatureExtractor(nn.Module):
    def __init__(self, input_dim=46, num_classes=4):
        super(FeatureExtractor, self).__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool1d(kernel_size=2)
        )
        self.conv_layer2 = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool1d(kernel_size=2)
        )
        self.conv_layer3 = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool1d(kernel_size=2)
        )
        self.apply(self.init_weights)

    def init_weights(self, m):
        if type(m) == nn.Conv1d or type(m) == nn.Linear:
            init.xavier_uniform_(m.weight)
            if m.bias is not None:
                init.constant_(m.bias, 0.0)

    def forward(self, x):
        x1 = self.conv_layer1(x)
        x2 = self.conv_layer2(x1)
        w1 = x2
        x3 = self.conv_layer3(x2)
        return x3, w1