"""Byte-level tokenizer for MicroMixer-1 language model.""" class ByteTokenizer: """Byte-level tokenizer with 256 vocabulary size. Reserves byte values 0, 1, 2 for special tokens: - 0: pad_token - 1: bos_token (beginning of sequence) - 2: eos_token (end of sequence) All other byte values (3-255) represent raw UTF-8 bytes. """ def __init__(self): self.vocab_size = 256 # Special tokens (using byte values that are rare in text) self.pad_token_id = 0 self.bos_token_id = 1 # Beginning of sequence self.eos_token_id = 2 # End of sequence def encode(self, text: str) -> list[int]: """Encode string to list of byte IDs. Args: text: Input string to encode Returns: List of byte IDs with BOS token at start and EOS token at end """ if not text: # Empty string: just return BOS + EOS return [self.bos_token_id, self.eos_token_id] # Convert text to UTF-8 bytes text_bytes = text.encode("utf-8") # Map each byte to its integer value (0-255) byte_ids = [b for b in text_bytes] # Add BOS at start, EOS at end return [self.bos_token_id] + byte_ids + [self.eos_token_id] def decode(self, ids: list[int]) -> str: """Decode list of byte IDs back to string. Args: ids: List of byte IDs Returns: Decoded string """ if not ids: return "" # Filter out special tokens (pad, bos, eos) byte_values = [ b for b in ids if b not in (self.pad_token_id, self.bos_token_id, self.eos_token_id) ] if not byte_values: return "" # Convert byte values back to bytes object byte_data = bytes(byte_values) # Decode UTF-8 bytes to string, handle errors gracefully return byte_data.decode("utf-8", errors="replace") def encode_batch( self, texts: list[str], max_length: int = None, padding: bool = True ) -> dict: """Encode a batch of texts with optional padding. Args: texts: List of strings to encode max_length: Maximum sequence length (truncation if specified) padding: Whether to pad sequences to longest in batch Returns: Dict with 'input_ids' (list of lists) and 'attention_mask' (list of lists) """ # Encode each text encoded = [self.encode(text) for text in texts] # Get sequence lengths before padding/truncation lengths = [len(ids) for ids in encoded] # Truncate if max_length specified if max_length is not None: encoded = [ids[:max_length] for ids in encoded] # Pad to longest sequence if padding=True if padding and encoded: max_seq_len = max(len(ids) for ids in encoded) pad_token_id = self.pad_token_id padded = [] attention_masks = [] for ids in encoded: pad_len = max_seq_len - len(ids) padded.append(ids + [pad_token_id] * pad_len) attention_masks.append([1] * len(ids) + [0] * pad_len) return { "input_ids": padded, "attention_mask": attention_masks, } # No padding attention_masks = [[1] * len(ids) for ids in encoded] return { "input_ids": encoded, "attention_mask": attention_masks, }