| """Rotary Position Embedding (RoPE) implementation."""
|
|
|
| import torch
|
| import torch.nn as nn
|
| import math
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| """Rotary position embeddings."""
|
|
|
| def __init__(self, dim, scale=40):
|
| super().__init__()
|
| assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
|
| self.dim = dim
|
| self.scale = scale
|
|
|
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| self.register_buffer("inv_freq", inv_freq)
|
|
|
| def forward(self, seq_len, device):
|
| """Generate rotary embeddings for sequence."""
|
| t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.scale
|
| freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| return torch.cat((freqs, freqs), dim=-1)
|
|
|
|
|
| def rotate_half(x):
|
| """Rotate half the hidden dims of the input."""
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
| def apply_rotary(x, cos, sin):
|
| """Apply rotary embeddings to input tensor."""
|
|
|
| cos = cos[..., :x.shape[-1]]
|
| sin = sin[..., :x.shape[-1]]
|
|
|
|
|
| x_rot = x[..., :cos.shape[-1]]
|
| x_base = x[..., cos.shape[-1]:]
|
|
|
|
|
| x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
|
|
|
|
|
| return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot |