Rifqi Hafizuddin commited on
Commit ·
d1e1264
1
Parent(s): 2814813
[KM-438][KM-439] framework for knowledge retriever
Browse files- src/agents/orchestration.py +5 -0
- src/api/v1/chat.py +1 -0
- src/models/structured_output.py +4 -0
- src/rag/base.py +20 -0
- src/rag/retriever.py +22 -48
- src/rag/retrievers/__init__.py +0 -0
- src/rag/retrievers/baseline.py +70 -0
- src/rag/retrievers/document.py +32 -0
- src/rag/retrievers/schema.py +86 -0
- src/rag/router.py +75 -0
src/agents/orchestration.py
CHANGED
|
@@ -35,6 +35,11 @@ Intent Routing:
|
|
| 35 |
- greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
|
| 36 |
- goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
|
| 37 |
- other -> needs_search=True, search_query=<standalone rewritten query>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"""),
|
| 39 |
MessagesPlaceholder(variable_name="history"),
|
| 40 |
("user", "{message}")
|
|
|
|
| 35 |
- greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
|
| 36 |
- goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
|
| 37 |
- other -> needs_search=True, search_query=<standalone rewritten query>
|
| 38 |
+
|
| 39 |
+
Source Routing (set source_hint):
|
| 40 |
+
- Columns, tables, sheets, data types, schema, row counts, statistics -> source_hint=schema
|
| 41 |
+
- Document content, paragraphs, reports, articles, text -> source_hint=document
|
| 42 |
+
- Unclear or spans both -> source_hint=both
|
| 43 |
"""),
|
| 44 |
MessagesPlaceholder(variable_name="history"),
|
| 45 |
("user", "{message}")
|
src/api/v1/chat.py
CHANGED
|
@@ -192,6 +192,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 192 |
query=search_query,
|
| 193 |
user_id=request.user_id,
|
| 194 |
db=db,
|
|
|
|
| 195 |
)
|
| 196 |
else:
|
| 197 |
raw_results = await retrieval_task
|
|
|
|
| 192 |
query=search_query,
|
| 193 |
user_id=request.user_id,
|
| 194 |
db=db,
|
| 195 |
+
source_hint=intent_result.get("source_hint", "both"),
|
| 196 |
)
|
| 197 |
else:
|
| 198 |
raw_results = await retrieval_task
|
src/models/structured_output.py
CHANGED
|
@@ -19,3 +19,7 @@ class IntentClassification(BaseModel):
|
|
| 19 |
default="",
|
| 20 |
description="Direct response if no search needed (for greetings, etc.)"
|
| 21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
default="",
|
| 20 |
description="Direct response if no search needed (for greetings, etc.)"
|
| 21 |
)
|
| 22 |
+
source_hint: str = Field(
|
| 23 |
+
default="both",
|
| 24 |
+
description="Which sources to search: 'document' (PDF/DOCX/TXT), 'schema' (DB/CSV/XLSX), or 'both'"
|
| 25 |
+
)
|
src/rag/base.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared contract for all retriever implementations."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class RetrievalResult:
|
| 10 |
+
content: str
|
| 11 |
+
metadata: dict[str, Any]
|
| 12 |
+
score: float
|
| 13 |
+
source_type: str # "document" | "database"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseRetriever(ABC):
|
| 17 |
+
@abstractmethod
|
| 18 |
+
async def retrieve(
|
| 19 |
+
self, query: str, user_id: str, k: int = 5
|
| 20 |
+
) -> list[RetrievalResult]: ...
|
src/rag/retriever.py
CHANGED
|
@@ -1,69 +1,43 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
import hashlib
|
| 4 |
-
import json
|
| 5 |
-
from src.db.postgres.vector_store import get_vector_store
|
| 6 |
-
from src.db.redis.connection import get_redis
|
| 7 |
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
| 8 |
from src.middlewares.logging import get_logger
|
| 9 |
-
from
|
|
|
|
|
|
|
| 10 |
|
| 11 |
logger = get_logger("retriever")
|
| 12 |
|
| 13 |
-
_RETRIEVAL_CACHE_TTL = 3600 # 1 hour
|
| 14 |
-
|
| 15 |
|
| 16 |
class RetrieverService:
|
| 17 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def __init__(self):
|
| 20 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
async def retrieve(
|
| 23 |
self,
|
| 24 |
query: str,
|
| 25 |
user_id: str,
|
| 26 |
db: AsyncSession,
|
| 27 |
-
k: int = 5
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
List of dicts with keys: content, metadata
|
| 33 |
-
metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
|
| 34 |
-
"""
|
| 35 |
try:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
|
| 39 |
-
|
| 40 |
-
cached = await redis.get(cache_key)
|
| 41 |
-
if cached:
|
| 42 |
-
logger.info("Returning cached retrieval results")
|
| 43 |
-
return json.loads(cached)
|
| 44 |
-
|
| 45 |
-
logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
|
| 46 |
-
|
| 47 |
-
docs = await self.vector_store.asimilarity_search(
|
| 48 |
-
query=query,
|
| 49 |
-
k=k,
|
| 50 |
-
filter={"user_id": user_id}
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
results = [
|
| 54 |
-
{
|
| 55 |
-
"content": doc.page_content,
|
| 56 |
-
"metadata": doc.metadata,
|
| 57 |
-
}
|
| 58 |
-
for doc in docs
|
| 59 |
-
]
|
| 60 |
-
|
| 61 |
-
logger.info(f"Retrieved {len(results)} chunks")
|
| 62 |
-
await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
|
| 63 |
-
return results
|
| 64 |
-
|
| 65 |
except Exception as e:
|
| 66 |
-
logger.error("
|
| 67 |
return []
|
| 68 |
|
| 69 |
|
|
|
|
| 1 |
+
"""Public retrieval API — thin wrapper around RetrievalRouter."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
|
| 7 |
from src.middlewares.logging import get_logger
|
| 8 |
+
from src.rag.retrievers.document import document_retriever
|
| 9 |
+
from src.rag.retrievers.schema import schema_retriever
|
| 10 |
+
from src.rag.router import RetrievalRouter, SourceHint
|
| 11 |
|
| 12 |
logger = get_logger("retriever")
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class RetrieverService:
|
| 16 |
+
"""Public retrieval service used by chat.py and search tools.
|
| 17 |
+
|
| 18 |
+
Delegates to RetrievalRouter which dispatches based on source_hint.
|
| 19 |
+
Returns List[Dict] to preserve backward compatibility with chat.py.
|
| 20 |
+
"""
|
| 21 |
|
| 22 |
def __init__(self):
|
| 23 |
+
self._router = RetrievalRouter(
|
| 24 |
+
schema_retriever=schema_retriever,
|
| 25 |
+
document_retriever=document_retriever,
|
| 26 |
+
)
|
| 27 |
|
| 28 |
async def retrieve(
|
| 29 |
self,
|
| 30 |
query: str,
|
| 31 |
user_id: str,
|
| 32 |
db: AsyncSession,
|
| 33 |
+
k: int = 5,
|
| 34 |
+
source_hint: SourceHint = "both",
|
| 35 |
+
) -> list[dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
+
results = await self._router.retrieve(query, user_id, source_hint, k)
|
| 38 |
+
return [{"content": r.content, "metadata": r.metadata} for r in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
except Exception as e:
|
| 40 |
+
logger.error("retrieval failed", error=str(e))
|
| 41 |
return []
|
| 42 |
|
| 43 |
|
src/rag/retrievers/__init__.py
ADDED
|
File without changes
|
src/rag/retrievers/baseline.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Service for retrieving relevant documents from vector store."""
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
from src.db.postgres.vector_store import get_vector_store
|
| 6 |
+
from src.db.redis.connection import get_redis
|
| 7 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 8 |
+
from src.middlewares.logging import get_logger
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
logger = get_logger("retriever")
|
| 12 |
+
|
| 13 |
+
_RETRIEVAL_CACHE_TTL = 3600 # 1 hour
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RetrieverService:
|
| 17 |
+
"""Service for retrieving relevant documents."""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.vector_store = get_vector_store()
|
| 21 |
+
|
| 22 |
+
async def retrieve(
|
| 23 |
+
self,
|
| 24 |
+
query: str,
|
| 25 |
+
user_id: str,
|
| 26 |
+
db: AsyncSession,
|
| 27 |
+
k: int = 5
|
| 28 |
+
) -> List[Dict[str, Any]]:
|
| 29 |
+
"""Retrieve relevant chunks for a query, scoped to the user's documents.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List of dicts with keys: content, metadata
|
| 33 |
+
metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
redis = await get_redis()
|
| 37 |
+
query_hash = hashlib.md5(query.encode()).hexdigest()
|
| 38 |
+
cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
|
| 39 |
+
|
| 40 |
+
cached = await redis.get(cache_key)
|
| 41 |
+
if cached:
|
| 42 |
+
logger.info("Returning cached retrieval results")
|
| 43 |
+
return json.loads(cached)
|
| 44 |
+
|
| 45 |
+
logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
|
| 46 |
+
|
| 47 |
+
docs = await self.vector_store.asimilarity_search(
|
| 48 |
+
query=query,
|
| 49 |
+
k=k,
|
| 50 |
+
filter={"user_id": user_id}
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
results = [
|
| 54 |
+
{
|
| 55 |
+
"content": doc.page_content,
|
| 56 |
+
"metadata": doc.metadata,
|
| 57 |
+
}
|
| 58 |
+
for doc in docs
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
logger.info(f"Retrieved {len(results)} chunks")
|
| 62 |
+
await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
|
| 63 |
+
return results
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error("Retrieval failed", error=str(e))
|
| 67 |
+
return []
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
retriever = RetrieverService()
|
src/rag/retrievers/document.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).
|
| 2 |
+
|
| 3 |
+
TEAMMATE: implement retrieve() below.
|
| 4 |
+
Strategy: MMR (amax_marginal_relevance_search) + score threshold to avoid returning
|
| 5 |
+
near-identical chunks from the same PDF page.
|
| 6 |
+
Filter: source_type="document" AND data->>'file_type' NOT IN ('csv', 'xlsx')
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from src.db.postgres.vector_store import get_vector_store
|
| 10 |
+
from src.middlewares.logging import get_logger
|
| 11 |
+
from src.rag.base import BaseRetriever, RetrievalResult
|
| 12 |
+
|
| 13 |
+
logger = get_logger("document_retriever")
|
| 14 |
+
|
| 15 |
+
_SCORE_THRESHOLD = 0.45 # discard chunks with cosine distance above this
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DocumentRetriever(BaseRetriever):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.vector_store = get_vector_store()
|
| 21 |
+
|
| 22 |
+
async def retrieve(
|
| 23 |
+
self, query: str, user_id: str, k: int = 5
|
| 24 |
+
) -> list[RetrievalResult]:
|
| 25 |
+
# TODO (teammate): implement MMR retrieval for prose documents
|
| 26 |
+
# Filter: {"user_id": user_id, "source_type": "document"}
|
| 27 |
+
# then post-filter to exclude file_type in ("csv", "xlsx")
|
| 28 |
+
logger.info("document retriever not yet implemented — returning empty")
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
document_retriever = DocumentRetriever()
|
src/rag/retrievers/schema.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Schema retriever — handles DB schemas (source_type="database") and tabular file
|
| 2 |
+
columns stored as source_type="document" with file_type in ("csv","xlsx").
|
| 3 |
+
|
| 4 |
+
Strategy: similarity search with score threshold on two metadata shapes,
|
| 5 |
+
run in parallel, merged and re-ranked by score.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
|
| 10 |
+
from src.db.postgres.vector_store import get_vector_store
|
| 11 |
+
from src.middlewares.logging import get_logger
|
| 12 |
+
from src.rag.base import BaseRetriever, RetrievalResult
|
| 13 |
+
|
| 14 |
+
logger = get_logger("schema_retriever")
|
| 15 |
+
|
| 16 |
+
_SCORE_THRESHOLD = 0.45 # cosine distance — discard above this value
|
| 17 |
+
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SchemaRetriever(BaseRetriever):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.vector_store = get_vector_store()
|
| 23 |
+
|
| 24 |
+
async def _search_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
|
| 25 |
+
"""Retrieve DB schema chunks (source_type="database")."""
|
| 26 |
+
docs_with_scores = await self.vector_store.asimilarity_search_with_score(
|
| 27 |
+
query=query,
|
| 28 |
+
k=k,
|
| 29 |
+
filter={"user_id": user_id, "source_type": "database"},
|
| 30 |
+
)
|
| 31 |
+
results = []
|
| 32 |
+
for doc, distance in docs_with_scores:
|
| 33 |
+
if distance <= _SCORE_THRESHOLD:
|
| 34 |
+
results.append(
|
| 35 |
+
RetrievalResult(
|
| 36 |
+
content=doc.page_content,
|
| 37 |
+
metadata=doc.metadata,
|
| 38 |
+
score=1.0 - distance,
|
| 39 |
+
source_type="database",
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
return results
|
| 43 |
+
|
| 44 |
+
async def _search_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
|
| 45 |
+
"""Retrieve CSV/XLSX column chunks (source_type="document", file_type=csv|xlsx)."""
|
| 46 |
+
results = []
|
| 47 |
+
for file_type in _TABULAR_FILE_TYPES:
|
| 48 |
+
docs_with_scores = await self.vector_store.asimilarity_search_with_score(
|
| 49 |
+
query=query,
|
| 50 |
+
k=k,
|
| 51 |
+
filter={
|
| 52 |
+
"user_id": user_id,
|
| 53 |
+
"source_type": "document",
|
| 54 |
+
"data": {"file_type": file_type},
|
| 55 |
+
},
|
| 56 |
+
)
|
| 57 |
+
for doc, distance in docs_with_scores:
|
| 58 |
+
if distance <= _SCORE_THRESHOLD:
|
| 59 |
+
results.append(
|
| 60 |
+
RetrievalResult(
|
| 61 |
+
content=doc.page_content,
|
| 62 |
+
metadata=doc.metadata,
|
| 63 |
+
score=1.0 - distance,
|
| 64 |
+
source_type="document",
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
return results
|
| 68 |
+
|
| 69 |
+
async def retrieve(
|
| 70 |
+
self, query: str, user_id: str, k: int = 5
|
| 71 |
+
) -> list[RetrievalResult]:
|
| 72 |
+
db_results, tabular_results = await asyncio.gather(
|
| 73 |
+
self._search_db(query, user_id, k),
|
| 74 |
+
self._search_tabular(query, user_id, k),
|
| 75 |
+
)
|
| 76 |
+
combined = db_results + tabular_results
|
| 77 |
+
combined.sort(key=lambda r: r.score, reverse=True)
|
| 78 |
+
logger.info(
|
| 79 |
+
"schema retrieval",
|
| 80 |
+
db_chunks=len(db_results),
|
| 81 |
+
tabular_chunks=len(tabular_results),
|
| 82 |
+
)
|
| 83 |
+
return combined[:k]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
schema_retriever = SchemaRetriever()
|
src/rag/router.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Routes retrieval requests to the appropriate retriever based on source_hint."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import hashlib
|
| 5 |
+
import json
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
+
from src.db.redis.connection import get_redis
|
| 9 |
+
from src.middlewares.logging import get_logger
|
| 10 |
+
from src.rag.base import BaseRetriever, RetrievalResult
|
| 11 |
+
|
| 12 |
+
logger = get_logger("retrieval_router")
|
| 13 |
+
|
| 14 |
+
_CACHE_TTL = 3600 # 1 hour
|
| 15 |
+
SourceHint = Literal["document", "schema", "both"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RetrievalRouter:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
schema_retriever: BaseRetriever,
|
| 22 |
+
document_retriever: BaseRetriever,
|
| 23 |
+
):
|
| 24 |
+
self._retrievers: dict[str, BaseRetriever] = {
|
| 25 |
+
"schema": schema_retriever,
|
| 26 |
+
"document": document_retriever,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
|
| 30 |
+
if source_hint == "schema":
|
| 31 |
+
return [self._retrievers["schema"]]
|
| 32 |
+
if source_hint == "document":
|
| 33 |
+
return [self._retrievers["document"]]
|
| 34 |
+
return list(self._retrievers.values())
|
| 35 |
+
|
| 36 |
+
async def retrieve(
|
| 37 |
+
self,
|
| 38 |
+
query: str,
|
| 39 |
+
user_id: str,
|
| 40 |
+
source_hint: SourceHint = "both",
|
| 41 |
+
k: int = 5,
|
| 42 |
+
) -> list[RetrievalResult]:
|
| 43 |
+
redis = await get_redis()
|
| 44 |
+
query_hash = hashlib.md5(query.encode()).hexdigest()
|
| 45 |
+
cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"
|
| 46 |
+
|
| 47 |
+
cached = await redis.get(cache_key)
|
| 48 |
+
if cached:
|
| 49 |
+
logger.info("returning cached retrieval results", source_hint=source_hint)
|
| 50 |
+
raw = json.loads(cached)
|
| 51 |
+
return [RetrievalResult(**r) for r in raw]
|
| 52 |
+
|
| 53 |
+
retrievers = self._route(source_hint)
|
| 54 |
+
batches = await asyncio.gather(
|
| 55 |
+
*[r.retrieve(query, user_id, k) for r in retrievers],
|
| 56 |
+
return_exceptions=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
results: list[RetrievalResult] = []
|
| 60 |
+
for batch in batches:
|
| 61 |
+
if isinstance(batch, Exception):
|
| 62 |
+
logger.error("retriever failed", error=str(batch))
|
| 63 |
+
continue
|
| 64 |
+
results.extend(batch)
|
| 65 |
+
|
| 66 |
+
results.sort(key=lambda r: r.score, reverse=True)
|
| 67 |
+
results = results[:k]
|
| 68 |
+
|
| 69 |
+
logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
|
| 70 |
+
await redis.setex(
|
| 71 |
+
cache_key,
|
| 72 |
+
_CACHE_TTL,
|
| 73 |
+
json.dumps([vars(r) for r in results]),
|
| 74 |
+
)
|
| 75 |
+
return results
|