sofhiaazzhr commited on
Commit
cf77d20
·
1 Parent(s): 3604994

[KM-455][document] decided methods retrieval for document

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/document.py +135 -13
src/rag/retrievers/document.py CHANGED
@@ -1,32 +1,154 @@
1
- """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).
2
 
3
- TEAMMATE: implement retrieve() below.
4
- Strategy: MMR (amax_marginal_relevance_search) + score threshold to avoid returning
5
- near-identical chunks from the same PDF page.
6
- Filter: source_type="document" AND data->>'file_type' NOT IN ('csv', 'xlsx')
7
- """
8
 
 
 
9
  from src.db.postgres.vector_store import get_vector_store
10
  from src.middlewares.logging import get_logger
11
  from src.rag.base import BaseRetriever, RetrievalResult
12
 
13
  logger = get_logger("document_retriever")
14
 
15
- _SCORE_THRESHOLD = 0.45 # discard chunks with cosine distance above this
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class DocumentRetriever(BaseRetriever):
19
- def __init__(self):
20
  self.vector_store = get_vector_store()
21
 
22
  async def retrieve(
23
  self, query: str, user_id: str, k: int = 5
24
  ) -> list[RetrievalResult]:
25
- # TODO (teammate): implement MMR retrieval for prose documents
26
- # Filter: {"user_id": user_id, "source_type": "document"}
27
- # then post-filter to exclude file_type in ("csv", "xlsx")
28
- logger.info("document retriever not yet implemented — returning empty")
29
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  document_retriever = DocumentRetriever()
 
1
+ """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular)."""
2
 
3
+ from langchain_postgres import PGVector
4
+ from langchain_postgres.vectorstores import DistanceStrategy
5
+ from langchain_openai import AzureOpenAIEmbeddings
6
+ from sqlalchemy import text
 
7
 
8
+ from src.config.settings import settings
9
+ from src.db.postgres.connection import _pgvector_engine
10
  from src.db.postgres.vector_store import get_vector_store
11
  from src.middlewares.logging import get_logger
12
  from src.rag.base import BaseRetriever, RetrievalResult
13
 
14
  logger = get_logger("document_retriever")
15
 
16
+ # Change this one line to switch retrieval method
17
+ # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
18
+ _RETRIEVAL_METHOD = "mmr"
19
+
20
+ _TABULAR_TYPES = {"csv", "xlsx"}
21
+ _FETCH_K = 20
22
+ _LAMBDA_MULT = 0.5
23
+ _COLLECTION_NAME = "document_embeddings"
24
+
25
+ _embeddings = AzureOpenAIEmbeddings(
26
+ azure_deployment=settings.azureai_deployment_name_embedding,
27
+ openai_api_version=settings.azureai_api_version_embedding,
28
+ azure_endpoint=settings.azureai_endpoint_url_embedding,
29
+ api_key=settings.azureai_api_key_embedding,
30
+ )
31
+
32
+ _euclidean_store = PGVector(
33
+ embeddings=_embeddings,
34
+ connection=_pgvector_engine,
35
+ collection_name=_COLLECTION_NAME,
36
+ distance_strategy=DistanceStrategy.EUCLIDEAN,
37
+ use_jsonb=True,
38
+ async_mode=True,
39
+ create_extension=False,
40
+ )
41
+
42
+ _ip_store = PGVector(
43
+ embeddings=_embeddings,
44
+ connection=_pgvector_engine,
45
+ collection_name=_COLLECTION_NAME,
46
+ distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
47
+ use_jsonb=True,
48
+ async_mode=True,
49
+ create_extension=False,
50
+ )
51
+
52
+ _MANHATTAN_SQL = text("""
53
+ SELECT
54
+ lpe.document,
55
+ lpe.cmetadata,
56
+ lpe.embedding <+> CAST(:embedding AS vector) AS distance
57
+ FROM langchain_pg_embedding lpe
58
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
59
+ WHERE lpc.name = :collection
60
+ AND lpe.cmetadata->>'user_id' = :user_id
61
+ AND lpe.cmetadata->>'source_type' = 'document'
62
+ ORDER BY distance ASC
63
+ LIMIT :k
64
+ """)
65
 
66
 
67
  class DocumentRetriever(BaseRetriever):
68
+ def __init__(self) -> None:
69
  self.vector_store = get_vector_store()
70
 
71
  async def retrieve(
72
  self, query: str, user_id: str, k: int = 5
73
  ) -> list[RetrievalResult]:
74
+ filter_ = {"user_id": user_id, "source_type": "document"}
75
+ fetch_k = k + len(_TABULAR_TYPES)
76
+
77
+ if _RETRIEVAL_METHOD == "manhattan":
78
+ return await self._retrieve_manhattan(query, user_id, k, fetch_k)
79
+
80
+ if _RETRIEVAL_METHOD == "mmr":
81
+ docs = await self.vector_store.amax_marginal_relevance_search(
82
+ query=query,
83
+ k=fetch_k,
84
+ fetch_k=_FETCH_K,
85
+ lambda_mult=_LAMBDA_MULT,
86
+ filter=filter_,
87
+ )
88
+ cosine = await self.vector_store.asimilarity_search_with_score(
89
+ query=query, k=fetch_k, filter=filter_,
90
+ )
91
+ score_map = {doc.page_content: score for doc, score in cosine}
92
+ docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
93
+ elif _RETRIEVAL_METHOD == "euclidean":
94
+ docs_with_scores = await _euclidean_store.asimilarity_search_with_score(
95
+ query=query, k=fetch_k, filter=filter_,
96
+ )
97
+ elif _RETRIEVAL_METHOD == "inner_product":
98
+ docs_with_scores = await _ip_store.asimilarity_search_with_score(
99
+ query=query, k=fetch_k, filter=filter_,
100
+ )
101
+ else: # cosine
102
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
103
+ query=query, k=fetch_k, filter=filter_,
104
+ )
105
+
106
+ results = []
107
+ for doc, score in docs_with_scores:
108
+ file_type = doc.metadata.get("data", {}).get("file_type", "")
109
+ if file_type not in _TABULAR_TYPES:
110
+ results.append(RetrievalResult(
111
+ content=doc.page_content,
112
+ metadata=doc.metadata,
113
+ score=score,
114
+ source_type="document",
115
+ ))
116
+ if len(results) == k:
117
+ break
118
+
119
+ logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
120
+ return results
121
+
122
+ async def _retrieve_manhattan(
123
+ self, query: str, user_id: str, k: int, fetch_k: int
124
+ ) -> list[RetrievalResult]:
125
+ query_vector = await _embeddings.aembed_query(query)
126
+ vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
127
+
128
+ async with _pgvector_engine.connect() as conn:
129
+ result = await conn.execute(_MANHATTAN_SQL, {
130
+ "embedding": vector_str,
131
+ "collection": _COLLECTION_NAME,
132
+ "user_id": user_id,
133
+ "k": fetch_k,
134
+ })
135
+ rows = result.fetchall()
136
+
137
+ results = []
138
+ for row in rows:
139
+ file_type = row.cmetadata.get("data", {}).get("file_type", "")
140
+ if file_type not in _TABULAR_TYPES:
141
+ results.append(RetrievalResult(
142
+ content=row.document,
143
+ metadata=row.cmetadata,
144
+ score=float(row.distance),
145
+ source_type="document",
146
+ ))
147
+ if len(results) == k:
148
+ break
149
+
150
+ logger.info("retrieved chunks", method="manhattan", count=len(results))
151
+ return results
152
 
153
 
154
  document_retriever = DocumentRetriever()