|
|
import faiss |
|
|
import numpy as np |
|
|
|
|
|
def generate_faiss_index(embeddings): |
|
|
|
|
|
embeddings = np.array(embeddings, dtype=np.float32) |
|
|
index = faiss.IndexFlatL2(768) |
|
|
index.add(embeddings) |
|
|
return index |
|
|
|
|
|
def load_faiss_index_to_gpu(index): |
|
|
|
|
|
res = faiss.StandardGpuResources() |
|
|
gpu_index = faiss.index_cpu_to_gpu(res, 0, index) |
|
|
return gpu_index |
|
|
|
|
|
def query_faiss_index(query_embedding, gpu_index): |
|
|
|
|
|
query_embedding = np.array(query_embedding, dtype=np.float32) |
|
|
distances, indices = gpu_index.search(query_embedding.reshape(1, -1), 1) |
|
|
return indices, distances |