File size: 3,584 Bytes
c047c1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Byte-level tokenizer for MicroMixer-1 language model."""


class ByteTokenizer:
    """Byte-level tokenizer with 256 vocabulary size.

    Reserves byte values 0, 1, 2 for special tokens:
    - 0: pad_token
    - 1: bos_token (beginning of sequence)
    - 2: eos_token (end of sequence)

    All other byte values (3-255) represent raw UTF-8 bytes.
    """

    def __init__(self):
        self.vocab_size = 256
        # Special tokens (using byte values that are rare in text)
        self.pad_token_id = 0
        self.bos_token_id = 1  # Beginning of sequence
        self.eos_token_id = 2  # End of sequence

    def encode(self, text: str) -> list[int]:
        """Encode string to list of byte IDs.

        Args:
            text: Input string to encode

        Returns:
            List of byte IDs with BOS token at start and EOS token at end
        """
        if not text:
            # Empty string: just return BOS + EOS
            return [self.bos_token_id, self.eos_token_id]

        # Convert text to UTF-8 bytes
        text_bytes = text.encode("utf-8")

        # Map each byte to its integer value (0-255)
        byte_ids = [b for b in text_bytes]

        # Add BOS at start, EOS at end
        return [self.bos_token_id] + byte_ids + [self.eos_token_id]

    def decode(self, ids: list[int]) -> str:
        """Decode list of byte IDs back to string.

        Args:
            ids: List of byte IDs

        Returns:
            Decoded string
        """
        if not ids:
            return ""

        # Filter out special tokens (pad, bos, eos)
        byte_values = [
            b for b in ids
            if b not in (self.pad_token_id, self.bos_token_id, self.eos_token_id)
        ]

        if not byte_values:
            return ""

        # Convert byte values back to bytes object
        byte_data = bytes(byte_values)

        # Decode UTF-8 bytes to string, handle errors gracefully
        return byte_data.decode("utf-8", errors="replace")

    def encode_batch(
        self, texts: list[str], max_length: int = None, padding: bool = True
    ) -> dict:
        """Encode a batch of texts with optional padding.

        Args:
            texts: List of strings to encode
            max_length: Maximum sequence length (truncation if specified)
            padding: Whether to pad sequences to longest in batch

        Returns:
            Dict with 'input_ids' (list of lists) and 'attention_mask' (list of lists)
        """
        # Encode each text
        encoded = [self.encode(text) for text in texts]

        # Get sequence lengths before padding/truncation
        lengths = [len(ids) for ids in encoded]

        # Truncate if max_length specified
        if max_length is not None:
            encoded = [ids[:max_length] for ids in encoded]

        # Pad to longest sequence if padding=True
        if padding and encoded:
            max_seq_len = max(len(ids) for ids in encoded)
            pad_token_id = self.pad_token_id

            padded = []
            attention_masks = []
            for ids in encoded:
                pad_len = max_seq_len - len(ids)
                padded.append(ids + [pad_token_id] * pad_len)
                attention_masks.append([1] * len(ids) + [0] * pad_len)

            return {
                "input_ids": padded,
                "attention_mask": attention_masks,
            }

        # No padding
        attention_masks = [[1] * len(ids) for ids in encoded]
        return {
            "input_ids": encoded,
            "attention_mask": attention_masks,
        }