ishaq101 commited on
Commit
61c746f
·
1 Parent(s): 0e9263a

[KM-582][DED][AI] Fix Retrieval in Agentic Service

Browse files

ticket: https://bukittechnology.atlassian.net/browse/KM-582
fix: replace LangChain ORM retrieval with raw SQL and fix pgvector collection name
- Rewrite DocumentRetriever.retrieve() using raw SQL cosine/manhattan
queries instead of LangChain PGVector ORM, bypassing asyncpg type-mapping
issues (id String vs UUID column, jsonb_path_match binding quirks)
- Fix _COLLECTION_NAME from "document_embeddings" to "documents" to match
the collection name set by the Golang ingestion service
- Fix collection_name in vector_store.py to match consistently
- Fix Redis chat cache to store {response, sources} dict so cached replies
also populate message_sources table
- Add cache management endpoints: DELETE /chat/cache, /chat/cache/room/{id},
/retrieval/cache/{user_id}
- Invalidate retrieval cache automatically after document processing
- Update intent_router prompt: route topical/knowledge questions to
unstructured even without explicit document mention; prefer unstructured
when ambiguous; add Indonesian few-shot examples
- Fix logging level from WARNING to INFO so structured logs are visible
- Add page_label: null to non-PDF chunk metadata for consistency
- Add diagnostic logging in retrieve() to expose collection, user_id,
and raw row count per call

src/agents/chat_handler.py CHANGED
@@ -170,6 +170,7 @@ class ChatHandler:
170
  sources = _build_sources(
171
  decision.source_hint, user_id, query_result, raw_chunks
172
  )
 
173
  yield {"event": "sources", "data": json.dumps(sources)}
174
 
175
  # ---- 3. Stream answer ----------------------------------------
 
170
  sources = _build_sources(
171
  decision.source_hint, user_id, query_result, raw_chunks
172
  )
173
+ logger.info("built sources", source_hint=decision.source_hint, sources_count=len(sources), raw_chunks_count=len(raw_chunks) if raw_chunks else 0)
174
  yield {"event": "sources", "data": json.dumps(sources)}
175
 
176
  # ---- 3. Stream answer ----------------------------------------
src/api/v1/chat.py CHANGED
@@ -42,15 +42,19 @@ class ChatRequest(BaseModel):
42
  message: str
43
 
44
 
45
- async def get_cached_response(redis, cache_key: str) -> Optional[str]:
46
  cached = await redis.get(cache_key)
47
  if cached:
48
- return json.loads(cached)
 
 
 
 
49
  return None
50
 
51
 
52
- async def cache_response(redis, cache_key: str, response: str):
53
- await redis.setex(cache_key, 86400, json.dumps(response))
54
 
55
 
56
  async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
@@ -91,6 +95,34 @@ async def save_messages(
91
  await db.commit()
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @router.post("/chat/stream")
95
  @log_execution(logger)
96
  async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
@@ -107,13 +139,17 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
107
 
108
  # Redis cache hit
109
  cached = await get_cached_response(redis, cache_key)
 
110
  if cached:
111
  logger.info("Returning cached response")
 
 
 
112
 
113
  async def stream_cached():
114
- yield {"event": "sources", "data": json.dumps([])}
115
- for i in range(0, len(cached), 50):
116
- yield {"event": "chunk", "data": cached[i:i + 50]}
117
  yield {"event": "done", "data": ""}
118
 
119
  return EventSourceResponse(stream_cached())
@@ -122,7 +158,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
122
  # Fast intent: greetings/farewells bypass LLM entirely
123
  direct = _fast_intent(request.message)
124
  if direct:
125
- await cache_response(redis, cache_key, direct)
126
  await save_messages(db, request.room_id, request.message, direct, sources=[])
127
 
128
  async def stream_direct():
@@ -136,6 +172,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
136
  handler = ChatHandler()
137
 
138
  async def stream_response():
 
139
  full_response = ""
140
  sources: List[Dict[str, Any]] = []
141
  async for event in handler.handle(request.message, request.user_id, history):
@@ -149,8 +186,12 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
149
  full_response += event["data"]
150
  yield event
151
  elif event["event"] == "done":
152
- await cache_response(redis, cache_key, full_response)
153
- await save_messages(db, request.room_id, request.message, full_response, sources=sources)
 
 
 
 
154
  yield event
155
  elif event["event"] == "error":
156
  yield event
 
42
  message: str
43
 
44
 
45
+ async def get_cached_response(redis, cache_key: str) -> Optional[dict]:
46
  cached = await redis.get(cache_key)
47
  if cached:
48
+ data = json.loads(cached)
49
+ if isinstance(data, dict) and "response" in data:
50
+ return data
51
+ # legacy: plain string cached before this change
52
+ return {"response": data, "sources": []}
53
  return None
54
 
55
 
56
+ async def cache_response(redis, cache_key: str, response: str, sources: list):
57
+ await redis.setex(cache_key, 86400, json.dumps({"response": response, "sources": sources}))
58
 
59
 
60
  async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
 
95
  await db.commit()
96
 
97
 
98
+ @router.delete("/chat/cache")
99
+ async def clear_chat_cache(room_id: str, message: str):
100
+ """Delete the Redis cache entry for a specific room + message pair."""
101
+ redis = await get_redis()
102
+ cache_key = f"{settings.redis_prefix}chat:{room_id}:{message}"
103
+ deleted = await redis.delete(cache_key)
104
+ return {"deleted": deleted > 0, "cache_key": cache_key}
105
+
106
+
107
+ @router.delete("/chat/cache/room/{room_id}")
108
+ async def clear_room_cache(room_id: str):
109
+ """Delete all Redis cache entries for a room."""
110
+ redis = await get_redis()
111
+ pattern = f"{settings.redis_prefix}chat:{room_id}:*"
112
+ keys = await redis.keys(pattern)
113
+ if keys:
114
+ await redis.delete(*keys)
115
+ return {"deleted_count": len(keys), "room_id": room_id}
116
+
117
+
118
+ @router.delete("/retrieval/cache/{user_id}")
119
+ async def clear_retrieval_cache(user_id: str):
120
+ """Delete all cached retrieval results for a user. Call this after uploading/processing new documents."""
121
+ from src.retrieval.router import retrieval_router
122
+ deleted = await retrieval_router.invalidate_cache(user_id)
123
+ return {"deleted_count": deleted, "user_id": user_id}
124
+
125
+
126
  @router.post("/chat/stream")
127
  @log_execution(logger)
128
  async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
 
139
 
140
  # Redis cache hit
141
  cached = await get_cached_response(redis, cache_key)
142
+ logger.info("cache check", cache_key=cache_key, cache_hit=cached is not None)
143
  if cached:
144
  logger.info("Returning cached response")
145
+ cached_text = cached["response"]
146
+ cached_sources = cached["sources"]
147
+ await save_messages(db, request.room_id, request.message, cached_text, sources=cached_sources)
148
 
149
  async def stream_cached():
150
+ yield {"event": "sources", "data": json.dumps(cached_sources)}
151
+ for i in range(0, len(cached_text), 50):
152
+ yield {"event": "chunk", "data": cached_text[i:i + 50]}
153
  yield {"event": "done", "data": ""}
154
 
155
  return EventSourceResponse(stream_cached())
 
158
  # Fast intent: greetings/farewells bypass LLM entirely
159
  direct = _fast_intent(request.message)
160
  if direct:
161
+ await cache_response(redis, cache_key, direct, sources=[])
162
  await save_messages(db, request.room_id, request.message, direct, sources=[])
163
 
164
  async def stream_direct():
 
172
  handler = ChatHandler()
173
 
174
  async def stream_response():
175
+ logger.info("stream_response started", room_id=request.room_id, user_id=request.user_id)
176
  full_response = ""
177
  sources: List[Dict[str, Any]] = []
178
  async for event in handler.handle(request.message, request.user_id, history):
 
186
  full_response += event["data"]
187
  yield event
188
  elif event["event"] == "done":
189
+ await cache_response(redis, cache_key, full_response, sources=sources)
190
+ logger.info("saving messages", sources_count=len(sources), sources=sources)
191
+ try:
192
+ await save_messages(db, request.room_id, request.message, full_response, sources=sources)
193
+ except Exception as e:
194
+ logger.error("save_messages failed", room_id=request.room_id, error=str(e))
195
  yield event
196
  elif event["event"] == "error":
197
  yield event
src/api/v1/document.py CHANGED
@@ -114,5 +114,8 @@ async def process_document(
114
  except Exception as e:
115
  logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e))
116
 
 
 
 
117
  return {"status": "success", "message": "Document processed successfully", "data": data}
118
 
 
114
  except Exception as e:
115
  logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e))
116
 
117
+ from src.retrieval.router import retrieval_router
118
+ await retrieval_router.invalidate_cache(user_id)
119
+
120
  return {"status": "success", "message": "Document processed successfully", "data": data}
121
 
src/config/prompts/intent_router.md CHANGED
@@ -7,16 +7,16 @@ Return three fields:
7
  - **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself.
8
  - **`source_hint`** — one of:
9
  - `chat` — no data lookup needed (greetings, farewells, generic small talk).
10
- - `unstructured` — the user is asking about the **content** of an uploaded document (PDF / DOCX / TXT).
11
  - `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources.
12
  - **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null.
13
 
14
  ## Routing rules
15
 
16
- 1. If the message is a pure greeting / farewell / thanks / "how are you" / "what can you do" → `chat` + `needs_search=false`.
17
- 2. If the message references content that lives in a registered DB or uploaded tabular file (sales numbers, customer counts, order trends, sheet rows, table columns) → `structured` + `needs_search=true`.
18
- 3. If the message asks about prose content (a section of a PDF, what a memo says, a quote from a document) `unstructured` + `needs_search=true`.
19
- 4. If the message is ambiguous between structured and unstructured, prefer `structured` the planner can fall back if the catalog has nothing relevant.
20
  5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate.
21
 
22
  ## Rewriting follow-ups
@@ -53,6 +53,22 @@ User: "Top 5 customers by revenue this year"
53
  → needs_search=true, source_hint="structured",
54
  rewritten_query="Top 5 customers by revenue this year"
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  History: assistant: "Pro Plan Annual led at $487,200 in April."
57
  User: "And in March?"
58
  → needs_search=true, source_hint="structured",
@@ -61,6 +77,6 @@ User: "And in March?"
61
 
62
  ## Constraints
63
 
64
- - Do not invent data. If you don't know whether a topic exists in the user's data, route to `structured` and let the planner decide.
65
  - Do not refuse — refusal happens later in guardrails. Just classify.
66
  - One JSON object as output; no prose, no markdown.
 
7
  - **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself.
8
  - **`source_hint`** — one of:
9
  - `chat` — no data lookup needed (greetings, farewells, generic small talk).
10
+ - `unstructured` — the user is asking about a topic, concept, feature, or factual knowledge that may exist in uploaded documents (PDF / DOCX / TXT). The user does not need to explicitly mention a document.
11
  - `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources.
12
  - **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null.
13
 
14
  ## Routing rules
15
 
16
+ 1. If the message is ONLY a pure greeting / farewell / thanks / "how are you" / "what can you do" / compliment with no factual question → `chat` + `needs_search=false`.
17
+ 2. If the message asks a data question answerable from a database or tabular file (counts, sums, top-N, filters, comparisons, trends, sheet rows, table columns) → `structured` + `needs_search=true`.
18
+ 3. If the message asks about a topic, concept, feature, explanation, summary, or factual knowledge even without explicitly mentioning a document route to `unstructured` + `needs_search=true`. The user may have uploaded relevant documents covering that topic.
19
+ 4. If ambiguous between structured and unstructured prefer `unstructured`. Only prefer `structured` if there are clear signals of tabular/numeric data questions.
20
  5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate.
21
 
22
  ## Rewriting follow-ups
 
53
  → needs_search=true, source_hint="structured",
54
  rewritten_query="Top 5 customers by revenue this year"
55
 
56
+ User: "apa key feature dari iot connectivity?"
57
+ → needs_search=true, source_hint="unstructured",
58
+ rewritten_query="What are the key features of IoT connectivity?"
59
+
60
+ User: "jelaskan tentang machine learning"
61
+ → needs_search=true, source_hint="unstructured",
62
+ rewritten_query="Explain machine learning"
63
+
64
+ User: "bagaimana cara kerja neural network?"
65
+ → needs_search=true, source_hint="unstructured",
66
+ rewritten_query="How does a neural network work?"
67
+
68
+ User: "what is the main purpose of this system?"
69
+ → needs_search=true, source_hint="unstructured",
70
+ rewritten_query="What is the main purpose of this system?"
71
+
72
  History: assistant: "Pro Plan Annual led at $487,200 in April."
73
  User: "And in March?"
74
  → needs_search=true, source_hint="structured",
 
77
 
78
  ## Constraints
79
 
80
+ - Do not invent data. If the question is factual or knowledge-based (not clearly tabular), route to `unstructured` and let the retriever decide. Only route to `structured` if the question clearly involves counts, sums, filters, or trends from tabular sources.
81
  - Do not refuse — refusal happens later in guardrails. Just classify.
82
  - One JSON object as output; no prose, no markdown.
src/db/postgres/vector_store.py CHANGED
@@ -19,7 +19,7 @@ embeddings = AzureOpenAIEmbeddings(
19
  vector_store = PGVector(
20
  embeddings=embeddings,
21
  connection=_pgvector_engine,
22
- collection_name="document_embeddings",
23
  use_jsonb=True,
24
  async_mode=True,
25
  create_extension=False, # Extension pre-created in init_db.py (avoids multi-statement asyncpg bug)
 
19
  vector_store = PGVector(
20
  embeddings=embeddings,
21
  connection=_pgvector_engine,
22
+ collection_name="documents",
23
  use_jsonb=True,
24
  async_mode=True,
25
  create_extension=False, # Extension pre-created in init_db.py (avoids multi-statement asyncpg bug)
src/knowledge/processing_service.py CHANGED
@@ -59,6 +59,7 @@ class KnowledgeProcessingService:
59
  "filename": db_doc.filename,
60
  "file_type": db_doc.file_type,
61
  "chunk_index": i,
 
62
  },
63
  }
64
  )
 
59
  "filename": db_doc.filename,
60
  "file_type": db_doc.file_type,
61
  "chunk_index": i,
62
+ "page_label": None,
63
  },
64
  }
65
  )
src/middlewares/logging.py CHANGED
@@ -9,7 +9,7 @@ import time
9
 
10
  def configure_logging():
11
  """Configure structured logging."""
12
- logging.basicConfig(level=logging.WARNING)
13
  logging.getLogger("tabular_executor").setLevel(logging.INFO)
14
  structlog.configure(
15
  processors=[
 
9
 
10
  def configure_logging():
11
  """Configure structured logging."""
12
+ logging.basicConfig(level=logging.INFO)
13
  logging.getLogger("tabular_executor").setLevel(logging.INFO)
14
  structlog.configure(
15
  processors=[
src/retrieval/document.py CHANGED
@@ -1,68 +1,44 @@
1
- """DocumentRetriever — dense similarity over prose chunks (Cu).
2
 
3
- For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector with
4
- collection `document_embeddings`. Methods: MMR, cosine, euclidean, etc.
 
 
5
  """
6
 
7
  import functools
8
  import math
9
 
10
- from langchain_postgres import PGVector
11
- from langchain_postgres.vectorstores import DistanceStrategy
12
  from langchain_openai import AzureOpenAIEmbeddings
13
  from sqlalchemy import text
14
 
15
  from src.config.settings import settings
16
  from src.db.postgres.connection import _pgvector_engine
17
- from src.db.postgres.vector_store import get_vector_store
18
  from src.middlewares.logging import get_logger
19
  from src.retrieval.base import BaseRetriever, RetrievalResult
20
 
21
  logger = get_logger("document_retriever")
22
 
23
  # Change this one line to switch retrieval method
24
- # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
25
- _RETRIEVAL_METHOD = "mmr"
26
 
27
  _TABULAR_TYPES = {"csv", "xlsx"}
28
- _FETCH_K = 20
29
- _LAMBDA_MULT = 0.5
30
- _COLLECTION_NAME = "document_embeddings"
31
 
32
- @functools.cache
33
- def _get_embeddings() -> AzureOpenAIEmbeddings:
34
- return AzureOpenAIEmbeddings(
35
- azure_deployment=settings.azureai_deployment_name_embedding,
36
- openai_api_version=settings.azureai_api_version_embedding,
37
- azure_endpoint=settings.azureai_endpoint_url_embedding,
38
- api_key=settings.azureai_api_key_embedding,
39
- )
40
-
41
-
42
- @functools.cache
43
- def _get_euclidean_store() -> PGVector:
44
- return PGVector(
45
- embeddings=_get_embeddings(),
46
- connection=_pgvector_engine,
47
- collection_name=_COLLECTION_NAME,
48
- distance_strategy=DistanceStrategy.EUCLIDEAN,
49
- use_jsonb=True,
50
- async_mode=True,
51
- create_extension=False,
52
- )
53
-
54
-
55
- @functools.cache
56
- def _get_ip_store() -> PGVector:
57
- return PGVector(
58
- embeddings=_get_embeddings(),
59
- connection=_pgvector_engine,
60
- collection_name=_COLLECTION_NAME,
61
- distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
62
- use_jsonb=True,
63
- async_mode=True,
64
- create_extension=False,
65
- )
66
 
67
  _MANHATTAN_SQL = text("""
68
  SELECT
@@ -79,71 +55,32 @@ _MANHATTAN_SQL = text("""
79
  """)
80
 
81
 
82
- class DocumentRetriever(BaseRetriever):
83
- def __init__(self) -> None:
84
- self.vector_store = get_vector_store()
 
 
 
 
 
 
85
 
 
86
  async def retrieve(
87
  self, query: str, user_id: str, k: int = 5
88
- ) -> list[RetrievalResult]:
89
- filter_ = {"user_id": user_id, "source_type": "document"}
90
- fetch_k = k + len(_TABULAR_TYPES)
91
-
92
- if _RETRIEVAL_METHOD == "manhattan":
93
- return await self._retrieve_manhattan(query, user_id, k, fetch_k)
94
-
95
- if _RETRIEVAL_METHOD == "mmr":
96
- docs = await self.vector_store.amax_marginal_relevance_search(
97
- query=query,
98
- k=fetch_k,
99
- fetch_k=_FETCH_K,
100
- lambda_mult=_LAMBDA_MULT,
101
- filter=filter_,
102
- )
103
- cosine = await self.vector_store.asimilarity_search_with_score(
104
- query=query, k=fetch_k, filter=filter_,
105
- )
106
- score_map = {doc.page_content: score for doc, score in cosine}
107
- docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
108
- elif _RETRIEVAL_METHOD == "euclidean":
109
- docs_with_scores = await _get_euclidean_store().asimilarity_search_with_score(
110
- query=query, k=fetch_k, filter=filter_,
111
- )
112
- elif _RETRIEVAL_METHOD == "inner_product":
113
- docs_with_scores = await _get_ip_store().asimilarity_search_with_score(
114
- query=query, k=fetch_k, filter=filter_,
115
- )
116
- else: # cosine
117
- docs_with_scores = await self.vector_store.asimilarity_search_with_score(
118
- query=query, k=fetch_k, filter=filter_,
119
- )
120
-
121
- results = []
122
- for doc, score in docs_with_scores:
123
- file_type = doc.metadata.get("data", {}).get("file_type", "")
124
- if file_type not in _TABULAR_TYPES:
125
- results.append(RetrievalResult(
126
- content=doc.page_content,
127
- metadata=doc.metadata,
128
- score=score,
129
- source_type="document",
130
- ))
131
- if len(results) == k:
132
- break
133
-
134
- logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
135
- return results
136
-
137
- async def _retrieve_manhattan(
138
- self, query: str, user_id: str, k: int, fetch_k: int
139
  ) -> list[RetrievalResult]:
140
  query_vector = await _get_embeddings().aembed_query(query)
141
  if not all(math.isfinite(v) for v in query_vector):
142
  raise ValueError("Embedding vector contains NaN or Infinity values.")
143
  vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
 
 
 
 
 
144
 
145
  async with _pgvector_engine.connect() as conn:
146
- result = await conn.execute(_MANHATTAN_SQL, {
147
  "embedding": vector_str,
148
  "collection": _COLLECTION_NAME,
149
  "user_id": user_id,
@@ -151,6 +88,8 @@ class DocumentRetriever(BaseRetriever):
151
  })
152
  rows = result.fetchall()
153
 
 
 
154
  results = []
155
  for row in rows:
156
  file_type = row.cmetadata.get("data", {}).get("file_type", "")
@@ -164,7 +103,7 @@ class DocumentRetriever(BaseRetriever):
164
  if len(results) == k:
165
  break
166
 
167
- logger.info("retrieved chunks", method="manhattan", count=len(results))
168
  return results
169
 
170
 
 
1
+ """DocumentRetriever — dense similarity over prose chunks.
2
 
3
+ For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via
4
+ raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs
5
+ String mismatch, jsonb_path_match asyncpg binding quirks).
6
+ Collection `document_embeddings`. Methods: cosine | manhattan.
7
  """
8
 
9
  import functools
10
  import math
11
 
 
 
12
  from langchain_openai import AzureOpenAIEmbeddings
13
  from sqlalchemy import text
14
 
15
  from src.config.settings import settings
16
  from src.db.postgres.connection import _pgvector_engine
 
17
  from src.middlewares.logging import get_logger
18
  from src.retrieval.base import BaseRetriever, RetrievalResult
19
 
20
  logger = get_logger("document_retriever")
21
 
22
  # Change this one line to switch retrieval method
23
+ # Options: "cosine" | "manhattan"
24
+ _RETRIEVAL_METHOD = "cosine"
25
 
26
  _TABULAR_TYPES = {"csv", "xlsx"}
27
+ _COLLECTION_NAME = "documents"
 
 
28
 
29
+ _COSINE_SQL = text("""
30
+ SELECT
31
+ lpe.document,
32
+ lpe.cmetadata,
33
+ lpe.embedding <=> CAST(:embedding AS vector) AS distance
34
+ FROM langchain_pg_embedding lpe
35
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
36
+ WHERE lpc.name = :collection
37
+ AND lpe.cmetadata->>'user_id' = :user_id
38
+ AND lpe.cmetadata->>'source_type' = 'document'
39
+ ORDER BY distance ASC
40
+ LIMIT :k
41
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  _MANHATTAN_SQL = text("""
44
  SELECT
 
55
  """)
56
 
57
 
58
+ @functools.cache
59
+ def _get_embeddings() -> AzureOpenAIEmbeddings:
60
+ return AzureOpenAIEmbeddings(
61
+ azure_deployment=settings.azureai_deployment_name_embedding,
62
+ openai_api_version=settings.azureai_api_version_embedding,
63
+ azure_endpoint=settings.azureai_endpoint_url_embedding,
64
+ api_key=settings.azureai_api_key_embedding,
65
+ )
66
+
67
 
68
+ class DocumentRetriever(BaseRetriever):
69
  async def retrieve(
70
  self, query: str, user_id: str, k: int = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ) -> list[RetrievalResult]:
72
  query_vector = await _get_embeddings().aembed_query(query)
73
  if not all(math.isfinite(v) for v in query_vector):
74
  raise ValueError("Embedding vector contains NaN or Infinity values.")
75
  vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
76
+ fetch_k = k + len(_TABULAR_TYPES)
77
+
78
+ sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL
79
+
80
+ logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k)
81
 
82
  async with _pgvector_engine.connect() as conn:
83
+ result = await conn.execute(sql, {
84
  "embedding": vector_str,
85
  "collection": _COLLECTION_NAME,
86
  "user_id": user_id,
 
88
  })
89
  rows = result.fetchall()
90
 
91
+ logger.info("raw rows from db", row_count=len(rows))
92
+
93
  results = []
94
  for row in rows:
95
  file_type = row.cmetadata.get("data", {}).get("file_type", "")
 
103
  if len(results) == k:
104
  break
105
 
106
+ logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
107
  return results
108
 
109