| """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).""" |
|
|
| from langchain_postgres import PGVector |
| from langchain_postgres.vectorstores import DistanceStrategy |
| from langchain_openai import AzureOpenAIEmbeddings |
| from sqlalchemy import text |
|
|
| from src.config.settings import settings |
| from src.db.postgres.connection import _pgvector_engine |
| from src.db.postgres.vector_store import get_vector_store |
| from src.middlewares.logging import get_logger |
| from src.rag.base import BaseRetriever, RetrievalResult |
|
|
| logger = get_logger("document_retriever") |
|
|
| |
| |
| _RETRIEVAL_METHOD = "mmr" |
|
|
| _TABULAR_TYPES = {"csv", "xlsx"} |
| _FETCH_K = 20 |
| _LAMBDA_MULT = 0.5 |
| _COLLECTION_NAME = "document_embeddings" |
|
|
| _embeddings = AzureOpenAIEmbeddings( |
| azure_deployment=settings.azureai_deployment_name_embedding, |
| openai_api_version=settings.azureai_api_version_embedding, |
| azure_endpoint=settings.azureai_endpoint_url_embedding, |
| api_key=settings.azureai_api_key_embedding, |
| ) |
|
|
| _euclidean_store = PGVector( |
| embeddings=_embeddings, |
| connection=_pgvector_engine, |
| collection_name=_COLLECTION_NAME, |
| distance_strategy=DistanceStrategy.EUCLIDEAN, |
| use_jsonb=True, |
| async_mode=True, |
| create_extension=False, |
| ) |
|
|
| _ip_store = PGVector( |
| embeddings=_embeddings, |
| connection=_pgvector_engine, |
| collection_name=_COLLECTION_NAME, |
| distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT, |
| use_jsonb=True, |
| async_mode=True, |
| create_extension=False, |
| ) |
|
|
| _MANHATTAN_SQL = text(""" |
| SELECT |
| lpe.document, |
| lpe.cmetadata, |
| lpe.embedding <+> CAST(:embedding AS vector) AS distance |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = :collection |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| ORDER BY distance ASC |
| LIMIT :k |
| """) |
|
|
|
|
| class DocumentRetriever(BaseRetriever): |
| def __init__(self) -> None: |
| self.vector_store = get_vector_store() |
|
|
| async def retrieve( |
| self, query: str, user_id: str, k: int = 5 |
| ) -> list[RetrievalResult]: |
| filter_ = {"user_id": user_id, "source_type": "document"} |
| fetch_k = k + len(_TABULAR_TYPES) |
|
|
| if _RETRIEVAL_METHOD == "manhattan": |
| return await self._retrieve_manhattan(query, user_id, k, fetch_k) |
|
|
| if _RETRIEVAL_METHOD == "mmr": |
| docs = await self.vector_store.amax_marginal_relevance_search( |
| query=query, |
| k=fetch_k, |
| fetch_k=_FETCH_K, |
| lambda_mult=_LAMBDA_MULT, |
| filter=filter_, |
| ) |
| cosine = await self.vector_store.asimilarity_search_with_score( |
| query=query, k=fetch_k, filter=filter_, |
| ) |
| score_map = {doc.page_content: score for doc, score in cosine} |
| docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs] |
| elif _RETRIEVAL_METHOD == "euclidean": |
| docs_with_scores = await _euclidean_store.asimilarity_search_with_score( |
| query=query, k=fetch_k, filter=filter_, |
| ) |
| elif _RETRIEVAL_METHOD == "inner_product": |
| docs_with_scores = await _ip_store.asimilarity_search_with_score( |
| query=query, k=fetch_k, filter=filter_, |
| ) |
| else: |
| docs_with_scores = await self.vector_store.asimilarity_search_with_score( |
| query=query, k=fetch_k, filter=filter_, |
| ) |
|
|
| results = [] |
| for doc, score in docs_with_scores: |
| file_type = doc.metadata.get("data", {}).get("file_type", "") |
| if file_type not in _TABULAR_TYPES: |
| results.append(RetrievalResult( |
| content=doc.page_content, |
| metadata=doc.metadata, |
| score=score, |
| source_type="document", |
| )) |
| if len(results) == k: |
| break |
|
|
| logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results)) |
| return results |
|
|
| async def _retrieve_manhattan( |
| self, query: str, user_id: str, k: int, fetch_k: int |
| ) -> list[RetrievalResult]: |
| query_vector = await _embeddings.aembed_query(query) |
| vector_str = "[" + ",".join(str(v) for v in query_vector) + "]" |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(_MANHATTAN_SQL, { |
| "embedding": vector_str, |
| "collection": _COLLECTION_NAME, |
| "user_id": user_id, |
| "k": fetch_k, |
| }) |
| rows = result.fetchall() |
|
|
| results = [] |
| for row in rows: |
| file_type = row.cmetadata.get("data", {}).get("file_type", "") |
| if file_type not in _TABULAR_TYPES: |
| results.append(RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.distance), |
| source_type="document", |
| )) |
| if len(results) == k: |
| break |
|
|
| logger.info("retrieved chunks", method="manhattan", count=len(results)) |
| return results |
|
|
|
|
| document_retriever = DocumentRetriever() |
|
|