File size: 2,932 Bytes
0956ad7
 
 
 
707b7ba
0956ad7
 
707b7ba
 
 
 
 
 
 
 
 
 
 
 
0956ad7
 
 
 
707b7ba
 
 
 
 
0956ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707b7ba
0956ad7
 
707b7ba
0956ad7
 
707b7ba
0956ad7
707b7ba
 
 
0956ad7
 
 
707b7ba
 
 
0956ad7
707b7ba
0956ad7
707b7ba
 
0956ad7
707b7ba
 
0956ad7
 
707b7ba
0956ad7
 
707b7ba
0956ad7
 
707b7ba
0956ad7
 
 
 
 
 
 
 
 
707b7ba
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
import os
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer, AutoTokenizer

class STLTokenizer(PreTrainedTokenizer):
    model_type = "stl_encoder"

    def __init__(
        self, 
        vocab_file="vocab.json", 
        unk_token="unk", 
        pad_token="pad",
        bos_token="/s", 
        eos_token="s", 
        model_max_length=512, 
        **kwargs
    ):
        current_dir = os.path.dirname(__file__)
        full_vocab_path = os.path.join(current_dir, vocab_file)
        
        if not os.path.exists(full_vocab_path):
            from huggingface_hub import hf_hub_download
            try:
                full_vocab_path = hf_hub_download("saracandu/stlenc", vocab_file)
            except:
                full_vocab_path = vocab_file

        with open(full_vocab_path, "r", encoding="utf-8") as f:
            self.vocab = json.load(f)

        self.id_to_token = {v: k for k, v in self.vocab.items()}
        
        super().__init__(
            unk_token=unk_token, 
            pad_token=pad_token, 
            bos_token=bos_token, 
            eos_token=eos_token, 
            model_max_length=model_max_length, 
            **kwargs
        )

    @property
    def vocab_size(self) -> int:
        return len(self.vocab)

    def get_vocab(self) -> Dict[str, int]:
        return dict(self.vocab)

    def _tokenize(self, text: str) -> List[str]:
        text = f'{self.bos_token} {text} {self.eos_token}'.replace(' ', '@')
        
        tokens = []
        i = 0
        while i < len(text):
            best_match = None
            for j in range(min(i + 50, len(text)), i, -1):
                subtoken = text[i:j]
                if subtoken in self.vocab:
                    best_match = subtoken
                    break
            
            if best_match:
                tokens.append(best_match)
                i += len(best_match)
            else:
                tokens.append(self.unk_token)
                i += 1
        return tokens

    def _convert_token_to_id(self, token: str) -> int:
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index: int) -> str:
        return self.id_to_token.get(index, self.unk_token)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory)
            
        prefix = filename_prefix if filename_prefix is not None else ""
        vocab_file = os.path.join(save_directory, prefix + "vocab.json")
        
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.vocab, f, indent=2, ensure_ascii=False)
            
        return (vocab_file,)

try:
    AutoTokenizer.register("stl_encoder", STLTokenizer)
except Exception:
    pass