File size: 2,080 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | import torch
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10_000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
# dims are: batch, seq_len, qkv, head, dim
self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1)
# This makes the transformation on v an identity.
self.cos_cached[:, :, 2, :, :].fill_(1.0)
self.sin_cached[:, :, 2, :, :].fill_(0.0)
return self.cos_cached, self.sin_cached
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb_native(qkv, cos, sin):
"""Native PyTorch implementation without JIT compilation"""
return (qkv * cos) + (rotate_half(qkv) * sin)
@torch.jit.script
def _apply_rotary_pos_emb_torchscript(qkv, cos, sin):
return (qkv * cos) + (rotate_half(qkv) * sin)
def apply_rotary_pos_emb(qkv, cos, sin):
try:
import flash_attn.layers.rotary
cos_flash = cos[0, :, 0, 0, : cos.shape[-1] // 2]
sin_flash = sin[0, :, 0, 0, : sin.shape[-1] // 2]
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos_flash, sin_flash)
except (ImportError, AttributeError, RuntimeError):
# Use native implementation without TorchScript due to compatibility issues
return _apply_rotary_pos_emb_native(qkv, cos, sin)
|