abalone_chat_application / src /vectorstore.py
cmd0160's picture
Fixing deployment
0438c70
from typing import Literal, List
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
RetrievalMode = Literal["mmr", "similarity", "hybrid"]
def get_vectorstore(persist_dir: str) -> Chroma:
embeddings = OpenAIEmbeddings()
db = Chroma(
persist_directory=persist_dir,
embedding_function=embeddings,
)
return db
class HybridRetriever(BaseRetriever):
db: Chroma
top_k: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
dense = self.db.similarity_search(query, k=self.top_k * 2)
mmr = self.db.max_marginal_relevance_search(
query,
k=self.top_k,
fetch_k=self.top_k * 3,
)
docs: List[Document] = []
seen = set()
for d in dense + mmr:
key = (d.metadata.get("source"), d.page_content)
if key in seen:
continue
seen.add(key)
docs.append(d)
if len(docs) >= self.top_k:
break
return docs
def get_retriever(
persist_dir: str,
top_k: int,
retrieval_mode: RetrievalMode = "hybrid"
):
db = get_vectorstore(persist_dir=persist_dir)
mode = retrieval_mode.lower()
if mode == "hybrid":
return HybridRetriever(db=db, top_k=top_k)
if mode == "similarity":
return db.as_retriever(
search_type="similarity",
search_kwargs={"k": top_k},
)
return db.as_retriever(
search_type="mmr",
search_kwargs={
"k": top_k,
"fetch_k": max(top_k * 3, top_k + 2),
},
)