File size: 4,580 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""SentencePiece tokenizer wrapper for HuggingFace compatibility."""

from typing import Optional, List, Union


class SentencePieceTokenizerWrapper:
    """Wrapper to make SentencePiece tokenizer compatible with HuggingFace interface."""
    
    def __init__(self, sp_processor):
        """

        Initialize wrapper.

        

        Args:

            sp_processor: sentencepiece.SentencePieceProcessor instance

        """
        self.sp = sp_processor
        self.vocab_size = self.sp.vocab_size()
        self.pad_token_id = self.sp.pad_id()
        self.eos_token_id = self.sp.eos_id()
        self.bos_token_id = self.sp.bos_id()
        self.unk_token_id = self.sp.unk_id()
    
    def __call__(self, text, **kwargs):
        """

        Tokenize text.

        

        Args:

            text: Input text or list of texts

            **kwargs: Additional arguments (truncation, max_length, padding, return_attention_mask)

        

        Returns:

            Dict with input_ids and attention_mask

        """
        # Handle both single string and list of strings
        is_single = isinstance(text, str)
        texts = [text] if is_single else text
        
        max_length = kwargs.get('max_length', None)
        padding = kwargs.get('padding', None)
        truncation = kwargs.get('truncation', False)
        return_attention_mask = kwargs.get('return_attention_mask', True)
        
        # Tokenize all texts
        all_input_ids = []
        for t in texts:
            tokens = self.sp.encode(t, out_type=int)
            
            # Truncate if needed
            if truncation and max_length and len(tokens) > max_length:
                tokens = tokens[:max_length]
            
            all_input_ids.append(tokens)
        
        # Padding
        if padding or max_length:
            target_length = max_length or max(len(ids) for ids in all_input_ids) if all_input_ids else 0
            padded_input_ids = []
            padded_attention_masks = []
            
            for ids in all_input_ids:
                pad_length = target_length - len(ids)
                if pad_length > 0:
                    padded_ids = ids + [self.pad_token_id] * pad_length
                else:
                    padded_ids = ids[:target_length]
                
                padded_input_ids.append(padded_ids)
                attention_mask = [1] * len(ids) + [0] * (target_length - len(ids))
                padded_attention_masks.append(attention_mask)
            
            result = {
                "input_ids": padded_input_ids if not is_single else padded_input_ids[0],
            }
            if return_attention_mask:
                result["attention_mask"] = padded_attention_masks if not is_single else padded_attention_masks[0]
        else:
            result = {
                "input_ids": all_input_ids[0] if is_single else all_input_ids,
            }
            if return_attention_mask:
                attention_masks = [[1] * len(ids) for ids in all_input_ids]
                result["attention_mask"] = attention_masks[0] if is_single else attention_masks
        
        return result
    
    def encode(self, text, return_tensors=None, **kwargs):
        """Encode text to token IDs."""
        result = self(text, **kwargs)
        input_ids = result["input_ids"]
        
        if return_tensors == "pt":
            import torch
            # Ensure input_ids is a 1D list of ints
            if isinstance(input_ids[0], list):
                input_ids = input_ids[0]
            return torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
        
        return input_ids
    
    def encode_plus(self, text, **kwargs):
        """Encode text with additional information (HuggingFace compatibility)."""
        return self(text, **kwargs)
    
    def decode(self, token_ids, skip_special_tokens=False, **kwargs):
        """Decode token IDs to text."""
        if hasattr(token_ids, 'tolist'):  # Handle torch tensors
            token_ids = token_ids.tolist()
        
        # Handle various input formats
        if isinstance(token_ids, (list, tuple)):
            if len(token_ids) > 0 and isinstance(token_ids[0], (list, tuple)):
                token_ids = token_ids[0]
        
        # Ensure it's a list of ints
        if not isinstance(token_ids, list):
            token_ids = [int(t) for t in token_ids]
        
        return self.sp.decode(token_ids)