File size: 6,056 Bytes
6bff5d9 52999bc 6bff5d9 c93ec90 8802920 52999bc 8802920 52999bc 6bff5d9 52999bc 8802920 6bff5d9 8802920 52999bc 8802920 52999bc 8802920 6bff5d9 8802920 6bff5d9 8802920 6bff5d9 c93ec90 8802920 52999bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """DocumentRetriever — dense similarity over prose chunks (Cu).
For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector with
collection `document_embeddings`. Methods: MMR, cosine, euclidean, etc.
"""
import functools
import math
from langchain_postgres import PGVector
from langchain_postgres.vectorstores import DistanceStrategy
from langchain_openai import AzureOpenAIEmbeddings
from sqlalchemy import text
from src.config.settings import settings
from src.db.postgres.connection import _pgvector_engine
from src.db.postgres.vector_store import get_vector_store
from src.middlewares.logging import get_logger
from src.retrieval.base import BaseRetriever, RetrievalResult
logger = get_logger("document_retriever")
# Change this one line to switch retrieval method
# Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
_RETRIEVAL_METHOD = "mmr"
_TABULAR_TYPES = {"csv", "xlsx"}
_FETCH_K = 20
_LAMBDA_MULT = 0.5
_COLLECTION_NAME = "document_embeddings"
@functools.cache
def _get_embeddings() -> AzureOpenAIEmbeddings:
return AzureOpenAIEmbeddings(
azure_deployment=settings.azureai_deployment_name_embedding,
openai_api_version=settings.azureai_api_version_embedding,
azure_endpoint=settings.azureai_endpoint_url_embedding,
api_key=settings.azureai_api_key_embedding,
)
@functools.cache
def _get_euclidean_store() -> PGVector:
return PGVector(
embeddings=_get_embeddings(),
connection=_pgvector_engine,
collection_name=_COLLECTION_NAME,
distance_strategy=DistanceStrategy.EUCLIDEAN,
use_jsonb=True,
async_mode=True,
create_extension=False,
)
@functools.cache
def _get_ip_store() -> PGVector:
return PGVector(
embeddings=_get_embeddings(),
connection=_pgvector_engine,
collection_name=_COLLECTION_NAME,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
use_jsonb=True,
async_mode=True,
create_extension=False,
)
_MANHATTAN_SQL = text("""
SELECT
lpe.document,
lpe.cmetadata,
lpe.embedding <+> CAST(:embedding AS vector) AS distance
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = :collection
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
ORDER BY distance ASC
LIMIT :k
""")
class DocumentRetriever(BaseRetriever):
def __init__(self) -> None:
self.vector_store = get_vector_store()
async def retrieve(
self, query: str, user_id: str, k: int = 5
) -> list[RetrievalResult]:
filter_ = {"user_id": user_id, "source_type": "document"}
fetch_k = k + len(_TABULAR_TYPES)
if _RETRIEVAL_METHOD == "manhattan":
return await self._retrieve_manhattan(query, user_id, k, fetch_k)
if _RETRIEVAL_METHOD == "mmr":
docs = await self.vector_store.amax_marginal_relevance_search(
query=query,
k=fetch_k,
fetch_k=_FETCH_K,
lambda_mult=_LAMBDA_MULT,
filter=filter_,
)
cosine = await self.vector_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
score_map = {doc.page_content: score for doc, score in cosine}
docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
elif _RETRIEVAL_METHOD == "euclidean":
docs_with_scores = await _get_euclidean_store().asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
elif _RETRIEVAL_METHOD == "inner_product":
docs_with_scores = await _get_ip_store().asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
else: # cosine
docs_with_scores = await self.vector_store.asimilarity_search_with_score(
query=query, k=fetch_k, filter=filter_,
)
results = []
for doc, score in docs_with_scores:
file_type = doc.metadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=doc.page_content,
metadata=doc.metadata,
score=score,
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
return results
async def _retrieve_manhattan(
self, query: str, user_id: str, k: int, fetch_k: int
) -> list[RetrievalResult]:
query_vector = await _get_embeddings().aembed_query(query)
if not all(math.isfinite(v) for v in query_vector):
raise ValueError("Embedding vector contains NaN or Infinity values.")
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
async with _pgvector_engine.connect() as conn:
result = await conn.execute(_MANHATTAN_SQL, {
"embedding": vector_str,
"collection": _COLLECTION_NAME,
"user_id": user_id,
"k": fetch_k,
})
rows = result.fetchall()
results = []
for row in rows:
file_type = row.cmetadata.get("data", {}).get("file_type", "")
if file_type not in _TABULAR_TYPES:
results.append(RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.distance),
source_type="document",
))
if len(results) == k:
break
logger.info("retrieved chunks", method="manhattan", count=len(results))
return results
document_retriever = DocumentRetriever()
|