Sentence Similarity
ONNX
Safetensors
English
ogma
embeddings
dense-retrieval
matryoshka
rag
agents
mteb
semantic-search
text-embeddings
text-embedding
vector-search
document-retrieval
similarity-search
classification
clustering
edge-ai
on-device
local-inference
efficient-ai
rag-retrieval
custom_code
Eval Results (legacy)
| """Token embeddings, task token embeddings, and RoPE for Ogma.""" | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| from .config import OgmaConfig | |
| __all__ = ["TokenEmbedding", "RotaryPositionalEncoding"] | |
| class TokenEmbedding(nn.Module): | |
| """Token embedding with optional linear projection. | |
| Loads a vocab_size x d_embed embedding table and projects to d_model. | |
| Includes 3 learnable task token embeddings ([QRY], [DOC], [SYM]). | |
| """ | |
| def __init__(self, config: OgmaConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.embed = nn.Embedding( | |
| config.vocab_size + config.n_special_tokens, | |
| config.d_embed, | |
| padding_idx=config.pad_id, | |
| ) | |
| if config.d_embed != config.d_model: | |
| self.proj = nn.Linear(config.d_embed, config.d_model) | |
| else: | |
| self.proj = nn.Identity() # type: ignore[assignment] | |
| # Task token embeddings are learned separately at d_model | |
| self.task_tokens = nn.Embedding(3, config.d_model) | |
| def forward( | |
| self, | |
| token_ids: torch.Tensor, | |
| task_token_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Embed tokens and prepend task token. | |
| Args: | |
| token_ids: (B, S) token IDs. | |
| task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM). | |
| Returns: | |
| (B, S+1, d_model) embeddings with task token prepended. | |
| """ | |
| # Embed and project regular tokens | |
| x = self.embed(token_ids) # (B, S, d_embed) | |
| x = self.proj(x) # (B, S, d_model) | |
| # Get task token embeddings (map 4,5,6 -> 0,1,2) | |
| task_idx = task_token_ids - self.config.qry_id # (B,) | |
| task_emb = self.task_tokens(task_idx) # (B, d_model) | |
| task_emb = task_emb.unsqueeze(1) # (B, 1, d_model) | |
| # Prepend task token | |
| return torch.cat([task_emb, x], dim=1) # (B, S+1, d_model) | |
| def load_pretrained_embeddings( | |
| self, embeddings: torch.Tensor | |
| ) -> None: | |
| """Load pre-computed token embeddings (e.g., from teacher PCA). | |
| Args: | |
| embeddings: (vocab_size, d_embed) tensor. | |
| """ | |
| with torch.no_grad(): | |
| n = min(embeddings.shape[0], self.config.vocab_size) | |
| start = self.config.n_special_tokens | |
| self.embed.weight[start : n + start] = embeddings[:n] | |
| class RotaryPositionalEncoding(nn.Module): | |
| """Rotary Position Embedding (RoPE). Zero trainable parameters.""" | |
| def __init__(self, dim: int, max_seq_len: int = 512) -> None: | |
| super().__init__() | |
| inv_freq = 1.0 / ( | |
| 10000.0 ** (torch.arange(0, dim, 2).float() / dim) | |
| ) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self._build_cache(max_seq_len) | |
| def _build_cache(self, seq_len: int) -> None: | |
| inv_freq: torch.Tensor = self.inv_freq # type: ignore[assignment] | |
| t = torch.arange(seq_len, dtype=inv_freq.dtype) | |
| freqs = torch.outer(t, inv_freq) | |
| cos_cached = freqs.cos() | |
| sin_cached = freqs.sin() | |
| self.register_buffer("cos_cached", cos_cached, persistent=False) | |
| self.register_buffer("sin_cached", sin_cached, persistent=False) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Return cos and sin for sequence length of x. | |
| Args: | |
| x: (B, S, ...) tensor to determine sequence length. | |
| Returns: | |
| Tuple of (cos, sin) each of shape (S, d_head//2). | |
| """ | |
| seq_len = x.shape[1] | |
| cos: torch.Tensor = self.cos_cached # type: ignore[assignment] | |
| sin: torch.Tensor = self.sin_cached # type: ignore[assignment] | |
| if seq_len > cos.shape[0] or not torch.isfinite(cos[:seq_len]).all(): | |
| self._build_cache(max(seq_len, cos.shape[0])) | |
| cos = self.cos_cached # type: ignore[assignment] | |
| sin = self.sin_cached # type: ignore[assignment] | |
| return cos[:seq_len], sin[:seq_len] | |
| def apply_rope( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Apply rotary embeddings to query and key tensors. | |
| Args: | |
| q: (B, n_heads, S, d_head) query tensor. | |
| k: (B, n_heads, S, d_head) key tensor. | |
| cos: (S, d_head//2) cosine cache. | |
| sin: (S, d_head//2) sine cache. | |
| Returns: | |
| Rotated (q, k) tensors. | |
| """ | |
| def _rotate(x: torch.Tensor) -> torch.Tensor: | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| cos_exp = cos.unsqueeze(0).unsqueeze(0) # (1, 1, S, d_head//2) | |
| sin_exp = sin.unsqueeze(0).unsqueeze(0) | |
| return torch.cat( | |
| [x1 * cos_exp - x2 * sin_exp, x2 * cos_exp + x1 * sin_exp], | |
| dim=-1, | |
| ) | |
| return _rotate(q), _rotate(k) | |