Test / src /rag /retriever.py
Архипов Дмитрий
test
565e754
raw
history blame contribute delete
829 Bytes
from langchain_core.runnables import RunnableLambda
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from src.db_utils.qdrant_utils import qdrant_search
class Retriever:
def __init__(self, embed_model_name: str, embed_index_name: str):
self.embed_model = HuggingFaceEmbeddings(
model_name=embed_model_name,
encode_kwargs={"normalize_embeddings": True},
)
self.embed_index_name = embed_index_name
self.chain = RunnableLambda(self._retrieve)
def _retrieve(self, query: str) -> str:
docs = qdrant_search(
self.embed_index_name,
self.embed_model.embed_query(query),
)
return "\n".join(
f"{i}) {doc.payload['text']}"
for i, doc in enumerate(docs.points, 1)
)