File size: 8,536 Bytes
e161c3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
H4 Document Encoder — Encode text into E8 lattice memory for geometric retrieval.

Each document chunk becomes an 8D embedding stored in a Voronoi cell.
The encoding uses golden-angle spiral placement based on token content,
ensuring that retrieval and attention share the same geometric space
through the E8→H4 projection (cos(π/5) = φ/2).

No separate embedding model needed. The same geometry handles both.
"""

import math
import numpy as np
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from h4_polytopic_attention import E8LatticeIndex, PHI, PHI_INV


@dataclass
class Chunk:
    """A document chunk stored in E8 lattice memory."""
    text: str
    doc_id: str
    chunk_idx: int
    token_ids: List[int]


class H4DocumentEncoder:
    """
    Encode documents into E8 lattice memory for retrieval.

    Each chunk becomes an 8D embedding stored in an E8 Voronoi cell.
    The encoding uses golden-angle spiral placement based on token
    frequencies, ensuring geometric consistency with H4 attention.

    The 8D embedding captures:
    - dims 0-3: semantic content (weighted token frequency features)
    - dims 4-7: positional/structural features (chunk position, doc features)
    """

    def __init__(self, stoi: Dict[str, int], chunk_size: int = 256, overlap: int = 64):
        """
        Args:
            stoi: string-to-index vocabulary mapping
            chunk_size: tokens per chunk
            overlap: overlap between consecutive chunks
        """
        self.stoi = stoi
        self.itos = {v: k for k, v in stoi.items()}
        self.vocab_size = len(stoi)
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.lattice = E8LatticeIndex(max_cell_size=240)
        self.chunks: List[Chunk] = []
        self._address_counter = 0

        # Precompute per-token 8D embeddings using golden-angle spiral
        self._token_embeddings = self._build_token_embeddings()

    def _build_token_embeddings(self) -> np.ndarray:
        """Build 8D embeddings for each token using golden-angle spiral."""
        embs = np.zeros((self.vocab_size, 8))
        for i in range(self.vocab_size):
            # Golden-angle placement in 8D: pairs of (cos, sin) at φ-scaled frequencies
            for d in range(4):
                angle = i * 2 * math.pi * PHI_INV * (PHI ** (-d / 4))
                embs[i, 2 * d] = math.cos(angle)
                embs[i, 2 * d + 1] = math.sin(angle)
            # Normalize
            norm = np.linalg.norm(embs[i])
            if norm > 1e-10:
                embs[i] /= norm
        return embs

    def _text_to_tokens(self, text: str) -> List[int]:
        """Convert text to token IDs using the vocabulary."""
        return [self.stoi.get(c, 0) for c in text]

    def _chunk_tokens(self, token_ids: List[int]) -> List[List[int]]:
        """Split token list into overlapping chunks."""
        chunks = []
        start = 0
        while start < len(token_ids):
            end = min(start + self.chunk_size, len(token_ids))
            chunks.append(token_ids[start:end])
            start += self.chunk_size - self.overlap
            if end == len(token_ids):
                break
        return chunks

    def _embed_chunk(self, token_ids: List[int], chunk_idx: int, n_chunks: int) -> np.ndarray:
        """
        Compute 8D embedding for a chunk.

        Combines token frequency features (dims 0-3) with
        positional features (dims 4-7).
        """
        emb = np.zeros(8)

        # Semantic: weighted average of token embeddings
        if token_ids:
            token_embs = self._token_embeddings[token_ids]  # (n_tokens, 8)
            # Weight by position in chunk (later tokens slightly higher weight)
            weights = np.array([1.0 + 0.5 * PHI_INV * (i / len(token_ids))
                                for i in range(len(token_ids))])
            weights /= weights.sum()
            emb[:4] = (token_embs[:, :4].T @ weights)

        # Positional: chunk position and length features
        pos_frac = chunk_idx / max(n_chunks - 1, 1) if n_chunks > 1 else 0.5
        len_frac = len(token_ids) / self.chunk_size
        angle1 = pos_frac * 2 * math.pi * PHI_INV
        angle2 = len_frac * math.pi * PHI_INV
        emb[4] = math.cos(angle1)
        emb[5] = math.sin(angle1)
        emb[6] = math.cos(angle2)
        emb[7] = math.sin(angle2)

        # Normalize to unit sphere
        norm = np.linalg.norm(emb)
        if norm > 1e-10:
            emb /= norm
        return emb

    def encode_document(self, text: str, doc_id: str):
        """
        Chunk document, encode each chunk as 8D vector, store in E8 lattice.

        Args:
            text: document text
            doc_id: unique identifier for the document
        """
        token_ids = self._text_to_tokens(text)
        token_chunks = self._chunk_tokens(token_ids)
        n_chunks = len(token_chunks)

        for i, chunk_tokens in enumerate(token_chunks):
            # Compute 8D embedding
            embedding = self._embed_chunk(chunk_tokens, i, n_chunks)

            # Store in E8 lattice
            address = self._address_counter
            self._address_counter += 1
            self.lattice.insert(embedding, value=float(address), address=address)

            # Store chunk metadata
            chunk_text = ''.join(self.itos.get(t, '?') for t in chunk_tokens)
            self.chunks.append(Chunk(
                text=chunk_text,
                doc_id=doc_id,
                chunk_idx=i,
                token_ids=chunk_tokens,
            ))

    def retrieve(self, query_text: str, k: int = 5) -> List[Tuple[Chunk, float]]:
        """
        Encode query as 8D vector, find k nearest chunks in E8 lattice.

        Returns:
            List of (chunk, distance) tuples sorted by E8 distance.
        """
        query_tokens = self._text_to_tokens(query_text)
        query_emb = self._embed_chunk(query_tokens, 0, 1)

        results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True)

        retrieved = []
        for dist_sq, value, addr in results:
            idx = int(addr)
            if idx < len(self.chunks):
                retrieved.append((self.chunks[idx], dist_sq))

        return retrieved

    def retrieve_with_h4(self, query_text: str, k: int = 5):
        """
        Retrieve chunks AND return their H4 projections for attention.

        Returns:
            chunks: list of Chunk objects
            h4_keys: (k, 4) array — 4D projections for direct attention use
            e8_embeddings: (k, 8) array — full 8D embeddings
            distances: (k,) array — E8 distances to query
        """
        query_tokens = self._text_to_tokens(query_text)
        query_emb = self._embed_chunk(query_tokens, 0, 1)

        results = self.lattice.query_nearest(query_emb, k=k, search_neighbors=True)

        chunks = []
        h4_keys = []
        e8_embs = []
        distances = []

        for dist_sq, value, addr in results:
            idx = int(addr)
            if idx < len(self.chunks):
                chunk = self.chunks[idx]
                chunks.append(chunk)
                distances.append(dist_sq)

                # Reconstruct 8D embedding
                emb = self._embed_chunk(
                    chunk.token_ids, chunk.chunk_idx,
                    sum(1 for c in self.chunks if c.doc_id == chunk.doc_id)
                )
                e8_embs.append(emb)

                # Project to H4 via E8→H4 projection
                h4_key = self.lattice.project_to_h4(emb)
                h4_keys.append(h4_key)

        return (
            chunks,
            np.array(h4_keys) if h4_keys else np.zeros((0, 4)),
            np.array(e8_embs) if e8_embs else np.zeros((0, 8)),
            np.array(distances) if distances else np.zeros(0),
        )

    def stats(self) -> Dict:
        """Return encoder statistics."""
        lattice_stats = self.lattice.stats()
        return {
            'n_chunks': len(self.chunks),
            'n_documents': len(set(c.doc_id for c in self.chunks)),
            'lattice_cells': lattice_stats.get('occupied_cells', 0),
            'lattice_utilization': lattice_stats.get('utilization', 0),
            'avg_chunk_len': (
                sum(len(c.token_ids) for c in self.chunks) / len(self.chunks)
                if self.chunks else 0
            ),
        }