Spaces:
Sleeping
Sleeping
File size: 4,674 Bytes
3912a9f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | from tokenizers.models import WordLevel
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Split
from tokenizers import Regex
from tokenizers.processors import TemplateProcessing
from transformers import BatchEncoding
import torch
class ChembertaTokenizer:
def __init__(self, vocab_file):
self.tokenizer = Tokenizer(
WordLevel.from_file(
vocab_file,
unk_token='[UNK]'
))
self.tokenizer.pre_tokenizer = Split(
pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"),
behavior='isolated'
)
# Disable padding
self.tokenizer.encode_special_tokens = True
self.special_token_ids = {
self.tokenizer.token_to_id('[CLS]'),
self.tokenizer.token_to_id('[SEP]'),
self.tokenizer.token_to_id('[PAD]'),
self.tokenizer.token_to_id('[UNK]')
}
self.tokenizer.post_processor = TemplateProcessing(
single='[CLS] $A [SEP]',
pair='[CLS] $A [SEP] $B:1 [SEP]:1',
special_tokens=[
('[CLS]', self.tokenizer.token_to_id('[CLS]')),
('[SEP]', self.tokenizer.token_to_id('[SEP]'))
]
)
def encode(self, inputs, padding=None, truncation=False,
max_length=None, return_tensors=None):
# Configure padding/truncation
if padding:
self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'),
pad_token='[PAD]', length=max_length)
else:
self.tokenizer.no_padding()
if truncation:
self.tokenizer.enable_truncation(max_length=max_length)
else:
self.tokenizer.no_truncation()
if return_tensors == 'pt':
tensor_type = 'pt'
else:
tensor_type = None
# Handle batch or single input
if isinstance(inputs, list):
enc = self.tokenizer.encode_batch(inputs)
data = {
"input_ids": [e.ids for e in enc],
"attention_mask": [e.attention_mask for e in enc]
}
return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
else:
# Single sequence: wrap into batch of size 1
enc = [self.tokenizer.encode(inputs)]
data = {
"input_ids": [e.ids for e in enc],
"attention_mask": [e.attention_mask for e in enc]
}
return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
def __call__(self, inputs, padding=None, truncation=False,
max_length=None, return_tensors=None):
return self.encode(inputs, padding=padding, truncation=truncation,
max_length=max_length, return_tensors=return_tensors)
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
def _decode_sequence(seq):
if skip_special_tokens:
seq = [idx for idx in seq if idx not in self.special_token_ids]
return [self.tokenizer.id_to_token(idx) for idx in seq]
# 1) batch: list of lists or torch tensor
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if len(ids) == 1:
ids = ids[0]
if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
return [_decode_sequence(seq) for seq in ids]
# 2) single sequence: list of ints or torch tensor
if isinstance(ids, (list)):
return _decode_sequence(ids)
# 3) single int
if isinstance(ids, int):
return self.tokenizer.id_to_token(ids)
def decode(self, ids, skip_special_tokens=False):
def _decode_sequence(seq):
if skip_special_tokens:
seq = [idx for idx in seq if idx not in self.special_token_ids]
return ''.join(self.tokenizer.id_to_token(idx) for idx in seq)
# 1) batch: list of lists or torch tensor
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if len(ids) == 1:
ids = ids[0]
if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
return [_decode_sequence(seq) for seq in ids]
# 2) single sequence: list of ints or torch tensor
if isinstance(ids, (list)):
return _decode_sequence(ids)
# 3) single int
if isinstance(ids, int):
return self.tokenizer.id_to_token(ids)
|