File size: 3,021 Bytes
79a81b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Auto-generated simple loader for the morpheme tokenizer
import json, unicodedata
from collections import defaultdict

SPECIAL_TOKENS = ['[PAD]', '[UNK]', '[BOS]', '[EOS]']

class TrieNode(defaultdict):
    def __init__(self):
        super().__init__(TrieNode)
        self.end = False

def dict_to_node(d):
    node = TrieNode()
    node.end = d.get("#end", False)
    for k,v in d.items():
        if k == "#end": continue
        node[k] = dict_to_node(v)
    return node

def normalize_text(s):
    s = unicodedata.normalize("NFC", s)
    s = __import__("re").sub(r"\s+", " ", s).strip()
    return s

def is_dev_char(ch):
    cp = ord(ch)
    for a,b in [(0x0900,0x097F),(0xA8E0,0xA8FF),(0x1CD0,0x1CFF)]:
        if a <= cp <= b:
            return True
    return False

PUNCT_CHARS = set(list("।॥,;:—-–()[]{}\"'‘’“”…!?|/\\·•*^`~"))

def longest_match_tokenize(text, trie, unk_token="[UNK]"):
    text = normalize_text(text)
    out = []
    i, n = 0, len(text)
    while i < n:
        ch = text[i]
        if ch.isspace(): i+=1; continue
        if not is_dev_char(ch) and ch not in PUNCT_CHARS:
            j = i+1
            while j < n and (not is_dev_char(text[j])) and (text[j] not in PUNCT_CHARS) and (not text[j].isspace()):
                j+=1
            out.append(text[i:j]); i=j; continue
        if ch in PUNCT_CHARS:
            out.append(ch); i+=1; continue
        node = trie; j = i; last = -1
        while j < n and text[j] in node:
            node = node[text[j]]; j+=1
            if node.end: last = j
        if last != -1:
            out.append(text[i:last]); i=last
        else:
            j=i+1
            while j<n and is_dev_char(text[j]) and text[j] not in PUNCT_CHARS and not text[j].isspace():
                j+=1
            cand = text[i:j]
            out.append(cand if "#FORCE#"=="never" else unk_token)  # keep UNK
            i=j
    return [t for t in out if t!=""]

class MorphemeTokenizer:
    def __init__(self, path):
        with open(path+"/vocab.json","r",encoding="utf-8") as f:
            vv = json.load(f)
        self.tok2id = vv["tok2id"]
        self.id2tok = {int(k):v for k,v in vv["id2tok"].items()}
        with open(path+"/trie.json","r",encoding="utf-8") as f:
            self.trie = dict_to_node(json.load(f))
        self.bos_token="[BOS]"; self.eos_token="[EOS]"; self.unk_token="[UNK]"; self.pad_token="[PAD]"

    def tokenize(self, text):
        return longest_match_tokenize(text, self.trie, self.unk_token)

    def encode(self, text, add_special_tokens=True):
        toks = self.tokenize(text)
        if add_special_tokens:
            toks = [self.bos_token] + toks + [self.eos_token]
        return [self.tok2id.get(t, self.tok2id[self.unk_token]) for t in toks]

    def decode(self, ids, skip_special_tokens=True):
        toks = [self.id2tok.get(i, self.unk_token) for i in ids]
        if skip_special_tokens:
            toks = [t for t in toks if t not in SPECIAL_TOKENS]
        return "".join(toks)