File size: 988 Bytes
d138de8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
tokenizer.py — thin wrapper around HuggingFace tokenizer for TMT.
"""
from __future__ import annotations

from typing import List, Union

from transformers import AutoTokenizer


class TMTTokenizer:
    """Wraps a HuggingFace tokenizer with a consistent TMT interface."""

    def __init__(self, name: str = "gpt2") -> None:
        self.hf = AutoTokenizer.from_pretrained(name)
        if self.hf.pad_token is None:
            self.hf.add_special_tokens({"pad_token": "[PAD]"})
        self.vocab_size = len(self.hf)

    def encode(self, text: Union[str, List[str]], max_length: int = 1024) -> dict:
        return self.hf(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=max_length,
        )

    def decode(self, token_ids) -> str:
        return self.hf.decode(token_ids, skip_special_tokens=True)

    def __repr__(self) -> str:
        return f"TMTTokenizer(vocab={self.vocab_size})"