File size: 6,569 Bytes
708f4a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
Double-Array Trie (DAT) Compiler for Crayon.
Compiles a sorted vocabulary list into a highly compressed, cache-local binary format (.dat).
Algorithm:
- Base[s] + c = t
- Check[t] = s
"""
import struct
import sys
import array
from typing import List, Tuple, Dict
class DATBuilder:
def __init__(self):
# Arrays: base and check.
# Initial size estimate: 2x vocab size * avg length is usually overkill but safe.
# We will resize dynamically.
self.base = array.array('i', [0] * 1024)
self.check = array.array('i', [0] * 1024)
self.used = array.array('b', [0] * 1024) # Bitset for allocation
self.check[0] = 0 # Root check is typically 0
self.size = 1024
self.max_idx = 0
# Token ID mapping
self.output = {} # state_index -> token_id
def _resize(self, new_size):
if new_size <= self.size:
return
# Python arrays scale efficiently
extension = [0] * (new_size - self.size)
self.base.extend(extension)
self.check.extend(extension)
self.used.extend([0] * (new_size - self.size))
self.size = new_size
def _find_base(self, children_keys: List[int]) -> int:
"""Finds a base offset 'b' such that check[b + c] are all empty for each c in children."""
if not children_keys:
return 1 # Leaf
first = children_keys[0]
# Start searching from 1
b = 1
while True:
# First candidate check: base + first_child
pos = b + first
if pos >= self.size:
self._resize(pos + 256)
if self.check[pos] != 0:
# Collision for first child, move forward
b += 1
continue
# Now verify all other children
overlap = False
max_pos = 0
for k in children_keys:
p = b + k
if p >= self.size:
self._resize(p + 256)
max_pos = max(max_pos, p)
if self.check[p] != 0:
overlap = True
break
if not overlap:
return b
b += 1
def build(self, tokens: List[str]) -> bytes:
"""
Builds the Double-Array Trie from sorted tokens.
"""
# 1. Build Standard Trie first (Intermediate representation)
# Dictionary of node -> {char: next_node}
trie = {'id': -1, 'children': {}}
for i, token in enumerate(tokens):
node = trie
for char in token:
key = ord(char)
if key not in node['children']:
node['children'][key] = {'id': -1, 'children': {}}
node = node['children'][key]
node['id'] = i
# 2. Convert to Double-Array via BFS
# Queue: (trie_node, dat_state_index)
queue: List[Tuple[Dict, int]] = [(trie, 0)] # Root is state 0
# Mark root as used
self.base[0] = 1
self._resize(256) # Ensure capacity
processed_count = 0
while queue:
node, state = queue.pop(0)
if node['id'] != -1:
self.output[state] = node['id']
# Mark as terminal in base array?
# Technique: We usually store leaf status by negative base or separate array.
# For Crayon, we want fast token ID retrieval.
# We will store token_id mapping separately OR encode it.
# Let's encode token_id as negative base: base[s] = -token_id - 1
# BUT a node can be both transit and terminal (e.g., "apple", "apples").
# Standard DAT handles this by specific termination char '\0' or separate array.
# To keep it compact: We will use a separate output structure for now
# OR stick to the Crayon specialized TrieNode structure.
# Solution: We will store token_ids in a separate array `terminals` which parallels check/base.
# If terminals[s] != -1, it's a match.
pass
children = node['children']
if not children:
continue
sorted_keys = sorted(children.keys())
# Find a valid base for this state
base_offset = self._find_base(sorted_keys)
self.base[state] = base_offset
# set check and prepare children
for k in sorted_keys:
next_state = base_offset + k
self.check[next_state] = state
self.used[next_state] = 1 # Mark
self.max_idx = max(self.max_idx, next_state)
queue.append((children[k], next_state))
processed_count += 1
if processed_count % 1000 == 0:
print(f"Compiled {processed_count} states...", end='\r')
print(f"\nDAT Construction Complete. {self.max_idx} states.")
return self._serialize()
def _serialize(self) -> bytes:
"""
Format:
[HEADER: 16 bytes]
- Magic: "CRYN" (4)
- Version: 1 (4)
- Size: int (4)
[BODY]
- Base: int32 * size
- Check: int32 * size
- Terminals: int32 * size (Token mapping)
"""
# Optimize size
final_size = self.max_idx + 1
# Build terminals array
terminals = array.array('i', [-1] * final_size)
for state, pid in self.output.items():
if state < final_size:
terminals[state] = pid
header = struct.pack('<4sII', b'CRYN', 1, final_size)
# Slice correct size
final_base = self.base[:final_size]
final_check = self.check[:final_size]
print(f"Serialized Size: {(final_size * 12 + 12) / 1024 / 1024:.2f} MB")
return (
header +
final_base.tobytes() +
final_check.tobytes() +
terminals.tobytes()
)
def compile_dat(tokens: List[str], output_path: str):
builder = DATBuilder()
data = builder.build(tokens)
with open(output_path, 'wb') as f:
f.write(data)
print(f"Saved: {output_path}")
|