File size: 5,631 Bytes
acd771b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | # Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
import html
import os
import string
from typing import List, Optional, Union
import warnings
try:
import ftfy
except ImportError:
ftfy = None
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_CONTEXT_LENGTH = 77
def basic_clean(text):
if ftfy is not None:
text = ftfy.fix_text(text)
else:
text
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = " ".join(text.split())
text = text.strip()
return text
def _clean_canonicalize(x):
return canonicalize_text(basic_clean(x))
def _clean_lower(x):
return whitespace_clean(basic_clean(x)).lower()
def _clean_whitespace(x):
return whitespace_clean(basic_clean(x))
def get_clean_fn(type: str):
if type == "canonicalize":
return _clean_canonicalize
elif type == "lower":
return _clean_lower
elif type == "whitespace":
return _clean_whitespace
else:
assert False, f"Invalid clean function ({type})."
def canonicalize_text(
text,
*,
keep_punctuation_exact_string=None,
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
):
"""Returns canonicalized `text` (lowercase and punctuation removed)."""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(trans_punctuation)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(trans_punctuation)
text = text.lower()
text = " ".join(text.split())
return text.strip()
class HFTokenizer:
"""HuggingFace tokenizer wrapper with support for custom tokenization modes"""
def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = "whitespace",
strip_sep_token: bool = False,
language: Optional[str] = None,
cache_dir: Optional[str] = None,
tokenizer_mode: Optional[str] = None,
**kwargs,
):
self.tokenizer_mode = tokenizer_mode or ""
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, cache_dir=cache_dir, **kwargs
)
set_lang_fn = getattr(self.tokenizer, "set_src_lang_special_tokens", None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(
self, texts: Union[str, List[str]], context_length: Optional[int] = None
) -> torch.Tensor:
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, (
"Please set a valid context length in class init or call."
)
texts = [self.clean_fn(text) for text in texts]
if self.tokenizer_mode == "clips":
return self._clips_tokenize(texts, context_length)
else:
output = self.tokenizer(
texts,
return_tensors="pt",
max_length=context_length,
padding="max_length",
truncation=True,
)
input_ids = output.input_ids
if self.strip_sep_token:
input_ids = torch.where(
input_ids == self.tokenizer.sep_token_id,
torch.zeros_like(input_ids),
input_ids,
)
return input_ids
def set_language(self, src_lang):
if hasattr(self, "set_lang_fn"):
self.set_lang_fn(src_lang)
else:
warnings.warn("Cannot set language for the tokenizer.")
def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
encoded_outputs = self.tokenizer(
texts,
add_special_tokens=False,
padding=False,
truncation=False,
return_tensors=None,
)
encoded = []
for tokens in encoded_outputs["input_ids"]:
tokens = tokens[: context_length - 3]
tokens = (
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
)
encoded.append(tokens)
result = torch.zeros(len(encoded), context_length, dtype=torch.long)
for i, tokens in enumerate(encoded):
padded_tokens = self._pad_and_add_class_token(
tokens,
max_length=context_length,
pad_token_id=self.tokenizer.pad_token_id,
cls_token_id=self.tokenizer.cls_token_id,
)
result[i, : len(padded_tokens)] = torch.tensor(padded_tokens)
return result
def _pad_and_add_class_token(
self,
tokens: List[int],
max_length: int,
pad_token_id: int = 0,
cls_token_id: int = 101,
) -> List[int]:
if len(tokens) > max_length - 1:
tokens = tokens[: max_length - 1]
if len(tokens) < max_length - 1:
tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
tokens = tokens + [cls_token_id]
return tokens
|