pls-rag / modules /retriever.py
m97j's picture
Initial codes commit
3c9754e
raw
history blame contribute delete
527 Bytes
# rag/modules/retriever.py
import numpy as np
from config import TOP_K
_index = None # in-memory FAISS index
def set_index(index_obj):
global _index
_index = index_obj
def has_index() -> bool:
return _index is not None
def retrieve_ids(query_embedding: list[float]) -> list[int]:
if _index is None:
raise RuntimeError("FAISS index is not loaded in memory.")
q = np.array(query_embedding, dtype="float32").reshape(1, -1)
_, idx = _index.search(q, TOP_K)
return [int(i) for i in idx[0]]