[KM-582][DED][AI] Fix Retrieval in Agentic Service
Browse filesticket: https://bukittechnology.atlassian.net/browse/KM-582
fix: replace LangChain ORM retrieval with raw SQL and fix pgvector collection name
- Rewrite DocumentRetriever.retrieve() using raw SQL cosine/manhattan
queries instead of LangChain PGVector ORM, bypassing asyncpg type-mapping
issues (id String vs UUID column, jsonb_path_match binding quirks)
- Fix _COLLECTION_NAME from "document_embeddings" to "documents" to match
the collection name set by the Golang ingestion service
- Fix collection_name in vector_store.py to match consistently
- Fix Redis chat cache to store {response, sources} dict so cached replies
also populate message_sources table
- Add cache management endpoints: DELETE /chat/cache, /chat/cache/room/{id},
/retrieval/cache/{user_id}
- Invalidate retrieval cache automatically after document processing
- Update intent_router prompt: route topical/knowledge questions to
unstructured even without explicit document mention; prefer unstructured
when ambiguous; add Indonesian few-shot examples
- Fix logging level from WARNING to INFO so structured logs are visible
- Add page_label: null to non-PDF chunk metadata for consistency
- Add diagnostic logging in retrieve() to expose collection, user_id,
and raw row count per call
- src/agents/chat_handler.py +1 -0
- src/api/v1/chat.py +51 -10
- src/api/v1/document.py +3 -0
- src/config/prompts/intent_router.md +22 -6
- src/db/postgres/vector_store.py +1 -1
- src/knowledge/processing_service.py +1 -0
- src/middlewares/logging.py +1 -1
- src/retrieval/document.py +40 -101
|
@@ -170,6 +170,7 @@ class ChatHandler:
|
|
| 170 |
sources = _build_sources(
|
| 171 |
decision.source_hint, user_id, query_result, raw_chunks
|
| 172 |
)
|
|
|
|
| 173 |
yield {"event": "sources", "data": json.dumps(sources)}
|
| 174 |
|
| 175 |
# ---- 3. Stream answer ----------------------------------------
|
|
|
|
| 170 |
sources = _build_sources(
|
| 171 |
decision.source_hint, user_id, query_result, raw_chunks
|
| 172 |
)
|
| 173 |
+
logger.info("built sources", source_hint=decision.source_hint, sources_count=len(sources), raw_chunks_count=len(raw_chunks) if raw_chunks else 0)
|
| 174 |
yield {"event": "sources", "data": json.dumps(sources)}
|
| 175 |
|
| 176 |
# ---- 3. Stream answer ----------------------------------------
|
|
@@ -42,15 +42,19 @@ class ChatRequest(BaseModel):
|
|
| 42 |
message: str
|
| 43 |
|
| 44 |
|
| 45 |
-
async def get_cached_response(redis, cache_key: str) -> Optional[
|
| 46 |
cached = await redis.get(cache_key)
|
| 47 |
if cached:
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
return None
|
| 50 |
|
| 51 |
|
| 52 |
-
async def cache_response(redis, cache_key: str, response: str):
|
| 53 |
-
await redis.setex(cache_key, 86400, json.dumps(response))
|
| 54 |
|
| 55 |
|
| 56 |
async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
|
|
@@ -91,6 +95,34 @@ async def save_messages(
|
|
| 91 |
await db.commit()
|
| 92 |
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
@router.post("/chat/stream")
|
| 95 |
@log_execution(logger)
|
| 96 |
async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
@@ -107,13 +139,17 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 107 |
|
| 108 |
# Redis cache hit
|
| 109 |
cached = await get_cached_response(redis, cache_key)
|
|
|
|
| 110 |
if cached:
|
| 111 |
logger.info("Returning cached response")
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
async def stream_cached():
|
| 114 |
-
yield {"event": "sources", "data": json.dumps(
|
| 115 |
-
for i in range(0, len(
|
| 116 |
-
yield {"event": "chunk", "data":
|
| 117 |
yield {"event": "done", "data": ""}
|
| 118 |
|
| 119 |
return EventSourceResponse(stream_cached())
|
|
@@ -122,7 +158,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 122 |
# Fast intent: greetings/farewells bypass LLM entirely
|
| 123 |
direct = _fast_intent(request.message)
|
| 124 |
if direct:
|
| 125 |
-
await cache_response(redis, cache_key, direct)
|
| 126 |
await save_messages(db, request.room_id, request.message, direct, sources=[])
|
| 127 |
|
| 128 |
async def stream_direct():
|
|
@@ -136,6 +172,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 136 |
handler = ChatHandler()
|
| 137 |
|
| 138 |
async def stream_response():
|
|
|
|
| 139 |
full_response = ""
|
| 140 |
sources: List[Dict[str, Any]] = []
|
| 141 |
async for event in handler.handle(request.message, request.user_id, history):
|
|
@@ -149,8 +186,12 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 149 |
full_response += event["data"]
|
| 150 |
yield event
|
| 151 |
elif event["event"] == "done":
|
| 152 |
-
await cache_response(redis, cache_key, full_response)
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
yield event
|
| 155 |
elif event["event"] == "error":
|
| 156 |
yield event
|
|
|
|
| 42 |
message: str
|
| 43 |
|
| 44 |
|
| 45 |
+
async def get_cached_response(redis, cache_key: str) -> Optional[dict]:
|
| 46 |
cached = await redis.get(cache_key)
|
| 47 |
if cached:
|
| 48 |
+
data = json.loads(cached)
|
| 49 |
+
if isinstance(data, dict) and "response" in data:
|
| 50 |
+
return data
|
| 51 |
+
# legacy: plain string cached before this change
|
| 52 |
+
return {"response": data, "sources": []}
|
| 53 |
return None
|
| 54 |
|
| 55 |
|
| 56 |
+
async def cache_response(redis, cache_key: str, response: str, sources: list):
|
| 57 |
+
await redis.setex(cache_key, 86400, json.dumps({"response": response, "sources": sources}))
|
| 58 |
|
| 59 |
|
| 60 |
async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
|
|
|
|
| 95 |
await db.commit()
|
| 96 |
|
| 97 |
|
| 98 |
+
@router.delete("/chat/cache")
|
| 99 |
+
async def clear_chat_cache(room_id: str, message: str):
|
| 100 |
+
"""Delete the Redis cache entry for a specific room + message pair."""
|
| 101 |
+
redis = await get_redis()
|
| 102 |
+
cache_key = f"{settings.redis_prefix}chat:{room_id}:{message}"
|
| 103 |
+
deleted = await redis.delete(cache_key)
|
| 104 |
+
return {"deleted": deleted > 0, "cache_key": cache_key}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@router.delete("/chat/cache/room/{room_id}")
|
| 108 |
+
async def clear_room_cache(room_id: str):
|
| 109 |
+
"""Delete all Redis cache entries for a room."""
|
| 110 |
+
redis = await get_redis()
|
| 111 |
+
pattern = f"{settings.redis_prefix}chat:{room_id}:*"
|
| 112 |
+
keys = await redis.keys(pattern)
|
| 113 |
+
if keys:
|
| 114 |
+
await redis.delete(*keys)
|
| 115 |
+
return {"deleted_count": len(keys), "room_id": room_id}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@router.delete("/retrieval/cache/{user_id}")
|
| 119 |
+
async def clear_retrieval_cache(user_id: str):
|
| 120 |
+
"""Delete all cached retrieval results for a user. Call this after uploading/processing new documents."""
|
| 121 |
+
from src.retrieval.router import retrieval_router
|
| 122 |
+
deleted = await retrieval_router.invalidate_cache(user_id)
|
| 123 |
+
return {"deleted_count": deleted, "user_id": user_id}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
@router.post("/chat/stream")
|
| 127 |
@log_execution(logger)
|
| 128 |
async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
|
|
| 139 |
|
| 140 |
# Redis cache hit
|
| 141 |
cached = await get_cached_response(redis, cache_key)
|
| 142 |
+
logger.info("cache check", cache_key=cache_key, cache_hit=cached is not None)
|
| 143 |
if cached:
|
| 144 |
logger.info("Returning cached response")
|
| 145 |
+
cached_text = cached["response"]
|
| 146 |
+
cached_sources = cached["sources"]
|
| 147 |
+
await save_messages(db, request.room_id, request.message, cached_text, sources=cached_sources)
|
| 148 |
|
| 149 |
async def stream_cached():
|
| 150 |
+
yield {"event": "sources", "data": json.dumps(cached_sources)}
|
| 151 |
+
for i in range(0, len(cached_text), 50):
|
| 152 |
+
yield {"event": "chunk", "data": cached_text[i:i + 50]}
|
| 153 |
yield {"event": "done", "data": ""}
|
| 154 |
|
| 155 |
return EventSourceResponse(stream_cached())
|
|
|
|
| 158 |
# Fast intent: greetings/farewells bypass LLM entirely
|
| 159 |
direct = _fast_intent(request.message)
|
| 160 |
if direct:
|
| 161 |
+
await cache_response(redis, cache_key, direct, sources=[])
|
| 162 |
await save_messages(db, request.room_id, request.message, direct, sources=[])
|
| 163 |
|
| 164 |
async def stream_direct():
|
|
|
|
| 172 |
handler = ChatHandler()
|
| 173 |
|
| 174 |
async def stream_response():
|
| 175 |
+
logger.info("stream_response started", room_id=request.room_id, user_id=request.user_id)
|
| 176 |
full_response = ""
|
| 177 |
sources: List[Dict[str, Any]] = []
|
| 178 |
async for event in handler.handle(request.message, request.user_id, history):
|
|
|
|
| 186 |
full_response += event["data"]
|
| 187 |
yield event
|
| 188 |
elif event["event"] == "done":
|
| 189 |
+
await cache_response(redis, cache_key, full_response, sources=sources)
|
| 190 |
+
logger.info("saving messages", sources_count=len(sources), sources=sources)
|
| 191 |
+
try:
|
| 192 |
+
await save_messages(db, request.room_id, request.message, full_response, sources=sources)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error("save_messages failed", room_id=request.room_id, error=str(e))
|
| 195 |
yield event
|
| 196 |
elif event["event"] == "error":
|
| 197 |
yield event
|
|
@@ -114,5 +114,8 @@ async def process_document(
|
|
| 114 |
except Exception as e:
|
| 115 |
logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e))
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
return {"status": "success", "message": "Document processed successfully", "data": data}
|
| 118 |
|
|
|
|
| 114 |
except Exception as e:
|
| 115 |
logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e))
|
| 116 |
|
| 117 |
+
from src.retrieval.router import retrieval_router
|
| 118 |
+
await retrieval_router.invalidate_cache(user_id)
|
| 119 |
+
|
| 120 |
return {"status": "success", "message": "Document processed successfully", "data": data}
|
| 121 |
|
|
@@ -7,16 +7,16 @@ Return three fields:
|
|
| 7 |
- **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself.
|
| 8 |
- **`source_hint`** — one of:
|
| 9 |
- `chat` — no data lookup needed (greetings, farewells, generic small talk).
|
| 10 |
-
- `unstructured` — the user is asking about
|
| 11 |
- `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources.
|
| 12 |
- **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null.
|
| 13 |
|
| 14 |
## Routing rules
|
| 15 |
|
| 16 |
-
1. If the message is a pure greeting / farewell / thanks / "how are you" / "what can you do" → `chat` + `needs_search=false`.
|
| 17 |
-
2. If the message
|
| 18 |
-
3. If the message asks about
|
| 19 |
-
4. If
|
| 20 |
5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate.
|
| 21 |
|
| 22 |
## Rewriting follow-ups
|
|
@@ -53,6 +53,22 @@ User: "Top 5 customers by revenue this year"
|
|
| 53 |
→ needs_search=true, source_hint="structured",
|
| 54 |
rewritten_query="Top 5 customers by revenue this year"
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
History: assistant: "Pro Plan Annual led at $487,200 in April."
|
| 57 |
User: "And in March?"
|
| 58 |
→ needs_search=true, source_hint="structured",
|
|
@@ -61,6 +77,6 @@ User: "And in March?"
|
|
| 61 |
|
| 62 |
## Constraints
|
| 63 |
|
| 64 |
-
- Do not invent data. If
|
| 65 |
- Do not refuse — refusal happens later in guardrails. Just classify.
|
| 66 |
- One JSON object as output; no prose, no markdown.
|
|
|
|
| 7 |
- **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself.
|
| 8 |
- **`source_hint`** — one of:
|
| 9 |
- `chat` — no data lookup needed (greetings, farewells, generic small talk).
|
| 10 |
+
- `unstructured` — the user is asking about a topic, concept, feature, or factual knowledge that may exist in uploaded documents (PDF / DOCX / TXT). The user does not need to explicitly mention a document.
|
| 11 |
- `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources.
|
| 12 |
- **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null.
|
| 13 |
|
| 14 |
## Routing rules
|
| 15 |
|
| 16 |
+
1. If the message is ONLY a pure greeting / farewell / thanks / "how are you" / "what can you do" / compliment with no factual question → `chat` + `needs_search=false`.
|
| 17 |
+
2. If the message asks a data question answerable from a database or tabular file (counts, sums, top-N, filters, comparisons, trends, sheet rows, table columns) → `structured` + `needs_search=true`.
|
| 18 |
+
3. If the message asks about a topic, concept, feature, explanation, summary, or factual knowledge — even without explicitly mentioning a document — route to `unstructured` + `needs_search=true`. The user may have uploaded relevant documents covering that topic.
|
| 19 |
+
4. If ambiguous between structured and unstructured → prefer `unstructured`. Only prefer `structured` if there are clear signals of tabular/numeric data questions.
|
| 20 |
5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate.
|
| 21 |
|
| 22 |
## Rewriting follow-ups
|
|
|
|
| 53 |
→ needs_search=true, source_hint="structured",
|
| 54 |
rewritten_query="Top 5 customers by revenue this year"
|
| 55 |
|
| 56 |
+
User: "apa key feature dari iot connectivity?"
|
| 57 |
+
→ needs_search=true, source_hint="unstructured",
|
| 58 |
+
rewritten_query="What are the key features of IoT connectivity?"
|
| 59 |
+
|
| 60 |
+
User: "jelaskan tentang machine learning"
|
| 61 |
+
→ needs_search=true, source_hint="unstructured",
|
| 62 |
+
rewritten_query="Explain machine learning"
|
| 63 |
+
|
| 64 |
+
User: "bagaimana cara kerja neural network?"
|
| 65 |
+
→ needs_search=true, source_hint="unstructured",
|
| 66 |
+
rewritten_query="How does a neural network work?"
|
| 67 |
+
|
| 68 |
+
User: "what is the main purpose of this system?"
|
| 69 |
+
→ needs_search=true, source_hint="unstructured",
|
| 70 |
+
rewritten_query="What is the main purpose of this system?"
|
| 71 |
+
|
| 72 |
History: assistant: "Pro Plan Annual led at $487,200 in April."
|
| 73 |
User: "And in March?"
|
| 74 |
→ needs_search=true, source_hint="structured",
|
|
|
|
| 77 |
|
| 78 |
## Constraints
|
| 79 |
|
| 80 |
+
- Do not invent data. If the question is factual or knowledge-based (not clearly tabular), route to `unstructured` and let the retriever decide. Only route to `structured` if the question clearly involves counts, sums, filters, or trends from tabular sources.
|
| 81 |
- Do not refuse — refusal happens later in guardrails. Just classify.
|
| 82 |
- One JSON object as output; no prose, no markdown.
|
|
@@ -19,7 +19,7 @@ embeddings = AzureOpenAIEmbeddings(
|
|
| 19 |
vector_store = PGVector(
|
| 20 |
embeddings=embeddings,
|
| 21 |
connection=_pgvector_engine,
|
| 22 |
-
collection_name="
|
| 23 |
use_jsonb=True,
|
| 24 |
async_mode=True,
|
| 25 |
create_extension=False, # Extension pre-created in init_db.py (avoids multi-statement asyncpg bug)
|
|
|
|
| 19 |
vector_store = PGVector(
|
| 20 |
embeddings=embeddings,
|
| 21 |
connection=_pgvector_engine,
|
| 22 |
+
collection_name="documents",
|
| 23 |
use_jsonb=True,
|
| 24 |
async_mode=True,
|
| 25 |
create_extension=False, # Extension pre-created in init_db.py (avoids multi-statement asyncpg bug)
|
|
@@ -59,6 +59,7 @@ class KnowledgeProcessingService:
|
|
| 59 |
"filename": db_doc.filename,
|
| 60 |
"file_type": db_doc.file_type,
|
| 61 |
"chunk_index": i,
|
|
|
|
| 62 |
},
|
| 63 |
}
|
| 64 |
)
|
|
|
|
| 59 |
"filename": db_doc.filename,
|
| 60 |
"file_type": db_doc.file_type,
|
| 61 |
"chunk_index": i,
|
| 62 |
+
"page_label": None,
|
| 63 |
},
|
| 64 |
}
|
| 65 |
)
|
|
@@ -9,7 +9,7 @@ import time
|
|
| 9 |
|
| 10 |
def configure_logging():
|
| 11 |
"""Configure structured logging."""
|
| 12 |
-
logging.basicConfig(level=logging.
|
| 13 |
logging.getLogger("tabular_executor").setLevel(logging.INFO)
|
| 14 |
structlog.configure(
|
| 15 |
processors=[
|
|
|
|
| 9 |
|
| 10 |
def configure_logging():
|
| 11 |
"""Configure structured logging."""
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
logging.getLogger("tabular_executor").setLevel(logging.INFO)
|
| 14 |
structlog.configure(
|
| 15 |
processors=[
|
|
@@ -1,68 +1,44 @@
|
|
| 1 |
-
"""DocumentRetriever — dense similarity over prose chunks
|
| 2 |
|
| 3 |
-
For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import functools
|
| 8 |
import math
|
| 9 |
|
| 10 |
-
from langchain_postgres import PGVector
|
| 11 |
-
from langchain_postgres.vectorstores import DistanceStrategy
|
| 12 |
from langchain_openai import AzureOpenAIEmbeddings
|
| 13 |
from sqlalchemy import text
|
| 14 |
|
| 15 |
from src.config.settings import settings
|
| 16 |
from src.db.postgres.connection import _pgvector_engine
|
| 17 |
-
from src.db.postgres.vector_store import get_vector_store
|
| 18 |
from src.middlewares.logging import get_logger
|
| 19 |
from src.retrieval.base import BaseRetriever, RetrievalResult
|
| 20 |
|
| 21 |
logger = get_logger("document_retriever")
|
| 22 |
|
| 23 |
# Change this one line to switch retrieval method
|
| 24 |
-
# Options: "
|
| 25 |
-
_RETRIEVAL_METHOD = "
|
| 26 |
|
| 27 |
_TABULAR_TYPES = {"csv", "xlsx"}
|
| 28 |
-
|
| 29 |
-
_LAMBDA_MULT = 0.5
|
| 30 |
-
_COLLECTION_NAME = "document_embeddings"
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
embeddings=_get_embeddings(),
|
| 46 |
-
connection=_pgvector_engine,
|
| 47 |
-
collection_name=_COLLECTION_NAME,
|
| 48 |
-
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
| 49 |
-
use_jsonb=True,
|
| 50 |
-
async_mode=True,
|
| 51 |
-
create_extension=False,
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
@functools.cache
|
| 56 |
-
def _get_ip_store() -> PGVector:
|
| 57 |
-
return PGVector(
|
| 58 |
-
embeddings=_get_embeddings(),
|
| 59 |
-
connection=_pgvector_engine,
|
| 60 |
-
collection_name=_COLLECTION_NAME,
|
| 61 |
-
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
| 62 |
-
use_jsonb=True,
|
| 63 |
-
async_mode=True,
|
| 64 |
-
create_extension=False,
|
| 65 |
-
)
|
| 66 |
|
| 67 |
_MANHATTAN_SQL = text("""
|
| 68 |
SELECT
|
|
@@ -79,71 +55,32 @@ _MANHATTAN_SQL = text("""
|
|
| 79 |
""")
|
| 80 |
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
|
|
|
| 86 |
async def retrieve(
|
| 87 |
self, query: str, user_id: str, k: int = 5
|
| 88 |
-
) -> list[RetrievalResult]:
|
| 89 |
-
filter_ = {"user_id": user_id, "source_type": "document"}
|
| 90 |
-
fetch_k = k + len(_TABULAR_TYPES)
|
| 91 |
-
|
| 92 |
-
if _RETRIEVAL_METHOD == "manhattan":
|
| 93 |
-
return await self._retrieve_manhattan(query, user_id, k, fetch_k)
|
| 94 |
-
|
| 95 |
-
if _RETRIEVAL_METHOD == "mmr":
|
| 96 |
-
docs = await self.vector_store.amax_marginal_relevance_search(
|
| 97 |
-
query=query,
|
| 98 |
-
k=fetch_k,
|
| 99 |
-
fetch_k=_FETCH_K,
|
| 100 |
-
lambda_mult=_LAMBDA_MULT,
|
| 101 |
-
filter=filter_,
|
| 102 |
-
)
|
| 103 |
-
cosine = await self.vector_store.asimilarity_search_with_score(
|
| 104 |
-
query=query, k=fetch_k, filter=filter_,
|
| 105 |
-
)
|
| 106 |
-
score_map = {doc.page_content: score for doc, score in cosine}
|
| 107 |
-
docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
|
| 108 |
-
elif _RETRIEVAL_METHOD == "euclidean":
|
| 109 |
-
docs_with_scores = await _get_euclidean_store().asimilarity_search_with_score(
|
| 110 |
-
query=query, k=fetch_k, filter=filter_,
|
| 111 |
-
)
|
| 112 |
-
elif _RETRIEVAL_METHOD == "inner_product":
|
| 113 |
-
docs_with_scores = await _get_ip_store().asimilarity_search_with_score(
|
| 114 |
-
query=query, k=fetch_k, filter=filter_,
|
| 115 |
-
)
|
| 116 |
-
else: # cosine
|
| 117 |
-
docs_with_scores = await self.vector_store.asimilarity_search_with_score(
|
| 118 |
-
query=query, k=fetch_k, filter=filter_,
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
results = []
|
| 122 |
-
for doc, score in docs_with_scores:
|
| 123 |
-
file_type = doc.metadata.get("data", {}).get("file_type", "")
|
| 124 |
-
if file_type not in _TABULAR_TYPES:
|
| 125 |
-
results.append(RetrievalResult(
|
| 126 |
-
content=doc.page_content,
|
| 127 |
-
metadata=doc.metadata,
|
| 128 |
-
score=score,
|
| 129 |
-
source_type="document",
|
| 130 |
-
))
|
| 131 |
-
if len(results) == k:
|
| 132 |
-
break
|
| 133 |
-
|
| 134 |
-
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
|
| 135 |
-
return results
|
| 136 |
-
|
| 137 |
-
async def _retrieve_manhattan(
|
| 138 |
-
self, query: str, user_id: str, k: int, fetch_k: int
|
| 139 |
) -> list[RetrievalResult]:
|
| 140 |
query_vector = await _get_embeddings().aembed_query(query)
|
| 141 |
if not all(math.isfinite(v) for v in query_vector):
|
| 142 |
raise ValueError("Embedding vector contains NaN or Infinity values.")
|
| 143 |
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
async with _pgvector_engine.connect() as conn:
|
| 146 |
-
result = await conn.execute(
|
| 147 |
"embedding": vector_str,
|
| 148 |
"collection": _COLLECTION_NAME,
|
| 149 |
"user_id": user_id,
|
|
@@ -151,6 +88,8 @@ class DocumentRetriever(BaseRetriever):
|
|
| 151 |
})
|
| 152 |
rows = result.fetchall()
|
| 153 |
|
|
|
|
|
|
|
| 154 |
results = []
|
| 155 |
for row in rows:
|
| 156 |
file_type = row.cmetadata.get("data", {}).get("file_type", "")
|
|
@@ -164,7 +103,7 @@ class DocumentRetriever(BaseRetriever):
|
|
| 164 |
if len(results) == k:
|
| 165 |
break
|
| 166 |
|
| 167 |
-
logger.info("retrieved chunks", method=
|
| 168 |
return results
|
| 169 |
|
| 170 |
|
|
|
|
| 1 |
+
"""DocumentRetriever — dense similarity over prose chunks.
|
| 2 |
|
| 3 |
+
For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via
|
| 4 |
+
raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs
|
| 5 |
+
String mismatch, jsonb_path_match asyncpg binding quirks).
|
| 6 |
+
Collection `document_embeddings`. Methods: cosine | manhattan.
|
| 7 |
"""
|
| 8 |
|
| 9 |
import functools
|
| 10 |
import math
|
| 11 |
|
|
|
|
|
|
|
| 12 |
from langchain_openai import AzureOpenAIEmbeddings
|
| 13 |
from sqlalchemy import text
|
| 14 |
|
| 15 |
from src.config.settings import settings
|
| 16 |
from src.db.postgres.connection import _pgvector_engine
|
|
|
|
| 17 |
from src.middlewares.logging import get_logger
|
| 18 |
from src.retrieval.base import BaseRetriever, RetrievalResult
|
| 19 |
|
| 20 |
logger = get_logger("document_retriever")
|
| 21 |
|
| 22 |
# Change this one line to switch retrieval method
|
| 23 |
+
# Options: "cosine" | "manhattan"
|
| 24 |
+
_RETRIEVAL_METHOD = "cosine"
|
| 25 |
|
| 26 |
_TABULAR_TYPES = {"csv", "xlsx"}
|
| 27 |
+
_COLLECTION_NAME = "documents"
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
_COSINE_SQL = text("""
|
| 30 |
+
SELECT
|
| 31 |
+
lpe.document,
|
| 32 |
+
lpe.cmetadata,
|
| 33 |
+
lpe.embedding <=> CAST(:embedding AS vector) AS distance
|
| 34 |
+
FROM langchain_pg_embedding lpe
|
| 35 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 36 |
+
WHERE lpc.name = :collection
|
| 37 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 38 |
+
AND lpe.cmetadata->>'source_type' = 'document'
|
| 39 |
+
ORDER BY distance ASC
|
| 40 |
+
LIMIT :k
|
| 41 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
_MANHATTAN_SQL = text("""
|
| 44 |
SELECT
|
|
|
|
| 55 |
""")
|
| 56 |
|
| 57 |
|
| 58 |
+
@functools.cache
|
| 59 |
+
def _get_embeddings() -> AzureOpenAIEmbeddings:
|
| 60 |
+
return AzureOpenAIEmbeddings(
|
| 61 |
+
azure_deployment=settings.azureai_deployment_name_embedding,
|
| 62 |
+
openai_api_version=settings.azureai_api_version_embedding,
|
| 63 |
+
azure_endpoint=settings.azureai_endpoint_url_embedding,
|
| 64 |
+
api_key=settings.azureai_api_key_embedding,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
|
| 68 |
+
class DocumentRetriever(BaseRetriever):
|
| 69 |
async def retrieve(
|
| 70 |
self, query: str, user_id: str, k: int = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
) -> list[RetrievalResult]:
|
| 72 |
query_vector = await _get_embeddings().aembed_query(query)
|
| 73 |
if not all(math.isfinite(v) for v in query_vector):
|
| 74 |
raise ValueError("Embedding vector contains NaN or Infinity values.")
|
| 75 |
vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
|
| 76 |
+
fetch_k = k + len(_TABULAR_TYPES)
|
| 77 |
+
|
| 78 |
+
sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL
|
| 79 |
+
|
| 80 |
+
logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k)
|
| 81 |
|
| 82 |
async with _pgvector_engine.connect() as conn:
|
| 83 |
+
result = await conn.execute(sql, {
|
| 84 |
"embedding": vector_str,
|
| 85 |
"collection": _COLLECTION_NAME,
|
| 86 |
"user_id": user_id,
|
|
|
|
| 88 |
})
|
| 89 |
rows = result.fetchall()
|
| 90 |
|
| 91 |
+
logger.info("raw rows from db", row_count=len(rows))
|
| 92 |
+
|
| 93 |
results = []
|
| 94 |
for row in rows:
|
| 95 |
file_type = row.cmetadata.get("data", {}).get("file_type", "")
|
|
|
|
| 103 |
if len(results) == k:
|
| 104 |
break
|
| 105 |
|
| 106 |
+
logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
|
| 107 |
return results
|
| 108 |
|
| 109 |
|