Sentence Similarity
ONNX
Safetensors
English
ogma
embeddings
dense-retrieval
matryoshka
rag
agents
mteb
semantic-search
text-embeddings
text-embedding
vector-search
document-retrieval
similarity-search
classification
clustering
edge-ai
on-device
local-inference
efficient-ai
rag-retrieval
custom_code
Eval Results (legacy)
File size: 4,878 Bytes
6efaeab ac59af7 6efaeab 08ab6b6 6efaeab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """Token embeddings, task token embeddings, and RoPE for Ogma."""
from __future__ import annotations
import torch
import torch.nn as nn
from .config import OgmaConfig
__all__ = ["TokenEmbedding", "RotaryPositionalEncoding"]
class TokenEmbedding(nn.Module):
"""Token embedding with optional linear projection.
Loads a vocab_size x d_embed embedding table and projects to d_model.
Includes 3 learnable task token embeddings ([QRY], [DOC], [SYM]).
"""
def __init__(self, config: OgmaConfig) -> None:
super().__init__()
self.config = config
self.embed = nn.Embedding(
config.vocab_size + config.n_special_tokens,
config.d_embed,
padding_idx=config.pad_id,
)
if config.d_embed != config.d_model:
self.proj = nn.Linear(config.d_embed, config.d_model)
else:
self.proj = nn.Identity() # type: ignore[assignment]
# Task token embeddings are learned separately at d_model
self.task_tokens = nn.Embedding(3, config.d_model)
def forward(
self,
token_ids: torch.Tensor,
task_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Embed tokens and prepend task token.
Args:
token_ids: (B, S) token IDs.
task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
Returns:
(B, S+1, d_model) embeddings with task token prepended.
"""
# Embed and project regular tokens
x = self.embed(token_ids) # (B, S, d_embed)
x = self.proj(x) # (B, S, d_model)
# Get task token embeddings (map 4,5,6 -> 0,1,2)
task_idx = task_token_ids - self.config.qry_id # (B,)
task_emb = self.task_tokens(task_idx) # (B, d_model)
task_emb = task_emb.unsqueeze(1) # (B, 1, d_model)
# Prepend task token
return torch.cat([task_emb, x], dim=1) # (B, S+1, d_model)
def load_pretrained_embeddings(
self, embeddings: torch.Tensor
) -> None:
"""Load pre-computed token embeddings (e.g., from teacher PCA).
Args:
embeddings: (vocab_size, d_embed) tensor.
"""
with torch.no_grad():
n = min(embeddings.shape[0], self.config.vocab_size)
start = self.config.n_special_tokens
self.embed.weight[start : n + start] = embeddings[:n]
class RotaryPositionalEncoding(nn.Module):
"""Rotary Position Embedding (RoPE). Zero trainable parameters."""
def __init__(self, dim: int, max_seq_len: int = 512) -> None:
super().__init__()
inv_freq = 1.0 / (
10000.0 ** (torch.arange(0, dim, 2).float() / dim)
)
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int) -> None:
inv_freq: torch.Tensor = self.inv_freq # type: ignore[assignment]
t = torch.arange(seq_len, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq)
cos_cached = freqs.cos()
sin_cached = freqs.sin()
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Return cos and sin for sequence length of x.
Args:
x: (B, S, ...) tensor to determine sequence length.
Returns:
Tuple of (cos, sin) each of shape (S, d_head//2).
"""
seq_len = x.shape[1]
cos: torch.Tensor = self.cos_cached # type: ignore[assignment]
sin: torch.Tensor = self.sin_cached # type: ignore[assignment]
if seq_len > cos.shape[0] or not torch.isfinite(cos[:seq_len]).all():
self._build_cache(max(seq_len, cos.shape[0]))
cos = self.cos_cached # type: ignore[assignment]
sin = self.sin_cached # type: ignore[assignment]
return cos[:seq_len], sin[:seq_len]
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key tensors.
Args:
q: (B, n_heads, S, d_head) query tensor.
k: (B, n_heads, S, d_head) key tensor.
cos: (S, d_head//2) cosine cache.
sin: (S, d_head//2) sine cache.
Returns:
Rotated (q, k) tensors.
"""
def _rotate(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
cos_exp = cos.unsqueeze(0).unsqueeze(0) # (1, 1, S, d_head//2)
sin_exp = sin.unsqueeze(0).unsqueeze(0)
return torch.cat(
[x1 * cos_exp - x2 * sin_exp, x2 * cos_exp + x1 * sin_exp],
dim=-1,
)
return _rotate(q), _rotate(k)
|