File size: 1,202 Bytes
0116d50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Simple semantic search over review embeddings."""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence, Tuple

import numpy as np
from sklearn.neighbors import NearestNeighbors
from sentence_transformers import SentenceTransformer


@dataclass
class QueryEngine:
    embeddings: np.ndarray
    documents: Sequence[str]
    model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    top_k: int = 5

    def __post_init__(self) -> None:
        if len(self.documents) != len(self.embeddings):
            raise ValueError("Embeddings and documents must be aligned")
        self.model = SentenceTransformer(self.model_name)
        self.index = NearestNeighbors(metric="cosine")
        self.index.fit(self.embeddings)

    def search(self, query: str) -> List[Tuple[str, float]]:
        query_emb = self.model.encode([query])
        distances, indices = self.index.kneighbors(query_emb, n_neighbors=self.top_k)
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            similarity = 1 - dist
            results.append((self.documents[idx], float(similarity)))
        return results


__all__ = ["QueryEngine"]