bobs24's picture
track product with LFS
68f893f
raw
history blame contribute delete
949 Bytes
import faiss
import numpy as np
class FaissIndex:
def __init__(self, dim):
self.index = faiss.IndexFlatIP(dim)
self.image_map = []
def build(self, embeddings, image_ids):
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
normalized_embeddings = embeddings / (norms + 1e-10)
self.index.add(normalized_embeddings.astype('float32'))
self.image_map = image_ids
def search(self, query_vector, threshold=0.8, k=50):
query_norm = np.linalg.norm(query_vector)
if query_norm > 0:
query_vector = query_vector / query_norm
query = np.array([query_vector]).astype('float32')
similarities, indices = self.index.search(query, k)
results = []
for i, sim in zip(indices[0], similarities[0]):
if sim >= threshold:
results.append((self.image_map[i], sim))
return sorted(results, key=lambda x: -x[1])