"""DocumentRetriever — dense similarity over prose chunks. For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs String mismatch, jsonb_path_match asyncpg binding quirks). Collection `document_embeddings`. Methods: cosine | manhattan. """ import functools import math from langchain_openai import AzureOpenAIEmbeddings from sqlalchemy import text from src.config.settings import settings from src.db.postgres.connection import _pgvector_engine from src.middlewares.logging import get_logger from src.retrieval.base import BaseRetriever, RetrievalResult logger = get_logger("document_retriever") # Change this one line to switch retrieval method # Options: "cosine" | "manhattan" _RETRIEVAL_METHOD = "cosine" _TABULAR_TYPES = {"csv", "xlsx"} _COLLECTION_NAME = "documents" _COSINE_SQL = text(""" SELECT lpe.document, lpe.cmetadata, lpe.embedding <=> CAST(:embedding AS vector) AS distance FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = :collection AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' ORDER BY distance ASC LIMIT :k """) _MANHATTAN_SQL = text(""" SELECT lpe.document, lpe.cmetadata, lpe.embedding <+> CAST(:embedding AS vector) AS distance FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = :collection AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' ORDER BY distance ASC LIMIT :k """) @functools.cache def _get_embeddings() -> AzureOpenAIEmbeddings: return AzureOpenAIEmbeddings( azure_deployment=settings.azureai_deployment_name_embedding, openai_api_version=settings.azureai_api_version_embedding, azure_endpoint=settings.azureai_endpoint_url_embedding, api_key=settings.azureai_api_key_embedding, ) class DocumentRetriever(BaseRetriever): async def retrieve( self, query: str, user_id: str, k: int = 5 ) -> list[RetrievalResult]: query_vector = await _get_embeddings().aembed_query(query) if not all(math.isfinite(v) for v in query_vector): raise ValueError("Embedding vector contains NaN or Infinity values.") vector_str = "[" + ",".join(str(v) for v in query_vector) + "]" fetch_k = k + len(_TABULAR_TYPES) sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, { "embedding": vector_str, "collection": _COLLECTION_NAME, "user_id": user_id, "k": fetch_k, }) rows = result.fetchall() logger.info("raw rows from db", row_count=len(rows)) results = [] for row in rows: file_type = row.cmetadata.get("data", {}).get("file_type", "") if file_type not in _TABULAR_TYPES: results.append(RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.distance), source_type="document", )) if len(results) == k: break logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results)) return results document_retriever = DocumentRetriever()