st24hour's picture
Upload folder using huggingface_hub
e101805 verified
import math
from typing import List, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
# RoPE-related functions:
def rope_rotate_half(x: Tensor) -> Tensor:
# x: [ x0 x1 x2 x3 x4 x5]
# out: [-x3 -x4 -x5 x0 x1 x2]
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
return (x * cos) + (rope_rotate_half(x) * sin)
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.0,
device=None,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, device=device)
self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
# All operations will use the dtype of rope, the output is cast back to the dtype of q and k
q_dtype = q.dtype
k_dtype = k.dtype
sin, cos = rope
rope_dtype = sin.dtype
q = q.to(dtype=rope_dtype)
k = k.to(dtype=rope_dtype)
N = q.shape[-2]
prefix = N - sin.shape[-2]
assert prefix >= 0
q_prefix = q[:, :, :prefix, :]
q = rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head]
k_prefix = k[:, :, :prefix, :]
k = rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head]
q = q.to(dtype=q_dtype)
k = k.to(dtype=k_dtype)
return q, k
def forward(self, x: Tensor, attn_mask: Tensor | None = None, rope: Tensor | tuple[Tensor, Tensor] | None = None) -> Tensor:
# attn_mask: broadcastable to [B, num_heads, L, S] or [B, 1, 1, S]; True entries are attended
qkv = self.qkv(x)
attn_v = self.compute_attention(qkv=qkv, attn_mask=attn_mask, rope=rope)
x = self.proj(attn_v)
x = self.proj_drop(x)
return x
def compute_attention(self, qkv: Tensor, attn_mask: Tensor | None = None, rope: Tensor | tuple[Tensor, Tensor] | None = None) -> Tensor:
B, N, _ = qkv.shape
C = self.qkv.in_features
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = torch.unbind(qkv, 2)
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
if rope is not None:
q, k = self.apply_rope(q, k, rope)
# attn_mask follows PyTorch SDPA semantics; boolean True entries are attended
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = x.transpose(1, 2)
return x.reshape([B, N, C])