sofhiaazzhr's picture
[KM-455][document] decided methods retrieval for document
8c9cc79
raw
history blame
5.48 kB
"""Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular)."""
from langchain_postgres import PGVector
from langchain_postgres.vectorstores import DistanceStrategy
from langchain_openai import AzureOpenAIEmbeddings
from sqlalchemy import text
from src.config.settings import settings
from src.db.postgres.connection import _pgvector_engine
from src.db.postgres.vector_store import get_vector_store
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult
logger = get_logger("document_retriever")
# Change this one line to switch retrieval method
# Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
_RETRIEVAL_METHOD = "mmr"
_TABULAR_TYPES = {"csv", "xlsx"}
_FETCH_K = 20
_LAMBDA_MULT = 0.5
_COLLECTION_NAME = "document_embeddings"
_embeddings = AzureOpenAIEmbeddings(
azure_deployment=settings.azureai_deployment_name_embedding,
openai_api_version=settings.azureai_api_version_embedding,
azure_endpoint=settings.azureai_endpoint_url_embedding,
api_key=settings.azureai_api_key_embedding,
)
_euclidean_store = PGVector(
embeddings=_embeddings,
connection=_pgvector_engine,
collection_name=_COLLECTION_NAME,
distance_strategy=DistanceStrategy.EUCLIDEAN,
use_jsonb=True,
async_mode=True,
create_extension=False,
)
_ip_store = PGVector(
embeddings=_embeddings,
connection=_pgvector_engine,
collection_name=_COLLECTION_NAME,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
use_jsonb=True,
async_mode=True,
create_extension=False,
)
_MANHATTAN_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <+> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
class DocumentRetriever(BaseRetriever):
def __init__(self) -> None:
self.vector_store = get_vector_store()
async def retrieve(
self, query: str, user_id: str, k: int = 5
) -> list[RetrievalResult]:
filter_ = {"user_id": user_id, "source_type": "document"}
fetch_k = k + len(_TABULAR_TYPES)
if _RETRIEVAL_METHOD == "manhattan":
return await self._retrieve_manhattan(query, user_id, k, fetch_k)
if _RETRIEVAL_METHOD == "mmr":
docs = await self.vector_store.amax_marginal_relevance_search(
query=query,
k=fetch_k,
fetch_k=_FETCH_K,
lambda_mult=_LAMBDA_MULT,
filter=filter_,
)
cosine = await self.vector_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
score_map = {doc.page_content: score for doc, score in cosine}
docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
elif _RETRIEVAL_METHOD == "euclidean":
docs_with_scores = await _euclidean_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
elif _RETRIEVAL_METHOD == "inner_product":
docs_with_scores = await _ip_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
else: # cosine
docs_with_scores = await self.vector_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
results = []
for doc, score in docs_with_scores:
file_type = doc.metadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=doc.page_content,
metadata=doc.metadata,
score=score,
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
return results
async def _retrieve_manhattan(
self, query: str, user_id: str, k: int, fetch_k: int
) -> list[RetrievalResult]:
query_vector = await _embeddings.aembed_query(query)
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
async with _pgvector_engine.connect() as conn:
result = await conn.execute(_MANHATTAN_SQL, {
"embedding": vector_str,
"collection": _COLLECTION_NAME,
"user_id": user_id,
"k": fetch_k,
})
rows = result.fetchall()
results = []
for row in rows:
file_type = row.cmetadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.distance),
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method="manhattan", count=len(results))
return results
document_retriever = DocumentRetriever()