"""Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).""" from langchain_postgres import PGVector from langchain_postgres.vectorstores import DistanceStrategy from langchain_openai import AzureOpenAIEmbeddings from sqlalchemy import text from src.config.settings import settings from src.db.postgres.connection import _pgvector_engine from src.db.postgres.vector_store import get_vector_store from src.middlewares.logging import get_logger from src.rag.base import BaseRetriever, RetrievalResult logger = get_logger("document_retriever") # Change this one line to switch retrieval method # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan" _RETRIEVAL_METHOD = "mmr" _TABULAR_TYPES = {"csv", "xlsx"} _FETCH_K = 20 _LAMBDA_MULT = 0.5 _COLLECTION_NAME = "document_embeddings" _embeddings = AzureOpenAIEmbeddings( azure_deployment=settings.azureai_deployment_name_embedding, openai_api_version=settings.azureai_api_version_embedding, azure_endpoint=settings.azureai_endpoint_url_embedding, api_key=settings.azureai_api_key_embedding, ) _euclidean_store = PGVector( embeddings=_embeddings, connection=_pgvector_engine, collection_name=_COLLECTION_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN, use_jsonb=True, async_mode=True, create_extension=False, ) _ip_store = PGVector( embeddings=_embeddings, connection=_pgvector_engine, collection_name=_COLLECTION_NAME, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT, use_jsonb=True, async_mode=True, create_extension=False, ) _MANHATTAN_SQL = text(""" SELECT lpe.document, lpe.cmetadata, lpe.embedding <+> CAST(:embedding AS vector) AS distance FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = :collection AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' ORDER BY distance ASC LIMIT :k """) class DocumentRetriever(BaseRetriever): def __init__(self) -> None: self.vector_store = get_vector_store() async def retrieve( self, query: str, user_id: str, k: int = 5 ) -> list[RetrievalResult]: filter_ = {"user_id": user_id, "source_type": "document"} fetch_k = k + len(_TABULAR_TYPES) if _RETRIEVAL_METHOD == "manhattan": return await self._retrieve_manhattan(query, user_id, k, fetch_k) if _RETRIEVAL_METHOD == "mmr": docs = await self.vector_store.amax_marginal_relevance_search( query=query, k=fetch_k, fetch_k=_FETCH_K, lambda_mult=_LAMBDA_MULT, filter=filter_, ) cosine = await self.vector_store.asimilarity_search_with_score( query=query, k=fetch_k, filter=filter_, ) score_map = {doc.page_content: score for doc, score in cosine} docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs] elif _RETRIEVAL_METHOD == "euclidean": docs_with_scores = await _euclidean_store.asimilarity_search_with_score( query=query, k=fetch_k, filter=filter_, ) elif _RETRIEVAL_METHOD == "inner_product": docs_with_scores = await _ip_store.asimilarity_search_with_score( query=query, k=fetch_k, filter=filter_, ) else: # cosine docs_with_scores = await self.vector_store.asimilarity_search_with_score( query=query, k=fetch_k, filter=filter_, ) results = [] for doc, score in docs_with_scores: file_type = doc.metadata.get("data", {}).get("file_type", "") if file_type not in _TABULAR_TYPES: results.append(RetrievalResult( content=doc.page_content, metadata=doc.metadata, score=score, source_type="document", )) if len(results) == k: break logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results)) return results async def _retrieve_manhattan( self, query: str, user_id: str, k: int, fetch_k: int ) -> list[RetrievalResult]: query_vector = await _embeddings.aembed_query(query) vector_str = "[" + ",".join(str(v) for v in query_vector) + "]" async with _pgvector_engine.connect() as conn: result = await conn.execute(_MANHATTAN_SQL, { "embedding": vector_str, "collection": _COLLECTION_NAME, "user_id": user_id, "k": fetch_k, }) rows = result.fetchall() results = [] for row in rows: file_type = row.cmetadata.get("data", {}).get("file_type", "") if file_type not in _TABULAR_TYPES: results.append(RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.distance), source_type="document", )) if len(results) == k: break logger.info("retrieved chunks", method="manhattan", count=len(results)) return results document_retriever = DocumentRetriever()