| import torch | |
| import torch.nn as nn | |
| class MultiHeadSelfAttention(nn.Module): | |
| """Multi-Head Self-Attention mechanism""" | |
| def __init__(self, embed_dim, num_heads, dropout=0.3): | |
| super(MultiHeadSelfAttention, self).__init__() | |
| self.attention = nn.MultiheadAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| self.layer_norm = nn.LayerNorm(embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| attn_output, _ = self.attention(x, x, x) | |
| x = self.layer_norm(x + self.dropout(attn_output)) | |
| return x | |
| class MHSA_GRU(nn.Module): | |
| """Multi-Head Self-Attention with GRU model""" | |
| def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_gru_layers=2, dropout=0.3): | |
| super(MHSA_GRU, self).__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.input_projection = nn.Linear(input_dim, hidden_dim) | |
| self.mhsa1 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) | |
| self.mhsa2 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) | |
| self.gru = nn.GRU( | |
| input_size=hidden_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=num_gru_layers, | |
| batch_first=True, | |
| dropout=dropout if num_gru_layers > 1 else 0, | |
| bidirectional=False | |
| ) | |
| self.mhsa3 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) | |
| self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4) | |
| self.fc3 = nn.Linear(hidden_dim // 4, 1) | |
| self.bn1 = nn.BatchNorm1d(hidden_dim // 2) | |
| self.bn2 = nn.BatchNorm1d(hidden_dim // 4) | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| x = self.input_projection(x) | |
| x = x.unsqueeze(1) | |
| x = self.mhsa1(x) | |
| x = self.mhsa2(x) | |
| gru_out, hidden = self.gru(x) | |
| x = self.mhsa3(gru_out) | |
| x = x[:, -1, :] | |
| x = self.dropout(x) | |
| x = torch.relu(self.bn1(self.fc1(x))) | |
| x = self.dropout(x) | |
| x = torch.relu(self.bn2(self.fc2(x))) | |
| x = self.dropout(x) | |
| x = self.fc3(x) | |
| return torch.sigmoid(x) |