|
|
|
|
|
import json, unicodedata |
|
|
from collections import defaultdict |
|
|
|
|
|
SPECIAL_TOKENS = ['[PAD]', '[UNK]', '[BOS]', '[EOS]'] |
|
|
|
|
|
class TrieNode(defaultdict): |
|
|
def __init__(self): |
|
|
super().__init__(TrieNode) |
|
|
self.end = False |
|
|
|
|
|
def dict_to_node(d): |
|
|
node = TrieNode() |
|
|
node.end = d.get("#end", False) |
|
|
for k,v in d.items(): |
|
|
if k == "#end": continue |
|
|
node[k] = dict_to_node(v) |
|
|
return node |
|
|
|
|
|
def normalize_text(s): |
|
|
s = unicodedata.normalize("NFC", s) |
|
|
s = __import__("re").sub(r"\s+", " ", s).strip() |
|
|
return s |
|
|
|
|
|
def is_dev_char(ch): |
|
|
cp = ord(ch) |
|
|
for a,b in [(0x0900,0x097F),(0xA8E0,0xA8FF),(0x1CD0,0x1CFF)]: |
|
|
if a <= cp <= b: |
|
|
return True |
|
|
return False |
|
|
|
|
|
PUNCT_CHARS = set(list("ΰ₯€ΰ₯₯,;:β-β()[]{}\"'βββββ¦!?|/\\Β·β’*^`~")) |
|
|
|
|
|
def longest_match_tokenize(text, trie, unk_token="[UNK]"): |
|
|
text = normalize_text(text) |
|
|
out = [] |
|
|
i, n = 0, len(text) |
|
|
while i < n: |
|
|
ch = text[i] |
|
|
if ch.isspace(): i+=1; continue |
|
|
if not is_dev_char(ch) and ch not in PUNCT_CHARS: |
|
|
j = i+1 |
|
|
while j < n and (not is_dev_char(text[j])) and (text[j] not in PUNCT_CHARS) and (not text[j].isspace()): |
|
|
j+=1 |
|
|
out.append(text[i:j]); i=j; continue |
|
|
if ch in PUNCT_CHARS: |
|
|
out.append(ch); i+=1; continue |
|
|
node = trie; j = i; last = -1 |
|
|
while j < n and text[j] in node: |
|
|
node = node[text[j]]; j+=1 |
|
|
if node.end: last = j |
|
|
if last != -1: |
|
|
out.append(text[i:last]); i=last |
|
|
else: |
|
|
j=i+1 |
|
|
while j<n and is_dev_char(text[j]) and text[j] not in PUNCT_CHARS and not text[j].isspace(): |
|
|
j+=1 |
|
|
cand = text[i:j] |
|
|
out.append(cand if "#FORCE#"=="never" else unk_token) |
|
|
i=j |
|
|
return [t for t in out if t!=""] |
|
|
|
|
|
class MorphemeTokenizer: |
|
|
def __init__(self, path): |
|
|
with open(path+"/vocab.json","r",encoding="utf-8") as f: |
|
|
vv = json.load(f) |
|
|
self.tok2id = vv["tok2id"] |
|
|
self.id2tok = {int(k):v for k,v in vv["id2tok"].items()} |
|
|
with open(path+"/trie.json","r",encoding="utf-8") as f: |
|
|
self.trie = dict_to_node(json.load(f)) |
|
|
self.bos_token="[BOS]"; self.eos_token="[EOS]"; self.unk_token="[UNK]"; self.pad_token="[PAD]" |
|
|
|
|
|
def tokenize(self, text): |
|
|
return longest_match_tokenize(text, self.trie, self.unk_token) |
|
|
|
|
|
def encode(self, text, add_special_tokens=True): |
|
|
toks = self.tokenize(text) |
|
|
if add_special_tokens: |
|
|
toks = [self.bos_token] + toks + [self.eos_token] |
|
|
return [self.tok2id.get(t, self.tok2id[self.unk_token]) for t in toks] |
|
|
|
|
|
def decode(self, ids, skip_special_tokens=True): |
|
|
toks = [self.id2tok.get(i, self.unk_token) for i in ids] |
|
|
if skip_special_tokens: |
|
|
toks = [t for t in toks if t not in SPECIAL_TOKENS] |
|
|
return "".join(toks) |
|
|
|