Sentence Similarity
ONNX
Safetensors
English
ogma
embeddings
dense-retrieval
matryoshka
rag
agents
mteb
semantic-search
text-embeddings
text-embedding
vector-search
document-retrieval
similarity-search
classification
clustering
edge-ai
on-device
local-inference
efficient-ai
rag-retrieval
custom_code
Eval Results (legacy)
File size: 2,796 Bytes
6efaeab ac59af7 6efaeab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """Pooling strategies for Ogma."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import OgmaConfig, PoolingType
__all__ = [
"create_pooling",
"TaskTokenPooling",
"LatentAttentionPooling",
"MeanPooling",
]
def create_pooling(config: OgmaConfig) -> nn.Module:
"""Factory for pooling layers."""
if config.pooling == PoolingType.TASK_TOKEN:
return TaskTokenPooling()
elif config.pooling == PoolingType.LATENT_ATTENTION:
return LatentAttentionPooling(config.d_model)
elif config.pooling == PoolingType.MEAN:
return MeanPooling()
raise ValueError(f"Unknown pooling type: {config.pooling}")
class TaskTokenPooling(nn.Module):
"""Use the output at position 0 (task token) as the sentence embedding."""
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Extract task token output.
Args:
x: (B, S, D) sequence outputs.
attention_mask: unused, for interface compatibility.
Returns:
(B, D) pooled output.
"""
return x[:, 0, :]
class LatentAttentionPooling(nn.Module):
"""Learned query vector attends over all token outputs."""
def __init__(self, d_model: int) -> None:
super().__init__()
self.query = nn.Parameter(torch.randn(d_model))
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Attend over sequence with learned query.
Args:
x: (B, S, D) sequence outputs.
attention_mask: (B, S) mask where 1=valid, 0=pad.
Returns:
(B, D) pooled output.
"""
# (B, S)
scores = torch.matmul(x, self.query) / (x.shape[-1] ** 0.5)
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # (B, S)
return torch.bmm(weights.unsqueeze(1), x).squeeze(1) # (B, D)
class MeanPooling(nn.Module):
"""Average all token outputs (excluding padding)."""
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Mean pool over valid tokens.
Args:
x: (B, S, D) sequence outputs.
attention_mask: (B, S) mask where 1=valid, 0=pad.
Returns:
(B, D) pooled output.
"""
if attention_mask is None:
return x.mean(dim=1)
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
return (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
|