sage / model /rope.py
sage002's picture
feat: rewrite SAGE 1B architecture and replace legacy repo contents
ef18673 verified
"""Rotary positional embedding helpers."""
from __future__ import annotations
import torch
def _scaled_positions(seq_len: int, scaling_factor: float, device: torch.device) -> torch.Tensor:
"""Apply a simple YaRN-style position scaling factor."""
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
if scaling_factor > 1.0:
positions = positions / scaling_factor
return positions
def build_rope_cache(
seq_len: int,
head_dim: int,
base_frequency: int = 500_000,
scaling_factor: float = 1.0,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute cosine and sine tables for RoPE."""
if head_dim % 2 != 0:
raise ValueError("head_dim must be even for RoPE.")
device = device or torch.device("cpu")
positions = _scaled_positions(seq_len, scaling_factor, device)
inv_freq = 1.0 / (base_frequency ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim))
freqs = torch.outer(positions, inv_freq)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate the last dimension in pairs."""
even = x[..., ::2]
odd = x[..., 1::2]
rotated = torch.stack((-odd, even), dim=-1)
return rotated.flatten(start_dim=-2)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key tensors."""
if q.shape != k.shape:
raise ValueError("q and k must share the same shape for RoPE application.")
seq_len = q.size(-2)
cos = cos[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1)
sin = sin[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1)
q_out = (q * cos) + (rotate_half(q) * sin)
k_out = (k * cos) + (rotate_half(k) * sin)
return q_out, k_out