Continual-Mega commited on
Commit
f41cee4
verified
1 Parent(s): 537f272

Upload CLIP/tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/tokenizer.py +186 -0
CLIP/tokenizer.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+
14
+ import gzip
15
+ import html
16
+ from functools import lru_cache
17
+
18
+ import ftfy
19
+ import regex as re
20
+
21
+
22
+ @lru_cache()
23
+ def default_bpe():
24
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
25
+
26
+
27
+ @lru_cache()
28
+ def bytes_to_unicode():
29
+ """
30
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
31
+ The reversible bpe codes work on unicode strings.
32
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
33
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
34
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
35
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
36
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
37
+ """
38
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("隆"), ord("卢")+1))+list(range(ord("庐"), ord("每")+1))
39
+ cs = bs[:]
40
+ n = 0
41
+ for b in range(2**8):
42
+ if b not in bs:
43
+ bs.append(b)
44
+ cs.append(2**8+n)
45
+ n += 1
46
+ cs = [chr(n) for n in cs]
47
+ return dict(zip(bs, cs))
48
+
49
+
50
+ def get_pairs(word):
51
+ """Return set of symbol pairs in a word.
52
+ Word is represented as tuple of symbols (symbols being variable-length strings).
53
+ """
54
+ pairs = set()
55
+ prev_char = word[0]
56
+ for char in word[1:]:
57
+ pairs.add((prev_char, char))
58
+ prev_char = char
59
+ return pairs
60
+
61
+
62
+ def basic_clean(text):
63
+ text = ftfy.fix_text(text)
64
+ text = html.unescape(html.unescape(text))
65
+ return text.strip()
66
+
67
+
68
+ def whitespace_clean(text):
69
+ text = re.sub(r'\s+', ' ', text)
70
+ text = text.strip()
71
+ return text
72
+
73
+
74
+ class SimpleTokenizer(object):
75
+ def __init__(self, bpe_path: str = default_bpe()):
76
+ self.byte_encoder = bytes_to_unicode()
77
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
79
+ merges = merges[1:49152-256-2+1]
80
+ merges = [tuple(merge.split()) for merge in merges]
81
+ vocab = list(bytes_to_unicode().values())
82
+ vocab = vocab + [v+'</w>' for v in vocab]
83
+ for merge in merges:
84
+ vocab.append(''.join(merge))
85
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
86
+ self.encoder = dict(zip(vocab, range(len(vocab))))
87
+ self.decoder = {v: k for k, v in self.encoder.items()}
88
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
89
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
90
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
91
+
92
+ def bpe(self, token):
93
+ if token in self.cache:
94
+ return self.cache[token]
95
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
96
+ pairs = get_pairs(word)
97
+
98
+ if not pairs:
99
+ return token+'</w>'
100
+
101
+ while True:
102
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
103
+ if bigram not in self.bpe_ranks:
104
+ break
105
+ first, second = bigram
106
+ new_word = []
107
+ i = 0
108
+ while i < len(word):
109
+ try:
110
+ j = word.index(first, i)
111
+ new_word.extend(word[i:j])
112
+ i = j
113
+ except:
114
+ new_word.extend(word[i:])
115
+ break
116
+
117
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
118
+ new_word.append(first+second)
119
+ i += 2
120
+ else:
121
+ new_word.append(word[i])
122
+ i += 1
123
+ new_word = tuple(new_word)
124
+ word = new_word
125
+ if len(word) == 1:
126
+ break
127
+ else:
128
+ pairs = get_pairs(word)
129
+ word = ' '.join(word)
130
+ self.cache[token] = word
131
+ return word
132
+
133
+ def encode(self, text):
134
+ bpe_tokens = []
135
+ text = whitespace_clean(basic_clean(text)).lower()
136
+ for token in re.findall(self.pat, text):
137
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
138
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
139
+ return bpe_tokens
140
+
141
+ def decode(self, tokens):
142
+ text = ''.join([self.decoder[token] for token in tokens])
143
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
144
+ return text
145
+
146
+
147
+
148
+ _tokenizer = SimpleTokenizer()
149
+
150
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
151
+ """
152
+ Returns the tokenized representation of given input string(s)
153
+ Parameters
154
+ ----------
155
+ texts : Union[str, List[str]]
156
+ An input string or a list of input strings to tokenize
157
+ context_length : int
158
+ The context length to use; all CLIP models use 77 as the context length
159
+ truncate: bool
160
+ Whether to truncate the text in case its encoding is longer than the context length
161
+ Returns
162
+ -------
163
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
164
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
165
+ """
166
+ if isinstance(texts, str):
167
+ texts = [texts]
168
+
169
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
170
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
171
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
172
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
173
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
174
+ else:
175
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
176
+
177
+ for i, tokens in enumerate(all_tokens):
178
+ if len(tokens) > context_length:
179
+ if truncate:
180
+ tokens = tokens[:context_length]
181
+ tokens[-1] = eot_token
182
+ else:
183
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
184
+ result[i, :len(tokens)] = torch.tensor(tokens)
185
+
186
+ return result