File size: 8,536 Bytes
e161c3a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | """
H4 Document Encoder — Encode text into E8 lattice memory for geometric retrieval.
Each document chunk becomes an 8D embedding stored in a Voronoi cell.
The encoding uses golden-angle spiral placement based on token content,
ensuring that retrieval and attention share the same geometric space
through the E8→H4 projection (cos(π/5) = φ/2).
No separate embedding model needed. The same geometry handles both.
"""
import math
import numpy as np
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from h4_polytopic_attention import E8LatticeIndex, PHI, PHI_INV
@dataclass
class Chunk:
"""A document chunk stored in E8 lattice memory."""
text: str
doc_id: str
chunk_idx: int
token_ids: List[int]
class H4DocumentEncoder:
"""
Encode documents into E8 lattice memory for retrieval.
Each chunk becomes an 8D embedding stored in an E8 Voronoi cell.
The encoding uses golden-angle spiral placement based on token
frequencies, ensuring geometric consistency with H4 attention.
The 8D embedding captures:
- dims 0-3: semantic content (weighted token frequency features)
- dims 4-7: positional/structural features (chunk position, doc features)
"""
def __init__(self, stoi: Dict[str, int], chunk_size: int = 256, overlap: int = 64):
"""
Args:
stoi: string-to-index vocabulary mapping
chunk_size: tokens per chunk
overlap: overlap between consecutive chunks
"""
self.stoi = stoi
self.itos = {v: k for k, v in stoi.items()}
self.vocab_size = len(stoi)
self.chunk_size = chunk_size
self.overlap = overlap
self.lattice = E8LatticeIndex(max_cell_size=240)
self.chunks: List[Chunk] = []
self._address_counter = 0
# Precompute per-token 8D embeddings using golden-angle spiral
self._token_embeddings = self._build_token_embeddings()
def _build_token_embeddings(self) -> np.ndarray:
"""Build 8D embeddings for each token using golden-angle spiral."""
embs = np.zeros((self.vocab_size, 8))
for i in range(self.vocab_size):
# Golden-angle placement in 8D: pairs of (cos, sin) at φ-scaled frequencies
for d in range(4):
angle = i * 2 * math.pi * PHI_INV * (PHI ** (-d / 4))
embs[i, 2 * d] = math.cos(angle)
embs[i, 2 * d + 1] = math.sin(angle)
# Normalize
norm = np.linalg.norm(embs[i])
if norm > 1e-10:
embs[i] /= norm
return embs
def _text_to_tokens(self, text: str) -> List[int]:
"""Convert text to token IDs using the vocabulary."""
return [self.stoi.get(c, 0) for c in text]
def _chunk_tokens(self, token_ids: List[int]) -> List[List[int]]:
"""Split token list into overlapping chunks."""
chunks = []
start = 0
while start < len(token_ids):
end = min(start + self.chunk_size, len(token_ids))
chunks.append(token_ids[start:end])
start += self.chunk_size - self.overlap
if end == len(token_ids):
break
return chunks
def _embed_chunk(self, token_ids: List[int], chunk_idx: int, n_chunks: int) -> np.ndarray:
"""
Compute 8D embedding for a chunk.
Combines token frequency features (dims 0-3) with
positional features (dims 4-7).
"""
emb = np.zeros(8)
# Semantic: weighted average of token embeddings
if token_ids:
token_embs = self._token_embeddings[token_ids] # (n_tokens, 8)
# Weight by position in chunk (later tokens slightly higher weight)
weights = np.array([1.0 + 0.5 * PHI_INV * (i / len(token_ids))
for i in range(len(token_ids))])
weights /= weights.sum()
emb[:4] = (token_embs[:, :4].T @ weights)
# Positional: chunk position and length features
pos_frac = chunk_idx / max(n_chunks - 1, 1) if n_chunks > 1 else 0.5
len_frac = len(token_ids) / self.chunk_size
angle1 = pos_frac * 2 * math.pi * PHI_INV
angle2 = len_frac * math.pi * PHI_INV
emb[4] = math.cos(angle1)
emb[5] = math.sin(angle1)
emb[6] = math.cos(angle2)
emb[7] = math.sin(angle2)
# Normalize to unit sphere
norm = np.linalg.norm(emb)
if norm > 1e-10:
emb /= norm
return emb
def encode_document(self, text: str, doc_id: str):
"""
Chunk document, encode each chunk as 8D vector, store in E8 lattice.
Args:
text: document text
doc_id: unique identifier for the document
"""
token_ids = self._text_to_tokens(text)
token_chunks = self._chunk_tokens(token_ids)
n_chunks = len(token_chunks)
for i, chunk_tokens in enumerate(token_chunks):
# Compute 8D embedding
embedding = self._embed_chunk(chunk_tokens, i, n_chunks)
# Store in E8 lattice
address = self._address_counter
self._address_counter += 1
self.lattice.insert(embedding, value=float(address), address=address)
# Store chunk metadata
chunk_text = ''.join(self.itos.get(t, '?') for t in chunk_tokens)
self.chunks.append(Chunk(
text=chunk_text,
doc_id=doc_id,
chunk_idx=i,
token_ids=chunk_tokens,
))
def retrieve(self, query_text: str, k: int = 5) -> List[Tuple[Chunk, float]]:
"""
Encode query as 8D vector, find k nearest chunks in E8 lattice.
Returns:
List of (chunk, distance) tuples sorted by E8 distance.
"""
query_tokens = self._text_to_tokens(query_text)
query_emb = self._embed_chunk(query_tokens, 0, 1)
results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True)
retrieved = []
for dist_sq, value, addr in results:
idx = int(addr)
if idx < len(self.chunks):
retrieved.append((self.chunks[idx], dist_sq))
return retrieved
def retrieve_with_h4(self, query_text: str, k: int = 5):
"""
Retrieve chunks AND return their H4 projections for attention.
Returns:
chunks: list of Chunk objects
h4_keys: (k, 4) array — 4D projections for direct attention use
e8_embeddings: (k, 8) array — full 8D embeddings
distances: (k,) array — E8 distances to query
"""
query_tokens = self._text_to_tokens(query_text)
query_emb = self._embed_chunk(query_tokens, 0, 1)
results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True)
chunks = []
h4_keys = []
e8_embs = []
distances = []
for dist_sq, value, addr in results:
idx = int(addr)
if idx < len(self.chunks):
chunk = self.chunks[idx]
chunks.append(chunk)
distances.append(dist_sq)
# Reconstruct 8D embedding
emb = self._embed_chunk(
chunk.token_ids, chunk.chunk_idx,
sum(1 for c in self.chunks if c.doc_id == chunk.doc_id)
)
e8_embs.append(emb)
# Project to H4 via E8→H4 projection
h4_key = self.lattice.project_to_h4(emb)
h4_keys.append(h4_key)
return (
chunks,
np.array(h4_keys) if h4_keys else np.zeros((0, 4)),
np.array(e8_embs) if e8_embs else np.zeros((0, 8)),
np.array(distances) if distances else np.zeros(0),
)
def stats(self) -> Dict:
"""Return encoder statistics."""
lattice_stats = self.lattice.stats()
return {
'n_chunks': len(self.chunks),
'n_documents': len(set(c.doc_id for c in self.chunks)),
'lattice_cells': lattice_stats.get('occupied_cells', 0),
'lattice_utilization': lattice_stats.get('utilization', 0),
'avg_chunk_len': (
sum(len(c.token_ids) for c in self.chunks) / len(self.chunks)
if self.chunks else 0
),
}
|