ogma-micro / embeddings.py
Antreas's picture
Enable AutoModel loading
08ab6b6 verified
"""Token embeddings, task token embeddings, and RoPE for Ogma."""
from __future__ import annotations
import torch
import torch.nn as nn
from .config import OgmaConfig
__all__ = ["TokenEmbedding", "RotaryPositionalEncoding"]
class TokenEmbedding(nn.Module):
"""Token embedding with optional linear projection.
Loads a vocab_size x d_embed embedding table and projects to d_model.
Includes 3 learnable task token embeddings ([QRY], [DOC], [SYM]).
"""
def __init__(self, config: OgmaConfig) -> None:
super().__init__()
self.config = config
self.embed = nn.Embedding(
config.vocab_size + config.n_special_tokens,
config.d_embed,
padding_idx=config.pad_id,
)
if config.d_embed != config.d_model:
self.proj = nn.Linear(config.d_embed, config.d_model)
else:
self.proj = nn.Identity() # type: ignore[assignment]
# Task token embeddings are learned separately at d_model
self.task_tokens = nn.Embedding(3, config.d_model)
def forward(
self,
token_ids: torch.Tensor,
task_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Embed tokens and prepend task token.
Args:
token_ids: (B, S) token IDs.
task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
Returns:
(B, S+1, d_model) embeddings with task token prepended.
"""
# Embed and project regular tokens
x = self.embed(token_ids) # (B, S, d_embed)
x = self.proj(x) # (B, S, d_model)
# Get task token embeddings (map 4,5,6 -> 0,1,2)
task_idx = task_token_ids - self.config.qry_id # (B,)
task_emb = self.task_tokens(task_idx) # (B, d_model)
task_emb = task_emb.unsqueeze(1) # (B, 1, d_model)
# Prepend task token
return torch.cat([task_emb, x], dim=1) # (B, S+1, d_model)
def load_pretrained_embeddings(
self, embeddings: torch.Tensor
) -> None:
"""Load pre-computed token embeddings (e.g., from teacher PCA).
Args:
embeddings: (vocab_size, d_embed) tensor.
"""
with torch.no_grad():
n = min(embeddings.shape[0], self.config.vocab_size)
start = self.config.n_special_tokens
self.embed.weight[start : n + start] = embeddings[:n]
class RotaryPositionalEncoding(nn.Module):
"""Rotary Position Embedding (RoPE). Zero trainable parameters."""
def __init__(self, dim: int, max_seq_len: int = 512) -> None:
super().__init__()
inv_freq = 1.0 / (
10000.0 ** (torch.arange(0, dim, 2).float() / dim)
)
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int) -> None:
inv_freq: torch.Tensor = self.inv_freq # type: ignore[assignment]
t = torch.arange(seq_len, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq)
cos_cached = freqs.cos()
sin_cached = freqs.sin()
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Return cos and sin for sequence length of x.
Args:
x: (B, S, ...) tensor to determine sequence length.
Returns:
Tuple of (cos, sin) each of shape (S, d_head//2).
"""
seq_len = x.shape[1]
cos: torch.Tensor = self.cos_cached # type: ignore[assignment]
sin: torch.Tensor = self.sin_cached # type: ignore[assignment]
if seq_len > cos.shape[0] or not torch.isfinite(cos[:seq_len]).all():
self._build_cache(max(seq_len, cos.shape[0]))
cos = self.cos_cached # type: ignore[assignment]
sin = self.sin_cached # type: ignore[assignment]
return cos[:seq_len], sin[:seq_len]
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key tensors.
Args:
q: (B, n_heads, S, d_head) query tensor.
k: (B, n_heads, S, d_head) key tensor.
cos: (S, d_head//2) cosine cache.
sin: (S, d_head//2) sine cache.
Returns:
Rotated (q, k) tensors.
"""
def _rotate(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
cos_exp = cos.unsqueeze(0).unsqueeze(0) # (1, 1, S, d_head//2)
sin_exp = sin.unsqueeze(0).unsqueeze(0)
return torch.cat(
[x1 * cos_exp - x2 * sin_exp, x2 * cos_exp + x1 * sin_exp],
dim=-1,
)
return _rotate(q), _rotate(k)