| import torch |
|
|
|
|
| class Rotary(torch.nn.Module): |
| def __init__(self, dim, base=10_000): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| self.seq_len_cached = None |
| self.cos_cached = None |
| self.sin_cached = None |
|
|
| def forward(self, x, seq_dim=1): |
| seq_len = x.shape[seq_dim] |
| if seq_len != self.seq_len_cached: |
| self.seq_len_cached = seq_len |
| t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| |
| self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) |
| self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1, 1, 3, 1, 1) |
| |
| self.cos_cached[:, :, 2, :, :].fill_(1.0) |
| self.sin_cached[:, :, 2, :, :].fill_(0.0) |
|
|
| return self.cos_cached, self.sin_cached |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def _apply_rotary_pos_emb_native(qkv, cos, sin): |
| """Native PyTorch implementation without JIT compilation""" |
| return (qkv * cos) + (rotate_half(qkv) * sin) |
|
|
|
|
| @torch.jit.script |
| def _apply_rotary_pos_emb_torchscript(qkv, cos, sin): |
| return (qkv * cos) + (rotate_half(qkv) * sin) |
|
|
|
|
| def apply_rotary_pos_emb(qkv, cos, sin): |
| try: |
| import flash_attn.layers.rotary |
|
|
| cos_flash = cos[0, :, 0, 0, : cos.shape[-1] // 2] |
| sin_flash = sin[0, :, 0, 0, : sin.shape[-1] // 2] |
| return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos_flash, sin_flash) |
| except (ImportError, AttributeError, RuntimeError): |
| |
| return _apply_rotary_pos_emb_native(qkv, cos, sin) |
|
|