File size: 5,425 Bytes
3f42614 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
ByteTokenizer — Native byte-level tokenizer for GLADIUS machine code corpus.
No BPE. No subword. No merges. Raw bytes in, raw bytes out.
256 byte tokens (0x00-0xFF) + 3 specials (PAD=256, BOS=257, EOS=258).
The CPU's native language doesn't need human linguistic scaffolding.
"""
import json
from pathlib import Path
from typing import List, Union
class ByteTokenizer:
"""Byte-level tokenizer: one token per byte value (0-255) + specials."""
PAD_ID = 256
BOS_ID = 257
EOS_ID = 258
VOCAB_SIZE = 259 # 256 bytes + 3 specials
def __init__(self):
self._special_tokens = {
"<PAD>": self.PAD_ID,
"<BOS>": self.BOS_ID,
"<EOS>": self.EOS_ID,
}
self._id_to_special = {v: k for k, v in self._special_tokens.items()}
@property
def vocab_size(self) -> int:
return self.VOCAB_SIZE
def encode(self, data: Union[bytes, bytearray], add_bos: bool = True, add_eos: bool = True) -> List[int]:
"""Encode raw bytes to token IDs.
Args:
data: Raw bytes to encode.
add_bos: Prepend BOS token.
add_eos: Append EOS token.
Returns:
List of integer token IDs.
"""
tokens = []
if add_bos:
tokens.append(self.BOS_ID)
tokens.extend(int(b) for b in data)
if add_eos:
tokens.append(self.EOS_ID)
return tokens
def decode(self, token_ids: List[int], strip_special: bool = True) -> bytes:
"""Decode token IDs back to raw bytes.
Args:
token_ids: List of integer token IDs.
strip_special: Remove PAD/BOS/EOS from output.
Returns:
Raw bytes.
"""
result = bytearray()
for tid in token_ids:
if strip_special and tid in (self.PAD_ID, self.BOS_ID, self.EOS_ID):
continue
if 0 <= tid <= 255:
result.append(tid)
# IDs outside 0-258 are silently skipped
return bytes(result)
def encode_hex(self, hex_string: str, add_bos: bool = True, add_eos: bool = True) -> List[int]:
"""Encode a hex string (e.g., '4889e548...' or '48 89 e5 48...') to tokens."""
clean = hex_string.replace(" ", "").replace("\n", "")
data = bytes.fromhex(clean)
return self.encode(data, add_bos=add_bos, add_eos=add_eos)
def decode_hex(self, token_ids: List[int], strip_special: bool = True) -> str:
"""Decode token IDs to hex string."""
raw = self.decode(token_ids, strip_special=strip_special)
return raw.hex()
def pad(self, token_ids: List[int], max_len: int) -> List[int]:
"""Pad or truncate to max_len."""
if len(token_ids) >= max_len:
return token_ids[:max_len]
return token_ids + [self.PAD_ID] * (max_len - len(token_ids))
def save(self, path: Union[str, Path]):
"""Save tokenizer config as JSON."""
config = {
"type": "ByteTokenizer",
"vocab_size": self.VOCAB_SIZE,
"byte_range": [0, 255],
"special_tokens": self._special_tokens,
"description": "Native byte-level tokenizer for machine code. 256 byte values + PAD/BOS/EOS."
}
Path(path).write_text(json.dumps(config, indent=2))
@classmethod
def load(cls, path: Union[str, Path]) -> "ByteTokenizer":
"""Load tokenizer from JSON config (validates, returns fresh instance)."""
config = json.loads(Path(path).read_text())
assert config["type"] == "ByteTokenizer", f"Wrong tokenizer type: {config['type']}"
assert config["vocab_size"] == cls.VOCAB_SIZE, f"Vocab size mismatch: {config['vocab_size']}"
return cls()
def __repr__(self):
return f"ByteTokenizer(vocab_size={self.VOCAB_SIZE}, bytes=0-255, specials=PAD/BOS/EOS)"
# Quick self-test
if __name__ == "__main__":
tok = ByteTokenizer()
# Test basic encode/decode roundtrip
test_bytes = bytes([0x48, 0x89, 0xe5, 0x48, 0x83, 0xec, 0x10]) # mov rbp,rsp; sub rsp,0x10
encoded = tok.encode(test_bytes)
decoded = tok.decode(encoded)
assert decoded == test_bytes, f"Roundtrip failed: {decoded.hex()} != {test_bytes.hex()}"
# Test hex encode/decode
hex_encoded = tok.encode_hex("4889e54883ec10")
assert tok.decode(hex_encoded) == test_bytes
# Test special tokens
assert encoded[0] == tok.BOS_ID
assert encoded[-1] == tok.EOS_ID
assert len(encoded) == len(test_bytes) + 2
# Test padding
padded = tok.pad(encoded, 20)
assert len(padded) == 20
assert padded[-1] == tok.PAD_ID
# Test save/load
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
tok.save(f.name)
tok2 = ByteTokenizer.load(f.name)
assert tok2.vocab_size == tok.vocab_size
os.unlink(f.name)
# Test all 256 byte values roundtrip
all_bytes = bytes(range(256))
assert tok.decode(tok.encode(all_bytes)) == all_bytes
print(f"✅ ByteTokenizer: {tok}")
print(f" Roundtrip test: {test_bytes.hex()} → {encoded} → {decoded.hex()}")
print(f" All 256 byte values roundtrip: PASS")
print(f" Vocab: 256 bytes + PAD({tok.PAD_ID}) + BOS({tok.BOS_ID}) + EOS({tok.EOS_ID}) = {tok.VOCAB_SIZE}")
|