| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, embed_dim, num_heads): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num heads" | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| def forward(self, x, mask=None): | |
| B, T, C = x.shape | |
| qkv = self.qkv_proj(x) | |
| qkv = qkv.reshape(B, T, self.num_heads, 3 * self.head_dim) | |
| qkv = qkv.permute(0, 2, 1, 3) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) | |
| if mask is not None: | |
| attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) | |
| attn_weights = F.softmax(attn_scores, dim=-1) | |
| attn_output = attn_weights @ v | |
| attn_output = attn_output.transpose(1, 2).reshape(B, T, C) | |
| ouptut = self.out_proj(attn_output) | |
| return ouptut |