Antreas commited on
Commit
c6d0c2e
·
verified ·
1 Parent(s): e673a10

Enable AutoModel loading

Browse files
Files changed (1) hide show
  1. ogma_model.py +40 -33
ogma_model.py CHANGED
@@ -5,17 +5,12 @@ from __future__ import annotations
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
8
 
9
  from .config import OgmaConfig, TaskToken, VariantType
10
  from .embeddings import TokenEmbedding
11
  from .pooling import create_pooling
12
- from .variants.conv import ConvVariant
13
- from .variants.deep_narrow import DeepNarrowVariant
14
- from .variants.linear_attention import LinearAttentionVariant
15
- from .variants.mlp_mixer import MLPMixerVariant
16
- from .variants.transformer import TransformerVariant
17
- from .variants.transformer_resa import TransformerReSAVariant
18
- from .variants.gla import GLAVariant
19
 
20
  __all__ = ["OgmaModel"]
21
 
@@ -23,25 +18,13 @@ MAX_PARAMS = 10_000_000
23
 
24
 
25
  def _build_variant(config: OgmaConfig) -> nn.Module:
26
- """Instantiate the appropriate architecture variant."""
27
- if config.variant == VariantType.TRANSFORMER:
28
- return TransformerVariant(config)
29
- elif config.variant == VariantType.DEEP_NARROW:
30
- return DeepNarrowVariant(config)
31
- elif config.variant == VariantType.CONV:
32
- return ConvVariant(config)
33
- elif config.variant == VariantType.LINEAR_ATTENTION:
34
- return LinearAttentionVariant(config)
35
- elif config.variant == VariantType.MLP_MIXER:
36
- return MLPMixerVariant(config)
37
- elif config.variant == VariantType.TRANSFORMER_RESA:
38
- return TransformerReSAVariant(config)
39
- elif config.variant == VariantType.GLA:
40
- return GLAVariant(config)
41
- raise ValueError(f"Unknown variant: {config.variant}")
42
-
43
-
44
- class OgmaModel(nn.Module):
45
  """Ogma embedding model.
46
 
47
  Wraps any architecture variant with shared embedding, pooling, and
@@ -49,8 +32,14 @@ class OgmaModel(nn.Module):
49
  Matryoshka-compatible at configured sub-dimensions.
50
  """
51
 
 
 
 
 
 
 
52
  def __init__(self, config: OgmaConfig) -> None:
53
- super().__init__()
54
  self.config = config
55
  self.embedding = TokenEmbedding(config)
56
  self.variant = _build_variant(config)
@@ -71,20 +60,37 @@ class OgmaModel(nn.Module):
71
 
72
  def forward(
73
  self,
74
- token_ids: torch.Tensor,
75
- attention_mask: torch.Tensor,
76
- task_token_ids: torch.Tensor,
 
 
77
  ) -> torch.Tensor:
78
  """Forward pass producing L2-normalized embeddings.
79
 
80
  Args:
81
- token_ids: (B, S) token IDs.
82
  attention_mask: (B, S) attention mask (1=valid, 0=pad).
83
  task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
 
84
 
85
  Returns:
86
  (B, d_output) L2-normalized embeddings.
87
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # Embed tokens with task token prepended -> (B, S+1, d_model)
89
  x = self.embedding(token_ids, task_token_ids)
90
 
@@ -130,7 +136,7 @@ class OgmaModel(nn.Module):
130
  device=token_ids.device,
131
  dtype=torch.long,
132
  )
133
- return self.forward(token_ids, attention_mask, task_ids)
134
 
135
  def param_count(self) -> int:
136
  """Count total trainable parameters."""
@@ -147,7 +153,8 @@ class OgmaModel(nn.Module):
147
  def from_config(cls, config: OgmaConfig) -> OgmaModel:
148
  """Factory method to build a model from config."""
149
  model = cls(config)
150
- model.assert_param_budget()
 
151
  return model
152
 
153
  @classmethod
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
+ from transformers import PreTrainedModel
9
 
10
  from .config import OgmaConfig, TaskToken, VariantType
11
  from .embeddings import TokenEmbedding
12
  from .pooling import create_pooling
13
+ from .transformer import TransformerVariant
 
 
 
 
 
 
14
 
15
  __all__ = ["OgmaModel"]
16
 
 
18
 
19
 
20
  def _build_variant(config: OgmaConfig) -> nn.Module:
21
+ """Instantiate the released Ogma architecture variant."""
22
+ if config.variant != VariantType.TRANSFORMER:
23
+ raise ValueError(f"This HF release supports transformer checkpoints, got {config.variant}")
24
+ return TransformerVariant(config)
25
+
26
+
27
+ class OgmaModel(PreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
28
  """Ogma embedding model.
29
 
30
  Wraps any architecture variant with shared embedding, pooling, and
 
32
  Matryoshka-compatible at configured sub-dimensions.
33
  """
34
 
35
+ config_class = OgmaConfig
36
+ base_model_prefix = "ogma"
37
+ supports_gradient_checkpointing = False
38
+ _tied_weights_keys: list[str] = []
39
+ all_tied_weights_keys: dict[str, str] = {}
40
+
41
  def __init__(self, config: OgmaConfig) -> None:
42
+ super().__init__(config)
43
  self.config = config
44
  self.embedding = TokenEmbedding(config)
45
  self.variant = _build_variant(config)
 
60
 
61
  def forward(
62
  self,
63
+ input_ids: torch.Tensor | None = None,
64
+ attention_mask: torch.Tensor | None = None,
65
+ task_token_ids: torch.Tensor | None = None,
66
+ token_ids: torch.Tensor | None = None,
67
+ **_: object,
68
  ) -> torch.Tensor:
69
  """Forward pass producing L2-normalized embeddings.
70
 
71
  Args:
72
+ input_ids: (B, S) token IDs, Hugging Face style.
73
  attention_mask: (B, S) attention mask (1=valid, 0=pad).
74
  task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
75
+ token_ids: Backward-compatible alias for input_ids.
76
 
77
  Returns:
78
  (B, d_output) L2-normalized embeddings.
79
  """
80
+ if input_ids is None:
81
+ input_ids = token_ids
82
+ if input_ids is None:
83
+ raise ValueError("input_ids or token_ids must be provided")
84
+ if attention_mask is None:
85
+ attention_mask = torch.ones_like(input_ids)
86
+ if task_token_ids is None:
87
+ task_token_ids = torch.full(
88
+ (input_ids.shape[0],),
89
+ self.config.sym_id,
90
+ device=input_ids.device,
91
+ dtype=torch.long,
92
+ )
93
+ token_ids = input_ids
94
  # Embed tokens with task token prepended -> (B, S+1, d_model)
95
  x = self.embedding(token_ids, task_token_ids)
96
 
 
136
  device=token_ids.device,
137
  dtype=torch.long,
138
  )
139
+ return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids)
140
 
141
  def param_count(self) -> int:
142
  """Count total trainable parameters."""
 
153
  def from_config(cls, config: OgmaConfig) -> OgmaModel:
154
  """Factory method to build a model from config."""
155
  model = cls(config)
156
+ if model.param_count() < MAX_PARAMS:
157
+ model.assert_param_budget()
158
  return model
159
 
160
  @classmethod