"""Token embeddings, task token embeddings, and RoPE for Ogma.""" from __future__ import annotations import torch import torch.nn as nn from .config import OgmaConfig __all__ = ["TokenEmbedding", "RotaryPositionalEncoding"] class TokenEmbedding(nn.Module): """Token embedding with optional linear projection. Loads a vocab_size x d_embed embedding table and projects to d_model. Includes 3 learnable task token embeddings ([QRY], [DOC], [SYM]). """ def __init__(self, config: OgmaConfig) -> None: super().__init__() self.config = config self.embed = nn.Embedding( config.vocab_size + config.n_special_tokens, config.d_embed, padding_idx=config.pad_id, ) if config.d_embed != config.d_model: self.proj = nn.Linear(config.d_embed, config.d_model) else: self.proj = nn.Identity() # type: ignore[assignment] # Task token embeddings are learned separately at d_model self.task_tokens = nn.Embedding(3, config.d_model) def forward( self, token_ids: torch.Tensor, task_token_ids: torch.Tensor, ) -> torch.Tensor: """Embed tokens and prepend task token. Args: token_ids: (B, S) token IDs. task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM). Returns: (B, S+1, d_model) embeddings with task token prepended. """ # Embed and project regular tokens x = self.embed(token_ids) # (B, S, d_embed) x = self.proj(x) # (B, S, d_model) # Get task token embeddings (map 4,5,6 -> 0,1,2) task_idx = task_token_ids - self.config.qry_id # (B,) task_emb = self.task_tokens(task_idx) # (B, d_model) task_emb = task_emb.unsqueeze(1) # (B, 1, d_model) # Prepend task token return torch.cat([task_emb, x], dim=1) # (B, S+1, d_model) def load_pretrained_embeddings( self, embeddings: torch.Tensor ) -> None: """Load pre-computed token embeddings (e.g., from teacher PCA). Args: embeddings: (vocab_size, d_embed) tensor. """ with torch.no_grad(): n = min(embeddings.shape[0], self.config.vocab_size) start = self.config.n_special_tokens self.embed.weight[start : n + start] = embeddings[:n] class RotaryPositionalEncoding(nn.Module): """Rotary Position Embedding (RoPE). Zero trainable parameters.""" def __init__(self, dim: int, max_seq_len: int = 512) -> None: super().__init__() inv_freq = 1.0 / ( 10000.0 ** (torch.arange(0, dim, 2).float() / dim) ) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int) -> None: inv_freq: torch.Tensor = self.inv_freq # type: ignore[assignment] t = torch.arange(seq_len, dtype=inv_freq.dtype) freqs = torch.outer(t, inv_freq) cos_cached = freqs.cos() sin_cached = freqs.sin() self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return cos and sin for sequence length of x. Args: x: (B, S, ...) tensor to determine sequence length. Returns: Tuple of (cos, sin) each of shape (S, d_head//2). """ seq_len = x.shape[1] cos: torch.Tensor = self.cos_cached # type: ignore[assignment] sin: torch.Tensor = self.sin_cached # type: ignore[assignment] if seq_len > cos.shape[0] or not torch.isfinite(cos[:seq_len]).all(): self._build_cache(max(seq_len, cos.shape[0])) cos = self.cos_cached # type: ignore[assignment] sin = self.sin_cached # type: ignore[assignment] return cos[:seq_len], sin[:seq_len] def apply_rope( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to query and key tensors. Args: q: (B, n_heads, S, d_head) query tensor. k: (B, n_heads, S, d_head) key tensor. cos: (S, d_head//2) cosine cache. sin: (S, d_head//2) sine cache. Returns: Rotated (q, k) tensors. """ def _rotate(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] cos_exp = cos.unsqueeze(0).unsqueeze(0) # (1, 1, S, d_head//2) sin_exp = sin.unsqueeze(0).unsqueeze(0) return torch.cat( [x1 * cos_exp - x2 * sin_exp, x2 * cos_exp + x1 * sin_exp], dim=-1, ) return _rotate(q), _rotate(k)