h4-polytopic-attention / python /utils /chamber_index.py
grapheneaffiliates's picture
Upload python/utils/chamber_index.py with huggingface_hub
4890f37 verified
"""
PyTorch-compatible chamber lookup for H4 ChamberTree.
Provides a bridge between PyTorch tensors (gradient-tracked) and the
numpy-based H4ChamberTree (discrete, non-differentiable). The key trick:
- ChamberTree does fast O(log t) filtering to find top-k candidate keys
- We return candidate indices back to PyTorch
- Attention scores are computed only over candidates (differentiable)
- Gradients flow through Q/K projections and scores, not through the tree
This gives O(k) attention per query where k << t.
If the compiled Rust backend (h4_rust) is available, RustChamberIndex provides
a much faster implementation. Falls back to pure-Python ChamberIndex otherwise.
"""
import numpy as np
import torch
from typing import List, Tuple, Optional
import sys
import os
# Rust backend detection — optional, graceful fallback to Python
try:
import h4_rust
RUST_AVAILABLE = True
except ImportError:
RUST_AVAILABLE = False
# Add parent to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from h4_polytopic_attention import H4ChamberTree, build_coxeter_chambers, generate_600_cell_vertices
class ChamberIndex:
"""
Manages a set of H4ChamberTrees (one per head) and provides
batch top-k candidate lookup compatible with PyTorch autograd.
"""
def __init__(self, n_heads: int, simple_roots: np.ndarray):
self.n_heads = n_heads
self.simple_roots = simple_roots
self.trees = [H4ChamberTree(simple_roots) for _ in range(n_heads)]
self._keys_by_head = [[] for _ in range(n_heads)] # track insertion order
def reset(self):
"""Clear all trees and rebuild."""
self.trees = [H4ChamberTree(self.simple_roots) for _ in range(self.n_heads)]
self._keys_by_head = [[] for _ in range(self.n_heads)]
def insert_keys(self, keys: torch.Tensor):
"""
Insert keys for all heads at current timestep.
Args:
keys: (n_heads, 4) tensor of key vectors to insert
"""
keys_np = keys.detach().cpu().numpy()
t = len(self._keys_by_head[0]) # current position index
for h in range(self.n_heads):
key = keys_np[h]
# Use position index as both value and timestamp
self.trees[h].insert(key, np.array([t], dtype=np.float64), t)
self._keys_by_head[h].append(key.copy())
def bulk_insert(self, keys: torch.Tensor):
"""
Insert a full sequence of keys for all heads.
Args:
keys: (seq_len, n_heads, 4) tensor of key vectors
"""
seq_len = keys.shape[0]
keys_np = keys.detach().cpu().numpy()
for t in range(seq_len):
for h in range(self.n_heads):
key = keys_np[t, h]
self.trees[h].insert(key, np.array([t], dtype=np.float64), t)
self._keys_by_head[h].append(key.copy())
def query_topk(
self,
queries: torch.Tensor,
k: int,
causal_mask_pos: Optional[int] = None,
) -> List[List[List[int]]]:
"""
For each query, find top-k candidate key indices using ChamberTree.
Args:
queries: (n_queries, n_heads, 4) tensor of query vectors
k: number of candidates per query per head
causal_mask_pos: if set, only return candidates with index <= this value
Returns:
List of shape [n_queries][n_heads][<=k] containing key indices.
These indices can be used to gather from the full key/value tensors.
"""
n_queries = queries.shape[0]
queries_np = queries.detach().cpu().numpy()
results = []
for q_idx in range(n_queries):
head_results = []
for h in range(self.n_heads):
query = queries_np[q_idx, h]
# Query tree for top candidates
# Request more than k since some may be filtered by causal mask
tree_results = self.trees[h].query_max_dot(query, k=k * 2)
indices = []
for score, value, timestamp in tree_results:
t_idx = int(value[0]) if len(value) > 0 else timestamp
if causal_mask_pos is not None and t_idx > causal_mask_pos:
continue
indices.append(t_idx)
if len(indices) >= k:
break
# If tree didn't return enough, fall back to scanning
if len(indices) < k and len(self._keys_by_head[h]) > 0:
max_pos = causal_mask_pos if causal_mask_pos is not None else len(self._keys_by_head[h]) - 1
all_keys = np.array(self._keys_by_head[h][:max_pos + 1])
if len(all_keys) > 0:
dots = all_keys @ query
sorted_idx = np.argsort(-dots)
existing = set(indices)
for idx in sorted_idx:
if idx not in existing:
indices.append(int(idx))
existing.add(int(idx))
if len(indices) >= k:
break
head_results.append(indices)
results.append(head_results)
return results
def compute_chamber_ids(keys: torch.Tensor, simple_roots: torch.Tensor) -> torch.Tensor:
"""
Compute chamber IDs for a batch of keys (differentiable w.r.t. nothing,
but useful for logging chamber utilization).
Args:
keys: (..., 4) tensor of key vectors
simple_roots: (4, 4) tensor of H4 simple roots
Returns:
(...,) tensor of integer chamber IDs (0-15 for 4-bit sign pattern)
"""
# Dot products with all 4 roots: (..., 4)
dots = keys @ simple_roots.T
# Sign pattern → 4-bit chamber ID
signs = (dots >= 0).long()
ids = signs[..., 0] * 8 + signs[..., 1] * 4 + signs[..., 2] * 2 + signs[..., 3]
return ids
def chamber_utilization(chamber_ids: torch.Tensor, n_chambers: int = 16) -> dict:
"""
Compute chamber utilization statistics.
Returns:
Dict with 'counts' (per-chamber), 'entropy' (Shannon entropy),
and 'max_ratio' (max/mean ratio, 1.0 = perfectly uniform).
"""
counts = torch.zeros(n_chambers, dtype=torch.long, device=chamber_ids.device)
flat = chamber_ids.flatten()
for i in range(n_chambers):
counts[i] = (flat == i).sum()
total = counts.sum().float()
if total == 0:
return {'counts': counts, 'entropy': 0.0, 'max_ratio': 0.0}
probs = counts.float() / total
# Shannon entropy (nats)
log_probs = torch.where(probs > 0, torch.log(probs), torch.zeros_like(probs))
entropy = -(probs * log_probs).sum().item()
mean_count = total / n_chambers
max_ratio = (counts.max().float() / mean_count).item() if mean_count > 0 else 0.0
return {
'counts': counts,
'entropy': entropy,
'max_ratio': max_ratio,
}
class RustChamberIndex:
"""
Rust-accelerated chamber index using h4_rust compiled backend.
API-compatible with ChamberIndex for drop-in replacement.
All heavy computation (dot products, sorting, chamber indexing) runs
in compiled Rust via PyO3/numpy, typically 10-100x faster than Python.
"""
def __init__(self, n_heads: int, simple_roots: np.ndarray):
if not RUST_AVAILABLE:
raise ImportError("h4_rust is not available. Install with: cd rust && maturin develop --release")
self.n_heads = n_heads
self.simple_roots = simple_roots # (4, 4) numpy array
self._keys_by_head = [[] for _ in range(n_heads)] # list of (4,) arrays per head
def reset(self):
"""Clear all stored keys."""
self._keys_by_head = [[] for _ in range(self.n_heads)]
def insert_keys(self, keys: torch.Tensor):
"""
Insert keys for all heads at current timestep.
Args:
keys: (n_heads, 4) tensor of key vectors to insert
"""
keys_np = keys.detach().cpu().numpy()
for h in range(self.n_heads):
self._keys_by_head[h].append(keys_np[h].copy())
def bulk_insert(self, keys: torch.Tensor):
"""
Insert a full sequence of keys for all heads.
Args:
keys: (seq_len, n_heads, 4) tensor of key vectors
"""
keys_np = keys.detach().cpu().numpy()
seq_len = keys_np.shape[0]
for t in range(seq_len):
for h in range(self.n_heads):
self._keys_by_head[h].append(keys_np[t, h].copy())
def query_topk(
self,
queries: torch.Tensor,
k: int,
causal_mask_pos: Optional[int] = None,
) -> List[List[List[int]]]:
"""
For each query, find top-k candidate key indices using Rust backend.
Args:
queries: (n_queries, n_heads, 4) tensor of query vectors
k: number of candidates per query per head
causal_mask_pos: if set, only consider keys with index <= this value
Returns:
List of shape [n_queries][n_heads][<=k] containing key indices.
"""
n_queries = queries.shape[0]
queries_np = queries.detach().cpu().numpy()
results = []
for q_idx in range(n_queries):
head_results = []
for h in range(self.n_heads):
n_keys = len(self._keys_by_head[h])
if n_keys == 0:
head_results.append([])
continue
# Apply causal mask: only use keys up to causal_mask_pos
max_pos = causal_mask_pos if causal_mask_pos is not None else n_keys - 1
effective_n = min(n_keys, max_pos + 1)
if effective_n == 0:
head_results.append([])
continue
keys_arr = np.array(self._keys_by_head[h][:effective_n], dtype=np.float64)
query_arr = queries_np[q_idx, h:h+1].astype(np.float64)
actual_k = min(k, effective_n)
indices = h4_rust.query_topk(keys_arr, query_arr, actual_k)
# indices is (1, actual_k), extract the list and filter -1s
idx_list = [int(i) for i in indices[0] if i >= 0]
head_results.append(idx_list)
results.append(head_results)
return results
def get_chamber_index(n_heads: int, simple_roots: np.ndarray, prefer_rust: bool = True):
"""
Factory function: returns RustChamberIndex if available, else ChamberIndex.
Args:
n_heads: number of attention heads
simple_roots: (4, 4) numpy array of H4 simple roots
prefer_rust: if True (default), use Rust backend when available
Returns:
ChamberIndex or RustChamberIndex instance
"""
if prefer_rust and RUST_AVAILABLE:
return RustChamberIndex(n_heads, simple_roots)
return ChamberIndex(n_heads, simple_roots)