File size: 6,569 Bytes
708f4a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

"""
Double-Array Trie (DAT) Compiler for Crayon.
Compiles a sorted vocabulary list into a highly compressed, cache-local binary format (.dat).

Algorithm:
- Base[s] + c = t
- Check[t] = s
"""

import struct
import sys
import array
from typing import List, Tuple, Dict

class DATBuilder:
    def __init__(self):
        # Arrays: base and check. 
        # Initial size estimate: 2x vocab size * avg length is usually overkill but safe.
        # We will resize dynamically.
        self.base = array.array('i', [0] * 1024)
        self.check = array.array('i', [0] * 1024)
        self.used = array.array('b', [0] * 1024) # Bitset for allocation
        self.check[0] = 0 # Root check is typically 0
        self.size = 1024
        self.max_idx = 0
        
        # Token ID mapping
        self.output = {} # state_index -> token_id

    def _resize(self, new_size):
        if new_size <= self.size:
            return
        # Python arrays scale efficiently
        extension = [0] * (new_size - self.size)
        self.base.extend(extension)
        self.check.extend(extension)
        self.used.extend([0] * (new_size - self.size))
        self.size = new_size

    def _find_base(self, children_keys: List[int]) -> int:
        """Finds a base offset 'b' such that check[b + c] are all empty for each c in children."""
        if not children_keys:
            return 1 # Leaf
            
        first = children_keys[0]
        # Start searching from 1
        b = 1 
        while True:
            # First candidate check: base + first_child
            pos = b + first
            if pos >= self.size:
                self._resize(pos + 256)
                
            if self.check[pos] != 0:
                # Collision for first child, move forward
                b += 1
                continue
            
            # Now verify all other children
            overlap = False
            max_pos = 0
            for k in children_keys:
                p = b + k
                if p >= self.size:
                    self._resize(p + 256)
                max_pos = max(max_pos, p)
                
                if self.check[p] != 0:
                    overlap = True
                    break
            
            if not overlap:
                return b
            
            b += 1

    def build(self, tokens: List[str]) -> bytes:
        """
        Builds the Double-Array Trie from sorted tokens.
        """
        # 1. Build Standard Trie first (Intermediate representation)
        # Dictionary of node -> {char: next_node}
        trie = {'id': -1, 'children': {}}
        
        for i, token in enumerate(tokens):
            node = trie
            for char in token:
                key = ord(char)
                if key not in node['children']:
                    node['children'][key] = {'id': -1, 'children': {}}
                node = node['children'][key]
            node['id'] = i
            
        # 2. Convert to Double-Array via BFS
        # Queue: (trie_node, dat_state_index)
        queue: List[Tuple[Dict, int]] = [(trie, 0)] # Root is state 0
        
        # Mark root as used
        self.base[0] = 1
        self._resize(256) # Ensure capacity
        
        processed_count = 0
        
        while queue:
            node, state = queue.pop(0)
            
            if node['id'] != -1:
                self.output[state] = node['id']
                # Mark as terminal in base array? 
                # Technique: We usually store leaf status by negative base or separate array.
                # For Crayon, we want fast token ID retrieval.
                # We will store token_id mapping separately OR encode it.
                # Let's encode token_id as negative base: base[s] = -token_id - 1
                # BUT a node can be both transit and terminal (e.g., "apple", "apples").
                # Standard DAT handles this by specific termination char '\0' or separate array.
                # To keep it compact: We will use a separate output structure for now 
                # OR stick to the Crayon specialized TrieNode structure.
                
                # Solution: We will store token_ids in a separate array `terminals` which parallels check/base.
                # If terminals[s] != -1, it's a match.
                pass

            children = node['children']
            if not children:
                continue
                
            sorted_keys = sorted(children.keys())
            
            # Find a valid base for this state
            base_offset = self._find_base(sorted_keys)
            self.base[state] = base_offset
            
            # set check and prepare children
            for k in sorted_keys:
                next_state = base_offset + k
                self.check[next_state] = state
                self.used[next_state] = 1 # Mark
                self.max_idx = max(self.max_idx, next_state)
                
                queue.append((children[k], next_state))
                
            processed_count += 1
            if processed_count % 1000 == 0:
                print(f"Compiled {processed_count} states...", end='\r')

        print(f"\nDAT Construction Complete. {self.max_idx} states.")
        return self._serialize()

    def _serialize(self) -> bytes:
        """
        Format:
        [HEADER: 16 bytes]
          - Magic: "CRYN" (4)
          - Version: 1 (4)
          - Size: int (4)
        [BODY]
          - Base: int32 * size
          - Check: int32 * size
          - Terminals: int32 * size (Token mapping)
        """
        # Optimize size
        final_size = self.max_idx + 1
        
        # Build terminals array
        terminals = array.array('i', [-1] * final_size)
        for state, pid in self.output.items():
            if state < final_size:
                terminals[state] = pid
                
        header = struct.pack('<4sII', b'CRYN', 1, final_size)
        
        # Slice correct size
        final_base = self.base[:final_size]
        final_check = self.check[:final_size]
        
        print(f"Serialized Size: {(final_size * 12 + 12) / 1024 / 1024:.2f} MB")
        
        return (
            header + 
            final_base.tobytes() + 
            final_check.tobytes() + 
            terminals.tobytes()
        )

def compile_dat(tokens: List[str], output_path: str):
    builder = DATBuilder()
    data = builder.build(tokens)
    with open(output_path, 'wb') as f:
        f.write(data)
    print(f"Saved: {output_path}")