File size: 481 Bytes
549c270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# src/utils/faiss_utils.py
import faiss
import numpy as np

def build_faiss_index(embeddings: np.ndarray, use_gpu=False):
    d = embeddings.shape[1]
    index = faiss.IndexFlatIP(d)  # Inner Product (cosine similarity if normalized)

    if use_gpu:
        res = faiss.StandardGpuResources()
        index = faiss.index_cpu_to_gpu(res, 0, index)

    index.add(embeddings.astype(np.float32))
    return index

def save_faiss_index(index, path):
    faiss.write_index(index, path)