File size: 17,450 Bytes
3f42614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
#!/usr/bin/env python3
"""
GLADIUS v4.0 β€” Bytecode Tokenizer

A structural tokenizer for bytecode. Not BPE. Not byte-level. 
Direct opcode-to-token mapping with argument quantization.

Design:
  Token 0-255:    Reserved for byte-level (machine code curriculum, future)
  Token 256-258:  PAD, BOS, EOS
  Token 259-511:  Opcodes (146 used, room for 253)
  Token 512-639:  Numeric arguments (128 quantized buckets)
  Token 640-671:  Register/local indices (0-31)
  Token 672-687:  Comparison operators (16 slots)
  Token 688-703:  Task markers (NEXT, STACK, EXEC, FILL, etc.)
  Token 704-767:  Stack names, difficulty markers, structural tokens
  Token 768+:     Overflow / future expansion

Total vocab: ~768 active tokens (fits easily in 32K embedding space)

Encoding is O(1) per token β€” table lookup, no iteration, no probability.
Decoding is deterministic β€” token β†’ exact opcode string.
"""

import re
from typing import List, Optional, Dict, Tuple
from pathlib import Path


class BytecodeTokenizer:
    """Structural tokenizer for Python/WASM/EVM bytecode."""
    
    # Ranges
    BYTE_START = 0       # 0-255: byte-level (reserved)
    PAD = 256
    BOS = 257
    EOS = 258
    OPCODE_START = 259
    NUMERIC_START = 512   # quantized numeric values
    REGISTER_START = 640  # local/register indices 0-31
    COMPARE_START = 672   # comparison operators
    TASK_START = 688      # task markers
    STRUCT_START = 704    # structural tokens
    
    # Numeric quantization: map integers to 128 buckets
    # Bucket 0-63: exact values -32 to 31
    # Bucket 64-95: values 32-255 (8-value steps)
    # Bucket 96-111: values 256-4095 (256-value steps)
    # Bucket 112-127: values 4096-65535+ (exponential)
    
    COMPARE_OPS = ['<', '>', '==', '!=', '<=', '>=',
                   'lt_s', 'lt_u', 'gt_s', 'gt_u', 'le_s', 'le_u',
                   'ge_s', 'ge_u', 'eq', 'ne']
    
    TASK_MARKERS = ['NEXT', 'STACK', 'EXEC', 'FILL', 'TRACE', 'OUT', 'QED',
                    'D1', 'D2', 'D3', 'D4', 'D5',
                    'python', 'wasm', 'evm']
    
    STRUCT_TOKENS = ['(', ')', ',', '|', ':', '=',
                     'o=', 'a=',  # WASM memory args
                     'computed',  # EXEC result placeholder
                     'SEPARATOR']
    
    def __init__(self, vocab_file: Optional[str] = None):
        """Initialize with opcode vocabulary.
        
        If vocab_file is None, uses built-in vocabulary.
        """
        self._opcode_to_id: Dict[str, int] = {}
        self._id_to_opcode: Dict[int, str] = {}
        self._compare_to_id: Dict[str, int] = {}
        self._task_to_id: Dict[str, int] = {}
        self._struct_to_id: Dict[str, int] = {}
        
        # Build opcode vocabulary
        opcodes = self._get_builtin_opcodes()
        if vocab_file and Path(vocab_file).exists():
            with open(vocab_file) as f:
                file_ops = [line.strip() for line in f if line.strip()]
            # Merge (file ops take precedence for ordering)
            opcodes = list(dict.fromkeys(file_ops + opcodes))
        
        for i, op in enumerate(opcodes):
            tid = self.OPCODE_START + i
            self._opcode_to_id[op] = tid
            self._id_to_opcode[tid] = op
        
        # Comparison operators
        for i, cmp in enumerate(self.COMPARE_OPS):
            tid = self.COMPARE_START + i
            self._compare_to_id[cmp] = tid
            self._id_to_opcode[tid] = f"CMP:{cmp}"
        
        # Task markers
        for i, task in enumerate(self.TASK_MARKERS):
            tid = self.TASK_START + i
            self._task_to_id[task] = tid
            self._id_to_opcode[tid] = f"TASK:{task}"
        
        # Structural tokens
        for i, st in enumerate(self.STRUCT_TOKENS):
            tid = self.STRUCT_START + i
            self._struct_to_id[st] = tid
            self._id_to_opcode[tid] = f"STRUCT:{st}"
        
        self.vocab_size = self.STRUCT_START + len(self.STRUCT_TOKENS) + 1
        self._num_opcodes = len(opcodes)
    
    def _get_builtin_opcodes(self) -> List[str]:
        """Built-in opcode vocabulary covering Python/WASM/EVM."""
        return [
            # Python (CPython dis)
            'LOAD_CONST', 'LOAD_FAST', 'LOAD_GLOBAL', 'LOAD_NAME',
            'LOAD_ATTR', 'LOAD_DEREF', 'LOAD_CLOSURE',
            'STORE_FAST', 'STORE_GLOBAL', 'STORE_NAME', 'STORE_ATTR',
            'STORE_DEREF', 'STORE_SUBSCR',
            'BINARY_ADD', 'BINARY_SUBTRACT', 'BINARY_MULTIPLY',
            'BINARY_TRUE_DIVIDE', 'BINARY_FLOOR_DIVIDE', 'BINARY_MODULO',
            'BINARY_POWER', 'BINARY_AND', 'BINARY_OR', 'BINARY_XOR',
            'BINARY_LSHIFT', 'BINARY_RSHIFT', 'BINARY_SUBSCR',
            'UNARY_POSITIVE', 'UNARY_NEGATIVE', 'UNARY_NOT', 'UNARY_INVERT',
            'COMPARE_OP',
            'JUMP_ABSOLUTE', 'JUMP_FORWARD',
            'POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE',
            'JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP',
            'CALL_FUNCTION', 'CALL_FUNCTION_KW', 'CALL_METHOD',
            'BUILD_TUPLE', 'BUILD_LIST', 'BUILD_SET', 'BUILD_MAP',
            'BUILD_CONST_KEY_MAP', 'BUILD_STRING', 'BUILD_SLICE',
            'LIST_APPEND', 'SET_ADD', 'MAP_ADD',
            'POP_TOP', 'ROT_TWO', 'ROT_THREE', 'ROT_FOUR',
            'DUP_TOP', 'DUP_TOP_TWO',
            'RETURN_VALUE', 'GET_ITER', 'FOR_ITER', 'GET_YIELD_FROM_ITER',
            'NOP', 'MAKE_FUNCTION', 'SETUP_LOOP', 'POP_BLOCK',
            'SETUP_EXCEPT', 'SETUP_FINALLY', 'RAISE_VARARGS',
            'IMPORT_NAME', 'IMPORT_FROM', 'UNPACK_SEQUENCE',
            # WASM
            'i32.const', 'i64.const', 'f32.const', 'f64.const',
            'i32.add', 'i32.sub', 'i32.mul', 'i32.div_s', 'i32.div_u',
            'i32.rem_s', 'i32.rem_u', 'i32.and', 'i32.or', 'i32.xor',
            'i32.shl', 'i32.shr_s', 'i32.shr_u', 'i32.rotl', 'i32.rotr',
            'i32.clz', 'i32.ctz', 'i32.popcnt',
            'i64.add', 'i64.sub', 'i64.mul', 'i64.div_s',
            'i64.and', 'i64.or', 'i64.xor',
            'f32.add', 'f32.sub', 'f32.mul', 'f32.div',
            'f32.sqrt', 'f32.min', 'f32.max', 'f32.abs', 'f32.neg',
            'f64.add', 'f64.sub', 'f64.mul', 'f64.div',
            'f64.sqrt', 'f64.min', 'f64.max', 'f64.abs', 'f64.neg',
            'i32.eqz', 'i32.eq', 'i32.ne', 'i32.lt_s', 'i32.lt_u',
            'i32.gt_s', 'i32.gt_u', 'i32.le_s', 'i32.le_u',
            'i32.ge_s', 'i32.ge_u',
            'i64.eqz', 'i64.eq', 'i64.ne',
            'f32.eq', 'f32.ne', 'f32.lt', 'f32.gt',
            'f64.eq', 'f64.ne', 'f64.lt', 'f64.gt',
            'i32.load', 'i64.load', 'f32.load', 'f64.load',
            'i32.store', 'i64.store', 'f32.store', 'f64.store',
            'i32.load8_s', 'i32.load8_u', 'i32.load16_s', 'i32.load16_u',
            'memory.size', 'memory.grow',
            'block', 'loop', 'if', 'else', 'end',
            'br', 'br_if', 'br_table', 'return',
            'call', 'call_indirect',
            'local.get', 'local.set', 'local.tee',
            'global.get', 'global.set',
            'i32.wrap_i64', 'i64.extend_i32_s', 'i64.extend_i32_u',
            'f32.convert_i32_s', 'f64.convert_i32_s',
            'i32.trunc_f32_s', 'i32.trunc_f64_s',
            'f32.demote_f64', 'f64.promote_f32',
            'i32.reinterpret_f32', 'f32.reinterpret_i32',
            'drop', 'select', 'nop', 'unreachable',
            # EVM
            'PUSH1', 'PUSH2', 'PUSH32', 'POP',
            'DUP1', 'DUP2', 'DUP3', 'DUP4',
            'SWAP1', 'SWAP2', 'SWAP3', 'SWAP4',
            'ADD', 'MUL', 'SUB', 'DIV', 'SDIV', 'MOD', 'SMOD',
            'ADDMOD', 'MULMOD', 'EXP', 'SIGNEXTEND',
            'LT', 'GT', 'SLT', 'SGT', 'EQ', 'ISZERO',
            'AND', 'OR', 'XOR', 'NOT', 'BYTE', 'SHL', 'SHR', 'SAR',
            'SHA3',
            'MLOAD', 'MSTORE', 'MSTORE8', 'MSIZE',
            'SLOAD', 'SSTORE',
            'JUMP', 'JUMPI', 'JUMPDEST', 'STOP', 'RETURN', 'REVERT',
            'ADDRESS', 'BALANCE', 'ORIGIN', 'CALLER', 'CALLVALUE',
            'CALLDATALOAD', 'CALLDATASIZE', 'CALLDATACOPY',
            'CODESIZE', 'CODECOPY', 'GASPRICE', 'RETURNDATASIZE',
            'RETURNDATACOPY', 'BLOCKHASH', 'COINBASE', 'TIMESTAMP',
            'NUMBER', 'DIFFICULTY', 'GASLIMIT', 'CHAINID', 'SELFBALANCE',
            'GAS',
            'LOG0', 'LOG1', 'LOG2', 'LOG3', 'LOG4',
            'CALL', 'DELEGATECALL', 'STATICCALL', 'CREATE', 'CREATE2',
            'SELFDESTRUCT',
        ]
    
    def _quantize_number(self, val: float) -> int:
        """Quantize a number to a bucket index (0-127)."""
        v = int(val) if val == int(val) else int(val)
        
        # Exact range: -32 to 31 β†’ buckets 0-63
        if -32 <= v <= 31:
            return v + 32
        
        # Medium range: 32-255 β†’ buckets 64-91 (8-value steps)
        if 32 <= v <= 255:
            return 64 + min(27, (v - 32) // 8)
        if -256 <= v < -32:
            return 64 + min(27, (-v - 33) // 8)
        
        # Large range: 256-4095 β†’ buckets 92-107
        if 256 <= abs(v) <= 4095:
            return 92 + min(15, (abs(v) - 256) // 256)
        
        # Huge range: 4096+ β†’ buckets 108-127 (exponential)
        if abs(v) > 4095:
            import math
            return 108 + min(19, int(math.log2(max(1, abs(v) / 4096))))
        
        return 64  # fallback
    
    def _dequantize_number(self, bucket: int) -> int:
        """Reverse quantization β€” approximate original value."""
        if 0 <= bucket <= 63:
            return bucket - 32
        if 64 <= bucket <= 91:
            return 32 + (bucket - 64) * 8
        if 92 <= bucket <= 107:
            return 256 + (bucket - 92) * 256
        if 108 <= bucket <= 127:
            return 4096 * (2 ** (bucket - 108))
        return 0
    
    def encode_line(self, line: str) -> List[int]:
        """Encode a complete bytecode line into token IDs.
        
        Input format: D{n}|{stack}|{opcode_sequence}|{task}
        Returns: [BOS, ...token_ids..., EOS]
        """
        tokens = [self.BOS]
        
        parts = line.split('|')
        if len(parts) < 4:
            return tokens + [self.EOS]
        
        # Difficulty marker
        diff = parts[0].strip()
        if diff in self._task_to_id:
            tokens.append(self._task_to_id[diff])
        
        # Stack name
        stack = parts[1].strip()
        if stack in self._task_to_id:
            tokens.append(self._task_to_id[stack])
        
        # Separator
        tokens.append(self._struct_to_id.get('|', self.STRUCT_START + 3))
        
        # Opcode sequence
        ops_str = parts[2].strip()
        tokens.extend(self._encode_ops(ops_str))
        
        # Task separator
        tokens.append(self._struct_to_id.get('|', self.STRUCT_START + 3))
        
        # Task
        task_str = parts[3].strip()
        task_parts = task_str.split(':', 1)
        task_marker = task_parts[0]
        if task_marker in self._task_to_id:
            tokens.append(self._task_to_id[task_marker])
        
        if len(task_parts) > 1:
            task_content = task_parts[1].strip()
            # Task content is either opcodes (NEXT/FILL) or a value (STACK/EXEC)
            if task_marker in ('NEXT', 'FILL'):
                tokens.extend(self._encode_ops(task_content))
            elif task_marker == 'STACK':
                try:
                    val = int(task_content)
                    bucket = self._quantize_number(val)
                    tokens.append(self.NUMERIC_START + bucket)
                except ValueError:
                    pass
            elif task_marker == 'EXEC':
                if task_content in self._struct_to_id:
                    tokens.append(self._struct_to_id[task_content])
        
        tokens.append(self.EOS)
        return tokens
    
    def _encode_ops(self, ops_str: str) -> List[int]:
        """Encode a space-separated sequence of opcodes with arguments."""
        tokens = []
        # Parse: OPCODE(arg1,arg2) or OPCODE or OPCODE(arg)
        pattern = re.compile(r'(\S+?)(?:\(([^)]*)\))?(?:\s|$)')
        
        for match in pattern.finditer(ops_str):
            opcode = match.group(1)
            args = match.group(2)
            
            # Look up opcode
            if opcode in self._opcode_to_id:
                tokens.append(self._opcode_to_id[opcode])
            else:
                # Unknown opcode β€” skip (shouldn't happen with our corpus)
                continue
            
            # Encode arguments
            if args:
                for arg in args.split(','):
                    arg = arg.strip()
                    tokens.extend(self._encode_arg(arg))
        
        return tokens
    
    def _encode_arg(self, arg: str) -> List[int]:
        """Encode a single argument to token(s)."""
        tokens = []
        
        # Check if it's a comparison operator
        if arg in self._compare_to_id:
            return [self._compare_to_id[arg]]
        
        # Check for key=value (WASM memory args)
        if '=' in arg and not arg.startswith('0x'):
            key, val = arg.split('=', 1)
            key_tok = key + '='
            if key_tok in self._struct_to_id:
                tokens.append(self._struct_to_id[key_tok])
            try:
                v = int(val)
                bucket = self._quantize_number(v)
                tokens.append(self.NUMERIC_START + bucket)
            except ValueError:
                pass
            return tokens
        
        # Try as hex
        if arg.startswith('0x'):
            try:
                v = int(arg, 16)
                bucket = self._quantize_number(v)
                tokens.append(self.NUMERIC_START + bucket)
                return tokens
            except ValueError:
                pass
        
        # Try as number
        try:
            v = float(arg)
            bucket = self._quantize_number(v)
            tokens.append(self.NUMERIC_START + bucket)
            return tokens
        except ValueError:
            pass
        
        # Register index (small integer as string)
        try:
            v = int(arg)
            if 0 <= v <= 31:
                tokens.append(self.REGISTER_START + v)
                return tokens
        except ValueError:
            pass
        
        return tokens
    
    def decode(self, token_ids: List[int]) -> str:
        """Decode token IDs back to bytecode string (approximate)."""
        parts = []
        for tid in token_ids:
            if tid == self.BOS:
                parts.append('<BOS>')
            elif tid == self.EOS:
                parts.append('<EOS>')
            elif tid == self.PAD:
                continue
            elif tid in self._id_to_opcode:
                parts.append(self._id_to_opcode[tid])
            elif self.NUMERIC_START <= tid < self.REGISTER_START:
                bucket = tid - self.NUMERIC_START
                val = self._dequantize_number(bucket)
                parts.append(f"#{val}")
            elif self.REGISTER_START <= tid < self.COMPARE_START:
                reg = tid - self.REGISTER_START
                parts.append(f"r{reg}")
            else:
                parts.append(f"?{tid}")
        return ' '.join(parts)
    
    def stats(self) -> Dict[str, int]:
        return {
            'vocab_size': self.vocab_size,
            'num_opcodes': self._num_opcodes,
            'num_comparisons': len(self._compare_to_id),
            'num_task_markers': len(self._task_to_id),
            'num_struct_tokens': len(self._struct_to_id),
            'numeric_buckets': 128,
            'register_slots': 32,
        }


def main():
    """Test the tokenizer."""
    tok = BytecodeTokenizer()
    print(f"BytecodeTokenizer initialized:")
    for k, v in tok.stats().items():
        print(f"  {k}: {v}")
    
    # Test encoding
    test_lines = [
        "D1|python|LOAD_CONST(42) LOAD_CONST(10) BINARY_ADD RETURN_VALUE|NEXT:RETURN_VALUE",
        "D2|wasm|local.get(1) i32.const(54) i32.lt_u if i32.const(60)|STACK:3",
        "D3|evm|PUSH1(0x00) CALLDATALOAD PUSH2(0x3a8a) EQ PUSH1(0x29) JUMPI|NEXT:STOP",
    ]
    
    print(f"\nEncoding tests:")
    for line in test_lines:
        ids = tok.encode_line(line)
        decoded = tok.decode(ids)
        print(f"\n  Input:   {line[:80]}...")
        print(f"  Tokens:  {ids}")
        print(f"  Length:  {len(ids)}")
        print(f"  Decoded: {decoded[:80]}...")
    
    # Roundtrip test on full corpus
    corpus_path = Path(__file__).parent.parent / 'corpus' / 'bytecode.txt'
    if corpus_path.exists():
        print(f"\nCorpus encoding test ({corpus_path}):")
        total = 0
        total_tokens = 0
        empty = 0
        with open(corpus_path) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                ids = tok.encode_line(line)
                total += 1
                total_tokens += len(ids)
                if len(ids) <= 2:  # only BOS+EOS
                    empty += 1
        print(f"  Lines: {total}")
        print(f"  Total tokens: {total_tokens}")
        print(f"  Avg tokens/line: {total_tokens / max(1, total):.1f}")
        print(f"  Empty (BOS+EOS only): {empty}")
        print(f"  Compression vs raw chars: {total_tokens / sum(len(l) for l in open(corpus_path)):.2f}x")


if __name__ == '__main__':
    main()