CRAYON-tokenizer / src /crayon /core /dat_compiler.py
Phase-Technologies's picture
Upload folder using huggingface_hub
708f4a3 verified
"""
Double-Array Trie (DAT) Compiler for Crayon.
Compiles a sorted vocabulary list into a highly compressed, cache-local binary format (.dat).
Algorithm:
- Base[s] + c = t
- Check[t] = s
"""
import struct
import sys
import array
from typing import List, Tuple, Dict
class DATBuilder:
def __init__(self):
# Arrays: base and check.
# Initial size estimate: 2x vocab size * avg length is usually overkill but safe.
# We will resize dynamically.
self.base = array.array('i', [0] * 1024)
self.check = array.array('i', [0] * 1024)
self.used = array.array('b', [0] * 1024) # Bitset for allocation
self.check[0] = 0 # Root check is typically 0
self.size = 1024
self.max_idx = 0
# Token ID mapping
self.output = {} # state_index -> token_id
def _resize(self, new_size):
if new_size <= self.size:
return
# Python arrays scale efficiently
extension = [0] * (new_size - self.size)
self.base.extend(extension)
self.check.extend(extension)
self.used.extend([0] * (new_size - self.size))
self.size = new_size
def _find_base(self, children_keys: List[int]) -> int:
"""Finds a base offset 'b' such that check[b + c] are all empty for each c in children."""
if not children_keys:
return 1 # Leaf
first = children_keys[0]
# Start searching from 1
b = 1
while True:
# First candidate check: base + first_child
pos = b + first
if pos >= self.size:
self._resize(pos + 256)
if self.check[pos] != 0:
# Collision for first child, move forward
b += 1
continue
# Now verify all other children
overlap = False
max_pos = 0
for k in children_keys:
p = b + k
if p >= self.size:
self._resize(p + 256)
max_pos = max(max_pos, p)
if self.check[p] != 0:
overlap = True
break
if not overlap:
return b
b += 1
def build(self, tokens: List[str]) -> bytes:
"""
Builds the Double-Array Trie from sorted tokens.
"""
# 1. Build Standard Trie first (Intermediate representation)
# Dictionary of node -> {char: next_node}
trie = {'id': -1, 'children': {}}
for i, token in enumerate(tokens):
node = trie
for char in token:
key = ord(char)
if key not in node['children']:
node['children'][key] = {'id': -1, 'children': {}}
node = node['children'][key]
node['id'] = i
# 2. Convert to Double-Array via BFS
# Queue: (trie_node, dat_state_index)
queue: List[Tuple[Dict, int]] = [(trie, 0)] # Root is state 0
# Mark root as used
self.base[0] = 1
self._resize(256) # Ensure capacity
processed_count = 0
while queue:
node, state = queue.pop(0)
if node['id'] != -1:
self.output[state] = node['id']
# Mark as terminal in base array?
# Technique: We usually store leaf status by negative base or separate array.
# For Crayon, we want fast token ID retrieval.
# We will store token_id mapping separately OR encode it.
# Let's encode token_id as negative base: base[s] = -token_id - 1
# BUT a node can be both transit and terminal (e.g., "apple", "apples").
# Standard DAT handles this by specific termination char '\0' or separate array.
# To keep it compact: We will use a separate output structure for now
# OR stick to the Crayon specialized TrieNode structure.
# Solution: We will store token_ids in a separate array `terminals` which parallels check/base.
# If terminals[s] != -1, it's a match.
pass
children = node['children']
if not children:
continue
sorted_keys = sorted(children.keys())
# Find a valid base for this state
base_offset = self._find_base(sorted_keys)
self.base[state] = base_offset
# set check and prepare children
for k in sorted_keys:
next_state = base_offset + k
self.check[next_state] = state
self.used[next_state] = 1 # Mark
self.max_idx = max(self.max_idx, next_state)
queue.append((children[k], next_state))
processed_count += 1
if processed_count % 1000 == 0:
print(f"Compiled {processed_count} states...", end='\r')
print(f"\nDAT Construction Complete. {self.max_idx} states.")
return self._serialize()
def _serialize(self) -> bytes:
"""
Format:
[HEADER: 16 bytes]
- Magic: "CRYN" (4)
- Version: 1 (4)
- Size: int (4)
[BODY]
- Base: int32 * size
- Check: int32 * size
- Terminals: int32 * size (Token mapping)
"""
# Optimize size
final_size = self.max_idx + 1
# Build terminals array
terminals = array.array('i', [-1] * final_size)
for state, pid in self.output.items():
if state < final_size:
terminals[state] = pid
header = struct.pack('<4sII', b'CRYN', 1, final_size)
# Slice correct size
final_base = self.base[:final_size]
final_check = self.check[:final_size]
print(f"Serialized Size: {(final_size * 12 + 12) / 1024 / 1024:.2f} MB")
return (
header +
final_base.tobytes() +
final_check.tobytes() +
terminals.tobytes()
)
def compile_dat(tokens: List[str], output_path: str):
builder = DATBuilder()
data = builder.build(tokens)
with open(output_path, 'wb') as f:
f.write(data)
print(f"Saved: {output_path}")