Rifqi Hafizuddin commited on
Commit
d1e1264
·
1 Parent(s): 2814813

[KM-438][KM-439] framework for knowledge retriever

Browse files
src/agents/orchestration.py CHANGED
@@ -35,6 +35,11 @@ Intent Routing:
35
  - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
36
  - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
37
  - other -> needs_search=True, search_query=<standalone rewritten query>
 
 
 
 
 
38
  """),
39
  MessagesPlaceholder(variable_name="history"),
40
  ("user", "{message}")
 
35
  - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
36
  - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
37
  - other -> needs_search=True, search_query=<standalone rewritten query>
38
+
39
+ Source Routing (set source_hint):
40
+ - Columns, tables, sheets, data types, schema, row counts, statistics -> source_hint=schema
41
+ - Document content, paragraphs, reports, articles, text -> source_hint=document
42
+ - Unclear or spans both -> source_hint=both
43
  """),
44
  MessagesPlaceholder(variable_name="history"),
45
  ("user", "{message}")
src/api/v1/chat.py CHANGED
@@ -192,6 +192,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
192
  query=search_query,
193
  user_id=request.user_id,
194
  db=db,
 
195
  )
196
  else:
197
  raw_results = await retrieval_task
 
192
  query=search_query,
193
  user_id=request.user_id,
194
  db=db,
195
+ source_hint=intent_result.get("source_hint", "both"),
196
  )
197
  else:
198
  raw_results = await retrieval_task
src/models/structured_output.py CHANGED
@@ -19,3 +19,7 @@ class IntentClassification(BaseModel):
19
  default="",
20
  description="Direct response if no search needed (for greetings, etc.)"
21
  )
 
 
 
 
 
19
  default="",
20
  description="Direct response if no search needed (for greetings, etc.)"
21
  )
22
+ source_hint: str = Field(
23
+ default="both",
24
+ description="Which sources to search: 'document' (PDF/DOCX/TXT), 'schema' (DB/CSV/XLSX), or 'both'"
25
+ )
src/rag/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared contract for all retriever implementations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+
8
+ @dataclass
9
+ class RetrievalResult:
10
+ content: str
11
+ metadata: dict[str, Any]
12
+ score: float
13
+ source_type: str # "document" | "database"
14
+
15
+
16
+ class BaseRetriever(ABC):
17
+ @abstractmethod
18
+ async def retrieve(
19
+ self, query: str, user_id: str, k: int = 5
20
+ ) -> list[RetrievalResult]: ...
src/rag/retriever.py CHANGED
@@ -1,69 +1,43 @@
1
- """Service for retrieving relevant documents from vector store."""
 
 
2
 
3
- import hashlib
4
- import json
5
- from src.db.postgres.vector_store import get_vector_store
6
- from src.db.redis.connection import get_redis
7
  from sqlalchemy.ext.asyncio import AsyncSession
 
8
  from src.middlewares.logging import get_logger
9
- from typing import List, Dict, Any
 
 
10
 
11
  logger = get_logger("retriever")
12
 
13
- _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
-
15
 
16
  class RetrieverService:
17
- """Service for retrieving relevant documents."""
 
 
 
 
18
 
19
  def __init__(self):
20
- self.vector_store = get_vector_store()
 
 
 
21
 
22
  async def retrieve(
23
  self,
24
  query: str,
25
  user_id: str,
26
  db: AsyncSession,
27
- k: int = 5
28
- ) -> List[Dict[str, Any]]:
29
- """Retrieve relevant chunks for a query, scoped to the user's documents.
30
-
31
- Returns:
32
- List of dicts with keys: content, metadata
33
- metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
34
- """
35
  try:
36
- redis = await get_redis()
37
- query_hash = hashlib.md5(query.encode()).hexdigest()
38
- cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
39
-
40
- cached = await redis.get(cache_key)
41
- if cached:
42
- logger.info("Returning cached retrieval results")
43
- return json.loads(cached)
44
-
45
- logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
46
-
47
- docs = await self.vector_store.asimilarity_search(
48
- query=query,
49
- k=k,
50
- filter={"user_id": user_id}
51
- )
52
-
53
- results = [
54
- {
55
- "content": doc.page_content,
56
- "metadata": doc.metadata,
57
- }
58
- for doc in docs
59
- ]
60
-
61
- logger.info(f"Retrieved {len(results)} chunks")
62
- await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
63
- return results
64
-
65
  except Exception as e:
66
- logger.error("Retrieval failed", error=str(e))
67
  return []
68
 
69
 
 
1
+ """Public retrieval API thin wrapper around RetrievalRouter."""
2
+
3
+ from typing import Any
4
 
 
 
 
 
5
  from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
  from src.middlewares.logging import get_logger
8
+ from src.rag.retrievers.document import document_retriever
9
+ from src.rag.retrievers.schema import schema_retriever
10
+ from src.rag.router import RetrievalRouter, SourceHint
11
 
12
  logger = get_logger("retriever")
13
 
 
 
14
 
15
  class RetrieverService:
16
+ """Public retrieval service used by chat.py and search tools.
17
+
18
+ Delegates to RetrievalRouter which dispatches based on source_hint.
19
+ Returns List[Dict] to preserve backward compatibility with chat.py.
20
+ """
21
 
22
  def __init__(self):
23
+ self._router = RetrievalRouter(
24
+ schema_retriever=schema_retriever,
25
+ document_retriever=document_retriever,
26
+ )
27
 
28
  async def retrieve(
29
  self,
30
  query: str,
31
  user_id: str,
32
  db: AsyncSession,
33
+ k: int = 5,
34
+ source_hint: SourceHint = "both",
35
+ ) -> list[dict[str, Any]]:
 
 
 
 
 
36
  try:
37
+ results = await self._router.retrieve(query, user_id, source_hint, k)
38
+ return [{"content": r.content, "metadata": r.metadata} for r in results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
+ logger.error("retrieval failed", error=str(e))
41
  return []
42
 
43
 
src/rag/retrievers/__init__.py ADDED
File without changes
src/rag/retrievers/baseline.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service for retrieving relevant documents from vector store."""
2
+
3
+ import hashlib
4
+ import json
5
+ from src.db.postgres.vector_store import get_vector_store
6
+ from src.db.redis.connection import get_redis
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+ from src.middlewares.logging import get_logger
9
+ from typing import List, Dict, Any
10
+
11
+ logger = get_logger("retriever")
12
+
13
+ _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
+
15
+
16
+ class RetrieverService:
17
+ """Service for retrieving relevant documents."""
18
+
19
+ def __init__(self):
20
+ self.vector_store = get_vector_store()
21
+
22
+ async def retrieve(
23
+ self,
24
+ query: str,
25
+ user_id: str,
26
+ db: AsyncSession,
27
+ k: int = 5
28
+ ) -> List[Dict[str, Any]]:
29
+ """Retrieve relevant chunks for a query, scoped to the user's documents.
30
+
31
+ Returns:
32
+ List of dicts with keys: content, metadata
33
+ metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
34
+ """
35
+ try:
36
+ redis = await get_redis()
37
+ query_hash = hashlib.md5(query.encode()).hexdigest()
38
+ cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
39
+
40
+ cached = await redis.get(cache_key)
41
+ if cached:
42
+ logger.info("Returning cached retrieval results")
43
+ return json.loads(cached)
44
+
45
+ logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
46
+
47
+ docs = await self.vector_store.asimilarity_search(
48
+ query=query,
49
+ k=k,
50
+ filter={"user_id": user_id}
51
+ )
52
+
53
+ results = [
54
+ {
55
+ "content": doc.page_content,
56
+ "metadata": doc.metadata,
57
+ }
58
+ for doc in docs
59
+ ]
60
+
61
+ logger.info(f"Retrieved {len(results)} chunks")
62
+ await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
63
+ return results
64
+
65
+ except Exception as e:
66
+ logger.error("Retrieval failed", error=str(e))
67
+ return []
68
+
69
+
70
+ retriever = RetrieverService()
src/rag/retrievers/document.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).
2
+
3
+ TEAMMATE: implement retrieve() below.
4
+ Strategy: MMR (amax_marginal_relevance_search) + score threshold to avoid returning
5
+ near-identical chunks from the same PDF page.
6
+ Filter: source_type="document" AND data->>'file_type' NOT IN ('csv', 'xlsx')
7
+ """
8
+
9
+ from src.db.postgres.vector_store import get_vector_store
10
+ from src.middlewares.logging import get_logger
11
+ from src.rag.base import BaseRetriever, RetrievalResult
12
+
13
+ logger = get_logger("document_retriever")
14
+
15
+ _SCORE_THRESHOLD = 0.45 # discard chunks with cosine distance above this
16
+
17
+
18
+ class DocumentRetriever(BaseRetriever):
19
+ def __init__(self):
20
+ self.vector_store = get_vector_store()
21
+
22
+ async def retrieve(
23
+ self, query: str, user_id: str, k: int = 5
24
+ ) -> list[RetrievalResult]:
25
+ # TODO (teammate): implement MMR retrieval for prose documents
26
+ # Filter: {"user_id": user_id, "source_type": "document"}
27
+ # then post-filter to exclude file_type in ("csv", "xlsx")
28
+ logger.info("document retriever not yet implemented — returning empty")
29
+ return []
30
+
31
+
32
+ document_retriever = DocumentRetriever()
src/rag/retrievers/schema.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
+ columns stored as source_type="document" with file_type in ("csv","xlsx").
3
+
4
+ Strategy: similarity search with score threshold on two metadata shapes,
5
+ run in parallel, merged and re-ranked by score.
6
+ """
7
+
8
+ import asyncio
9
+
10
+ from src.db.postgres.vector_store import get_vector_store
11
+ from src.middlewares.logging import get_logger
12
+ from src.rag.base import BaseRetriever, RetrievalResult
13
+
14
+ logger = get_logger("schema_retriever")
15
+
16
+ _SCORE_THRESHOLD = 0.45 # cosine distance — discard above this value
17
+ _TABULAR_FILE_TYPES = ("csv", "xlsx")
18
+
19
+
20
+ class SchemaRetriever(BaseRetriever):
21
+ def __init__(self):
22
+ self.vector_store = get_vector_store()
23
+
24
+ async def _search_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
25
+ """Retrieve DB schema chunks (source_type="database")."""
26
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
27
+ query=query,
28
+ k=k,
29
+ filter={"user_id": user_id, "source_type": "database"},
30
+ )
31
+ results = []
32
+ for doc, distance in docs_with_scores:
33
+ if distance <= _SCORE_THRESHOLD:
34
+ results.append(
35
+ RetrievalResult(
36
+ content=doc.page_content,
37
+ metadata=doc.metadata,
38
+ score=1.0 - distance,
39
+ source_type="database",
40
+ )
41
+ )
42
+ return results
43
+
44
+ async def _search_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
45
+ """Retrieve CSV/XLSX column chunks (source_type="document", file_type=csv|xlsx)."""
46
+ results = []
47
+ for file_type in _TABULAR_FILE_TYPES:
48
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
49
+ query=query,
50
+ k=k,
51
+ filter={
52
+ "user_id": user_id,
53
+ "source_type": "document",
54
+ "data": {"file_type": file_type},
55
+ },
56
+ )
57
+ for doc, distance in docs_with_scores:
58
+ if distance <= _SCORE_THRESHOLD:
59
+ results.append(
60
+ RetrievalResult(
61
+ content=doc.page_content,
62
+ metadata=doc.metadata,
63
+ score=1.0 - distance,
64
+ source_type="document",
65
+ )
66
+ )
67
+ return results
68
+
69
+ async def retrieve(
70
+ self, query: str, user_id: str, k: int = 5
71
+ ) -> list[RetrievalResult]:
72
+ db_results, tabular_results = await asyncio.gather(
73
+ self._search_db(query, user_id, k),
74
+ self._search_tabular(query, user_id, k),
75
+ )
76
+ combined = db_results + tabular_results
77
+ combined.sort(key=lambda r: r.score, reverse=True)
78
+ logger.info(
79
+ "schema retrieval",
80
+ db_chunks=len(db_results),
81
+ tabular_chunks=len(tabular_results),
82
+ )
83
+ return combined[:k]
84
+
85
+
86
+ schema_retriever = SchemaRetriever()
src/rag/router.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Routes retrieval requests to the appropriate retriever based on source_hint."""
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import json
6
+ from typing import Literal
7
+
8
+ from src.db.redis.connection import get_redis
9
+ from src.middlewares.logging import get_logger
10
+ from src.rag.base import BaseRetriever, RetrievalResult
11
+
12
+ logger = get_logger("retrieval_router")
13
+
14
+ _CACHE_TTL = 3600 # 1 hour
15
+ SourceHint = Literal["document", "schema", "both"]
16
+
17
+
18
+ class RetrievalRouter:
19
+ def __init__(
20
+ self,
21
+ schema_retriever: BaseRetriever,
22
+ document_retriever: BaseRetriever,
23
+ ):
24
+ self._retrievers: dict[str, BaseRetriever] = {
25
+ "schema": schema_retriever,
26
+ "document": document_retriever,
27
+ }
28
+
29
+ def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
30
+ if source_hint == "schema":
31
+ return [self._retrievers["schema"]]
32
+ if source_hint == "document":
33
+ return [self._retrievers["document"]]
34
+ return list(self._retrievers.values())
35
+
36
+ async def retrieve(
37
+ self,
38
+ query: str,
39
+ user_id: str,
40
+ source_hint: SourceHint = "both",
41
+ k: int = 5,
42
+ ) -> list[RetrievalResult]:
43
+ redis = await get_redis()
44
+ query_hash = hashlib.md5(query.encode()).hexdigest()
45
+ cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"
46
+
47
+ cached = await redis.get(cache_key)
48
+ if cached:
49
+ logger.info("returning cached retrieval results", source_hint=source_hint)
50
+ raw = json.loads(cached)
51
+ return [RetrievalResult(**r) for r in raw]
52
+
53
+ retrievers = self._route(source_hint)
54
+ batches = await asyncio.gather(
55
+ *[r.retrieve(query, user_id, k) for r in retrievers],
56
+ return_exceptions=True,
57
+ )
58
+
59
+ results: list[RetrievalResult] = []
60
+ for batch in batches:
61
+ if isinstance(batch, Exception):
62
+ logger.error("retriever failed", error=str(batch))
63
+ continue
64
+ results.extend(batch)
65
+
66
+ results.sort(key=lambda r: r.score, reverse=True)
67
+ results = results[:k]
68
+
69
+ logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
70
+ await redis.setex(
71
+ cache_key,
72
+ _CACHE_TTL,
73
+ json.dumps([vars(r) for r in results]),
74
+ )
75
+ return results