| """ |
| PyTorch-compatible chamber lookup for H4 ChamberTree. |
| |
| Provides a bridge between PyTorch tensors (gradient-tracked) and the |
| numpy-based H4ChamberTree (discrete, non-differentiable). The key trick: |
| |
| - ChamberTree does fast O(log t) filtering to find top-k candidate keys |
| - We return candidate indices back to PyTorch |
| - Attention scores are computed only over candidates (differentiable) |
| - Gradients flow through Q/K projections and scores, not through the tree |
| |
| This gives O(k) attention per query where k << t. |
| |
| If the compiled Rust backend (h4_rust) is available, RustChamberIndex provides |
| a much faster implementation. Falls back to pure-Python ChamberIndex otherwise. |
| """ |
|
|
| import numpy as np |
| import torch |
| from typing import List, Tuple, Optional |
| import sys |
| import os |
|
|
| |
| try: |
| import h4_rust |
| RUST_AVAILABLE = True |
| except ImportError: |
| RUST_AVAILABLE = False |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from h4_polytopic_attention import H4ChamberTree, build_coxeter_chambers, generate_600_cell_vertices |
|
|
|
|
| class ChamberIndex: |
| """ |
| Manages a set of H4ChamberTrees (one per head) and provides |
| batch top-k candidate lookup compatible with PyTorch autograd. |
| """ |
|
|
| def __init__(self, n_heads: int, simple_roots: np.ndarray): |
| self.n_heads = n_heads |
| self.simple_roots = simple_roots |
| self.trees = [H4ChamberTree(simple_roots) for _ in range(n_heads)] |
| self._keys_by_head = [[] for _ in range(n_heads)] |
|
|
| def reset(self): |
| """Clear all trees and rebuild.""" |
| self.trees = [H4ChamberTree(self.simple_roots) for _ in range(self.n_heads)] |
| self._keys_by_head = [[] for _ in range(self.n_heads)] |
|
|
| def insert_keys(self, keys: torch.Tensor): |
| """ |
| Insert keys for all heads at current timestep. |
| |
| Args: |
| keys: (n_heads, 4) tensor of key vectors to insert |
| """ |
| keys_np = keys.detach().cpu().numpy() |
| t = len(self._keys_by_head[0]) |
| for h in range(self.n_heads): |
| key = keys_np[h] |
| |
| self.trees[h].insert(key, np.array([t], dtype=np.float64), t) |
| self._keys_by_head[h].append(key.copy()) |
|
|
| def bulk_insert(self, keys: torch.Tensor): |
| """ |
| Insert a full sequence of keys for all heads. |
| |
| Args: |
| keys: (seq_len, n_heads, 4) tensor of key vectors |
| """ |
| seq_len = keys.shape[0] |
| keys_np = keys.detach().cpu().numpy() |
| for t in range(seq_len): |
| for h in range(self.n_heads): |
| key = keys_np[t, h] |
| self.trees[h].insert(key, np.array([t], dtype=np.float64), t) |
| self._keys_by_head[h].append(key.copy()) |
|
|
| def query_topk( |
| self, |
| queries: torch.Tensor, |
| k: int, |
| causal_mask_pos: Optional[int] = None, |
| ) -> List[List[List[int]]]: |
| """ |
| For each query, find top-k candidate key indices using ChamberTree. |
| |
| Args: |
| queries: (n_queries, n_heads, 4) tensor of query vectors |
| k: number of candidates per query per head |
| causal_mask_pos: if set, only return candidates with index <= this value |
| |
| Returns: |
| List of shape [n_queries][n_heads][<=k] containing key indices. |
| These indices can be used to gather from the full key/value tensors. |
| """ |
| n_queries = queries.shape[0] |
| queries_np = queries.detach().cpu().numpy() |
| results = [] |
|
|
| for q_idx in range(n_queries): |
| head_results = [] |
| for h in range(self.n_heads): |
| query = queries_np[q_idx, h] |
| |
| |
| tree_results = self.trees[h].query_max_dot(query, k=k * 2) |
|
|
| indices = [] |
| for score, value, timestamp in tree_results: |
| t_idx = int(value[0]) if len(value) > 0 else timestamp |
| if causal_mask_pos is not None and t_idx > causal_mask_pos: |
| continue |
| indices.append(t_idx) |
| if len(indices) >= k: |
| break |
|
|
| |
| if len(indices) < k and len(self._keys_by_head[h]) > 0: |
| max_pos = causal_mask_pos if causal_mask_pos is not None else len(self._keys_by_head[h]) - 1 |
| all_keys = np.array(self._keys_by_head[h][:max_pos + 1]) |
| if len(all_keys) > 0: |
| dots = all_keys @ query |
| sorted_idx = np.argsort(-dots) |
| existing = set(indices) |
| for idx in sorted_idx: |
| if idx not in existing: |
| indices.append(int(idx)) |
| existing.add(int(idx)) |
| if len(indices) >= k: |
| break |
|
|
| head_results.append(indices) |
| results.append(head_results) |
|
|
| return results |
|
|
|
|
| def compute_chamber_ids(keys: torch.Tensor, simple_roots: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute chamber IDs for a batch of keys (differentiable w.r.t. nothing, |
| but useful for logging chamber utilization). |
| |
| Args: |
| keys: (..., 4) tensor of key vectors |
| simple_roots: (4, 4) tensor of H4 simple roots |
| |
| Returns: |
| (...,) tensor of integer chamber IDs (0-15 for 4-bit sign pattern) |
| """ |
| |
| dots = keys @ simple_roots.T |
| |
| signs = (dots >= 0).long() |
| ids = signs[..., 0] * 8 + signs[..., 1] * 4 + signs[..., 2] * 2 + signs[..., 3] |
| return ids |
|
|
|
|
| def chamber_utilization(chamber_ids: torch.Tensor, n_chambers: int = 16) -> dict: |
| """ |
| Compute chamber utilization statistics. |
| |
| Returns: |
| Dict with 'counts' (per-chamber), 'entropy' (Shannon entropy), |
| and 'max_ratio' (max/mean ratio, 1.0 = perfectly uniform). |
| """ |
| counts = torch.zeros(n_chambers, dtype=torch.long, device=chamber_ids.device) |
| flat = chamber_ids.flatten() |
| for i in range(n_chambers): |
| counts[i] = (flat == i).sum() |
|
|
| total = counts.sum().float() |
| if total == 0: |
| return {'counts': counts, 'entropy': 0.0, 'max_ratio': 0.0} |
|
|
| probs = counts.float() / total |
| |
| log_probs = torch.where(probs > 0, torch.log(probs), torch.zeros_like(probs)) |
| entropy = -(probs * log_probs).sum().item() |
|
|
| mean_count = total / n_chambers |
| max_ratio = (counts.max().float() / mean_count).item() if mean_count > 0 else 0.0 |
|
|
| return { |
| 'counts': counts, |
| 'entropy': entropy, |
| 'max_ratio': max_ratio, |
| } |
|
|
|
|
| class RustChamberIndex: |
| """ |
| Rust-accelerated chamber index using h4_rust compiled backend. |
| API-compatible with ChamberIndex for drop-in replacement. |
| |
| All heavy computation (dot products, sorting, chamber indexing) runs |
| in compiled Rust via PyO3/numpy, typically 10-100x faster than Python. |
| """ |
|
|
| def __init__(self, n_heads: int, simple_roots: np.ndarray): |
| if not RUST_AVAILABLE: |
| raise ImportError("h4_rust is not available. Install with: cd rust && maturin develop --release") |
| self.n_heads = n_heads |
| self.simple_roots = simple_roots |
| self._keys_by_head = [[] for _ in range(n_heads)] |
|
|
| def reset(self): |
| """Clear all stored keys.""" |
| self._keys_by_head = [[] for _ in range(self.n_heads)] |
|
|
| def insert_keys(self, keys: torch.Tensor): |
| """ |
| Insert keys for all heads at current timestep. |
| |
| Args: |
| keys: (n_heads, 4) tensor of key vectors to insert |
| """ |
| keys_np = keys.detach().cpu().numpy() |
| for h in range(self.n_heads): |
| self._keys_by_head[h].append(keys_np[h].copy()) |
|
|
| def bulk_insert(self, keys: torch.Tensor): |
| """ |
| Insert a full sequence of keys for all heads. |
| |
| Args: |
| keys: (seq_len, n_heads, 4) tensor of key vectors |
| """ |
| keys_np = keys.detach().cpu().numpy() |
| seq_len = keys_np.shape[0] |
| for t in range(seq_len): |
| for h in range(self.n_heads): |
| self._keys_by_head[h].append(keys_np[t, h].copy()) |
|
|
| def query_topk( |
| self, |
| queries: torch.Tensor, |
| k: int, |
| causal_mask_pos: Optional[int] = None, |
| ) -> List[List[List[int]]]: |
| """ |
| For each query, find top-k candidate key indices using Rust backend. |
| |
| Args: |
| queries: (n_queries, n_heads, 4) tensor of query vectors |
| k: number of candidates per query per head |
| causal_mask_pos: if set, only consider keys with index <= this value |
| |
| Returns: |
| List of shape [n_queries][n_heads][<=k] containing key indices. |
| """ |
| n_queries = queries.shape[0] |
| queries_np = queries.detach().cpu().numpy() |
| results = [] |
|
|
| for q_idx in range(n_queries): |
| head_results = [] |
| for h in range(self.n_heads): |
| n_keys = len(self._keys_by_head[h]) |
| if n_keys == 0: |
| head_results.append([]) |
| continue |
|
|
| |
| max_pos = causal_mask_pos if causal_mask_pos is not None else n_keys - 1 |
| effective_n = min(n_keys, max_pos + 1) |
|
|
| if effective_n == 0: |
| head_results.append([]) |
| continue |
|
|
| keys_arr = np.array(self._keys_by_head[h][:effective_n], dtype=np.float64) |
| query_arr = queries_np[q_idx, h:h+1].astype(np.float64) |
|
|
| actual_k = min(k, effective_n) |
| indices = h4_rust.query_topk(keys_arr, query_arr, actual_k) |
| |
| idx_list = [int(i) for i in indices[0] if i >= 0] |
| head_results.append(idx_list) |
|
|
| results.append(head_results) |
|
|
| return results |
|
|
|
|
| def get_chamber_index(n_heads: int, simple_roots: np.ndarray, prefer_rust: bool = True): |
| """ |
| Factory function: returns RustChamberIndex if available, else ChamberIndex. |
| |
| Args: |
| n_heads: number of attention heads |
| simple_roots: (4, 4) numpy array of H4 simple roots |
| prefer_rust: if True (default), use Rust backend when available |
| |
| Returns: |
| ChamberIndex or RustChamberIndex instance |
| """ |
| if prefer_rust and RUST_AVAILABLE: |
| return RustChamberIndex(n_heads, simple_roots) |
| return ChamberIndex(n_heads, simple_roots) |
|
|