Sentence Similarity
ONNX
Safetensors
English
ogma
embeddings
dense-retrieval
matryoshka
rag
agents
mteb
semantic-search
text-embeddings
text-embedding
vector-search
document-retrieval
similarity-search
classification
clustering
edge-ai
on-device
local-inference
efficient-ai
rag-retrieval
custom_code
Eval Results (legacy)
Enable AutoModel loading
Browse files- ogma_model.py +40 -33
ogma_model.py
CHANGED
|
@@ -5,17 +5,12 @@ from __future__ import annotations
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
|
|
|
| 8 |
|
| 9 |
from .config import OgmaConfig, TaskToken, VariantType
|
| 10 |
from .embeddings import TokenEmbedding
|
| 11 |
from .pooling import create_pooling
|
| 12 |
-
from .
|
| 13 |
-
from .variants.deep_narrow import DeepNarrowVariant
|
| 14 |
-
from .variants.linear_attention import LinearAttentionVariant
|
| 15 |
-
from .variants.mlp_mixer import MLPMixerVariant
|
| 16 |
-
from .variants.transformer import TransformerVariant
|
| 17 |
-
from .variants.transformer_resa import TransformerReSAVariant
|
| 18 |
-
from .variants.gla import GLAVariant
|
| 19 |
|
| 20 |
__all__ = ["OgmaModel"]
|
| 21 |
|
|
@@ -23,25 +18,13 @@ MAX_PARAMS = 10_000_000
|
|
| 23 |
|
| 24 |
|
| 25 |
def _build_variant(config: OgmaConfig) -> nn.Module:
|
| 26 |
-
"""Instantiate the
|
| 27 |
-
if config.variant =
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
elif config.variant == VariantType.LINEAR_ATTENTION:
|
| 34 |
-
return LinearAttentionVariant(config)
|
| 35 |
-
elif config.variant == VariantType.MLP_MIXER:
|
| 36 |
-
return MLPMixerVariant(config)
|
| 37 |
-
elif config.variant == VariantType.TRANSFORMER_RESA:
|
| 38 |
-
return TransformerReSAVariant(config)
|
| 39 |
-
elif config.variant == VariantType.GLA:
|
| 40 |
-
return GLAVariant(config)
|
| 41 |
-
raise ValueError(f"Unknown variant: {config.variant}")
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class OgmaModel(nn.Module):
|
| 45 |
"""Ogma embedding model.
|
| 46 |
|
| 47 |
Wraps any architecture variant with shared embedding, pooling, and
|
|
@@ -49,8 +32,14 @@ class OgmaModel(nn.Module):
|
|
| 49 |
Matryoshka-compatible at configured sub-dimensions.
|
| 50 |
"""
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def __init__(self, config: OgmaConfig) -> None:
|
| 53 |
-
super().__init__()
|
| 54 |
self.config = config
|
| 55 |
self.embedding = TokenEmbedding(config)
|
| 56 |
self.variant = _build_variant(config)
|
|
@@ -71,20 +60,37 @@ class OgmaModel(nn.Module):
|
|
| 71 |
|
| 72 |
def forward(
|
| 73 |
self,
|
| 74 |
-
|
| 75 |
-
attention_mask: torch.Tensor,
|
| 76 |
-
task_token_ids: torch.Tensor,
|
|
|
|
|
|
|
| 77 |
) -> torch.Tensor:
|
| 78 |
"""Forward pass producing L2-normalized embeddings.
|
| 79 |
|
| 80 |
Args:
|
| 81 |
-
|
| 82 |
attention_mask: (B, S) attention mask (1=valid, 0=pad).
|
| 83 |
task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
|
|
|
|
| 84 |
|
| 85 |
Returns:
|
| 86 |
(B, d_output) L2-normalized embeddings.
|
| 87 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
# Embed tokens with task token prepended -> (B, S+1, d_model)
|
| 89 |
x = self.embedding(token_ids, task_token_ids)
|
| 90 |
|
|
@@ -130,7 +136,7 @@ class OgmaModel(nn.Module):
|
|
| 130 |
device=token_ids.device,
|
| 131 |
dtype=torch.long,
|
| 132 |
)
|
| 133 |
-
return self.forward(token_ids, attention_mask, task_ids)
|
| 134 |
|
| 135 |
def param_count(self) -> int:
|
| 136 |
"""Count total trainable parameters."""
|
|
@@ -147,7 +153,8 @@ class OgmaModel(nn.Module):
|
|
| 147 |
def from_config(cls, config: OgmaConfig) -> OgmaModel:
|
| 148 |
"""Factory method to build a model from config."""
|
| 149 |
model = cls(config)
|
| 150 |
-
model.
|
|
|
|
| 151 |
return model
|
| 152 |
|
| 153 |
@classmethod
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
|
| 10 |
from .config import OgmaConfig, TaskToken, VariantType
|
| 11 |
from .embeddings import TokenEmbedding
|
| 12 |
from .pooling import create_pooling
|
| 13 |
+
from .transformer import TransformerVariant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
__all__ = ["OgmaModel"]
|
| 16 |
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def _build_variant(config: OgmaConfig) -> nn.Module:
|
| 21 |
+
"""Instantiate the released Ogma architecture variant."""
|
| 22 |
+
if config.variant != VariantType.TRANSFORMER:
|
| 23 |
+
raise ValueError(f"This HF release supports transformer checkpoints, got {config.variant}")
|
| 24 |
+
return TransformerVariant(config)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class OgmaModel(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"""Ogma embedding model.
|
| 29 |
|
| 30 |
Wraps any architecture variant with shared embedding, pooling, and
|
|
|
|
| 32 |
Matryoshka-compatible at configured sub-dimensions.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
config_class = OgmaConfig
|
| 36 |
+
base_model_prefix = "ogma"
|
| 37 |
+
supports_gradient_checkpointing = False
|
| 38 |
+
_tied_weights_keys: list[str] = []
|
| 39 |
+
all_tied_weights_keys: dict[str, str] = {}
|
| 40 |
+
|
| 41 |
def __init__(self, config: OgmaConfig) -> None:
|
| 42 |
+
super().__init__(config)
|
| 43 |
self.config = config
|
| 44 |
self.embedding = TokenEmbedding(config)
|
| 45 |
self.variant = _build_variant(config)
|
|
|
|
| 60 |
|
| 61 |
def forward(
|
| 62 |
self,
|
| 63 |
+
input_ids: torch.Tensor | None = None,
|
| 64 |
+
attention_mask: torch.Tensor | None = None,
|
| 65 |
+
task_token_ids: torch.Tensor | None = None,
|
| 66 |
+
token_ids: torch.Tensor | None = None,
|
| 67 |
+
**_: object,
|
| 68 |
) -> torch.Tensor:
|
| 69 |
"""Forward pass producing L2-normalized embeddings.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
+
input_ids: (B, S) token IDs, Hugging Face style.
|
| 73 |
attention_mask: (B, S) attention mask (1=valid, 0=pad).
|
| 74 |
task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
|
| 75 |
+
token_ids: Backward-compatible alias for input_ids.
|
| 76 |
|
| 77 |
Returns:
|
| 78 |
(B, d_output) L2-normalized embeddings.
|
| 79 |
"""
|
| 80 |
+
if input_ids is None:
|
| 81 |
+
input_ids = token_ids
|
| 82 |
+
if input_ids is None:
|
| 83 |
+
raise ValueError("input_ids or token_ids must be provided")
|
| 84 |
+
if attention_mask is None:
|
| 85 |
+
attention_mask = torch.ones_like(input_ids)
|
| 86 |
+
if task_token_ids is None:
|
| 87 |
+
task_token_ids = torch.full(
|
| 88 |
+
(input_ids.shape[0],),
|
| 89 |
+
self.config.sym_id,
|
| 90 |
+
device=input_ids.device,
|
| 91 |
+
dtype=torch.long,
|
| 92 |
+
)
|
| 93 |
+
token_ids = input_ids
|
| 94 |
# Embed tokens with task token prepended -> (B, S+1, d_model)
|
| 95 |
x = self.embedding(token_ids, task_token_ids)
|
| 96 |
|
|
|
|
| 136 |
device=token_ids.device,
|
| 137 |
dtype=torch.long,
|
| 138 |
)
|
| 139 |
+
return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids)
|
| 140 |
|
| 141 |
def param_count(self) -> int:
|
| 142 |
"""Count total trainable parameters."""
|
|
|
|
| 153 |
def from_config(cls, config: OgmaConfig) -> OgmaModel:
|
| 154 |
"""Factory method to build a model from config."""
|
| 155 |
model = cls(config)
|
| 156 |
+
if model.param_count() < MAX_PARAMS:
|
| 157 |
+
model.assert_param_budget()
|
| 158 |
return model
|
| 159 |
|
| 160 |
@classmethod
|