import math from typing import List, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn # RoPE-related functions: def rope_rotate_half(x: Tensor) -> Tensor: # x: [ x0 x1 x2 x3 x4 x5] # out: [-x3 -x4 -x5 x0 x1 x2] x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: # x: [..., D], eg [x0, x1, x2, x3, x4, x5] # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2] return (x * cos) + (rope_rotate_half(x) * sin) class SelfAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, proj_drop: float = 0.0, device=None, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, device=device) self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) self.proj_drop = nn.Dropout(proj_drop) def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: # All operations will use the dtype of rope, the output is cast back to the dtype of q and k q_dtype = q.dtype k_dtype = k.dtype sin, cos = rope rope_dtype = sin.dtype q = q.to(dtype=rope_dtype) k = k.to(dtype=rope_dtype) N = q.shape[-2] prefix = N - sin.shape[-2] assert prefix >= 0 q_prefix = q[:, :, :prefix, :] q = rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head] k_prefix = k[:, :, :prefix, :] k = rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head] q = q.to(dtype=q_dtype) k = k.to(dtype=k_dtype) return q, k def forward(self, x: Tensor, attn_mask: Tensor | None = None, rope: Tensor | tuple[Tensor, Tensor] | None = None) -> Tensor: # attn_mask: broadcastable to [B, num_heads, L, S] or [B, 1, 1, S]; True entries are attended qkv = self.qkv(x) attn_v = self.compute_attention(qkv=qkv, attn_mask=attn_mask, rope=rope) x = self.proj(attn_v) x = self.proj_drop(x) return x def compute_attention(self, qkv: Tensor, attn_mask: Tensor | None = None, rope: Tensor | tuple[Tensor, Tensor] | None = None) -> Tensor: B, N, _ = qkv.shape C = self.qkv.in_features qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = torch.unbind(qkv, 2) q, k, v = [t.transpose(1, 2) for t in [q, k, v]] if rope is not None: q, k = self.apply_rope(q, k, rope) # attn_mask follows PyTorch SDPA semantics; boolean True entries are attended x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = x.transpose(1, 2) return x.reshape([B, N, C])