vanh99 commited on
Commit
f565a64
·
verified ·
1 Parent(s): 278f520

Create tokenizer_base.py

Browse files
Files changed (1) hide show
  1. tokenizer_base.py +146 -0
tokenizer_base.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from itertools import groupby
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ class CharsetAdapter:
11
+ """Transforms labels according to the target charset."""
12
+
13
+ def __init__(self, target_charset) -> None:
14
+ super().__init__()
15
+ self.charset = target_charset
16
+ self.lowercase_only = target_charset == target_charset.lower()
17
+ self.uppercase_only = target_charset == target_charset.upper()
18
+
19
+ def __call__(self, label):
20
+ if self.lowercase_only:
21
+ label = label.lower()
22
+ elif self.uppercase_only:
23
+ label = label.upper()
24
+ return label
25
+
26
+
27
+ class BaseTokenizer(ABC):
28
+
29
+ def __init__(
30
+ self, charset: str, specials_first: tuple = (), specials_last: tuple = ()
31
+ ) -> None:
32
+ self._itos = specials_first + tuple(charset + "[UNK]") + specials_last
33
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
34
+
35
+ def __len__(self):
36
+ return len(self._itos)
37
+
38
+ def _tok2ids(self, tokens: str) -> List[int]:
39
+ return [self._stoi[s] for s in tokens]
40
+
41
+ def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
42
+ tokens = [self._itos[i] for i in token_ids]
43
+ return "".join(tokens) if join else tokens
44
+
45
+ @abstractmethod
46
+ def encode(
47
+ self, labels: List[str], device: Optional[torch.device] = None
48
+ ) -> Tensor:
49
+ """Encode a batch of labels to a representation suitable for the model.
50
+
51
+ Args:
52
+ labels: List of labels. Each can be of arbitrary length.
53
+ device: Create tensor on this device.
54
+
55
+ Returns:
56
+ Batched tensor representation padded to the max label length. Shape: N, L
57
+ """
58
+ raise NotImplementedError
59
+
60
+ @abstractmethod
61
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
62
+ """Internal method which performs the necessary filtering prior to decoding."""
63
+ raise NotImplementedError
64
+
65
+ def decode(
66
+ self, token_dists: Tensor, raw: bool = False
67
+ ) -> Tuple[List[str], List[Tensor]]:
68
+ """Decode a batch of token distributions.
69
+
70
+ Args:
71
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
72
+ raw: return unprocessed labels (will return list of list of strings)
73
+
74
+ Returns:
75
+ list of string labels (arbitrary length) and
76
+ their corresponding sequence probabilities as a list of Tensors
77
+ """
78
+ batch_tokens = []
79
+ batch_probs = []
80
+ for dist in token_dists:
81
+ probs, ids = dist.max(-1)
82
+ if not raw:
83
+ probs, ids = self._filter(probs, ids)
84
+ tokens = self._ids2tok(ids, not raw)
85
+ batch_tokens.append(tokens)
86
+ batch_probs.append(probs)
87
+ return batch_tokens, batch_probs
88
+
89
+
90
+ class Tokenizer(BaseTokenizer):
91
+ BOS = "[B]"
92
+ EOS = "[E]"
93
+ PAD = "[P]"
94
+
95
+ def __init__(self, charset: str) -> None:
96
+ specials_first = (self.EOS,)
97
+ specials_last = (self.BOS, self.PAD)
98
+ super().__init__(charset, specials_first, specials_last)
99
+ self.eos_id, self.bos_id, self.pad_id = [
100
+ self._stoi[s] for s in specials_first + specials_last
101
+ ]
102
+
103
+ def encode(
104
+ self, labels: List[str], device: Optional[torch.device] = None
105
+ ) -> Tensor:
106
+ batch = [
107
+ torch.as_tensor(
108
+ [self.bos_id] + self._tok2ids(y) + [self.eos_id],
109
+ dtype=torch.long,
110
+ device=device,
111
+ )
112
+ for y in labels
113
+ ]
114
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
115
+
116
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
117
+ ids = ids.tolist()
118
+ try:
119
+ eos_idx = ids.index(self.eos_id)
120
+ except ValueError:
121
+ eos_idx = len(ids)
122
+ ids = ids[:eos_idx]
123
+ probs = probs[: eos_idx + 1]
124
+ return probs, ids
125
+
126
+
127
+ class CTCTokenizer(BaseTokenizer):
128
+ BLANK = "[B]"
129
+
130
+ def __init__(self, charset: str) -> None:
131
+ super().__init__(charset, specials_first=(self.BLANK,))
132
+ self.blank_id = self._stoi[self.BLANK]
133
+
134
+ def encode(
135
+ self, labels: List[str], device: Optional[torch.device] = None
136
+ ) -> Tensor:
137
+ batch = [
138
+ torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device)
139
+ for y in labels
140
+ ]
141
+ return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
142
+
143
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
144
+ ids = list(zip(*groupby(ids.tolist())))[0]
145
+ ids = [x for x in ids if x != self.blank_id]
146
+ return probs, ids