tevr-token-entropy-predictor-de / text_tokenizer.py
fxtentacle's picture
Upload text_tokenizer.py
c2cd532
from typing import Any, Dict, List, Optional, Union
import json
class HajoTextTokenizer:
def __init__(self, config_file: str):
with open(config_file,'rt') as f:
self.all_tokens = json.load(f)
self.unk = 1000 + len(self.all_tokens)-1
self.all_tokens[self.unk-1000] = '?'
self.valid_tokens = self.all_tokens[:-1]
def encode(self, sentence):
sentence = sentence.replace('ß','ss').replace('-',' ').replace(' ',' ').replace(' ',' ').lower()
sentence = list(sentence)
for tokid,tok in enumerate(self.valid_tokens):
tlen = len(tok)
ltok = list(tok)
for off in range(len(sentence)-tlen+1):
# print(sentence[off:off+tlen], ltok)
if sentence[off:off+tlen] == ltok:
prefix = sentence[:off]
suffix = sentence[off+tlen:]
# print('MATCH', [prefix, tok, suffix])
#print('MATCH', tok)
sentence = prefix + [1000+tokid] + suffix
#break
out = []
last_id = 0
for t in sentence:
if isinstance(t, str):
t = self.unk
if t == last_id:
if t == self.unk:
continue
out.append(0)
last_id = t
out.append(t-1000)
return out
def decode(self, label_ids):
out = ''
last_id = 0
for i in label_ids:
if i == 0 or i == -100:
last_id = i
continue
if i == 1: break
if i != last_id:
out += self.all_tokens[i]
last_id = i
return out