vigneshwar234 commited on
Commit
d138de8
·
verified ·
1 Parent(s): 7bafa44

Add source: tmt/data/tokenizer.py

Browse files
Files changed (1) hide show
  1. tmt/data/tokenizer.py +33 -0
tmt/data/tokenizer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer.py — thin wrapper around HuggingFace tokenizer for TMT.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from typing import List, Union
7
+
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class TMTTokenizer:
12
+ """Wraps a HuggingFace tokenizer with a consistent TMT interface."""
13
+
14
+ def __init__(self, name: str = "gpt2") -> None:
15
+ self.hf = AutoTokenizer.from_pretrained(name)
16
+ if self.hf.pad_token is None:
17
+ self.hf.add_special_tokens({"pad_token": "[PAD]"})
18
+ self.vocab_size = len(self.hf)
19
+
20
+ def encode(self, text: Union[str, List[str]], max_length: int = 1024) -> dict:
21
+ return self.hf(
22
+ text,
23
+ return_tensors="pt",
24
+ padding="max_length",
25
+ truncation=True,
26
+ max_length=max_length,
27
+ )
28
+
29
+ def decode(self, token_ids) -> str:
30
+ return self.hf.decode(token_ids, skip_special_tokens=True)
31
+
32
+ def __repr__(self) -> str:
33
+ return f"TMTTokenizer(vocab={self.vocab_size})"