import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv_proj(x) qkv = qkv.reshape(B, T, self.num_heads, 3 * self.head_dim) qkv = qkv.permute(0, 2, 1, 3) q, k, v = qkv.chunk(3, dim=-1) attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) if mask is not None: attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(attn_scores, dim=-1) attn_output = attn_weights @ v attn_output = attn_output.transpose(1, 2).reshape(B, T, C) ouptut = self.out_proj(attn_output) return ouptut