File size: 3,739 Bytes
61c746f
52999bc
61c746f
 
 
 
6bff5d9
 
 
c93ec90
 
8802920
 
52999bc
8802920
 
52999bc
6bff5d9
52999bc
 
 
8802920
61c746f
 
8802920
 
61c746f
8802920
61c746f
 
 
 
 
 
 
 
 
 
 
 
 
8802920
 
 
 
 
 
 
 
 
 
 
 
 
 
52999bc
 
61c746f
 
 
 
 
 
 
 
 
52999bc
61c746f
52999bc
 
8802920
6bff5d9
c93ec90
 
8802920
61c746f
 
 
 
 
8802920
 
61c746f
8802920
 
 
 
 
 
 
61c746f
 
8802920
 
 
 
 
 
 
 
 
 
 
 
 
61c746f
8802920
52999bc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""DocumentRetriever — dense similarity over prose chunks.

For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via
raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs
String mismatch, jsonb_path_match asyncpg binding quirks).
Collection `document_embeddings`. Methods: cosine | manhattan.
"""

import functools
import math

from langchain_openai import AzureOpenAIEmbeddings
from sqlalchemy import text

from src.config.settings import settings
from src.db.postgres.connection import _pgvector_engine
from src.middlewares.logging import get_logger
from src.retrieval.base import BaseRetriever, RetrievalResult

logger = get_logger("document_retriever")

# Change this one line to switch retrieval method
# Options: "cosine" | "manhattan"
_RETRIEVAL_METHOD = "cosine"

_TABULAR_TYPES = {"csv", "xlsx"}
_COLLECTION_NAME = "documents"

_COSINE_SQL = text("""
    SELECT
        lpe.document,
        lpe.cmetadata,
        lpe.embedding <=> CAST(:embedding AS vector) AS distance
    FROM langchain_pg_embedding lpe
    JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
    WHERE lpc.name = :collection
      AND lpe.cmetadata->>'user_id' = :user_id
      AND lpe.cmetadata->>'source_type' = 'document'
    ORDER BY distance ASC
    LIMIT :k
""")

_MANHATTAN_SQL = text("""
    SELECT
        lpe.document,
        lpe.cmetadata,
        lpe.embedding <+> CAST(:embedding AS vector) AS distance
    FROM langchain_pg_embedding lpe
    JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
    WHERE lpc.name = :collection
      AND lpe.cmetadata->>'user_id' = :user_id
      AND lpe.cmetadata->>'source_type' = 'document'
    ORDER BY distance ASC
    LIMIT :k
""")


@functools.cache
def _get_embeddings() -> AzureOpenAIEmbeddings:
    return AzureOpenAIEmbeddings(
        azure_deployment=settings.azureai_deployment_name_embedding,
        openai_api_version=settings.azureai_api_version_embedding,
        azure_endpoint=settings.azureai_endpoint_url_embedding,
        api_key=settings.azureai_api_key_embedding,
    )


class DocumentRetriever(BaseRetriever):
    async def retrieve(
        self, query: str, user_id: str, k: int = 5
    ) -> list[RetrievalResult]:
        query_vector = await _get_embeddings().aembed_query(query)
        if not all(math.isfinite(v) for v in query_vector):
            raise ValueError("Embedding vector contains NaN or Infinity values.")
        vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
        fetch_k = k + len(_TABULAR_TYPES)

        sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL

        logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k)

        async with _pgvector_engine.connect() as conn:
            result = await conn.execute(sql, {
                "embedding": vector_str,
                "collection": _COLLECTION_NAME,
                "user_id": user_id,
                "k": fetch_k,
            })
            rows = result.fetchall()

        logger.info("raw rows from db", row_count=len(rows))

        results = []
        for row in rows:
            file_type = row.cmetadata.get("data", {}).get("file_type", "")
            if file_type not in _TABULAR_TYPES:
                results.append(RetrievalResult(
                    content=row.document,
                    metadata=row.cmetadata,
                    score=float(row.distance),
                    source_type="document",
                ))
            if len(results) == k:
                break

        logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
        return results


document_retriever = DocumentRetriever()