| | from tools import * |
| | from torch import nn |
| | import torch |
| | class GroupedQueryAttention(nn.Module): |
| | def __init__( |
| | self, d_in, d_out, num_heads, |
| | num_kv_groups, |
| | dtype=None |
| | ): |
| | super().__init__() |
| | assert d_out % num_heads == 0, "d_out must be divisible by num_heads" |
| | assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" |
| |
|
| | self.d_out = d_out |
| | self.num_heads = num_heads |
| | self.head_dim = d_out // num_heads |
| |
|
| | self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) |
| | self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) |
| | self.num_kv_groups = num_kv_groups |
| | self.group_size = num_heads // num_kv_groups |
| |
|
| | self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) |
| | self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) |
| |
|
| | def forward(self, x, mask, cos, sin): |
| | b, num_tokens, d_in = x.shape |
| |
|
| | queries = self.W_query(x) |
| | keys = self.W_key(x) |
| | values = self.W_value(x) |
| |
|
| | |
| | queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) |
| | keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) |
| | values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) |
| |
|
| | |
| | keys = keys.transpose(1, 2) |
| | values = values.transpose(1, 2) |
| | queries = queries.transpose(1, 2) |
| |
|
| | |
| | keys = apply_rope(keys, cos, sin) |
| | queries = apply_rope(queries, cos, sin) |
| |
|
| | |
| | |
| | keys = keys.repeat_interleave(self.group_size, dim=1) |
| | values = values.repeat_interleave(self.group_size, dim=1) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | attn_scores = queries @ keys.transpose(2, 3) |
| |
|
| | |
| | attn_scores = attn_scores.masked_fill(mask, -torch.inf) |
| |
|
| | attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) |
| | assert keys.shape[-1] == self.head_dim |
| |
|
| | |
| | context_vec = (attn_weights @ values).transpose(1, 2) |
| |
|
| | |
| | context_vec = context_vec.reshape(b, num_tokens, self.d_out) |
| | context_vec = self.out_proj(context_vec) |
| |
|
| | return context_vec |