grapheneaffiliates's picture
Upload python/rag/encoder.py with huggingface_hub
e161c3a verified
"""
H4 Document Encoder — Encode text into E8 lattice memory for geometric retrieval.
Each document chunk becomes an 8D embedding stored in a Voronoi cell.
The encoding uses golden-angle spiral placement based on token content,
ensuring that retrieval and attention share the same geometric space
through the E8→H4 projection (cos(π/5) = φ/2).
No separate embedding model needed. The same geometry handles both.
"""
import math
import numpy as np
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from h4_polytopic_attention import E8LatticeIndex, PHI, PHI_INV
@dataclass
class Chunk:
"""A document chunk stored in E8 lattice memory."""
text: str
doc_id: str
chunk_idx: int
token_ids: List[int]
class H4DocumentEncoder:
"""
Encode documents into E8 lattice memory for retrieval.
Each chunk becomes an 8D embedding stored in an E8 Voronoi cell.
The encoding uses golden-angle spiral placement based on token
frequencies, ensuring geometric consistency with H4 attention.
The 8D embedding captures:
- dims 0-3: semantic content (weighted token frequency features)
- dims 4-7: positional/structural features (chunk position, doc features)
"""
def __init__(self, stoi: Dict[str, int], chunk_size: int = 256, overlap: int = 64):
"""
Args:
stoi: string-to-index vocabulary mapping
chunk_size: tokens per chunk
overlap: overlap between consecutive chunks
"""
self.stoi = stoi
self.itos = {v: k for k, v in stoi.items()}
self.vocab_size = len(stoi)
self.chunk_size = chunk_size
self.overlap = overlap
self.lattice = E8LatticeIndex(max_cell_size=240)
self.chunks: List[Chunk] = []
self._address_counter = 0
# Precompute per-token 8D embeddings using golden-angle spiral
self._token_embeddings = self._build_token_embeddings()
def _build_token_embeddings(self) -> np.ndarray:
"""Build 8D embeddings for each token using golden-angle spiral."""
embs = np.zeros((self.vocab_size, 8))
for i in range(self.vocab_size):
# Golden-angle placement in 8D: pairs of (cos, sin) at φ-scaled frequencies
for d in range(4):
angle = i * 2 * math.pi * PHI_INV * (PHI ** (-d / 4))
embs[i, 2 * d] = math.cos(angle)
embs[i, 2 * d + 1] = math.sin(angle)
# Normalize
norm = np.linalg.norm(embs[i])
if norm > 1e-10:
embs[i] /= norm
return embs
def _text_to_tokens(self, text: str) -> List[int]:
"""Convert text to token IDs using the vocabulary."""
return [self.stoi.get(c, 0) for c in text]
def _chunk_tokens(self, token_ids: List[int]) -> List[List[int]]:
"""Split token list into overlapping chunks."""
chunks = []
start = 0
while start < len(token_ids):
end = min(start + self.chunk_size, len(token_ids))
chunks.append(token_ids[start:end])
start += self.chunk_size - self.overlap
if end == len(token_ids):
break
return chunks
def _embed_chunk(self, token_ids: List[int], chunk_idx: int, n_chunks: int) -> np.ndarray:
"""
Compute 8D embedding for a chunk.
Combines token frequency features (dims 0-3) with
positional features (dims 4-7).
"""
emb = np.zeros(8)
# Semantic: weighted average of token embeddings
if token_ids:
token_embs = self._token_embeddings[token_ids] # (n_tokens, 8)
# Weight by position in chunk (later tokens slightly higher weight)
weights = np.array([1.0 + 0.5 * PHI_INV * (i / len(token_ids))
for i in range(len(token_ids))])
weights /= weights.sum()
emb[:4] = (token_embs[:, :4].T @ weights)
# Positional: chunk position and length features
pos_frac = chunk_idx / max(n_chunks - 1, 1) if n_chunks > 1 else 0.5
len_frac = len(token_ids) / self.chunk_size
angle1 = pos_frac * 2 * math.pi * PHI_INV
angle2 = len_frac * math.pi * PHI_INV
emb[4] = math.cos(angle1)
emb[5] = math.sin(angle1)
emb[6] = math.cos(angle2)
emb[7] = math.sin(angle2)
# Normalize to unit sphere
norm = np.linalg.norm(emb)
if norm > 1e-10:
emb /= norm
return emb
def encode_document(self, text: str, doc_id: str):
"""
Chunk document, encode each chunk as 8D vector, store in E8 lattice.
Args:
text: document text
doc_id: unique identifier for the document
"""
token_ids = self._text_to_tokens(text)
token_chunks = self._chunk_tokens(token_ids)
n_chunks = len(token_chunks)
for i, chunk_tokens in enumerate(token_chunks):
# Compute 8D embedding
embedding = self._embed_chunk(chunk_tokens, i, n_chunks)
# Store in E8 lattice
address = self._address_counter
self._address_counter += 1
self.lattice.insert(embedding, value=float(address), address=address)
# Store chunk metadata
chunk_text = ''.join(self.itos.get(t, '?') for t in chunk_tokens)
self.chunks.append(Chunk(
text=chunk_text,
doc_id=doc_id,
chunk_idx=i,
token_ids=chunk_tokens,
))
def retrieve(self, query_text: str, k: int = 5) -> List[Tuple[Chunk, float]]:
"""
Encode query as 8D vector, find k nearest chunks in E8 lattice.
Returns:
List of (chunk, distance) tuples sorted by E8 distance.
"""
query_tokens = self._text_to_tokens(query_text)
query_emb = self._embed_chunk(query_tokens, 0, 1)
results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True)
retrieved = []
for dist_sq, value, addr in results:
idx = int(addr)
if idx < len(self.chunks):
retrieved.append((self.chunks[idx], dist_sq))
return retrieved
def retrieve_with_h4(self, query_text: str, k: int = 5):
"""
Retrieve chunks AND return their H4 projections for attention.
Returns:
chunks: list of Chunk objects
h4_keys: (k, 4) array — 4D projections for direct attention use
e8_embeddings: (k, 8) array — full 8D embeddings
distances: (k,) array — E8 distances to query
"""
query_tokens = self._text_to_tokens(query_text)
query_emb = self._embed_chunk(query_tokens, 0, 1)
results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True)
chunks = []
h4_keys = []
e8_embs = []
distances = []
for dist_sq, value, addr in results:
idx = int(addr)
if idx < len(self.chunks):
chunk = self.chunks[idx]
chunks.append(chunk)
distances.append(dist_sq)
# Reconstruct 8D embedding
emb = self._embed_chunk(
chunk.token_ids, chunk.chunk_idx,
sum(1 for c in self.chunks if c.doc_id == chunk.doc_id)
)
e8_embs.append(emb)
# Project to H4 via E8→H4 projection
h4_key = self.lattice.project_to_h4(emb)
h4_keys.append(h4_key)
return (
chunks,
np.array(h4_keys) if h4_keys else np.zeros((0, 4)),
np.array(e8_embs) if e8_embs else np.zeros((0, 8)),
np.array(distances) if distances else np.zeros(0),
)
def stats(self) -> Dict:
"""Return encoder statistics."""
lattice_stats = self.lattice.stats()
return {
'n_chunks': len(self.chunks),
'n_documents': len(set(c.doc_id for c in self.chunks)),
'lattice_cells': lattice_stats.get('occupied_cells', 0),
'lattice_utilization': lattice_stats.get('utilization', 0),
'avg_chunk_len': (
sum(len(c.token_ids) for c in self.chunks) / len(self.chunks)
if self.chunks else 0
),
}