import torch import torch.nn as nn class MultiHeadSelfAttention(nn.Module): """Multi-Head Self-Attention mechanism""" def __init__(self, embed_dim, num_heads, dropout=0.3): super(MultiHeadSelfAttention, self).__init__() self.attention = nn.MultiheadAttention( embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) self.layer_norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): attn_output, _ = self.attention(x, x, x) x = self.layer_norm(x + self.dropout(attn_output)) return x class MHSA_GRU(nn.Module): """Multi-Head Self-Attention with GRU model""" def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_gru_layers=2, dropout=0.3): super(MHSA_GRU, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.input_projection = nn.Linear(input_dim, hidden_dim) self.mhsa1 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) self.mhsa2 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) self.gru = nn.GRU( input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_gru_layers, batch_first=True, dropout=dropout if num_gru_layers > 1 else 0, bidirectional=False ) self.mhsa3 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) self.dropout = nn.Dropout(dropout) self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4) self.fc3 = nn.Linear(hidden_dim // 4, 1) self.bn1 = nn.BatchNorm1d(hidden_dim // 2) self.bn2 = nn.BatchNorm1d(hidden_dim // 4) def forward(self, x): batch_size = x.size(0) x = self.input_projection(x) x = x.unsqueeze(1) x = self.mhsa1(x) x = self.mhsa2(x) gru_out, hidden = self.gru(x) x = self.mhsa3(gru_out) x = x[:, -1, :] x = self.dropout(x) x = torch.relu(self.bn1(self.fc1(x))) x = self.dropout(x) x = torch.relu(self.bn2(self.fc2(x))) x = self.dropout(x) x = self.fc3(x) return torch.sigmoid(x)