CoLMbo / encoder /mha.py
massabaali's picture
Upload CoLMbo model weights and code
f55a095 verified
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
# Ensure that the model dimension (d_model) is divisible by the number of heads
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# Initialize dimensions
self.d_model = d_model # Model's dimension
self.num_heads = num_heads # Number of attention heads
self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
# Linear layers for transforming inputs
self.W_q = nn.Linear(d_model, d_model) # Query transformation
self.W_k = nn.Linear(d_model, d_model) # Key transformation
self.W_v = nn.Linear(d_model, d_model) # Value transformation
self.W_o = nn.Linear(d_model, d_model) # Output transformation
def scaled_dot_product_attention(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None):
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Add a singleton dimension to prob_phn at index 1
prob_phn = prob_phn.unsqueeze(1)
# Expand prob_phn to match the shape of attn_scores
# This will not increase memory usage as expand returns a new view on the existing tensor
prob_phn = prob_phn.expand(-1, self.num_heads, -1, -1)
if lambda_val > 0:
attn_scores = attn_scores - lambda_val * prob_phn.transpose(-2, -1)
attn_mask = mask
if mask is not None:
# print(mask.shape)
mask = mask.unsqueeze(1)
mask = mask.expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_probs = torch.softmax(attn_scores, dim=-1)
attn_probs = attn_probs.float()
output = torch.matmul(attn_probs, V)
return output, attn_mask
def split_heads(self, x):
# Reshape the input to have num_heads for multi-head attention
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
# Combine the multiple heads back to original shape
batch_size, _, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, prob_phn=None, mask=None, lambda_val=None):
# Apply linear transformations and split heads
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
# Perform scaled dot-product attention
attn_output, attn_mask = self.scaled_dot_product_attention(Q, K, V, prob_phn, mask,lambda_val)
# Combine heads and apply output transformation
output = self.W_o(self.combine_heads(attn_output))
return output, attn_mask