| """DocumentRetriever — dense similarity over prose chunks. |
| |
| For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via |
| raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs |
| String mismatch, jsonb_path_match asyncpg binding quirks). |
| Collection `document_embeddings`. Methods: cosine | manhattan. |
| """ |
|
|
| import functools |
| import math |
|
|
| from langchain_openai import AzureOpenAIEmbeddings |
| from sqlalchemy import text |
|
|
| from src.config.settings import settings |
| from src.db.postgres.connection import _pgvector_engine |
| from src.middlewares.logging import get_logger |
| from src.retrieval.base import BaseRetriever, RetrievalResult |
|
|
| logger = get_logger("document_retriever") |
|
|
| |
| |
| _RETRIEVAL_METHOD = "cosine" |
|
|
| _TABULAR_TYPES = {"csv", "xlsx"} |
| _COLLECTION_NAME = "documents" |
|
|
| _COSINE_SQL = text(""" |
| SELECT |
| lpe.document, |
| lpe.cmetadata, |
| lpe.embedding <=> CAST(:embedding AS vector) AS distance |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = :collection |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| ORDER BY distance ASC |
| LIMIT :k |
| """) |
|
|
| _MANHATTAN_SQL = text(""" |
| SELECT |
| lpe.document, |
| lpe.cmetadata, |
| lpe.embedding <+> CAST(:embedding AS vector) AS distance |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = :collection |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| ORDER BY distance ASC |
| LIMIT :k |
| """) |
|
|
|
|
| @functools.cache |
| def _get_embeddings() -> AzureOpenAIEmbeddings: |
| return AzureOpenAIEmbeddings( |
| azure_deployment=settings.azureai_deployment_name_embedding, |
| openai_api_version=settings.azureai_api_version_embedding, |
| azure_endpoint=settings.azureai_endpoint_url_embedding, |
| api_key=settings.azureai_api_key_embedding, |
| ) |
|
|
|
|
| class DocumentRetriever(BaseRetriever): |
| async def retrieve( |
| self, query: str, user_id: str, k: int = 5 |
| ) -> list[RetrievalResult]: |
| query_vector = await _get_embeddings().aembed_query(query) |
| if not all(math.isfinite(v) for v in query_vector): |
| raise ValueError("Embedding vector contains NaN or Infinity values.") |
| vector_str = "[" + ",".join(str(v) for v in query_vector) + "]" |
| fetch_k = k + len(_TABULAR_TYPES) |
|
|
| sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL |
|
|
| logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, { |
| "embedding": vector_str, |
| "collection": _COLLECTION_NAME, |
| "user_id": user_id, |
| "k": fetch_k, |
| }) |
| rows = result.fetchall() |
|
|
| logger.info("raw rows from db", row_count=len(rows)) |
|
|
| results = [] |
| for row in rows: |
| file_type = row.cmetadata.get("data", {}).get("file_type", "") |
| if file_type not in _TABULAR_TYPES: |
| results.append(RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.distance), |
| source_type="document", |
| )) |
| if len(results) == k: |
| break |
|
|
| logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results)) |
| return results |
|
|
|
|
| document_retriever = DocumentRetriever() |
|
|