| from typing import Literal, List | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_core.documents import Document | |
| RetrievalMode = Literal["mmr", "similarity", "hybrid"] | |
| def get_vectorstore(persist_dir: str) -> Chroma: | |
| embeddings = OpenAIEmbeddings() | |
| db = Chroma( | |
| persist_directory=persist_dir, | |
| embedding_function=embeddings, | |
| ) | |
| return db | |
| class HybridRetriever(BaseRetriever): | |
| db: Chroma | |
| top_k: int | |
| def _get_relevant_documents( | |
| self, | |
| query: str, | |
| *, | |
| run_manager: CallbackManagerForRetrieverRun, | |
| ) -> List[Document]: | |
| dense = self.db.similarity_search(query, k=self.top_k * 2) | |
| mmr = self.db.max_marginal_relevance_search( | |
| query, | |
| k=self.top_k, | |
| fetch_k=self.top_k * 3, | |
| ) | |
| docs: List[Document] = [] | |
| seen = set() | |
| for d in dense + mmr: | |
| key = (d.metadata.get("source"), d.page_content) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| docs.append(d) | |
| if len(docs) >= self.top_k: | |
| break | |
| return docs | |
| def get_retriever( | |
| persist_dir: str, | |
| top_k: int, | |
| retrieval_mode: RetrievalMode = "hybrid" | |
| ): | |
| db = get_vectorstore(persist_dir=persist_dir) | |
| mode = retrieval_mode.lower() | |
| if mode == "hybrid": | |
| return HybridRetriever(db=db, top_k=top_k) | |
| if mode == "similarity": | |
| return db.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": top_k}, | |
| ) | |
| return db.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={ | |
| "k": top_k, | |
| "fetch_k": max(top_k * 3, top_k + 2), | |
| }, | |
| ) | |