File size: 4,674 Bytes
3912a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from tokenizers.models import WordLevel
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Split
from tokenizers import Regex
from tokenizers.processors import TemplateProcessing
from transformers import BatchEncoding
import torch

class ChembertaTokenizer:
    def __init__(self, vocab_file):
        self.tokenizer = Tokenizer(
            WordLevel.from_file(
                vocab_file, 
                unk_token='[UNK]'
        ))
        self.tokenizer.pre_tokenizer = Split(
            pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"),
            behavior='isolated'
        )
        # Disable padding
        
        self.tokenizer.encode_special_tokens = True
        self.special_token_ids = {
            self.tokenizer.token_to_id('[CLS]'),
            self.tokenizer.token_to_id('[SEP]'),
            self.tokenizer.token_to_id('[PAD]'),
            self.tokenizer.token_to_id('[UNK]')  
        }

        self.tokenizer.post_processor = TemplateProcessing(
            single='[CLS] $A [SEP]',
            pair='[CLS] $A [SEP] $B:1 [SEP]:1',
            special_tokens=[
                ('[CLS]', self.tokenizer.token_to_id('[CLS]')),
                ('[SEP]', self.tokenizer.token_to_id('[SEP]'))
            ]
        )

    def encode(self, inputs, padding=None, truncation=False,
                 max_length=None, return_tensors=None):
        # Configure padding/truncation
        if padding:
            self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'),
                                          pad_token='[PAD]', length=max_length)
        else:
            self.tokenizer.no_padding()

        if truncation:
            self.tokenizer.enable_truncation(max_length=max_length)
        else:
            self.tokenizer.no_truncation()
        if return_tensors == 'pt':
            tensor_type = 'pt'
        else:
            tensor_type = None
        # Handle batch or single input
        if isinstance(inputs, list):
            enc = self.tokenizer.encode_batch(inputs)
            data = {
                "input_ids": [e.ids for e in enc],
                "attention_mask": [e.attention_mask for e in enc]
            }
            return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)

        else:
            # Single sequence: wrap into batch of size 1
            enc = [self.tokenizer.encode(inputs)]
            data = {
                "input_ids": [e.ids for e in enc],
                "attention_mask": [e.attention_mask for e in enc]
            }
            return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)

    def __call__(self, inputs, padding=None, truncation=False,
                 max_length=None, return_tensors=None):
        return self.encode(inputs, padding=padding, truncation=truncation,
                           max_length=max_length, return_tensors=return_tensors)
        
    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
        def _decode_sequence(seq):
            if skip_special_tokens:
                seq = [idx for idx in seq if idx not in self.special_token_ids]
            return [self.tokenizer.id_to_token(idx) for idx in seq]

        # 1) batch: list of lists or torch tensor
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
            if len(ids) == 1:
                ids = ids[0]
            
        if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
            return [_decode_sequence(seq) for seq in ids]

        # 2) single sequence: list of ints or torch tensor
        if isinstance(ids, (list)):
            return _decode_sequence(ids)
        
        # 3) single int
        if isinstance(ids, int):
            return self.tokenizer.id_to_token(ids)

    def decode(self, ids, skip_special_tokens=False):
        def _decode_sequence(seq):
            if skip_special_tokens:
                seq = [idx for idx in seq if idx not in self.special_token_ids]
            return ''.join(self.tokenizer.id_to_token(idx) for idx in seq)

        # 1) batch: list of lists or torch tensor
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
            if len(ids) == 1:
                ids = ids[0]
            
        if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
            return [_decode_sequence(seq) for seq in ids]

        # 2) single sequence: list of ints or torch tensor
        if isinstance(ids, (list)):
            return _decode_sequence(ids)
        
        # 3) single int
        if isinstance(ids, int):
            return self.tokenizer.id_to_token(ids)