Antreas commited on
Commit
1e6ff58
·
1 Parent(s): d66fb10

Add OgmaTokenizerFast + model.embed() high-level API

Browse files

- tokenization_ogma.py: PreTrainedTokenizerFast subclass that shifts
content token ids by N_SPECIAL=7 (training-matched tokenization,
no manual offset needed by callers)
- tokenizer_config.json: wires AutoTokenizer to OgmaTokenizerFast
- ogma_model.py: adds TaskToken class attr + embed(texts, task=) method
so callers need only AutoModel + AutoTokenizer, no sys.modules digging

Files changed (3) hide show
  1. ogma_model.py +50 -0
  2. tokenization_ogma.py +43 -0
  3. tokenizer_config.json +15 -12
ogma_model.py CHANGED
@@ -138,6 +138,56 @@ class OgmaModel(PreTrainedModel):
138
  )
139
  return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def param_count(self) -> int:
142
  """Count total trainable parameters."""
143
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
 
138
  )
139
  return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids)
140
 
141
+ # TaskToken re-exported as a class attribute for clean external access
142
+ TaskToken = TaskToken # noqa: F821 (imported at module top)
143
+
144
+ @torch.no_grad()
145
+ def embed(
146
+ self,
147
+ texts,
148
+ task: str = "sym",
149
+ tokenizer=None,
150
+ batch_size: int = 32,
151
+ max_length: int = 1024,
152
+ ) -> "torch.Tensor":
153
+ """High-level text → L2-normalized embeddings.
154
+
155
+ Args:
156
+ texts: str or list[str] to encode.
157
+ task: "qry" / "doc" / "sym" (or a TaskToken enum member).
158
+ tokenizer: OgmaTokenizerFast instance. If None, loaded
159
+ automatically from self.name_or_path (requires the model to
160
+ have been loaded via AutoModel.from_pretrained).
161
+ batch_size: Texts per forward pass.
162
+ max_length: Token cap per text (default 1024).
163
+
164
+ Returns:
165
+ (len(texts), d_output) tensor of L2-normalized embeddings on the
166
+ same device as the model.
167
+ """
168
+ if isinstance(texts, str):
169
+ texts = [texts]
170
+ if isinstance(task, str):
171
+ task = self.TaskToken[task.upper()]
172
+ if tokenizer is None:
173
+ from transformers import AutoTokenizer
174
+ tokenizer = AutoTokenizer.from_pretrained(
175
+ self.name_or_path, trust_remote_code=True
176
+ )
177
+ device = next(self.parameters()).device
178
+ outs = []
179
+ for i in range(0, len(texts), batch_size):
180
+ enc = tokenizer(
181
+ texts[i : i + batch_size],
182
+ return_tensors="pt",
183
+ padding=True,
184
+ truncation=True,
185
+ max_length=max_length,
186
+ ).to(device)
187
+ outs.append(self.encode(enc["input_ids"], enc["attention_mask"], task=task))
188
+ return torch.cat(outs, dim=0)
189
+
190
+
191
  def param_count(self) -> int:
192
  """Count total trainable parameters."""
193
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
tokenization_ogma.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OgmaTokenizerFast — wraps PreTrainedTokenizerFast, shifts token ids by
2
+ N_SPECIAL so they align with Ogma's embedding table.
3
+
4
+ Ogma reserved vocab ids (0-6):
5
+ 0 <pad> 1 <unk> 2 [CLS] 3 [SEP] 4 [MASK] 5 [DOC] 6 [SYM]
6
+ Regular SentencePiece tokens start at 7.
7
+
8
+ The tokenizer post-processor already adds [CLS] / [SEP] around every input.
9
+ This wrapper shifts ALL content positions (attention_mask == 1) up by
10
+ N_SPECIAL so that [CLS]->9, [SEP]->10, and content tokens land where the
11
+ model was trained to see them. Padding positions (attention_mask == 0) stay
12
+ at 0 (Ogma pad id).
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ from transformers import PreTrainedTokenizerFast
18
+ from transformers.tokenization_utils_base import BatchEncoding
19
+
20
+ __all__ = ["OgmaTokenizerFast"]
21
+
22
+ N_SPECIAL = 7
23
+
24
+
25
+ class OgmaTokenizerFast(PreTrainedTokenizerFast):
26
+ N_SPECIAL = N_SPECIAL
27
+
28
+ def _shift(self, ids, mask):
29
+ if isinstance(ids, torch.Tensor):
30
+ return ids + self.N_SPECIAL * mask.long()
31
+ return [
32
+ [i + self.N_SPECIAL if m else i for i, m in zip(row_i, row_m)]
33
+ for row_i, row_m in zip(ids, mask)
34
+ ]
35
+
36
+ def __call__(self, *args, **kwargs) -> BatchEncoding:
37
+ kwargs.setdefault("padding", True)
38
+ kwargs.setdefault("truncation", True)
39
+ kwargs.setdefault("max_length", self.model_max_length or 1024)
40
+ enc = super().__call__(*args, **kwargs)
41
+ if "input_ids" in enc and "attention_mask" in enc:
42
+ enc["input_ids"] = self._shift(enc["input_ids"], enc["attention_mask"])
43
+ return enc
tokenizer_config.json CHANGED
@@ -1,17 +1,20 @@
1
  {
2
- "add_prefix_space": true,
3
- "backend": "tokenizers",
4
- "bos_token": "[CLS]",
 
 
 
 
 
 
 
 
5
  "cls_token": "[CLS]",
6
- "do_lower_case": true,
 
7
  "eos_token": "[SEP]",
8
- "is_local": false,
9
- "keep_accents": false,
10
  "mask_token": "[MASK]",
11
- "model_max_length": 512,
12
- "pad_token": "<pad>",
13
- "sep_token": "[SEP]",
14
- "tokenizer_class": "AlbertTokenizer",
15
- "trim_offsets": true,
16
- "unk_token": "<unk>"
17
  }
 
1
  {
2
+ "tokenizer_class": "OgmaTokenizerFast",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ null,
6
+ "tokenization_ogma.OgmaTokenizerFast"
7
+ ]
8
+ },
9
+ "model_max_length": 1024,
10
+ "padding_side": "right",
11
+ "pad_token": "<pad>",
12
+ "unk_token": "<unk>",
13
  "cls_token": "[CLS]",
14
+ "sep_token": "[SEP]",
15
+ "bos_token": "[CLS]",
16
  "eos_token": "[SEP]",
 
 
17
  "mask_token": "[MASK]",
18
+ "do_lower_case": true,
19
+ "backend": "tokenizers"
 
 
 
 
20
  }