File size: 3,739 Bytes
61c746f 52999bc 61c746f 6bff5d9 c93ec90 8802920 52999bc 8802920 52999bc 6bff5d9 52999bc 8802920 61c746f 8802920 61c746f 8802920 61c746f 8802920 52999bc 61c746f 52999bc 61c746f 52999bc 8802920 6bff5d9 c93ec90 8802920 61c746f 8802920 61c746f 8802920 61c746f 8802920 61c746f 8802920 52999bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | """DocumentRetriever — dense similarity over prose chunks.
For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via
raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs
String mismatch, jsonb_path_match asyncpg binding quirks).
Collection `document_embeddings`. Methods: cosine | manhattan.
"""
import functools
import math
from langchain_openai import AzureOpenAIEmbeddings
from sqlalchemy import text
from src.config.settings import settings
from src.db.postgres.connection import _pgvector_engine
from src.middlewares.logging import get_logger
from src.retrieval.base import BaseRetriever, RetrievalResult
logger = get_logger("document_retriever")
# Change this one line to switch retrieval method
# Options: "cosine" | "manhattan"
_RETRIEVAL_METHOD = "cosine"
_TABULAR_TYPES = {"csv", "xlsx"}
_COLLECTION_NAME = "documents"
_COSINE_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <=> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
_MANHATTAN_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <+> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
@functools.cache
def _get_embeddings() -> AzureOpenAIEmbeddings:
return AzureOpenAIEmbeddings(
azure_deployment=settings.azureai_deployment_name_embedding,
openai_api_version=settings.azureai_api_version_embedding,
azure_endpoint=settings.azureai_endpoint_url_embedding,
api_key=settings.azureai_api_key_embedding,
)
class DocumentRetriever(BaseRetriever):
async def retrieve(
self, query: str, user_id: str, k: int = 5
) -> list[RetrievalResult]:
query_vector = await _get_embeddings().aembed_query(query)
if not all(math.isfinite(v) for v in query_vector):
raise ValueError("Embedding vector contains NaN or Infinity values.")
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
fetch_k = k + len(_TABULAR_TYPES)
sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL
logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k)
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {
"embedding": vector_str,
"collection": _COLLECTION_NAME,
"user_id": user_id,
"k": fetch_k,
})
rows = result.fetchall()
logger.info("raw rows from db", row_count=len(rows))
results = []
for row in rows:
file_type = row.cmetadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.distance),
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
return results
document_retriever = DocumentRetriever()
|