File size: 3,545 Bytes
c1e438c
9a90c86
 
c1e438c
c2bdc6b
9a90c86
fc38e18
 
 
 
c1e438c
9a90c86
 
 
c2bdc6b
9a90c86
 
 
c2bdc6b
9a90c86
c2bdc6b
c1e438c
 
c2bdc6b
 
c1e438c
fa5f48b
 
 
 
 
 
 
 
 
 
c1e438c
9a90c86
c2bdc6b
9a90c86
 
c2bdc6b
 
 
 
fa5f48b
 
 
 
 
 
 
 
 
c2bdc6b
9a90c86
c1e438c
b4566a1
d8ba801
 
 
 
1a24d78
 
 
 
 
 
 
 
 
 
 
b4566a1
c2bdc6b
 
 
c1e438c
 
c2bdc6b
c1e438c
 
c2bdc6b
 
 
c1e438c
c2bdc6b
 
7e25fba
c2bdc6b
7e25fba
c2bdc6b
7e25fba
c2bdc6b
c1e438c
c2bdc6b
 
c1e438c
c2bdc6b
1a24d78
 
 
 
 
 
 
 
9a90c86
1a24d78
 
 
 
9a90c86
c2bdc6b
9a90c86
c2bdc6b
 
 
f34a16b
c2bdc6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import json
import sentencepiece as spm


class HybridTokenizer:
    PAD_TOKEN = "<pad>"
    UNK_TOKEN = "<unk>"
    BOS_TOKEN = "<bos>"
    EOS_TOKEN = "<eos>"

    def __init__(self, sp_model_path=None, vocab_path=None):
        self.sp = spm.SentencePieceProcessor()
        self.has_sp = False

        if sp_model_path and os.path.exists(sp_model_path):
            self.sp.Load(sp_model_path)
            self.has_sp = True
            print(f"[TOKENIZER] Loaded SentencePiece model")

        # Load vocab ONLY for size alignment (not encoding)
        self.vocab = {}
        if vocab_path and os.path.exists(vocab_path):
            with open(vocab_path, "r", encoding="utf-8") as f:
                self.vocab = json.load(f)

        self.pad_id = self.sp.pad_id() if self.has_sp else 0
        self.unk_id = self.sp.unk_id() if self.has_sp else 1
        self.bos_id = self.sp.bos_id() if self.has_sp else 2
        self.eos_id = self.sp.eos_id() if self.has_sp else 3

        # Fix invalid (-1) values
        if self.pad_id < 0: self.pad_id = 0
        if self.unk_id < 0: self.unk_id = 0
        if self.bos_id < 0: self.bos_id = None
        if self.eos_id < 0: self.eos_id = None

    # ---------------------------
    # ENCODE (PURE SP)
    # ---------------------------
    def encode(self, text, max_len=512):
        if not self.has_sp:
            raise RuntimeError("SentencePiece model not loaded")

        ids = self.sp.encode(text, out_type=int)
        ids = ids[:max_len]

        if self.bos_id is not None:
            ids = [self.bos_id] + ids

        if self.eos_id is not None:
            ids = ids + [self.eos_id]

        ids = self._sanitize_ids(ids)

        return ids

    def safe_encode(self, text, max_len=512):
        try:
            return self.encode(text, max_len=max_len)
        except Exception as e:
            print(f"[TOKENIZER ERROR] {e}")
            fallback = []
            
            if self.bos_id is not None:
                fallback.append(self.bos_id)

            fallback.append(self.unk_id if self.unk_id is not None else 0)

            if self.eos_id is not None:
                fallback.append(self.eos_id)

            return fallback

    # ---------------------------
    # PAD
    # ---------------------------
    def pad(self, ids, max_len):
        if len(ids) < max_len:
            return ids + [self.pad_id] * (max_len - len(ids))
        return ids[:max_len]

    # ---------------------------
    # DECODE (PURE SP)
    # ---------------------------
    def decode(self, ids):
        # remove special tokens
        cleaned = []
        for i in ids:
            if i in {self.pad_id, self.bos_id}:
                continue
            if i == self.eos_id:
                break
            cleaned.append(i)

        if not cleaned:
            return ""

        return self.sp.decode(cleaned)
    # ---------------------------
    # SAFETY: ID SANITIZATION
    # ---------------------------
    def _sanitize_ids(self, ids):
        vocab_size = self.vocab_size

        # fallback UNK (must be valid)
        unk = self.unk_id if (self.unk_id is not None and self.unk_id >= 0) else 0

        return [
            i if (isinstance(i, int) and 0 <= i < vocab_size) else unk
            for i in ids
        ]
    # ---------------------------
    # VOCAB SIZE (CRITICAL)
    # ---------------------------
    @property
    def vocab_size(self):
        if self.has_sp:
            return self.sp.GetPieceSize()
        return len(self.vocab)