| import torch |
| import torch.nn as nn |
| import math |
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, d_model, num_heads): |
| super(MultiHeadAttention, self).__init__() |
| |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" |
| |
| |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.d_k = d_model // num_heads |
| |
| |
| self.W_q = nn.Linear(d_model, d_model) |
| self.W_k = nn.Linear(d_model, d_model) |
| self.W_v = nn.Linear(d_model, d_model) |
| self.W_o = nn.Linear(d_model, d_model) |
| |
| def scaled_dot_product_attention(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None): |
| attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) |
| |
| prob_phn = prob_phn.unsqueeze(1) |
| |
| |
| prob_phn = prob_phn.expand(-1, self.num_heads, -1, -1) |
| if lambda_val > 0: |
| attn_scores = attn_scores - lambda_val * prob_phn.transpose(-2, -1) |
| attn_mask = mask |
| if mask is not None: |
| |
| mask = mask.unsqueeze(1) |
| mask = mask.expand(-1, self.num_heads, -1, -1) |
| attn_scores = attn_scores.masked_fill(mask == 0, -1e9) |
| attn_probs = torch.softmax(attn_scores, dim=-1) |
| attn_probs = attn_probs.float() |
| output = torch.matmul(attn_probs, V) |
| return output, attn_mask |
| def split_heads(self, x): |
| |
| batch_size, seq_length, d_model = x.size() |
| return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) |
| |
| def combine_heads(self, x): |
| |
| batch_size, _, seq_length, d_k = x.size() |
| return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) |
| |
| def forward(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None): |
| |
| Q = self.split_heads(self.W_q(Q)) |
| K = self.split_heads(self.W_k(K)) |
| V = self.split_heads(self.W_v(V)) |
| |
| |
| attn_output, attn_mask = self.scaled_dot_product_attention(Q, K, V, prob_phn, mask,lambda_val) |
| |
| |
| output = self.W_o(self.combine_heads(attn_output)) |
| return output, attn_mask |