|
|
import math |
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor, nn |
|
|
|
|
|
|
|
|
|
|
|
def rope_rotate_half(x: Tensor) -> Tensor: |
|
|
|
|
|
|
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
|
|
|
|
def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
return (x * cos) + (rope_rotate_half(x) * sin) |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int = 8, |
|
|
qkv_bias: bool = False, |
|
|
proj_bias: bool = True, |
|
|
proj_drop: float = 0.0, |
|
|
device=None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
self.scale = head_dim**-0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, device=device) |
|
|
self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: |
|
|
|
|
|
q_dtype = q.dtype |
|
|
k_dtype = k.dtype |
|
|
sin, cos = rope |
|
|
rope_dtype = sin.dtype |
|
|
q = q.to(dtype=rope_dtype) |
|
|
k = k.to(dtype=rope_dtype) |
|
|
N = q.shape[-2] |
|
|
prefix = N - sin.shape[-2] |
|
|
assert prefix >= 0 |
|
|
q_prefix = q[:, :, :prefix, :] |
|
|
q = rope_apply(q[:, :, prefix:, :], sin, cos) |
|
|
q = torch.cat((q_prefix, q), dim=-2) |
|
|
k_prefix = k[:, :, :prefix, :] |
|
|
k = rope_apply(k[:, :, prefix:, :], sin, cos) |
|
|
k = torch.cat((k_prefix, k), dim=-2) |
|
|
q = q.to(dtype=q_dtype) |
|
|
k = k.to(dtype=k_dtype) |
|
|
return q, k |
|
|
|
|
|
def forward(self, x: Tensor, attn_mask: Tensor | None = None, rope: Tensor | tuple[Tensor, Tensor] | None = None) -> Tensor: |
|
|
|
|
|
qkv = self.qkv(x) |
|
|
attn_v = self.compute_attention(qkv=qkv, attn_mask=attn_mask, rope=rope) |
|
|
x = self.proj(attn_v) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
def compute_attention(self, qkv: Tensor, attn_mask: Tensor | None = None, rope: Tensor | tuple[Tensor, Tensor] | None = None) -> Tensor: |
|
|
B, N, _ = qkv.shape |
|
|
C = self.qkv.in_features |
|
|
|
|
|
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
|
q, k, v = torch.unbind(qkv, 2) |
|
|
q, k, v = [t.transpose(1, 2) for t in [q, k, v]] |
|
|
if rope is not None: |
|
|
q, k = self.apply_rope(q, k, rope) |
|
|
|
|
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) |
|
|
x = x.transpose(1, 2) |
|
|
return x.reshape([B, N, C]) |