philipp-zettl commited on
Commit
9eba86c
·
verified ·
1 Parent(s): da365b2

Add vrom_hub/hnsw.py

Browse files
Files changed (1) hide show
  1. vrom_hub/hnsw.py +348 -0
vrom_hub/hnsw.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure-Python HNSW index builder.
3
+
4
+ Produces index.json files that are 100% compatible with the Rust/WASM
5
+ VectorDB.load(json) method. Mirrors the exact serde serialization format
6
+ of the vecdb-wasm crate.
7
+
8
+ Key invariants:
9
+ - Node ID = index in the nodes array
10
+ - neighbors[i] = connections at layer i (length = node.max_layer + 1)
11
+ - Layer 0 uses m_max0 (= 2*m) max neighbors; higher layers use m
12
+ - metric variants are PascalCase: "Cosine", "Euclidean", "DotProduct"
13
+ - metadata is a JSON string (not an object)
14
+ - entry_point is null for empty indexes
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import heapq
20
+ import json
21
+ import logging
22
+ import math
23
+ import random
24
+ from dataclasses import dataclass, field
25
+ from typing import Optional
26
+
27
+ import numpy as np
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class HnswConfig:
34
+ """HNSW algorithm parameters — mirrors Rust HnswConfig exactly."""
35
+ m: int = 16
36
+ m_max0: int = 32
37
+ ef_construction: int = 128
38
+ ef_search: int = 40
39
+ level_multiplier: float = 0.0 # computed from m if 0
40
+ metric: str = "Cosine"
41
+
42
+ def __post_init__(self):
43
+ if self.m_max0 == 0:
44
+ self.m_max0 = 2 * self.m
45
+ if self.level_multiplier == 0.0:
46
+ self.level_multiplier = 1.0 / math.log(self.m)
47
+
48
+ def to_dict(self) -> dict:
49
+ return {
50
+ "m": self.m,
51
+ "m_max0": self.m_max0,
52
+ "ef_construction": self.ef_construction,
53
+ "ef_search": self.ef_search,
54
+ "level_multiplier": self.level_multiplier,
55
+ "metric": self.metric,
56
+ }
57
+
58
+
59
+ @dataclass
60
+ class HnswNode:
61
+ """A single node in the HNSW graph — mirrors Rust HnswNode exactly."""
62
+ vector: list[float] # len = dim
63
+ neighbors: list[list[int]] # neighbors[layer] = [node_ids...]
64
+ max_layer: int # highest layer this node is in
65
+ metadata: Optional[str] # JSON string or None
66
+
67
+ def to_dict(self) -> dict:
68
+ return {
69
+ "vector": self.vector,
70
+ "neighbors": self.neighbors,
71
+ "max_layer": self.max_layer,
72
+ "metadata": self.metadata,
73
+ }
74
+
75
+
76
+ class HnswIndex:
77
+ """
78
+ Top-level HNSW index — serializes to the exact JSON format
79
+ that VectorDB.load() expects.
80
+ """
81
+
82
+ def __init__(self, config: HnswConfig, dim: int):
83
+ self.config = config
84
+ self.dim = dim
85
+ self.nodes: list[HnswNode] = []
86
+ self.entry_point: Optional[int] = None
87
+ self.max_layer: int = 0
88
+
89
+ def to_dict(self) -> dict:
90
+ return {
91
+ "config": self.config.to_dict(),
92
+ "nodes": [n.to_dict() for n in self.nodes],
93
+ "entry_point": self.entry_point,
94
+ "max_layer": self.max_layer,
95
+ "dim": self.dim,
96
+ }
97
+
98
+ def to_json(self, indent: int | None = None) -> str:
99
+ return json.dumps(self.to_dict(), indent=indent)
100
+
101
+
102
+ def _cosine_distance(a: np.ndarray, b: np.ndarray) -> float:
103
+ """1 - cosine_similarity. For pre-normalized vectors, this is 1 - dot(a,b)."""
104
+ dot = float(np.dot(a, b))
105
+ return 1.0 - dot
106
+
107
+
108
+ def _euclidean_distance(a: np.ndarray, b: np.ndarray) -> float:
109
+ return float(np.linalg.norm(a - b))
110
+
111
+
112
+ def _dot_product_distance(a: np.ndarray, b: np.ndarray) -> float:
113
+ """Negative dot product (lower = more similar)."""
114
+ return -float(np.dot(a, b))
115
+
116
+
117
+ def _get_distance_fn(metric: str):
118
+ if metric == "Cosine":
119
+ return _cosine_distance
120
+ elif metric == "Euclidean":
121
+ return _euclidean_distance
122
+ elif metric == "DotProduct":
123
+ return _dot_product_distance
124
+ else:
125
+ raise ValueError(f"Unknown metric: {metric}")
126
+
127
+
128
+ class HnswBuilder:
129
+ """
130
+ Builds an HNSW index from vectors + metadata.
131
+
132
+ This is a faithful Python implementation of the HNSW insertion algorithm
133
+ that produces a graph loadable by the Rust VectorDB.load() method.
134
+ """
135
+
136
+ def __init__(self, config: HnswConfig | None = None, dim: int = 384):
137
+ self.config = config or HnswConfig()
138
+ self.dim = dim
139
+ self.index = HnswIndex(self.config, dim)
140
+ self._vectors: list[np.ndarray] = [] # numpy cache for fast distance
141
+ self._dist_fn = _get_distance_fn(self.config.metric)
142
+ self._rng = random.Random(42)
143
+
144
+ def _random_layer(self) -> int:
145
+ """Sample a random layer using the HNSW level multiplier."""
146
+ r = self._rng.random()
147
+ return int(-math.log(r) * self.config.level_multiplier)
148
+
149
+ def _distance(self, a_id: int, b_vec: np.ndarray) -> float:
150
+ """Distance between stored node a_id and query vector b_vec."""
151
+ return self._dist_fn(self._vectors[a_id], b_vec)
152
+
153
+ def _search_layer(
154
+ self, query: np.ndarray, entry_id: int, ef: int, layer: int
155
+ ) -> list[tuple[float, int]]:
156
+ """
157
+ Search a single layer of the HNSW graph.
158
+ Returns up to ef nearest (distance, node_id) pairs.
159
+ """
160
+ entry_dist = self._distance(entry_id, query)
161
+ candidates = [(entry_dist, entry_id)] # min-heap
162
+ results = [(-entry_dist, entry_id)] # max-heap (negative for max)
163
+ visited = {entry_id}
164
+
165
+ while candidates:
166
+ c_dist, c_id = heapq.heappop(candidates)
167
+
168
+ # Furthest in results
169
+ f_dist = -results[0][0]
170
+ if c_dist > f_dist:
171
+ break
172
+
173
+ # Explore neighbors at this layer
174
+ node = self.index.nodes[c_id]
175
+ if layer < len(node.neighbors):
176
+ for neighbor_id in node.neighbors[layer]:
177
+ if neighbor_id in visited:
178
+ continue
179
+ visited.add(neighbor_id)
180
+
181
+ n_dist = self._distance(neighbor_id, query)
182
+ f_dist = -results[0][0]
183
+
184
+ if n_dist < f_dist or len(results) < ef:
185
+ heapq.heappush(candidates, (n_dist, neighbor_id))
186
+ heapq.heappush(results, (-n_dist, neighbor_id))
187
+ if len(results) > ef:
188
+ heapq.heappop(results)
189
+
190
+ # Convert results (stored as negative distances)
191
+ return [(abs(d), nid) for d, nid in results]
192
+
193
+ def _select_neighbors_simple(
194
+ self, candidates: list[tuple[float, int]], m: int
195
+ ) -> list[int]:
196
+ """Select the M nearest neighbors from candidates."""
197
+ candidates.sort(key=lambda x: x[0])
198
+ return [nid for _, nid in candidates[:m]]
199
+
200
+ def add(self, vector: np.ndarray, metadata: str | None = None) -> int:
201
+ """
202
+ Insert a vector into the HNSW index.
203
+
204
+ Args:
205
+ vector: numpy array of shape (dim,)
206
+ metadata: Optional JSON string metadata
207
+
208
+ Returns:
209
+ The node ID (= index in nodes array)
210
+ """
211
+ assert vector.shape == (self.dim,), f"Expected ({self.dim},), got {vector.shape}"
212
+ node_id = len(self.index.nodes)
213
+ node_layer = self._random_layer()
214
+
215
+ # Create node with empty neighbor lists
216
+ node = HnswNode(
217
+ vector=vector.tolist(),
218
+ neighbors=[[] for _ in range(node_layer + 1)],
219
+ max_layer=node_layer,
220
+ metadata=metadata,
221
+ )
222
+ self.index.nodes.append(node)
223
+ self._vectors.append(vector.copy())
224
+
225
+ # First node — set as entry point
226
+ if self.index.entry_point is None:
227
+ self.index.entry_point = node_id
228
+ self.index.max_layer = node_layer
229
+ return node_id
230
+
231
+ # Traverse from top to the node's layer, greedily
232
+ ep_id = self.index.entry_point
233
+ current_layer = self.index.max_layer
234
+
235
+ # Phase 1: Greedy descent from top to node_layer + 1
236
+ while current_layer > node_layer:
237
+ results = self._search_layer(vector, ep_id, ef=1, layer=current_layer)
238
+ if results:
239
+ ep_id = min(results, key=lambda x: x[0])[1]
240
+ current_layer -= 1
241
+
242
+ # Phase 2: Search and connect at each layer from min(node_layer, max_layer) down to 0
243
+ for lc in range(min(node_layer, self.index.max_layer), -1, -1):
244
+ results = self._search_layer(
245
+ vector, ep_id, ef=self.config.ef_construction, layer=lc
246
+ )
247
+
248
+ # Select neighbors
249
+ m_for_layer = self.config.m_max0 if lc == 0 else self.config.m
250
+ neighbors = self._select_neighbors_simple(results, m_for_layer)
251
+
252
+ # Set this node's neighbors at layer lc
253
+ node.neighbors[lc] = neighbors
254
+
255
+ # Add bidirectional connections
256
+ for neighbor_id in neighbors:
257
+ neighbor_node = self.index.nodes[neighbor_id]
258
+ # Ensure neighbor has enough layers
259
+ while len(neighbor_node.neighbors) <= lc:
260
+ neighbor_node.neighbors.append([])
261
+
262
+ neighbor_node.neighbors[lc].append(node_id)
263
+
264
+ # Shrink if over capacity
265
+ max_conn = self.config.m_max0 if lc == 0 else self.config.m
266
+ if len(neighbor_node.neighbors[lc]) > max_conn:
267
+ # Keep only the closest
268
+ n_vec = self._vectors[neighbor_id]
269
+ scored = [
270
+ (self._dist_fn(self._vectors[nid], n_vec), nid)
271
+ for nid in neighbor_node.neighbors[lc]
272
+ ]
273
+ neighbor_node.neighbors[lc] = self._select_neighbors_simple(
274
+ scored, max_conn
275
+ )
276
+
277
+ # Update entry point for next layer
278
+ if results:
279
+ ep_id = min(results, key=lambda x: x[0])[1]
280
+
281
+ # Update global entry point if new node is higher
282
+ if node_layer > self.index.max_layer:
283
+ self.index.entry_point = node_id
284
+ self.index.max_layer = node_layer
285
+
286
+ return node_id
287
+
288
+ def build(
289
+ self,
290
+ vectors: np.ndarray,
291
+ metadatas: list[str | None] | None = None,
292
+ ) -> HnswIndex:
293
+ """
294
+ Build an entire HNSW index from a batch of vectors.
295
+
296
+ Args:
297
+ vectors: np.ndarray of shape (n, dim)
298
+ metadatas: Optional list of JSON string metadata, one per vector
299
+
300
+ Returns:
301
+ The built HnswIndex
302
+ """
303
+ n = vectors.shape[0]
304
+ if metadatas is None:
305
+ metadatas = [None] * n
306
+
307
+ assert len(metadatas) == n, f"Mismatch: {n} vectors, {len(metadatas)} metadatas"
308
+ assert vectors.shape[1] == self.dim, f"Expected dim={self.dim}, got {vectors.shape[1]}"
309
+
310
+ logger.info(f"Building HNSW index: {n} vectors, dim={self.dim}, m={self.config.m}")
311
+
312
+ for i in range(n):
313
+ self.add(vectors[i], metadatas[i])
314
+ if (i + 1) % 100 == 0 or i == n - 1:
315
+ logger.info(f" Indexed {i + 1}/{n} vectors (max_layer={self.index.max_layer})")
316
+
317
+ logger.info(
318
+ f"HNSW index built: {n} nodes, max_layer={self.index.max_layer}, "
319
+ f"entry_point={self.index.entry_point}"
320
+ )
321
+ return self.index
322
+
323
+ def search(self, query: np.ndarray, k: int = 5) -> list[tuple[float, int]]:
324
+ """
325
+ Search the index for the k nearest neighbors.
326
+
327
+ Returns:
328
+ List of (distance, node_id) tuples, sorted by distance (ascending).
329
+ """
330
+ if self.index.entry_point is None:
331
+ return []
332
+
333
+ ep_id = self.index.entry_point
334
+ current_layer = self.index.max_layer
335
+
336
+ # Greedy descent to layer 0
337
+ while current_layer > 0:
338
+ results = self._search_layer(query, ep_id, ef=1, layer=current_layer)
339
+ if results:
340
+ ep_id = min(results, key=lambda x: x[0])[1]
341
+ current_layer -= 1
342
+
343
+ # Search layer 0 with ef_search
344
+ results = self._search_layer(
345
+ query, ep_id, ef=max(k, self.config.ef_search), layer=0
346
+ )
347
+ results.sort(key=lambda x: x[0])
348
+ return results[:k]