File size: 4,619 Bytes
1ce3289 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """AAM Diffusion LLM — Rotary Position Encoding (RoPE)
Implements Rotary Position Encoding from Su et al. (2021).
Better length generalization than learned positional encodings.
Applied inside attention computation, not as a separate embedding.
"""
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
class RotaryPositionEncoding(nn.Module):
"""Rotary Position Encoding (RoPE).
Applies rotary embeddings to query and key tensors before
attention computation. This allows the model to naturally
encode relative positions through the rotation matrix.
"""
def __init__(self, d_model: int, max_seq_len: int = 8192, base: float = 10000.0) -> None:
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequency bands
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Precompute cos/sin for max_seq_len
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int) -> None:
t = torch.arange(seq_len, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seq_len: Optional[int] = None,
offset: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key.
Args:
q: Query tensor (batch, n_heads, seq_len, d_head)
k: Key tensor (batch, n_heads, seq_len, d_head)
seq_len: Sequence length (inferred from q if None)
offset: Position offset (for KV cache)
Returns:
Tuple (rotated_q, rotated_k)
"""
if seq_len is None:
seq_len = q.shape[2]
if offset + seq_len > self.max_seq_len:
self._precompute_cache(offset + seq_len)
cos = self.cos_cached[offset:offset + seq_len].unsqueeze(0).unsqueeze(0) # (1, 1, seq, d)
sin = self.sin_cached[offset:offset + seq_len].unsqueeze(0).unsqueeze(0)
q_rot = self._apply_rotation(q, cos, sin)
k_rot = self._apply_rotation(k, cos, sin)
return q_rot, k_rot
@staticmethod
def _apply_rotation(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1 = x[..., :d // 2]
x2 = x[..., d // 2:]
# Handle dimension mismatch
if cos.shape[-1] != d:
cos = cos[..., :d]
sin = sin[..., :d]
cos1 = cos[..., :d // 2]
cos2 = cos[..., d // 2:]
sin1 = sin[..., :d // 2]
sin2 = sin[..., d // 2:]
rotated = torch.cat([
x1 * cos1 - x2 * sin1,
x1 * sin2 + x2 * cos2,
], dim=-1)
return rotated
def apply_rope_to_attention(
q: torch.Tensor,
k: torch.Tensor,
d_model: int,
seq_len: int,
offset: int = 0,
device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Functional RoPE application — use when you don't want a module.
Args:
q: Query tensor (batch, n_heads, seq_len, d_head)
k: Key tensor (batch, n_heads, seq_len, d_head)
d_model: Model dimension
seq_len: Sequence length
offset: Position offset
device: Device
Returns:
Tuple (rotated_q, rotated_k)
"""
d_head = q.shape[-1]
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, d_head, 2, dtype=torch.float32, device=device) / d_head))
positions = torch.arange(offset, offset + seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(positions, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().unsqueeze(0).unsqueeze(0)
sin = emb.sin().unsqueeze(0).unsqueeze(0)
d = q.shape[-1]
x1_q, x2_q = q[..., :d // 2], q[..., d // 2:]
x1_k, x2_k = k[..., :d // 2], k[..., d // 2:]
cos1, cos2 = cos[..., :d // 2], cos[..., d // 2:]
sin1, sin2 = sin[..., :d // 2], sin[..., d // 2:]
q_rot = torch.cat([x1_q * cos1 - x2_q * sin1, x1_q * sin2 + x2_q * cos2], dim=-1)
k_rot = torch.cat([x1_k * cos1 - x2_k * sin1, x1_k * sin2 + x2_k * cos2], dim=-1)
return q_rot, k_rot
|