grapheneaffiliates commited on
Commit
4890f37
·
verified ·
1 Parent(s): 9095704

Upload python/utils/chamber_index.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. python/utils/chamber_index.py +304 -0
python/utils/chamber_index.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch-compatible chamber lookup for H4 ChamberTree.
3
+
4
+ Provides a bridge between PyTorch tensors (gradient-tracked) and the
5
+ numpy-based H4ChamberTree (discrete, non-differentiable). The key trick:
6
+
7
+ - ChamberTree does fast O(log t) filtering to find top-k candidate keys
8
+ - We return candidate indices back to PyTorch
9
+ - Attention scores are computed only over candidates (differentiable)
10
+ - Gradients flow through Q/K projections and scores, not through the tree
11
+
12
+ This gives O(k) attention per query where k << t.
13
+
14
+ If the compiled Rust backend (h4_rust) is available, RustChamberIndex provides
15
+ a much faster implementation. Falls back to pure-Python ChamberIndex otherwise.
16
+ """
17
+
18
+ import numpy as np
19
+ import torch
20
+ from typing import List, Tuple, Optional
21
+ import sys
22
+ import os
23
+
24
+ # Rust backend detection — optional, graceful fallback to Python
25
+ try:
26
+ import h4_rust
27
+ RUST_AVAILABLE = True
28
+ except ImportError:
29
+ RUST_AVAILABLE = False
30
+
31
+ # Add parent to path for imports
32
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
33
+ from h4_polytopic_attention import H4ChamberTree, build_coxeter_chambers, generate_600_cell_vertices
34
+
35
+
36
+ class ChamberIndex:
37
+ """
38
+ Manages a set of H4ChamberTrees (one per head) and provides
39
+ batch top-k candidate lookup compatible with PyTorch autograd.
40
+ """
41
+
42
+ def __init__(self, n_heads: int, simple_roots: np.ndarray):
43
+ self.n_heads = n_heads
44
+ self.simple_roots = simple_roots
45
+ self.trees = [H4ChamberTree(simple_roots) for _ in range(n_heads)]
46
+ self._keys_by_head = [[] for _ in range(n_heads)] # track insertion order
47
+
48
+ def reset(self):
49
+ """Clear all trees and rebuild."""
50
+ self.trees = [H4ChamberTree(self.simple_roots) for _ in range(self.n_heads)]
51
+ self._keys_by_head = [[] for _ in range(self.n_heads)]
52
+
53
+ def insert_keys(self, keys: torch.Tensor):
54
+ """
55
+ Insert keys for all heads at current timestep.
56
+
57
+ Args:
58
+ keys: (n_heads, 4) tensor of key vectors to insert
59
+ """
60
+ keys_np = keys.detach().cpu().numpy()
61
+ t = len(self._keys_by_head[0]) # current position index
62
+ for h in range(self.n_heads):
63
+ key = keys_np[h]
64
+ # Use position index as both value and timestamp
65
+ self.trees[h].insert(key, np.array([t], dtype=np.float64), t)
66
+ self._keys_by_head[h].append(key.copy())
67
+
68
+ def bulk_insert(self, keys: torch.Tensor):
69
+ """
70
+ Insert a full sequence of keys for all heads.
71
+
72
+ Args:
73
+ keys: (seq_len, n_heads, 4) tensor of key vectors
74
+ """
75
+ seq_len = keys.shape[0]
76
+ keys_np = keys.detach().cpu().numpy()
77
+ for t in range(seq_len):
78
+ for h in range(self.n_heads):
79
+ key = keys_np[t, h]
80
+ self.trees[h].insert(key, np.array([t], dtype=np.float64), t)
81
+ self._keys_by_head[h].append(key.copy())
82
+
83
+ def query_topk(
84
+ self,
85
+ queries: torch.Tensor,
86
+ k: int,
87
+ causal_mask_pos: Optional[int] = None,
88
+ ) -> List[List[List[int]]]:
89
+ """
90
+ For each query, find top-k candidate key indices using ChamberTree.
91
+
92
+ Args:
93
+ queries: (n_queries, n_heads, 4) tensor of query vectors
94
+ k: number of candidates per query per head
95
+ causal_mask_pos: if set, only return candidates with index <= this value
96
+
97
+ Returns:
98
+ List of shape [n_queries][n_heads][<=k] containing key indices.
99
+ These indices can be used to gather from the full key/value tensors.
100
+ """
101
+ n_queries = queries.shape[0]
102
+ queries_np = queries.detach().cpu().numpy()
103
+ results = []
104
+
105
+ for q_idx in range(n_queries):
106
+ head_results = []
107
+ for h in range(self.n_heads):
108
+ query = queries_np[q_idx, h]
109
+ # Query tree for top candidates
110
+ # Request more than k since some may be filtered by causal mask
111
+ tree_results = self.trees[h].query_max_dot(query, k=k * 2)
112
+
113
+ indices = []
114
+ for score, value, timestamp in tree_results:
115
+ t_idx = int(value[0]) if len(value) > 0 else timestamp
116
+ if causal_mask_pos is not None and t_idx > causal_mask_pos:
117
+ continue
118
+ indices.append(t_idx)
119
+ if len(indices) >= k:
120
+ break
121
+
122
+ # If tree didn't return enough, fall back to scanning
123
+ if len(indices) < k and len(self._keys_by_head[h]) > 0:
124
+ max_pos = causal_mask_pos if causal_mask_pos is not None else len(self._keys_by_head[h]) - 1
125
+ all_keys = np.array(self._keys_by_head[h][:max_pos + 1])
126
+ if len(all_keys) > 0:
127
+ dots = all_keys @ query
128
+ sorted_idx = np.argsort(-dots)
129
+ existing = set(indices)
130
+ for idx in sorted_idx:
131
+ if idx not in existing:
132
+ indices.append(int(idx))
133
+ existing.add(int(idx))
134
+ if len(indices) >= k:
135
+ break
136
+
137
+ head_results.append(indices)
138
+ results.append(head_results)
139
+
140
+ return results
141
+
142
+
143
+ def compute_chamber_ids(keys: torch.Tensor, simple_roots: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ Compute chamber IDs for a batch of keys (differentiable w.r.t. nothing,
146
+ but useful for logging chamber utilization).
147
+
148
+ Args:
149
+ keys: (..., 4) tensor of key vectors
150
+ simple_roots: (4, 4) tensor of H4 simple roots
151
+
152
+ Returns:
153
+ (...,) tensor of integer chamber IDs (0-15 for 4-bit sign pattern)
154
+ """
155
+ # Dot products with all 4 roots: (..., 4)
156
+ dots = keys @ simple_roots.T
157
+ # Sign pattern → 4-bit chamber ID
158
+ signs = (dots >= 0).long()
159
+ ids = signs[..., 0] * 8 + signs[..., 1] * 4 + signs[..., 2] * 2 + signs[..., 3]
160
+ return ids
161
+
162
+
163
+ def chamber_utilization(chamber_ids: torch.Tensor, n_chambers: int = 16) -> dict:
164
+ """
165
+ Compute chamber utilization statistics.
166
+
167
+ Returns:
168
+ Dict with 'counts' (per-chamber), 'entropy' (Shannon entropy),
169
+ and 'max_ratio' (max/mean ratio, 1.0 = perfectly uniform).
170
+ """
171
+ counts = torch.zeros(n_chambers, dtype=torch.long, device=chamber_ids.device)
172
+ flat = chamber_ids.flatten()
173
+ for i in range(n_chambers):
174
+ counts[i] = (flat == i).sum()
175
+
176
+ total = counts.sum().float()
177
+ if total == 0:
178
+ return {'counts': counts, 'entropy': 0.0, 'max_ratio': 0.0}
179
+
180
+ probs = counts.float() / total
181
+ # Shannon entropy (nats)
182
+ log_probs = torch.where(probs > 0, torch.log(probs), torch.zeros_like(probs))
183
+ entropy = -(probs * log_probs).sum().item()
184
+
185
+ mean_count = total / n_chambers
186
+ max_ratio = (counts.max().float() / mean_count).item() if mean_count > 0 else 0.0
187
+
188
+ return {
189
+ 'counts': counts,
190
+ 'entropy': entropy,
191
+ 'max_ratio': max_ratio,
192
+ }
193
+
194
+
195
+ class RustChamberIndex:
196
+ """
197
+ Rust-accelerated chamber index using h4_rust compiled backend.
198
+ API-compatible with ChamberIndex for drop-in replacement.
199
+
200
+ All heavy computation (dot products, sorting, chamber indexing) runs
201
+ in compiled Rust via PyO3/numpy, typically 10-100x faster than Python.
202
+ """
203
+
204
+ def __init__(self, n_heads: int, simple_roots: np.ndarray):
205
+ if not RUST_AVAILABLE:
206
+ raise ImportError("h4_rust is not available. Install with: cd rust && maturin develop --release")
207
+ self.n_heads = n_heads
208
+ self.simple_roots = simple_roots # (4, 4) numpy array
209
+ self._keys_by_head = [[] for _ in range(n_heads)] # list of (4,) arrays per head
210
+
211
+ def reset(self):
212
+ """Clear all stored keys."""
213
+ self._keys_by_head = [[] for _ in range(self.n_heads)]
214
+
215
+ def insert_keys(self, keys: torch.Tensor):
216
+ """
217
+ Insert keys for all heads at current timestep.
218
+
219
+ Args:
220
+ keys: (n_heads, 4) tensor of key vectors to insert
221
+ """
222
+ keys_np = keys.detach().cpu().numpy()
223
+ for h in range(self.n_heads):
224
+ self._keys_by_head[h].append(keys_np[h].copy())
225
+
226
+ def bulk_insert(self, keys: torch.Tensor):
227
+ """
228
+ Insert a full sequence of keys for all heads.
229
+
230
+ Args:
231
+ keys: (seq_len, n_heads, 4) tensor of key vectors
232
+ """
233
+ keys_np = keys.detach().cpu().numpy()
234
+ seq_len = keys_np.shape[0]
235
+ for t in range(seq_len):
236
+ for h in range(self.n_heads):
237
+ self._keys_by_head[h].append(keys_np[t, h].copy())
238
+
239
+ def query_topk(
240
+ self,
241
+ queries: torch.Tensor,
242
+ k: int,
243
+ causal_mask_pos: Optional[int] = None,
244
+ ) -> List[List[List[int]]]:
245
+ """
246
+ For each query, find top-k candidate key indices using Rust backend.
247
+
248
+ Args:
249
+ queries: (n_queries, n_heads, 4) tensor of query vectors
250
+ k: number of candidates per query per head
251
+ causal_mask_pos: if set, only consider keys with index <= this value
252
+
253
+ Returns:
254
+ List of shape [n_queries][n_heads][<=k] containing key indices.
255
+ """
256
+ n_queries = queries.shape[0]
257
+ queries_np = queries.detach().cpu().numpy()
258
+ results = []
259
+
260
+ for q_idx in range(n_queries):
261
+ head_results = []
262
+ for h in range(self.n_heads):
263
+ n_keys = len(self._keys_by_head[h])
264
+ if n_keys == 0:
265
+ head_results.append([])
266
+ continue
267
+
268
+ # Apply causal mask: only use keys up to causal_mask_pos
269
+ max_pos = causal_mask_pos if causal_mask_pos is not None else n_keys - 1
270
+ effective_n = min(n_keys, max_pos + 1)
271
+
272
+ if effective_n == 0:
273
+ head_results.append([])
274
+ continue
275
+
276
+ keys_arr = np.array(self._keys_by_head[h][:effective_n], dtype=np.float64)
277
+ query_arr = queries_np[q_idx, h:h+1].astype(np.float64)
278
+
279
+ actual_k = min(k, effective_n)
280
+ indices = h4_rust.query_topk(keys_arr, query_arr, actual_k)
281
+ # indices is (1, actual_k), extract the list and filter -1s
282
+ idx_list = [int(i) for i in indices[0] if i >= 0]
283
+ head_results.append(idx_list)
284
+
285
+ results.append(head_results)
286
+
287
+ return results
288
+
289
+
290
+ def get_chamber_index(n_heads: int, simple_roots: np.ndarray, prefer_rust: bool = True):
291
+ """
292
+ Factory function: returns RustChamberIndex if available, else ChamberIndex.
293
+
294
+ Args:
295
+ n_heads: number of attention heads
296
+ simple_roots: (4, 4) numpy array of H4 simple roots
297
+ prefer_rust: if True (default), use Rust backend when available
298
+
299
+ Returns:
300
+ ChamberIndex or RustChamberIndex instance
301
+ """
302
+ if prefer_rust and RUST_AVAILABLE:
303
+ return RustChamberIndex(n_heads, simple_roots)
304
+ return ChamberIndex(n_heads, simple_roots)