ishaq101's picture
[KM-582][DED][AI] Fix Retrieval in Agentic Service
61c746f
"""DocumentRetriever — dense similarity over prose chunks.
For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via
raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs
String mismatch, jsonb_path_match asyncpg binding quirks).
Collection `document_embeddings`. Methods: cosine | manhattan.
"""
import functools
import math
from langchain_openai import AzureOpenAIEmbeddings
from sqlalchemy import text
from src.config.settings import settings
from src.db.postgres.connection import _pgvector_engine
from src.middlewares.logging import get_logger
from src.retrieval.base import BaseRetriever, RetrievalResult
logger = get_logger("document_retriever")
# Change this one line to switch retrieval method
# Options: "cosine" | "manhattan"
_RETRIEVAL_METHOD = "cosine"
_TABULAR_TYPES = {"csv", "xlsx"}
_COLLECTION_NAME = "documents"
_COSINE_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <=> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
_MANHATTAN_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <+> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
@functools.cache
def _get_embeddings() -> AzureOpenAIEmbeddings:
return AzureOpenAIEmbeddings(
azure_deployment=settings.azureai_deployment_name_embedding,
openai_api_version=settings.azureai_api_version_embedding,
azure_endpoint=settings.azureai_endpoint_url_embedding,
api_key=settings.azureai_api_key_embedding,
)
class DocumentRetriever(BaseRetriever):
async def retrieve(
self, query: str, user_id: str, k: int = 5
) -> list[RetrievalResult]:
query_vector = await _get_embeddings().aembed_query(query)
if not all(math.isfinite(v) for v in query_vector):
raise ValueError("Embedding vector contains NaN or Infinity values.")
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
fetch_k = k + len(_TABULAR_TYPES)
sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL
logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k)
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {
"embedding": vector_str,
"collection": _COLLECTION_NAME,
"user_id": user_id,
"k": fetch_k,
})
rows = result.fetchall()
logger.info("raw rows from db", row_count=len(rows))
results = []
for row in rows:
file_type = row.cmetadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.distance),
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
return results
document_retriever = DocumentRetriever()