| """ |
| H4 Document Encoder — Encode text into E8 lattice memory for geometric retrieval. |
| |
| Each document chunk becomes an 8D embedding stored in a Voronoi cell. |
| The encoding uses golden-angle spiral placement based on token content, |
| ensuring that retrieval and attention share the same geometric space |
| through the E8→H4 projection (cos(π/5) = φ/2). |
| |
| No separate embedding model needed. The same geometry handles both. |
| """ |
|
|
| import math |
| import numpy as np |
| from typing import List, Tuple, Dict, Optional |
| from dataclasses import dataclass |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
| from h4_polytopic_attention import E8LatticeIndex, PHI, PHI_INV |
|
|
|
|
| @dataclass |
| class Chunk: |
| """A document chunk stored in E8 lattice memory.""" |
| text: str |
| doc_id: str |
| chunk_idx: int |
| token_ids: List[int] |
|
|
|
|
| class H4DocumentEncoder: |
| """ |
| Encode documents into E8 lattice memory for retrieval. |
| |
| Each chunk becomes an 8D embedding stored in an E8 Voronoi cell. |
| The encoding uses golden-angle spiral placement based on token |
| frequencies, ensuring geometric consistency with H4 attention. |
| |
| The 8D embedding captures: |
| - dims 0-3: semantic content (weighted token frequency features) |
| - dims 4-7: positional/structural features (chunk position, doc features) |
| """ |
|
|
| def __init__(self, stoi: Dict[str, int], chunk_size: int = 256, overlap: int = 64): |
| """ |
| Args: |
| stoi: string-to-index vocabulary mapping |
| chunk_size: tokens per chunk |
| overlap: overlap between consecutive chunks |
| """ |
| self.stoi = stoi |
| self.itos = {v: k for k, v in stoi.items()} |
| self.vocab_size = len(stoi) |
| self.chunk_size = chunk_size |
| self.overlap = overlap |
| self.lattice = E8LatticeIndex(max_cell_size=240) |
| self.chunks: List[Chunk] = [] |
| self._address_counter = 0 |
|
|
| |
| self._token_embeddings = self._build_token_embeddings() |
|
|
| def _build_token_embeddings(self) -> np.ndarray: |
| """Build 8D embeddings for each token using golden-angle spiral.""" |
| embs = np.zeros((self.vocab_size, 8)) |
| for i in range(self.vocab_size): |
| |
| for d in range(4): |
| angle = i * 2 * math.pi * PHI_INV * (PHI ** (-d / 4)) |
| embs[i, 2 * d] = math.cos(angle) |
| embs[i, 2 * d + 1] = math.sin(angle) |
| |
| norm = np.linalg.norm(embs[i]) |
| if norm > 1e-10: |
| embs[i] /= norm |
| return embs |
|
|
| def _text_to_tokens(self, text: str) -> List[int]: |
| """Convert text to token IDs using the vocabulary.""" |
| return [self.stoi.get(c, 0) for c in text] |
|
|
| def _chunk_tokens(self, token_ids: List[int]) -> List[List[int]]: |
| """Split token list into overlapping chunks.""" |
| chunks = [] |
| start = 0 |
| while start < len(token_ids): |
| end = min(start + self.chunk_size, len(token_ids)) |
| chunks.append(token_ids[start:end]) |
| start += self.chunk_size - self.overlap |
| if end == len(token_ids): |
| break |
| return chunks |
|
|
| def _embed_chunk(self, token_ids: List[int], chunk_idx: int, n_chunks: int) -> np.ndarray: |
| """ |
| Compute 8D embedding for a chunk. |
| |
| Combines token frequency features (dims 0-3) with |
| positional features (dims 4-7). |
| """ |
| emb = np.zeros(8) |
|
|
| |
| if token_ids: |
| token_embs = self._token_embeddings[token_ids] |
| |
| weights = np.array([1.0 + 0.5 * PHI_INV * (i / len(token_ids)) |
| for i in range(len(token_ids))]) |
| weights /= weights.sum() |
| emb[:4] = (token_embs[:, :4].T @ weights) |
|
|
| |
| pos_frac = chunk_idx / max(n_chunks - 1, 1) if n_chunks > 1 else 0.5 |
| len_frac = len(token_ids) / self.chunk_size |
| angle1 = pos_frac * 2 * math.pi * PHI_INV |
| angle2 = len_frac * math.pi * PHI_INV |
| emb[4] = math.cos(angle1) |
| emb[5] = math.sin(angle1) |
| emb[6] = math.cos(angle2) |
| emb[7] = math.sin(angle2) |
|
|
| |
| norm = np.linalg.norm(emb) |
| if norm > 1e-10: |
| emb /= norm |
| return emb |
|
|
| def encode_document(self, text: str, doc_id: str): |
| """ |
| Chunk document, encode each chunk as 8D vector, store in E8 lattice. |
| |
| Args: |
| text: document text |
| doc_id: unique identifier for the document |
| """ |
| token_ids = self._text_to_tokens(text) |
| token_chunks = self._chunk_tokens(token_ids) |
| n_chunks = len(token_chunks) |
|
|
| for i, chunk_tokens in enumerate(token_chunks): |
| |
| embedding = self._embed_chunk(chunk_tokens, i, n_chunks) |
|
|
| |
| address = self._address_counter |
| self._address_counter += 1 |
| self.lattice.insert(embedding, value=float(address), address=address) |
|
|
| |
| chunk_text = ''.join(self.itos.get(t, '?') for t in chunk_tokens) |
| self.chunks.append(Chunk( |
| text=chunk_text, |
| doc_id=doc_id, |
| chunk_idx=i, |
| token_ids=chunk_tokens, |
| )) |
|
|
| def retrieve(self, query_text: str, k: int = 5) -> List[Tuple[Chunk, float]]: |
| """ |
| Encode query as 8D vector, find k nearest chunks in E8 lattice. |
| |
| Returns: |
| List of (chunk, distance) tuples sorted by E8 distance. |
| """ |
| query_tokens = self._text_to_tokens(query_text) |
| query_emb = self._embed_chunk(query_tokens, 0, 1) |
|
|
| results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True) |
|
|
| retrieved = [] |
| for dist_sq, value, addr in results: |
| idx = int(addr) |
| if idx < len(self.chunks): |
| retrieved.append((self.chunks[idx], dist_sq)) |
|
|
| return retrieved |
|
|
| def retrieve_with_h4(self, query_text: str, k: int = 5): |
| """ |
| Retrieve chunks AND return their H4 projections for attention. |
| |
| Returns: |
| chunks: list of Chunk objects |
| h4_keys: (k, 4) array — 4D projections for direct attention use |
| e8_embeddings: (k, 8) array — full 8D embeddings |
| distances: (k,) array — E8 distances to query |
| """ |
| query_tokens = self._text_to_tokens(query_text) |
| query_emb = self._embed_chunk(query_tokens, 0, 1) |
|
|
| results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True) |
|
|
| chunks = [] |
| h4_keys = [] |
| e8_embs = [] |
| distances = [] |
|
|
| for dist_sq, value, addr in results: |
| idx = int(addr) |
| if idx < len(self.chunks): |
| chunk = self.chunks[idx] |
| chunks.append(chunk) |
| distances.append(dist_sq) |
|
|
| |
| emb = self._embed_chunk( |
| chunk.token_ids, chunk.chunk_idx, |
| sum(1 for c in self.chunks if c.doc_id == chunk.doc_id) |
| ) |
| e8_embs.append(emb) |
|
|
| |
| h4_key = self.lattice.project_to_h4(emb) |
| h4_keys.append(h4_key) |
|
|
| return ( |
| chunks, |
| np.array(h4_keys) if h4_keys else np.zeros((0, 4)), |
| np.array(e8_embs) if e8_embs else np.zeros((0, 8)), |
| np.array(distances) if distances else np.zeros(0), |
| ) |
|
|
| def stats(self) -> Dict: |
| """Return encoder statistics.""" |
| lattice_stats = self.lattice.stats() |
| return { |
| 'n_chunks': len(self.chunks), |
| 'n_documents': len(set(c.doc_id for c in self.chunks)), |
| 'lattice_cells': lattice_stats.get('occupied_cells', 0), |
| 'lattice_utilization': lattice_stats.get('utilization', 0), |
| 'avg_chunk_len': ( |
| sum(len(c.token_ids) for c in self.chunks) / len(self.chunks) |
| if self.chunks else 0 |
| ), |
| } |
|
|