File size: 672 Bytes
a9b94ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import faiss
import numpy as np
import torch


class FaissSemanticIndex:
    def __init__(self, dim: int):
        self.index = faiss.IndexFlatIP(dim)
        self.texts = []

    def add(self, embeddings: torch.Tensor, texts: list[str]):
        normalized = torch.nn.functional.normalize(embeddings, p=2, dim=1).cpu().numpy()
        self.index.add(normalized)
        self.texts.extend(texts)

    def search(self, query: torch.Tensor, k: int = 1):
        query = torch.nn.functional.normalize(query, p=2, dim=1).cpu().numpy()
        scores, indices = self.index.search(query, k)
        return [(self.texts[i], float(scores[0][j])) for j, i in enumerate(indices[0])]