Sentence Similarity
ONNX
Safetensors
English
ogma
embeddings
dense-retrieval
matryoshka
rag
agents
mteb
semantic-search
text-embeddings
text-embedding
vector-search
document-retrieval
similarity-search
classification
clustering
edge-ai
on-device
local-inference
efficient-ai
rag-retrieval
custom_code
Eval Results (legacy)
| """Pooling strategies for Ogma.""" | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .config import OgmaConfig, PoolingType | |
| __all__ = [ | |
| "create_pooling", | |
| "TaskTokenPooling", | |
| "LatentAttentionPooling", | |
| "MeanPooling", | |
| ] | |
| def create_pooling(config: OgmaConfig) -> nn.Module: | |
| """Factory for pooling layers.""" | |
| if config.pooling == PoolingType.TASK_TOKEN: | |
| return TaskTokenPooling() | |
| elif config.pooling == PoolingType.LATENT_ATTENTION: | |
| return LatentAttentionPooling(config.d_model) | |
| elif config.pooling == PoolingType.MEAN: | |
| return MeanPooling() | |
| raise ValueError(f"Unknown pooling type: {config.pooling}") | |
| class TaskTokenPooling(nn.Module): | |
| """Use the output at position 0 (task token) as the sentence embedding.""" | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| """Extract task token output. | |
| Args: | |
| x: (B, S, D) sequence outputs. | |
| attention_mask: unused, for interface compatibility. | |
| Returns: | |
| (B, D) pooled output. | |
| """ | |
| return x[:, 0, :] | |
| class LatentAttentionPooling(nn.Module): | |
| """Learned query vector attends over all token outputs.""" | |
| def __init__(self, d_model: int) -> None: | |
| super().__init__() | |
| self.query = nn.Parameter(torch.randn(d_model)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| """Attend over sequence with learned query. | |
| Args: | |
| x: (B, S, D) sequence outputs. | |
| attention_mask: (B, S) mask where 1=valid, 0=pad. | |
| Returns: | |
| (B, D) pooled output. | |
| """ | |
| # (B, S) | |
| scores = torch.matmul(x, self.query) / (x.shape[-1] ** 0.5) | |
| if attention_mask is not None: | |
| scores = scores.masked_fill(attention_mask == 0, float("-inf")) | |
| weights = F.softmax(scores, dim=-1) # (B, S) | |
| return torch.bmm(weights.unsqueeze(1), x).squeeze(1) # (B, D) | |
| class MeanPooling(nn.Module): | |
| """Average all token outputs (excluding padding).""" | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| """Mean pool over valid tokens. | |
| Args: | |
| x: (B, S, D) sequence outputs. | |
| attention_mask: (B, S) mask where 1=valid, 0=pad. | |
| Returns: | |
| (B, D) pooled output. | |
| """ | |
| if attention_mask is None: | |
| return x.mean(dim=1) | |
| mask = attention_mask.unsqueeze(-1).float() # (B, S, 1) | |
| return (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) | |