File size: 11,114 Bytes
4890f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""
PyTorch-compatible chamber lookup for H4 ChamberTree.

Provides a bridge between PyTorch tensors (gradient-tracked) and the
numpy-based H4ChamberTree (discrete, non-differentiable). The key trick:

  - ChamberTree does fast O(log t) filtering to find top-k candidate keys
  - We return candidate indices back to PyTorch
  - Attention scores are computed only over candidates (differentiable)
  - Gradients flow through Q/K projections and scores, not through the tree

This gives O(k) attention per query where k << t.

If the compiled Rust backend (h4_rust) is available, RustChamberIndex provides
a much faster implementation. Falls back to pure-Python ChamberIndex otherwise.
"""

import numpy as np
import torch
from typing import List, Tuple, Optional
import sys
import os

# Rust backend detection — optional, graceful fallback to Python
try:
    import h4_rust
    RUST_AVAILABLE = True
except ImportError:
    RUST_AVAILABLE = False

# Add parent to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from h4_polytopic_attention import H4ChamberTree, build_coxeter_chambers, generate_600_cell_vertices


class ChamberIndex:
    """
    Manages a set of H4ChamberTrees (one per head) and provides
    batch top-k candidate lookup compatible with PyTorch autograd.
    """

    def __init__(self, n_heads: int, simple_roots: np.ndarray):
        self.n_heads = n_heads
        self.simple_roots = simple_roots
        self.trees = [H4ChamberTree(simple_roots) for _ in range(n_heads)]
        self._keys_by_head = [[] for _ in range(n_heads)]  # track insertion order

    def reset(self):
        """Clear all trees and rebuild."""
        self.trees = [H4ChamberTree(self.simple_roots) for _ in range(self.n_heads)]
        self._keys_by_head = [[] for _ in range(self.n_heads)]

    def insert_keys(self, keys: torch.Tensor):
        """
        Insert keys for all heads at current timestep.

        Args:
            keys: (n_heads, 4) tensor of key vectors to insert
        """
        keys_np = keys.detach().cpu().numpy()
        t = len(self._keys_by_head[0])  # current position index
        for h in range(self.n_heads):
            key = keys_np[h]
            # Use position index as both value and timestamp
            self.trees[h].insert(key, np.array([t], dtype=np.float64), t)
            self._keys_by_head[h].append(key.copy())

    def bulk_insert(self, keys: torch.Tensor):
        """
        Insert a full sequence of keys for all heads.

        Args:
            keys: (seq_len, n_heads, 4) tensor of key vectors
        """
        seq_len = keys.shape[0]
        keys_np = keys.detach().cpu().numpy()
        for t in range(seq_len):
            for h in range(self.n_heads):
                key = keys_np[t, h]
                self.trees[h].insert(key, np.array([t], dtype=np.float64), t)
                self._keys_by_head[h].append(key.copy())

    def query_topk(
        self,
        queries: torch.Tensor,
        k: int,
        causal_mask_pos: Optional[int] = None,
    ) -> List[List[List[int]]]:
        """
        For each query, find top-k candidate key indices using ChamberTree.

        Args:
            queries: (n_queries, n_heads, 4) tensor of query vectors
            k: number of candidates per query per head
            causal_mask_pos: if set, only return candidates with index <= this value

        Returns:
            List of shape [n_queries][n_heads][<=k] containing key indices.
            These indices can be used to gather from the full key/value tensors.
        """
        n_queries = queries.shape[0]
        queries_np = queries.detach().cpu().numpy()
        results = []

        for q_idx in range(n_queries):
            head_results = []
            for h in range(self.n_heads):
                query = queries_np[q_idx, h]
                # Query tree for top candidates
                # Request more than k since some may be filtered by causal mask
                tree_results = self.trees[h].query_max_dot(query, k=k * 2)

                indices = []
                for score, value, timestamp in tree_results:
                    t_idx = int(value[0]) if len(value) > 0 else timestamp
                    if causal_mask_pos is not None and t_idx > causal_mask_pos:
                        continue
                    indices.append(t_idx)
                    if len(indices) >= k:
                        break

                # If tree didn't return enough, fall back to scanning
                if len(indices) < k and len(self._keys_by_head[h]) > 0:
                    max_pos = causal_mask_pos if causal_mask_pos is not None else len(self._keys_by_head[h]) - 1
                    all_keys = np.array(self._keys_by_head[h][:max_pos + 1])
                    if len(all_keys) > 0:
                        dots = all_keys @ query
                        sorted_idx = np.argsort(-dots)
                        existing = set(indices)
                        for idx in sorted_idx:
                            if idx not in existing:
                                indices.append(int(idx))
                                existing.add(int(idx))
                            if len(indices) >= k:
                                break

                head_results.append(indices)
            results.append(head_results)

        return results


def compute_chamber_ids(keys: torch.Tensor, simple_roots: torch.Tensor) -> torch.Tensor:
    """
    Compute chamber IDs for a batch of keys (differentiable w.r.t. nothing,
    but useful for logging chamber utilization).

    Args:
        keys: (..., 4) tensor of key vectors
        simple_roots: (4, 4) tensor of H4 simple roots

    Returns:
        (...,) tensor of integer chamber IDs (0-15 for 4-bit sign pattern)
    """
    # Dot products with all 4 roots: (..., 4)
    dots = keys @ simple_roots.T
    # Sign pattern → 4-bit chamber ID
    signs = (dots >= 0).long()
    ids = signs[..., 0] * 8 + signs[..., 1] * 4 + signs[..., 2] * 2 + signs[..., 3]
    return ids


def chamber_utilization(chamber_ids: torch.Tensor, n_chambers: int = 16) -> dict:
    """
    Compute chamber utilization statistics.

    Returns:
        Dict with 'counts' (per-chamber), 'entropy' (Shannon entropy),
        and 'max_ratio' (max/mean ratio, 1.0 = perfectly uniform).
    """
    counts = torch.zeros(n_chambers, dtype=torch.long, device=chamber_ids.device)
    flat = chamber_ids.flatten()
    for i in range(n_chambers):
        counts[i] = (flat == i).sum()

    total = counts.sum().float()
    if total == 0:
        return {'counts': counts, 'entropy': 0.0, 'max_ratio': 0.0}

    probs = counts.float() / total
    # Shannon entropy (nats)
    log_probs = torch.where(probs > 0, torch.log(probs), torch.zeros_like(probs))
    entropy = -(probs * log_probs).sum().item()

    mean_count = total / n_chambers
    max_ratio = (counts.max().float() / mean_count).item() if mean_count > 0 else 0.0

    return {
        'counts': counts,
        'entropy': entropy,
        'max_ratio': max_ratio,
    }


class RustChamberIndex:
    """
    Rust-accelerated chamber index using h4_rust compiled backend.
    API-compatible with ChamberIndex for drop-in replacement.

    All heavy computation (dot products, sorting, chamber indexing) runs
    in compiled Rust via PyO3/numpy, typically 10-100x faster than Python.
    """

    def __init__(self, n_heads: int, simple_roots: np.ndarray):
        if not RUST_AVAILABLE:
            raise ImportError("h4_rust is not available. Install with: cd rust && maturin develop --release")
        self.n_heads = n_heads
        self.simple_roots = simple_roots  # (4, 4) numpy array
        self._keys_by_head = [[] for _ in range(n_heads)]  # list of (4,) arrays per head

    def reset(self):
        """Clear all stored keys."""
        self._keys_by_head = [[] for _ in range(self.n_heads)]

    def insert_keys(self, keys: torch.Tensor):
        """
        Insert keys for all heads at current timestep.

        Args:
            keys: (n_heads, 4) tensor of key vectors to insert
        """
        keys_np = keys.detach().cpu().numpy()
        for h in range(self.n_heads):
            self._keys_by_head[h].append(keys_np[h].copy())

    def bulk_insert(self, keys: torch.Tensor):
        """
        Insert a full sequence of keys for all heads.

        Args:
            keys: (seq_len, n_heads, 4) tensor of key vectors
        """
        keys_np = keys.detach().cpu().numpy()
        seq_len = keys_np.shape[0]
        for t in range(seq_len):
            for h in range(self.n_heads):
                self._keys_by_head[h].append(keys_np[t, h].copy())

    def query_topk(
        self,
        queries: torch.Tensor,
        k: int,
        causal_mask_pos: Optional[int] = None,
    ) -> List[List[List[int]]]:
        """
        For each query, find top-k candidate key indices using Rust backend.

        Args:
            queries: (n_queries, n_heads, 4) tensor of query vectors
            k: number of candidates per query per head
            causal_mask_pos: if set, only consider keys with index <= this value

        Returns:
            List of shape [n_queries][n_heads][<=k] containing key indices.
        """
        n_queries = queries.shape[0]
        queries_np = queries.detach().cpu().numpy()
        results = []

        for q_idx in range(n_queries):
            head_results = []
            for h in range(self.n_heads):
                n_keys = len(self._keys_by_head[h])
                if n_keys == 0:
                    head_results.append([])
                    continue

                # Apply causal mask: only use keys up to causal_mask_pos
                max_pos = causal_mask_pos if causal_mask_pos is not None else n_keys - 1
                effective_n = min(n_keys, max_pos + 1)

                if effective_n == 0:
                    head_results.append([])
                    continue

                keys_arr = np.array(self._keys_by_head[h][:effective_n], dtype=np.float64)
                query_arr = queries_np[q_idx, h:h+1].astype(np.float64)

                actual_k = min(k, effective_n)
                indices = h4_rust.query_topk(keys_arr, query_arr, actual_k)
                # indices is (1, actual_k), extract the list and filter -1s
                idx_list = [int(i) for i in indices[0] if i >= 0]
                head_results.append(idx_list)

            results.append(head_results)

        return results


def get_chamber_index(n_heads: int, simple_roots: np.ndarray, prefer_rust: bool = True):
    """
    Factory function: returns RustChamberIndex if available, else ChamberIndex.

    Args:
        n_heads: number of attention heads
        simple_roots: (4, 4) numpy array of H4 simple roots
        prefer_rust: if True (default), use Rust backend when available

    Returns:
        ChamberIndex or RustChamberIndex instance
    """
    if prefer_rust and RUST_AVAILABLE:
        return RustChamberIndex(n_heads, simple_roots)
    return ChamberIndex(n_heads, simple_roots)