File size: 2,482 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import re
import pandas as pd
from typing import Optional
from koja_diffuser.tokenizer.special import SpecialToken
from io import BytesIO


class JapaneseTokenizer:
    series: pd.Series
    token_pattern: re.Pattern[str]

    def __init__(self, filepath: Optional[str | BytesIO] = None):
        self.token_pattern = re.compile(
            r".[ぁぃぅぇぉゃゅょっゎァィゥェォャュョッヮヵヶ]?"
        )
        if filepath is None:
            self.series = pd.Series()
        else:
            self.series = pd.read_parquet(filepath).squeeze()

    def train(self, frame: list[str], output="./dist/ja_token.parquet"):
        tokens_dict: dict[str, int] = {}
        next_id = SpecialToken.next_id
        for text in frame:
            tokens = self.tokenize(text)
            for token in tokens:
                if token not in tokens_dict:
                    tokens_dict[token] = next_id
                    next_id += 1
        df = pd.Series(tokens_dict).to_frame()
        df.to_parquet(output)

    def tokenize(self, text: str) -> list[str]:
        return self.token_pattern.findall(text)

    def id_to_token(self, i: int) -> str:
        if i == SpecialToken.sep:
            return " "
        result = self.series[self.series == i].index
        key = result[0] if not result.empty else ""
        return key

    def encode(self, text: str, *, add_eos=False, max_len=0) -> list[int]:
        tokens = self.tokenize(text)
        encoded = []

        for token in tokens:
            if token == " ":
                token_id = SpecialToken.sep
            else:
                token_id = int(self.series.get(token, SpecialToken.unk))
            encoded.append(token_id)

        if add_eos:
            if max_len > 0 and len(encoded) >= max_len:
                encoded = encoded[: max_len - 1]
            encoded.append(SpecialToken.eos)

        if max_len > 0:
            if len(encoded) > max_len:
                encoded = encoded[:max_len]
                raise ("impossible")
            else:
                pad_len = max_len - len(encoded)
                encoded.extend([SpecialToken.pad] * pad_len)

        return encoded

    def decode(self, ids: list[int]) -> str:
        tokens = []
        for t in ids:
            if t == SpecialToken.eos:
                break
            tokens.append(t)
        return "".join([self.id_to_token(i) for i in tokens])

    def __len__(self):
        return self.series.max() + 1