File size: 10,064 Bytes
53d82e6
 
 
8e8be81
53d82e6
8e8be81
53d82e6
 
8e8be81
 
 
53d82e6
 
 
8e8be81
53d82e6
0f72521
53d82e6
0f72521
53d82e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""

Tokenizer wrapper to handle different tokenizer types with a consistent interface

"""
import os
import sys
import logging
from typing import Dict, List, Optional, Union, Any
import torch

logger = logging.getLogger(__name__)

# Set memory optimization flag
os.environ["LOW_MEMORY_MODE"] = "1"

class TokenizerWrapper:
    """A wrapper for tokenizers with common functionality for GPT-2 and BERT models"""
    
    def __init__(self, model_name="gpt2", use_fast=True, *args, **kwargs):
        self.model_name = model_name
        self.use_fast = use_fast
        self.tokenizer = None
        self._initialize_tokenizer()
        
        # Special token defaults
        self.eos_token = "</s>"  # Fixed: This was the unterminated string
        self.pad_token = "[PAD]"
        self.unk_token = "[UNK]"
        self.mask_token = "[MASK]"
        self.bos_token = "<s>"
        
        # Ensure pad_token is always set (critical for GPT-2)
        self._ensure_pad_token()
        
        logger.info(f"Initialized TokenizerWrapper with {model_name}")
    
    def _initialize_tokenizer(self):
        """Initialize the actual tokenizer with graceful fallbacks"""
        try:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                use_fast=self.use_fast
            )
            logger.info(f"Successfully loaded {self.model_name} tokenizer")
        except Exception as e:
            logger.warning(f"Error loading {self.model_name} tokenizer: {e}")
            try:
                # Fallback to GPT-2
                from transformers import AutoTokenizer
                self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
                logger.info("Loaded fallback GPT-2 tokenizer")
            except Exception as e2:
                logger.error(f"Failed to load fallback tokenizer: {e2}")
                # Create minimal placeholder
                self.tokenizer = MinimalTokenizer()
                logger.warning("Using minimal placeholder tokenizer")
    
    def _ensure_pad_token(self):
        """Ensure the pad_token is set (especially important for GPT-2)"""
        if not self.tokenizer:
            return
            
        if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
            # GPT-2 doesn't have a pad_token by default, use eos_token instead
            if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.pad_token = self.tokenizer.pad_token
                logger.info(f"Set pad_token to eos_token: {self.pad_token}")
            else:
                # Last resort
                self.tokenizer.pad_token = "[PAD]"
                self.pad_token = "[PAD]"
                logger.info("Set default pad_token: [PAD]")
    
    @property
    def vocab_size(self) -> int:
        """Get the vocabulary size of the tokenizer"""
        if hasattr(self.tokenizer, 'vocab_size'):
            return self.tokenizer.vocab_size
        elif hasattr(self.tokenizer, 'get_vocab'):
            return len(self.tokenizer.get_vocab())
        return 50257  # Default GPT-2 vocab size
    
    @property
    def pad_token_id(self) -> int:
        """Get pad token ID with fallback"""
        if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None:
            return self.tokenizer.pad_token_id
        elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
            return self.tokenizer.eos_token_id
        return 0  # Last resort fallback
    
    @property
    def eos_token_id(self) -> int:
        """Get EOS token ID with fallback"""
        if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
            return self.tokenizer.eos_token_id
        return 50256  # Default for GPT-2
    
    def __call__(self, text, *args, **kwargs):
        """Delegate to the actual tokenizer"""
        if self.tokenizer is None:
            logger.error("Tokenizer not initialized")
            # Create minimal output compatible with model expectations
            if isinstance(text, str):
                # Single string input
                dummy_ids = torch.ones((1, 10), dtype=torch.long)
                return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
            # Batch input
            batch_size = len(text) if isinstance(text, list) else 1
            dummy_ids = torch.ones((batch_size, 10), dtype=torch.long)
            return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
            
        return self.tokenizer(text, *args, **kwargs)
    
    def encode(self, text, *args, **kwargs):
        """Encode text to token IDs"""
        if self.tokenizer is None:
            logger.error("Tokenizer not initialized")
            if isinstance(text, str):
                return [1] * 10  # Return minimal dummy encoding
            return [[1] * 10 for _ in text]  # Batch of dummy encodings
            
        return self.tokenizer.encode(text, *args, **kwargs)
    
    def decode(self, token_ids, *args, **kwargs):
        """Decode token IDs to text"""
        if self.tokenizer is None:
            logger.error("Tokenizer not initialized")
            return "Error: Tokenizer not initialized"
            
        return self.tokenizer.decode(token_ids, *args, **kwargs)
    
    def batch_decode(self, sequences, *args, **kwargs):
        """Decode multiple sequences"""
        if self.tokenizer is None:
            logger.error("Tokenizer not initialized")
            return ["Error: Tokenizer not initialized"] * len(sequences)
            
        return self.tokenizer.batch_decode(sequences, *args, **kwargs)
    
    def __getattr__(self, name):
        """Delegate to the underlying tokenizer for missing attributes"""
        if self.tokenizer is not None and hasattr(self.tokenizer, name):
            return getattr(self.tokenizer, name)
        raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")


class MinimalTokenizer:
    """Minimal tokenizer implementation for fallback"""
    def __init__(self):
        self.pad_token = "[PAD]"
        self.pad_token_id = 0
        self.eos_token = "</s>"
        self.eos_token_id = 1
        self.bos_token = "<s>"
        self.bos_token_id = 2
        self.unk_token = "[UNK]"
        self.unk_token_id = 3
        self.vocab_size = 50257  # Standard GPT-2 vocab size
        logger.warning("Using minimal placeholder tokenizer with no actual encoding/decoding capability")
    
    def __call__(self, text, return_tensors=None, padding=False, truncation=False, max_length=None, *args, **kwargs):
        """Minimal tokenize implementation"""
        # Simple word-splitting tokenizer
        if isinstance(text, str):
            # Handle single string
            tokens = text.split()[:max_length] if max_length else text.split()
            input_ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
            if padding and max_length:
                pad_length = max(0, max_length - len(input_ids))
                input_ids = input_ids + [self.pad_token_id] * pad_length
        else:
            # Handle list of strings
            results = []
            max_len = 0
            for t in text:
                tokens = t.split()[:max_length] if max_length else t.split()
                ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
                results.append(ids)
                max_len = max(max_len, len(ids))
            
            # Pad if needed
            if padding:
                results = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in results]
            
            input_ids = results
        
        # Convert to tensor if requested
        if return_tensors == "pt":
            import torch
            if isinstance(input_ids[0], list):
                input_ids = torch.tensor(input_ids, dtype=torch.long)
                attention_mask = torch.ones_like(input_ids)
            else:
                input_ids = torch.tensor([input_ids], dtype=torch.long)
                attention_mask = torch.ones_like(input_ids)
            return {"input_ids": input_ids, "attention_mask": attention_mask}
        
        # Return dictionary for compatibility
        return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}
    
    def encode(self, text, add_special_tokens=True, *args, **kwargs):
        """Minimal encode implementation"""
        if isinstance(text, str):
            tokens = text.split()
            return [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
        return [[i % (self.vocab_size - 4) + 4 for i in range(len(t.split()))] for t in text]
    
    def decode(self, token_ids, skip_special_tokens=True, *args, **kwargs):
        """Minimal decode implementation"""
        return " ".join(["token" + str(i) for i in token_ids])
    
    def batch_decode(self, sequences, skip_special_tokens=True, *args, **kwargs):
        """Minimal batch decode implementation"""
        return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences]


def get_tokenizer(model_name="gpt2", use_fast=True):
    """Create a tokenizer for the specified model"""
    # First check registry
    try:
        from service_registry import registry, TOKENIZER
        if registry.has(TOKENIZER):
            logger.info("Retrieved tokenizer from registry")
            return registry.get(TOKENIZER)
    except ImportError:
        pass
    
    # Create a new tokenizer
    return TokenizerWrapper(model_name=model_name, use_fast=use_fast)