| |
|
| | import html
|
| | import string
|
| |
|
| | import ftfy
|
| | import regex as re
|
| | from transformers import AutoTokenizer
|
| |
|
| | __all__ = ['HuggingfaceTokenizer']
|
| |
|
| |
|
| | def basic_clean(text):
|
| | text = ftfy.fix_text(text)
|
| | text = html.unescape(html.unescape(text))
|
| | return text.strip()
|
| |
|
| |
|
| | def whitespace_clean(text):
|
| | text = re.sub(r'\s+', ' ', text)
|
| | text = text.strip()
|
| | return text
|
| |
|
| |
|
| | def canonicalize(text, keep_punctuation_exact_string=None):
|
| | text = text.replace('_', ' ')
|
| | if keep_punctuation_exact_string:
|
| | text = keep_punctuation_exact_string.join(
|
| | part.translate(str.maketrans('', '', string.punctuation))
|
| | for part in text.split(keep_punctuation_exact_string))
|
| | else:
|
| | text = text.translate(str.maketrans('', '', string.punctuation))
|
| | text = text.lower()
|
| | text = re.sub(r'\s+', ' ', text)
|
| | return text.strip()
|
| |
|
| |
|
| | class HuggingfaceTokenizer:
|
| |
|
| | def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| | assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
| | self.name = name
|
| | self.seq_len = seq_len
|
| | self.clean = clean
|
| |
|
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| | self.vocab_size = self.tokenizer.vocab_size
|
| |
|
| | def __call__(self, sequence, **kwargs):
|
| | return_mask = kwargs.pop('return_mask', False)
|
| |
|
| |
|
| | _kwargs = {'return_tensors': 'pt'}
|
| | if self.seq_len is not None:
|
| | _kwargs.update({
|
| | 'padding': 'max_length',
|
| | 'truncation': True,
|
| | 'max_length': self.seq_len
|
| | })
|
| | _kwargs.update(**kwargs)
|
| |
|
| |
|
| | if isinstance(sequence, str):
|
| | sequence = [sequence]
|
| | if self.clean:
|
| | sequence = [self._clean(u) for u in sequence]
|
| | ids = self.tokenizer(sequence, **_kwargs)
|
| |
|
| |
|
| | if return_mask:
|
| | return ids.input_ids, ids.attention_mask
|
| | else:
|
| | return ids.input_ids
|
| |
|
| | def _clean(self, text):
|
| | if self.clean == 'whitespace':
|
| | text = whitespace_clean(basic_clean(text))
|
| | elif self.clean == 'lower':
|
| | text = whitespace_clean(basic_clean(text)).lower()
|
| | elif self.clean == 'canonicalize':
|
| | text = canonicalize(basic_clean(text))
|
| | return text
|
| |
|