| """
|
| Byte-Level Tokenizer V6.1.2 - Compression-First Learning
|
| No vocabulary, no language rules - just bytes
|
| """
|
|
|
| import torch
|
| from typing import List, Dict, Union, Optional
|
| import numpy as np
|
|
|
|
|
| class ByteTokenizerV6:
|
| """
|
| Pure byte-level tokenizer
|
| - No vocabulary needed (bytes are 0-255)
|
| - No language-specific rules
|
| - Model learns all patterns from data
|
| """
|
|
|
| def __init__(self, max_seq_len: int = 64):
|
| """Initialize byte tokenizer"""
|
|
|
| self.max_seq_len = max_seq_len
|
|
|
|
|
| self.PAD = 256
|
| self.BOS = 257
|
| self.EOS = 258
|
| self.MASK = 259
|
|
|
|
|
| self.vocab_size = 260
|
|
|
| print(f"Byte tokenizer initialized (vocab_size={self.vocab_size})")
|
|
|
| def encode(self, text: str, add_special_tokens: bool = True) -> Dict:
|
| """
|
| Encode text to byte IDs
|
|
|
| Args:
|
| text: Input text
|
| add_special_tokens: Whether to add BOS/EOS
|
|
|
| Returns:
|
| dict with 'input_ids', 'attention_mask', 'length'
|
| """
|
|
|
| byte_sequence = list(text.encode('utf-8'))
|
|
|
|
|
| max_len = self.max_seq_len - 2 if add_special_tokens else self.max_seq_len
|
| if len(byte_sequence) > max_len:
|
| byte_sequence = byte_sequence[:max_len]
|
|
|
|
|
| if add_special_tokens:
|
| input_ids = [self.BOS] + byte_sequence + [self.EOS]
|
| else:
|
| input_ids = byte_sequence
|
|
|
|
|
| attention_mask = [1] * len(input_ids)
|
|
|
| return {
|
| 'input_ids': input_ids,
|
| 'attention_mask': attention_mask,
|
| 'length': len(input_ids)
|
| }
|
|
|
| def encode_batch(self, texts: List[str], add_special_tokens: bool = True) -> Dict:
|
| """
|
| Encode multiple texts with padding
|
|
|
| Args:
|
| texts: List of input texts
|
| add_special_tokens: Whether to add special tokens
|
|
|
| Returns:
|
| Batched tensors with padding
|
| """
|
| encoded_texts = []
|
| max_length = 0
|
|
|
|
|
| for text in texts:
|
| encoded = self.encode(text, add_special_tokens)
|
| encoded_texts.append(encoded)
|
| max_length = max(max_length, encoded['length'])
|
|
|
|
|
| max_length = min(max_length, self.max_seq_len)
|
|
|
|
|
| batch_size = len(texts)
|
| input_ids = np.full((batch_size, max_length), self.PAD, dtype=np.int64)
|
| attention_mask = np.zeros((batch_size, max_length), dtype=np.float32)
|
|
|
|
|
| for i, encoded in enumerate(encoded_texts):
|
| seq_len = min(encoded['length'], max_length)
|
| input_ids[i, :seq_len] = encoded['input_ids'][:seq_len]
|
| attention_mask[i, :seq_len] = 1.0
|
|
|
| return {
|
| 'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
| 'attention_mask': torch.tensor(attention_mask, dtype=torch.float32),
|
| 'lengths': torch.tensor([e['length'] for e in encoded_texts], dtype=torch.long)
|
| }
|
|
|
| def decode(self, input_ids: Union[List[int], torch.Tensor, np.ndarray],
|
| skip_special_tokens: bool = True) -> str:
|
| """
|
| Decode byte IDs back to text
|
|
|
| Args:
|
| input_ids: Byte ID sequence
|
| skip_special_tokens: Whether to skip special tokens
|
|
|
| Returns:
|
| Decoded text string
|
| """
|
|
|
| if isinstance(input_ids, torch.Tensor):
|
| input_ids = input_ids.cpu().numpy().tolist()
|
| elif isinstance(input_ids, np.ndarray):
|
| input_ids = input_ids.tolist()
|
|
|
|
|
| if skip_special_tokens:
|
|
|
| input_ids = [b for b in input_ids if 0 <= b <= 255]
|
| else:
|
|
|
| processed = []
|
| for b in input_ids:
|
| if b == self.PAD:
|
| continue
|
| elif b == self.BOS:
|
| processed.append(ord('['))
|
| elif b == self.EOS:
|
| processed.append(ord(']'))
|
| elif b == self.MASK:
|
| processed.append(ord('*'))
|
| elif 0 <= b <= 255:
|
| processed.append(b)
|
| input_ids = processed
|
|
|
|
|
| if not input_ids:
|
| return ""
|
|
|
| try:
|
|
|
| valid_bytes = []
|
| i = 0
|
| while i < len(input_ids):
|
| b = input_ids[i]
|
| if b < 128:
|
| valid_bytes.append(b)
|
| i += 1
|
| elif 192 <= b < 224:
|
| if i + 1 < len(input_ids) and 128 <= input_ids[i+1] < 192:
|
| valid_bytes.extend(input_ids[i:i+2])
|
| i += 2
|
| else:
|
| i += 1
|
| elif 224 <= b < 240:
|
| if i + 2 < len(input_ids) and all(128 <= input_ids[j] < 192 for j in range(i+1, min(i+3, len(input_ids)))):
|
| valid_bytes.extend(input_ids[i:i+3])
|
| i += 3
|
| else:
|
| i += 1
|
| elif 240 <= b < 248:
|
| if i + 3 < len(input_ids) and all(128 <= input_ids[j] < 192 for j in range(i+1, min(i+4, len(input_ids)))):
|
| valid_bytes.extend(input_ids[i:i+4])
|
| i += 4
|
| else:
|
| i += 1
|
| else:
|
| i += 1
|
|
|
|
|
| if valid_bytes:
|
| byte_array = bytes(valid_bytes)
|
| text = byte_array.decode('utf-8', errors='replace')
|
| return text
|
| else:
|
| return ""
|
| except Exception as e:
|
|
|
| return "".join([chr(b) if b < 128 else '' for b in input_ids])
|
|
|
| def decode_batch(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
|
| """
|
| Decode a batch of byte sequences
|
|
|
| Args:
|
| input_ids: Batch of byte IDs (batch_size, seq_len)
|
| skip_special_tokens: Whether to skip special tokens
|
|
|
| Returns:
|
| List of decoded texts
|
| """
|
| texts = []
|
| for i in range(input_ids.shape[0]):
|
| text = self.decode(input_ids[i], skip_special_tokens)
|
| texts.append(text)
|
| return texts
|
|
|
| def tokenize(self, text: str) -> List[int]:
|
| """
|
| Simple tokenization to byte IDs (no special tokens)
|
|
|
| Args:
|
| text: Input text
|
|
|
| Returns:
|
| List of byte IDs
|
| """
|
| return list(text.encode('utf-8'))
|
|
|
| def detokenize(self, byte_ids: List[int]) -> str:
|
| """
|
| Simple detokenization from byte IDs
|
|
|
| Args:
|
| byte_ids: List of byte IDs
|
|
|
| Returns:
|
| Decoded text
|
| """
|
| try:
|
| return bytes(byte_ids).decode('utf-8', errors='replace')
|
| except:
|
| return "".join([chr(b) if b < 128 else '?' for b in byte_ids])
|
|
|
| def get_vocab_size(self) -> int:
|
| """Get vocabulary size"""
|
| return self.vocab_size
|
|
|
| def get_special_tokens(self) -> Dict[str, int]:
|
| """Get special token IDs"""
|
| return {
|
| 'pad_id': self.PAD,
|
| 'bos_id': self.BOS,
|
| 'eos_id': self.EOS,
|
| 'mask_id': self.MASK
|
| }
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| tokenizer = ByteTokenizerV6()
|
|
|
|
|
| test_texts = [
|
| "Hello World!",
|
| "안녕하세요",
|
| "你好世界",
|
| "こんにちは",
|
| "مرحبا بالعالم",
|
| "Здравствуй мир"
|
| ]
|
|
|
| print("=" * 50)
|
| print("Single Text Encoding/Decoding Test")
|
| print("=" * 50)
|
|
|
| for text in test_texts:
|
| print(f"\nOriginal: {text}")
|
|
|
|
|
| encoded = tokenizer.encode(text)
|
| print(f"Encoded length: {encoded['length']}")
|
| print(f"First 10 bytes: {encoded['input_ids'][:10]}")
|
|
|
|
|
| decoded = tokenizer.decode(encoded['input_ids'])
|
| print(f"Decoded: {decoded}")
|
| print(f"Match: {decoded == text}")
|
|
|
| print("\n" + "=" * 50)
|
| print("Batch Encoding/Decoding Test")
|
| print("=" * 50)
|
|
|
|
|
| batch_result = tokenizer.encode_batch(test_texts)
|
| print(f"Batch shape: {batch_result['input_ids'].shape}")
|
| print(f"Attention mask shape: {batch_result['attention_mask'].shape}")
|
|
|
|
|
| decoded_texts = tokenizer.decode_batch(batch_result['input_ids'])
|
| print("\nBatch decoding results:")
|
| for orig, dec in zip(test_texts, decoded_texts):
|
| print(f"Original: {orig}")
|
| print(f"Decoded: {dec}")
|
| print(f"Match: {orig == dec}")
|
| print() |