| """Rotary positional embedding helpers.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
|
|
|
|
| def _scaled_positions(seq_len: int, scaling_factor: float, device: torch.device) -> torch.Tensor: |
| """Apply a simple YaRN-style position scaling factor.""" |
| positions = torch.arange(seq_len, device=device, dtype=torch.float32) |
| if scaling_factor > 1.0: |
| positions = positions / scaling_factor |
| return positions |
|
|
|
|
| def build_rope_cache( |
| seq_len: int, |
| head_dim: int, |
| base_frequency: int = 500_000, |
| scaling_factor: float = 1.0, |
| device: torch.device | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Precompute cosine and sine tables for RoPE.""" |
| if head_dim % 2 != 0: |
| raise ValueError("head_dim must be even for RoPE.") |
| device = device or torch.device("cpu") |
| positions = _scaled_positions(seq_len, scaling_factor, device) |
| inv_freq = 1.0 / (base_frequency ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)) |
| freqs = torch.outer(positions, inv_freq) |
| cos = torch.cos(freqs) |
| sin = torch.sin(freqs) |
| return cos, sin |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate the last dimension in pairs.""" |
| even = x[..., ::2] |
| odd = x[..., 1::2] |
| rotated = torch.stack((-odd, even), dim=-1) |
| return rotated.flatten(start_dim=-2) |
|
|
|
|
| def apply_rope( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Apply rotary embeddings to query and key tensors.""" |
| if q.shape != k.shape: |
| raise ValueError("q and k must share the same shape for RoPE application.") |
| seq_len = q.size(-2) |
| cos = cos[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1) |
| sin = sin[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1) |
| q_out = (q * cos) + (rotate_half(q) * sin) |
| k_out = (k * cos) + (rotate_half(k) * sin) |
| return q_out, k_out |
|
|