from typing import Literal, List from langchain_community.vectorstores import Chroma from langchain_community.embeddings import OpenAIEmbeddings from langchain_core.retrievers import BaseRetriever from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document RetrievalMode = Literal["mmr", "similarity", "hybrid"] def get_vectorstore(persist_dir: str) -> Chroma: embeddings = OpenAIEmbeddings() db = Chroma( persist_directory=persist_dir, embedding_function=embeddings, ) return db class HybridRetriever(BaseRetriever): db: Chroma top_k: int def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: dense = self.db.similarity_search(query, k=self.top_k * 2) mmr = self.db.max_marginal_relevance_search( query, k=self.top_k, fetch_k=self.top_k * 3, ) docs: List[Document] = [] seen = set() for d in dense + mmr: key = (d.metadata.get("source"), d.page_content) if key in seen: continue seen.add(key) docs.append(d) if len(docs) >= self.top_k: break return docs def get_retriever( persist_dir: str, top_k: int, retrieval_mode: RetrievalMode = "hybrid" ): db = get_vectorstore(persist_dir=persist_dir) mode = retrieval_mode.lower() if mode == "hybrid": return HybridRetriever(db=db, top_k=top_k) if mode == "similarity": return db.as_retriever( search_type="similarity", search_kwargs={"k": top_k}, ) return db.as_retriever( search_type="mmr", search_kwargs={ "k": top_k, "fetch_k": max(top_k * 3, top_k + 2), }, )