| import faiss | |
| import numpy as np | |
| import torch | |
| class FaissSemanticIndex: | |
| def __init__(self, dim: int): | |
| self.index = faiss.IndexFlatIP(dim) | |
| self.texts = [] | |
| def add(self, embeddings: torch.Tensor, texts: list[str]): | |
| normalized = torch.nn.functional.normalize(embeddings, p=2, dim=1).cpu().numpy() | |
| self.index.add(normalized) | |
| self.texts.extend(texts) | |
| def search(self, query: torch.Tensor, k: int = 1): | |
| query = torch.nn.functional.normalize(query, p=2, dim=1).cpu().numpy() | |
| scores, indices = self.index.search(query, k) | |
| return [(self.texts[i], float(scores[0][j])) for j, i in enumerate(indices[0])] |