File size: 1,748 Bytes
c2cd532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Any, Dict, List, Optional, Union
import json

class HajoTextTokenizer:
    def __init__(self, config_file: str):
        with open(config_file,'rt') as f:
            self.all_tokens = json.load(f)
        self.unk = 1000 + len(self.all_tokens)-1
        self.all_tokens[self.unk-1000] = '?'
        self.valid_tokens = self.all_tokens[:-1]
    
    def encode(self, sentence):
        sentence = sentence.replace('ß','ss').replace('-',' ').replace('  ',' ').replace('  ',' ').lower()
        sentence = list(sentence)
        for tokid,tok in enumerate(self.valid_tokens):
            tlen = len(tok)
            ltok = list(tok)
            for off in range(len(sentence)-tlen+1):
                # print(sentence[off:off+tlen], ltok)
                if sentence[off:off+tlen] == ltok:
                    prefix = sentence[:off]
                    suffix = sentence[off+tlen:]
                    # print('MATCH', [prefix, tok, suffix])
                    #print('MATCH', tok)
                    sentence = prefix + [1000+tokid] + suffix
            #break
        out = []
        last_id = 0
        for t in sentence:
            if isinstance(t, str):
                t = self.unk
            if t == last_id:
                if t == self.unk:
                    continue
                out.append(0)
            last_id = t
            out.append(t-1000)
        return out
    
    def decode(self, label_ids):
        out = ''
        last_id = 0
        for i in label_ids:
            if i == 0 or i == -100: 
                last_id = i
                continue
            if i == 1: break
            if i != last_id:
                out += self.all_tokens[i] 
                last_id = i
        return out