sofhiaazzhr Claude Sonnet 4.6 commited on
Commit
df220ea
·
1 Parent(s): f31f673

[KM-553] migrate retrieval layer and remove obsolete rag/tools modules

Browse files

- Replace src/rag/ with src/retrieval/: implement DocumentRetriever (MMR/cosine/euclidean/manhattan), simplified RetrievalRouter (unstructured-only, no schema leg, Redis cache preserved), and shared RetrievalResult/BaseRetriever base
- Remove src/tools/ (orphaned LangChain @tool wrapper, never called by production code)
- Update RetrievalResult imports in chat.py, query/base.py, query/executors/db_executor.py, query/executors/tabular.py, query/query_executor.py from src.rag.base to src.retrieval.base
- Wire chat.py to new retrieval_router (aliased as retriever, no call-site changes)
- Delete dead stubs: src/query/service.py, src/models/user_info.py, src/pipeline/document_pipeline.py (flat, shadowed by subfolder)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/api/v1/chat.py CHANGED
@@ -8,8 +8,8 @@ from src.db.postgres.connection import get_db
8
  from src.db.postgres.models import ChatMessage, MessageSource
9
  from src.agents.orchestration import orchestrator
10
  from src.agents.chatbot import chatbot
11
- from src.rag.retriever import retriever
12
- from src.rag.base import RetrievalResult
13
  from src.query.query_executor import query_executor
14
  from src.query.base import QueryResult
15
  from src.db.redis.connection import get_redis
 
8
  from src.db.postgres.models import ChatMessage, MessageSource
9
  from src.agents.orchestration import orchestrator
10
  from src.agents.chatbot import chatbot
11
+ from src.retrieval.router import retrieval_router as retriever
12
+ from src.retrieval.base import RetrievalResult
13
  from src.query.query_executor import query_executor
14
  from src.query.base import QueryResult
15
  from src.db.redis.connection import get_redis
src/models/user_info.py DELETED
@@ -1,15 +0,0 @@
1
- """User info models for existing users.py."""
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class UserCreate(BaseModel):
7
- """User creation model."""
8
- fullname: str
9
- email: str
10
- password: str
11
- company: str | None = None
12
- company_size: str | None = None
13
- function: str | None = None
14
- site: str | None = None
15
- role: str | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pipeline/document_pipeline.py DELETED
@@ -1,11 +0,0 @@
1
- """DocumentPipeline — extract text, chunk, embed, ingest to PGVector.
2
-
3
- For unstructured sources (PDF / DOCX / TXT). Receives the working
4
- implementation from the previous pipeline/document_pipeline/document_pipeline.py
5
- during the cleanup phase.
6
- """
7
-
8
-
9
- class DocumentPipeline:
10
- async def run(self, document_id: str, user_id: str) -> None:
11
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
src/query/base.py CHANGED
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
5
 
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
 
8
- from src.rag.base import RetrievalResult
9
 
10
 
11
  @dataclass
 
5
 
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
 
8
+ from src.retrieval.base import RetrievalResult
9
 
10
 
11
  @dataclass
src/query/executors/db_executor.py CHANGED
@@ -31,7 +31,7 @@ from src.middlewares.logging import get_logger
31
  from src.models.sql_query import SQLQuery
32
  from src.pipeline.db_pipeline import db_pipeline_service
33
  from src.query.base import BaseExecutor, QueryResult
34
- from src.rag.base import RetrievalResult
35
  from src.utils.db_credential_encryption import decrypt_credentials_dict
36
 
37
  logger = get_logger("db_executor")
 
31
  from src.models.sql_query import SQLQuery
32
  from src.pipeline.db_pipeline import db_pipeline_service
33
  from src.query.base import BaseExecutor, QueryResult
34
+ from src.retrieval.base import RetrievalResult
35
  from src.utils.db_credential_encryption import decrypt_credentials_dict
36
 
37
  logger = get_logger("db_executor")
src/query/executors/tabular.py CHANGED
@@ -22,7 +22,7 @@ from src.config.settings import settings
22
  from src.knowledge.parquet_service import download_parquet
23
  from src.middlewares.logging import get_logger
24
  from src.query.base import BaseExecutor, QueryResult
25
- from src.rag.base import RetrievalResult
26
 
27
  logger = get_logger("tabular_executor")
28
 
 
22
  from src.knowledge.parquet_service import download_parquet
23
  from src.middlewares.logging import get_logger
24
  from src.query.base import BaseExecutor, QueryResult
25
+ from src.retrieval.base import RetrievalResult
26
 
27
  logger = get_logger("tabular_executor")
28
 
src/query/query_executor.py CHANGED
@@ -8,7 +8,7 @@ from src.middlewares.logging import get_logger
8
  from src.query.base import QueryResult
9
  from src.query.executors.db_executor import db_executor
10
  from src.query.executors.tabular import tabular_executor
11
- from src.rag.base import RetrievalResult
12
 
13
  logger = get_logger("query_executor")
14
 
 
8
  from src.query.base import QueryResult
9
  from src.query.executors.db_executor import db_executor
10
  from src.query.executors.tabular import tabular_executor
11
+ from src.retrieval.base import RetrievalResult
12
 
13
  logger = get_logger("query_executor")
14
 
src/query/service.py DELETED
@@ -1,15 +0,0 @@
1
- """QueryService — orchestrates plan → validate → compile → execute.
2
-
3
- Top-level entry point for catalog-driven structured queries. Wired into
4
- the chat endpoint when source_hint == "structured".
5
- """
6
-
7
- from ..catalog.models import Catalog
8
- from .executor.base import QueryResult
9
-
10
-
11
- class QueryService:
12
- """End-to-end runner for a user question against a catalog."""
13
-
14
- async def run(self, user_id: str, question: str, catalog: Catalog) -> QueryResult:
15
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/__init__.py DELETED
File without changes
src/rag/retriever.py DELETED
@@ -1,46 +0,0 @@
1
- """Public retrieval API — thin wrapper around RetrievalRouter."""
2
-
3
- from sqlalchemy.ext.asyncio import AsyncSession
4
-
5
- from src.middlewares.logging import get_logger
6
- from src.rag.base import RetrievalResult
7
- from src.rag.retrievers.document import document_retriever
8
- from src.rag.retrievers.schema import schema_retriever
9
- from src.rag.router import RetrievalRouter, SourceHint
10
-
11
- logger = get_logger("retriever")
12
-
13
-
14
- class RetrieverService:
15
- """Public retrieval service used by chat.py and search tools.
16
-
17
- Delegates to RetrievalRouter which dispatches based on source_hint.
18
- Returns RetrievalResult objects directly so downstream consumers
19
- (db_executor, tabular_executor) can be fed without lossy dict
20
- conversion. The `db` parameter is accepted for call-site compatibility
21
- but currently unused — retrieval reads PGVector via _pgvector_engine
22
- inside each retriever.
23
- """
24
-
25
- def __init__(self):
26
- self._router = RetrievalRouter(
27
- schema_retriever=schema_retriever,
28
- document_retriever=document_retriever,
29
- )
30
-
31
- async def retrieve(
32
- self,
33
- query: str,
34
- user_id: str,
35
- db: AsyncSession,
36
- k: int = 5,
37
- source_hint: SourceHint = "both",
38
- ) -> list[RetrievalResult]:
39
- try:
40
- return await self._router.retrieve(query, user_id, source_hint, k)
41
- except Exception as e:
42
- logger.error("retrieval failed", error=str(e))
43
- return []
44
-
45
-
46
- retriever = RetrieverService()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/retrievers/__init__.py DELETED
File without changes
src/rag/retrievers/baseline.py DELETED
@@ -1,76 +0,0 @@
1
- """Service for retrieving relevant documents from vector store."""
2
-
3
- import hashlib
4
- import json
5
- from src.db.postgres.vector_store import get_vector_store
6
- from src.db.redis.connection import get_redis
7
- from sqlalchemy.ext.asyncio import AsyncSession
8
- from src.middlewares.logging import get_logger
9
- from typing import List, Dict, Any
10
-
11
- logger = get_logger("retriever")
12
-
13
- _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
-
15
-
16
- class BaselineRetrieverService:
17
- """Baseline (pre-Phase-1) retriever — preserved for benchmark comparison.
18
-
19
- Renamed from RetrieverService so it doesn't shadow the production wrapper
20
- at src/rag/retriever.py. Production code imports from src.rag.retriever;
21
- benchmark scripts that want this baseline must import explicitly from
22
- src.rag.retrievers.baseline.
23
- """
24
-
25
- def __init__(self):
26
- self.vector_store = get_vector_store()
27
-
28
- async def retrieve(
29
- self,
30
- query: str,
31
- user_id: str,
32
- db: AsyncSession,
33
- k: int = 5
34
- ) -> List[Dict[str, Any]]:
35
- """Retrieve relevant chunks for a query, scoped to the user's documents.
36
-
37
- Returns:
38
- List of dicts with keys: content, metadata
39
- metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
40
- """
41
- try:
42
- redis = await get_redis()
43
- query_hash = hashlib.md5(query.encode()).hexdigest()
44
- cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
45
-
46
- cached = await redis.get(cache_key)
47
- if cached:
48
- logger.info("Returning cached retrieval results")
49
- return json.loads(cached)
50
-
51
- logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
52
-
53
- docs = await self.vector_store.asimilarity_search(
54
- query=query,
55
- k=k,
56
- filter={"user_id": user_id}
57
- )
58
-
59
- results = [
60
- {
61
- "content": doc.page_content,
62
- "metadata": doc.metadata,
63
- }
64
- for doc in docs
65
- ]
66
-
67
- logger.info(f"Retrieved {len(results)} chunks")
68
- await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
69
- return results
70
-
71
- except Exception as e:
72
- logger.error("Retrieval failed", error=str(e))
73
- return []
74
-
75
-
76
- baseline_retriever = BaselineRetrieverService()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/retrievers/document.py DELETED
@@ -1,158 +0,0 @@
1
- """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular)."""
2
-
3
- import math
4
-
5
- from langchain_postgres import PGVector
6
- from langchain_postgres.vectorstores import DistanceStrategy
7
- from langchain_openai import AzureOpenAIEmbeddings
8
- from sqlalchemy import text
9
-
10
- from src.config.settings import settings
11
- from src.db.postgres.connection import _pgvector_engine
12
- from src.db.postgres.vector_store import get_vector_store
13
- from src.middlewares.logging import get_logger
14
- from src.rag.base import BaseRetriever, RetrievalResult
15
-
16
- logger = get_logger("document_retriever")
17
-
18
- # Change this one line to switch retrieval method
19
- # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
20
- _RETRIEVAL_METHOD = "mmr"
21
-
22
- _TABULAR_TYPES = {"csv", "xlsx"}
23
- _FETCH_K = 20
24
- _LAMBDA_MULT = 0.5
25
- _COLLECTION_NAME = "document_embeddings"
26
-
27
- _embeddings = AzureOpenAIEmbeddings(
28
- azure_deployment=settings.azureai_deployment_name_embedding,
29
- openai_api_version=settings.azureai_api_version_embedding,
30
- azure_endpoint=settings.azureai_endpoint_url_embedding,
31
- api_key=settings.azureai_api_key_embedding,
32
- )
33
-
34
- _euclidean_store = PGVector(
35
- embeddings=_embeddings,
36
- connection=_pgvector_engine,
37
- collection_name=_COLLECTION_NAME,
38
- distance_strategy=DistanceStrategy.EUCLIDEAN,
39
- use_jsonb=True,
40
- async_mode=True,
41
- create_extension=False,
42
- )
43
-
44
- _ip_store = PGVector(
45
- embeddings=_embeddings,
46
- connection=_pgvector_engine,
47
- collection_name=_COLLECTION_NAME,
48
- distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
49
- use_jsonb=True,
50
- async_mode=True,
51
- create_extension=False,
52
- )
53
-
54
- _MANHATTAN_SQL = text("""
55
- SELECT
56
- lpe.document,
57
- lpe.cmetadata,
58
- lpe.embedding <+> CAST(:embedding AS vector) AS distance
59
- FROM langchain_pg_embedding lpe
60
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
61
- WHERE lpc.name = :collection
62
- AND lpe.cmetadata->>'user_id' = :user_id
63
- AND lpe.cmetadata->>'source_type' = 'document'
64
- ORDER BY distance ASC
65
- LIMIT :k
66
- """)
67
-
68
-
69
- class DocumentRetriever(BaseRetriever):
70
- def __init__(self) -> None:
71
- self.vector_store = get_vector_store()
72
-
73
- async def retrieve(
74
- self, query: str, user_id: str, k: int = 5
75
- ) -> list[RetrievalResult]:
76
- filter_ = {"user_id": user_id, "source_type": "document"}
77
- fetch_k = k + len(_TABULAR_TYPES)
78
-
79
- if _RETRIEVAL_METHOD == "manhattan":
80
- return await self._retrieve_manhattan(query, user_id, k, fetch_k)
81
-
82
- if _RETRIEVAL_METHOD == "mmr":
83
- docs = await self.vector_store.amax_marginal_relevance_search(
84
- query=query,
85
- k=fetch_k,
86
- fetch_k=_FETCH_K,
87
- lambda_mult=_LAMBDA_MULT,
88
- filter=filter_,
89
- )
90
- cosine = await self.vector_store.asimilarity_search_with_score(
91
- query=query, k=fetch_k, filter=filter_,
92
- )
93
- score_map = {doc.page_content: score for doc, score in cosine}
94
- docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
95
- elif _RETRIEVAL_METHOD == "euclidean":
96
- docs_with_scores = await _euclidean_store.asimilarity_search_with_score(
97
- query=query, k=fetch_k, filter=filter_,
98
- )
99
- elif _RETRIEVAL_METHOD == "inner_product":
100
- docs_with_scores = await _ip_store.asimilarity_search_with_score(
101
- query=query, k=fetch_k, filter=filter_,
102
- )
103
- else: # cosine
104
- docs_with_scores = await self.vector_store.asimilarity_search_with_score(
105
- query=query, k=fetch_k, filter=filter_,
106
- )
107
-
108
- results = []
109
- for doc, score in docs_with_scores:
110
- file_type = doc.metadata.get("data", {}).get("file_type", "")
111
- if file_type not in _TABULAR_TYPES:
112
- results.append(RetrievalResult(
113
- content=doc.page_content,
114
- metadata=doc.metadata,
115
- score=score,
116
- source_type="document",
117
- ))
118
- if len(results) == k:
119
- break
120
-
121
- logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
122
- return results
123
-
124
- async def _retrieve_manhattan(
125
- self, query: str, user_id: str, k: int, fetch_k: int
126
- ) -> list[RetrievalResult]:
127
- query_vector = await _embeddings.aembed_query(query)
128
- if not all(math.isfinite(v) for v in query_vector):
129
- raise ValueError("Embedding vector contains NaN or Infinity values.")
130
- vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
131
-
132
- async with _pgvector_engine.connect() as conn:
133
- result = await conn.execute(_MANHATTAN_SQL, {
134
- "embedding": vector_str,
135
- "collection": _COLLECTION_NAME,
136
- "user_id": user_id,
137
- "k": fetch_k,
138
- })
139
- rows = result.fetchall()
140
-
141
- results = []
142
- for row in rows:
143
- file_type = row.cmetadata.get("data", {}).get("file_type", "")
144
- if file_type not in _TABULAR_TYPES:
145
- results.append(RetrievalResult(
146
- content=row.document,
147
- metadata=row.cmetadata,
148
- score=float(row.distance),
149
- source_type="document",
150
- ))
151
- if len(results) == k:
152
- break
153
-
154
- logger.info("retrieved chunks", method="manhattan", count=len(results))
155
- return results
156
-
157
-
158
- document_retriever = DocumentRetriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/retrievers/schema.py DELETED
@@ -1,411 +0,0 @@
1
- """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
- columns stored as source_type="document" with file_type in ("csv","xlsx").
3
-
4
- Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
5
- + tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only).
6
- Embeds the query once, fans out five legs in parallel.
7
-
8
- The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
9
- a recall signal for multi-table questions: when a relevant table's columns
10
- don't individually win on similarity, the table chunk can still pull the table
11
- into the hit set, where db_executor's downstream full-schema fetch picks up
12
- the per-column detail.
13
-
14
- FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py).
15
- """
16
-
17
- import asyncio
18
-
19
- from sqlalchemy import text
20
-
21
- from src.db.postgres.connection import _pgvector_engine
22
- from src.db.postgres.vector_store import get_vector_store
23
- from src.middlewares.logging import get_logger
24
- from src.rag.base import BaseRetriever, RetrievalResult
25
-
26
- logger = get_logger("schema_retriever")
27
-
28
- _TABULAR_FILE_TYPES = ("csv", "xlsx")
29
- _TABLE_CHUNK_K_MULTIPLIER = 2 # how many table chunks to pull before RRF
30
-
31
-
32
- class SchemaRetriever(BaseRetriever):
33
- def __init__(self):
34
- self.vector_store = get_vector_store()
35
-
36
- # ------------------------------------------------------------------
37
- # Internal helpers
38
- # ------------------------------------------------------------------
39
-
40
- async def _embed_query(self, query: str) -> list[float]:
41
- return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
42
-
43
- async def _search_db(
44
- self, embedding: list[float], user_id: str, k: int
45
- ) -> list[RetrievalResult]:
46
- """Cosine vector search over database chunks."""
47
- emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
48
-
49
- sql = text(f"""
50
- SELECT lpe.document, lpe.cmetadata,
51
- 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
52
- FROM langchain_pg_embedding lpe
53
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
54
- WHERE lpc.name = 'document_embeddings'
55
- AND lpe.cmetadata->>'user_id' = :user_id
56
- AND lpe.cmetadata->>'source_type' = 'database'
57
- AND lpe.cmetadata->>'chunk_level' = 'column'
58
- ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
59
- LIMIT :k
60
- """)
61
-
62
- async with _pgvector_engine.connect() as conn:
63
- result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
64
- rows = result.fetchall()
65
-
66
- return [
67
- RetrievalResult(
68
- content=row.document,
69
- metadata=row.cmetadata,
70
- score=float(row.score),
71
- source_type="database",
72
- )
73
- for row in rows
74
- ]
75
-
76
- async def _search_db_tables(
77
- self, embedding: list[float], user_id: str, k: int
78
- ) -> list[RetrievalResult]:
79
- """Cosine vector search over database TABLE-level chunks.
80
-
81
- Recall channel for multi-table questions. The chunk's content is
82
- discarded downstream — db_executor only consumes its `data.table_name`
83
- to seed full-schema fetch.
84
- """
85
- emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
86
-
87
- sql = text(f"""
88
- SELECT lpe.document, lpe.cmetadata,
89
- 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
90
- FROM langchain_pg_embedding lpe
91
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
92
- WHERE lpc.name = 'document_embeddings'
93
- AND lpe.cmetadata->>'user_id' = :user_id
94
- AND lpe.cmetadata->>'source_type' = 'database'
95
- AND lpe.cmetadata->>'chunk_level' = 'table'
96
- ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
97
- LIMIT :k
98
- """)
99
-
100
- async with _pgvector_engine.connect() as conn:
101
- result = await conn.execute(
102
- sql, {"user_id": user_id, "k": k * _TABLE_CHUNK_K_MULTIPLIER}
103
- )
104
- rows = result.fetchall()
105
-
106
- return [
107
- RetrievalResult(
108
- content=row.document,
109
- metadata=row.cmetadata,
110
- score=float(row.score),
111
- source_type="database",
112
- )
113
- for row in rows
114
- ]
115
-
116
- async def _search_tabular(
117
- self, embedding: list[float], user_id: str, k: int
118
- ) -> list[RetrievalResult]:
119
- """Cosine vector search over tabular document chunks (csv/xlsx)."""
120
- emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
121
-
122
- sql = text(f"""
123
- SELECT lpe.document, lpe.cmetadata,
124
- 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
125
- FROM langchain_pg_embedding lpe
126
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
127
- WHERE lpc.name = 'document_embeddings'
128
- AND lpe.cmetadata->>'user_id' = :user_id
129
- AND lpe.cmetadata->>'source_type' = 'document'
130
- AND lpe.cmetadata->>'chunk_level' = 'column'
131
- AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
132
- OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
133
- ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
134
- LIMIT :k
135
- """)
136
-
137
- async with _pgvector_engine.connect() as conn:
138
- result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
139
- rows = result.fetchall()
140
-
141
- return [
142
- RetrievalResult(
143
- content=row.document,
144
- metadata=row.cmetadata,
145
- score=float(row.score),
146
- source_type="document",
147
- )
148
- for row in rows
149
- ]
150
-
151
- async def _search_tabular_sheets(
152
- self, embedding: list[float], user_id: str, k: int
153
- ) -> list[RetrievalResult]:
154
- """Leg 5: sheet-level summary chunks from CSV/XLSX files."""
155
- emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
156
-
157
- sql = text(f"""
158
- SELECT lpe.document, lpe.cmetadata,
159
- 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
160
- FROM langchain_pg_embedding lpe
161
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
162
- WHERE lpc.name = 'document_embeddings'
163
- AND lpe.cmetadata->>'user_id' = :user_id
164
- AND lpe.cmetadata->>'source_type' = 'document'
165
- AND lpe.cmetadata->>'chunk_level' = 'sheet'
166
- AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
167
- OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
168
- ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
169
- LIMIT :k
170
- """)
171
-
172
- async with _pgvector_engine.connect() as conn:
173
- result = await conn.execute(sql, {"user_id": user_id, "k": k})
174
- rows = result.fetchall()
175
-
176
- return [
177
- RetrievalResult(
178
- content=row.document,
179
- metadata=row.cmetadata,
180
- score=float(row.score),
181
- source_type="document",
182
- )
183
- for row in rows
184
- ]
185
-
186
- async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
187
- """Full-text search over DB schema chunks using PostgreSQL tsvector."""
188
- sql = text("""
189
- SELECT lpe.document, lpe.cmetadata,
190
- ts_rank(to_tsvector('english', lpe.document),
191
- plainto_tsquery('english', :query)) AS rank
192
- FROM langchain_pg_embedding lpe
193
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
194
- WHERE lpc.name = 'document_embeddings'
195
- AND lpe.cmetadata->>'user_id' = :user_id
196
- AND lpe.cmetadata->>'source_type' = 'database'
197
- AND lpe.cmetadata->>'chunk_level' = 'column'
198
- AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
199
- ORDER BY rank DESC
200
- LIMIT :k
201
- """)
202
-
203
- async with _pgvector_engine.connect() as conn:
204
- result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
205
- rows = result.fetchall()
206
-
207
- return [
208
- RetrievalResult(
209
- content=row.document,
210
- metadata=row.cmetadata,
211
- score=float(row.rank),
212
- source_type="database",
213
- )
214
- for row in rows
215
- ]
216
-
217
- def _rank_tabular_sheets(
218
- self,
219
- sheet_results: list[RetrievalResult],
220
- column_results: list[RetrievalResult],
221
- top_k: int,
222
- k_rrf: int = 60,
223
- ) -> list[RetrievalResult]:
224
- """Rank tabular sheets by RRF across two voting legs:
225
- L1 (primary): sheet-chunk cosine score
226
- L2 (vote): best column-chunk position per (doc_id, sheet_name)
227
-
228
- Returns top-k sheet-level RetrievalResults. The full column list of
229
- each sheet is already in the sheet chunk's data.column_names from
230
- ingestion, so downstream tabular_executor can read full sheet context.
231
-
232
- For sheets surfaced by column votes but missing a sheet chunk (rare —
233
- ingestion always creates one), a minimal stub is returned and
234
- tabular_executor falls back to reading columns from the parquet.
235
- """
236
- # L1: sheets indexed by (doc_id, sheet_name) from sheet chunks
237
- sheet_index: dict[tuple, RetrievalResult] = {}
238
- sheet_ranked: list[tuple] = []
239
- for r in sheet_results:
240
- d = r.metadata.get("data", {})
241
- key = (d.get("document_id"), d.get("sheet_name"))
242
- if key[0] and key not in sheet_index:
243
- sheet_index[key] = r
244
- sheet_ranked.append(key)
245
-
246
- # L2: sheets ranked by first-appearance in column-chunk results
247
- col_sheet_ranked: list[tuple] = []
248
- seen: set[tuple] = set()
249
- for r in column_results:
250
- d = r.metadata.get("data", {})
251
- key = (d.get("document_id"), d.get("sheet_name"))
252
- if key[0] and key not in seen:
253
- col_sheet_ranked.append(key)
254
- seen.add(key)
255
-
256
- # RRF over (doc_id, sheet_name) across the two legs
257
- rrf_scores: dict[tuple, float] = {}
258
- for ranked_list in [sheet_ranked, col_sheet_ranked]:
259
- for rank, key in enumerate(ranked_list):
260
- rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
261
-
262
- top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k]
263
-
264
- results: list[RetrievalResult] = []
265
- for key in top_sheets:
266
- if key in sheet_index:
267
- r = sheet_index[key]
268
- r.score = rrf_scores[key]
269
- results.append(r)
270
- else:
271
- # Surfaced by column votes only — build stub from a representative
272
- # column result so tabular_executor can group correctly.
273
- doc_id, sheet_name = key
274
- rep = next(
275
- (r for r in column_results
276
- if r.metadata.get("data", {}).get("document_id") == doc_id
277
- and r.metadata.get("data", {}).get("sheet_name") == sheet_name),
278
- None,
279
- )
280
- if rep is None:
281
- continue
282
- stub_data = dict(rep.metadata.get("data", {}))
283
- stub_data.pop("column_name", None)
284
- stub_data.pop("column_type", None)
285
- results.append(RetrievalResult(
286
- content=f"Sheet: {stub_data.get('filename', '')}"
287
- + (f" / sheet: {sheet_name}" if sheet_name else ""),
288
- metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"},
289
- score=rrf_scores[key],
290
- source_type="document",
291
- ))
292
- return results
293
-
294
- def _rank_db_tables(
295
- self,
296
- tbl_results: list[RetrievalResult],
297
- col_results: list[RetrievalResult],
298
- fts_results: list[RetrievalResult],
299
- top_k: int,
300
- k_rrf: int = 60,
301
- ) -> list[RetrievalResult]:
302
- """Rank DB tables by RRF across three legs:
303
- L1 (primary): table-summary chunk similarity
304
- L2 (vote): best column-chunk position per table
305
- L3 (vote): best FTS position per table
306
-
307
- Returns top-k table-chunk RetrievalResults. For tables surfaced by
308
- L2/L3 but missing a table chunk, a minimal stub is returned so that
309
- db_executor._fetch_full_schema can seed off data.table_name.
310
- """
311
- # L1: tables ranked by table-chunk cosine score
312
- tbl_index: dict[str, RetrievalResult] = {}
313
- tbl_ranked: list[str] = []
314
- for r in tbl_results:
315
- tname = r.metadata.get("data", {}).get("table_name")
316
- if tname and tname not in tbl_index:
317
- tbl_index[tname] = r
318
- tbl_ranked.append(tname)
319
-
320
- # L2: tables ranked by first-appearance in column-chunk list (best col score)
321
- col_table_ranked: list[str] = []
322
- seen: set[str] = set()
323
- for r in col_results:
324
- tname = r.metadata.get("data", {}).get("table_name")
325
- if tname and tname not in seen:
326
- col_table_ranked.append(tname)
327
- seen.add(tname)
328
-
329
- # L3: tables ranked by first-appearance in FTS list
330
- fts_table_ranked: list[str] = []
331
- seen = set()
332
- for r in fts_results:
333
- tname = r.metadata.get("data", {}).get("table_name")
334
- if tname and tname not in seen:
335
- fts_table_ranked.append(tname)
336
- seen.add(tname)
337
-
338
- # RRF over table names across the three legs
339
- rrf_scores: dict[str, float] = {}
340
- for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]:
341
- for rank, tname in enumerate(ranked_list):
342
- rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1)
343
-
344
- top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k]
345
-
346
- results: list[RetrievalResult] = []
347
- for tname in top_tables:
348
- if tname in tbl_index:
349
- r = tbl_index[tname]
350
- r.score = rrf_scores[tname]
351
- results.append(r)
352
- else:
353
- # Surfaced by column/FTS votes with no table chunk — minimal stub
354
- results.append(RetrievalResult(
355
- content=f"Table: {tname}",
356
- metadata={"data": {"table_name": tname}, "source_type": "database"},
357
- score=rrf_scores[tname],
358
- source_type="database",
359
- ))
360
- return results
361
-
362
- # ------------------------------------------------------------------
363
- # Public interface — called by the router
364
- # ------------------------------------------------------------------
365
-
366
- async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
367
- """Table-first retrieval for DB sources; chunk-level for tabular.
368
-
369
- DB tables are ranked via RRF across three legs:
370
- L1 (primary): table-summary chunk similarity
371
- L2 (vote): top-K column-chunk cosine, grouped by table
372
- L3 (vote): top-K FTS column hits, grouped by table
373
-
374
- db_executor downstream fetches the full per-column schema for the
375
- ranked table set via _fetch_full_schema — the column chunks returned
376
- here are intentionally NOT used as the schema source, only for voting.
377
-
378
- Tabular (CSV/XLSX) sheets are ranked via RRF across two legs:
379
- L1: sheet-chunk cosine
380
- L2: column-chunk votes (best position per sheet)
381
- Returns sheet-level RetrievalResults so tabular_executor receives
382
- full sheet context (all columns) rather than fragmented column hits.
383
- """
384
- embedding = await self._embed_query(query)
385
- db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
386
- self._search_db(embedding, user_id, k),
387
- self._search_db_tables(embedding, user_id, k),
388
- self._search_tabular(embedding, user_id, k),
389
- self._search_fts_db(query, user_id, k * 4),
390
- self._search_tabular_sheets(embedding, user_id, k),
391
- )
392
-
393
- db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
394
- tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k)
395
-
396
- results = sorted(db_ranked + tabular_ranked, key=lambda r: r.score, reverse=True)
397
- logger.info(
398
- "schema retrieval",
399
- count=len(results),
400
- db_tables_ranked=len(db_ranked),
401
- db_cols=len(db_col_results),
402
- db_tables=len(db_tbl_results),
403
- tabular_cols=len(tabular_results),
404
- tabular_sheets=len(sheet_results),
405
- tabular_ranked=len(tabular_ranked),
406
- fts=len(fts_results),
407
- )
408
- return results
409
-
410
-
411
- schema_retriever = SchemaRetriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/router.py DELETED
@@ -1,179 +0,0 @@
1
- """Routes retrieval requests to the appropriate retriever based on source_hint.
2
-
3
- Cross-retriever merging uses Reciprocal Rank Fusion (RRF) on per-retriever
4
- ranked lists — score scales differ across retrievers (RRF, cosine, distance)
5
- and aren't directly comparable, so we rank-merge instead of score-merge.
6
- """
7
-
8
- import asyncio
9
- import hashlib
10
- import json
11
- from dataclasses import asdict
12
- from typing import Literal
13
-
14
- from src.db.redis.connection import get_redis
15
- from src.middlewares.logging import get_logger
16
- from src.rag.base import BaseRetriever, RetrievalResult
17
-
18
- logger = get_logger("retrieval_router")
19
-
20
- _CACHE_TTL = 3600 # 1 hour
21
- _CACHE_KEY_PREFIX = "retrieval"
22
- _RRF_K = 60 # standard RRF constant
23
- SourceHint = Literal["document", "schema", "both"]
24
-
25
-
26
- def _result_dedup_key(r: RetrievalResult) -> tuple:
27
- """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
28
- tabular columns vs prose chunks vs sheet-level chunks."""
29
- data = r.metadata.get("data", {})
30
- return (
31
- r.source_type,
32
- data.get("table_name"),
33
- data.get("column_name"),
34
- data.get("filename"),
35
- data.get("sheet_name"),
36
- data.get("chunk_index"), # disambiguates multiple prose chunks per doc
37
- r.metadata.get("chunk_level"), # distinguishes sheet vs column chunks
38
- )
39
-
40
-
41
- def _rrf_merge(
42
- ranked_lists: list[list[RetrievalResult]],
43
- top_k: int,
44
- k_rrf: int = _RRF_K,
45
- ) -> list[RetrievalResult]:
46
- """Reciprocal Rank Fusion across retriever batches.
47
-
48
- Each input list is treated as already best-first ordered. Items are
49
- deduped via _result_dedup_key and re-ranked by aggregated reciprocal
50
- rank across all lists. Score on the returned RetrievalResult is the
51
- aggregated RRF score (uniform scale across legs).
52
- """
53
- scores: dict[tuple, float] = {}
54
- index: dict[tuple, RetrievalResult] = {}
55
-
56
- for ranked in ranked_lists:
57
- for rank, result in enumerate(ranked):
58
- key = _result_dedup_key(result)
59
- scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
60
- # Keep the first occurrence; metadata is identical for the same
61
- # key across lists, so any copy is fine.
62
- if key not in index:
63
- index[key] = result
64
-
65
- merged = sorted(index.values(), key=lambda r: scores[_result_dedup_key(r)], reverse=True)
66
- # Overwrite score with RRF score so downstream consumers see a uniform scale.
67
- for r in merged:
68
- r.score = scores[_result_dedup_key(r)]
69
- return merged[:top_k]
70
-
71
-
72
- async def invalidate_retrieval_cache(user_id: str) -> int:
73
- """Delete every cached retrieval entry for `user_id`.
74
-
75
- Called by ingest/upload/delete API handlers after a successful write so
76
- the next retrieval picks up the new data instead of stale cached top-k.
77
- Returns the number of keys removed.
78
- """
79
- redis = await get_redis()
80
- pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
81
- keys = [key async for key in redis.scan_iter(match=pattern)]
82
- if not keys:
83
- return 0
84
- deleted = await redis.delete(*keys)
85
- logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
86
- return int(deleted)
87
-
88
-
89
- class RetrievalRouter:
90
- def __init__(
91
- self,
92
- schema_retriever: BaseRetriever,
93
- document_retriever: BaseRetriever,
94
- ):
95
- self._retrievers: dict[str, BaseRetriever] = {
96
- "schema": schema_retriever,
97
- "document": document_retriever,
98
- }
99
-
100
- def _route(self, source_hint: SourceHint) -> list[tuple[str, BaseRetriever]]:
101
- if source_hint == "schema":
102
- return [("schema", self._retrievers["schema"])]
103
- if source_hint == "document":
104
- return [("document", self._retrievers["document"])]
105
- return list(self._retrievers.items())
106
-
107
- async def retrieve(
108
- self,
109
- query: str,
110
- user_id: str,
111
- source_hint: SourceHint = "both",
112
- k: int = 10,
113
- ) -> list[RetrievalResult]:
114
- redis = await get_redis()
115
- query_hash = hashlib.md5(query.encode()).hexdigest()
116
- cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}"
117
-
118
- cached = await redis.get(cache_key)
119
- if cached:
120
- try:
121
- raw = json.loads(cached)
122
- logger.info("returning cached retrieval results", source_hint=source_hint)
123
- return [RetrievalResult(**r) for r in raw]
124
- except Exception:
125
- logger.warning("corrupted retrieval cache, fetching fresh", cache_key=cache_key)
126
-
127
- results = await self._retrieve_uncached(query, user_id, source_hint, k)
128
-
129
- # Empty-result fallback: orchestrator may have misclassified intent.
130
- # Retry once with "both" before giving up. No-op when source_hint is
131
- # already "both".
132
- if not results and source_hint != "both":
133
- logger.warning(
134
- "empty retrieval, falling back to source_hint='both'",
135
- original_source_hint=source_hint,
136
- )
137
- results = await self._retrieve_uncached(query, user_id, "both", k)
138
-
139
- await redis.setex(
140
- cache_key,
141
- _CACHE_TTL,
142
- json.dumps([asdict(r) for r in results]),
143
- )
144
- return results
145
-
146
- async def _retrieve_uncached(
147
- self,
148
- query: str,
149
- user_id: str,
150
- source_hint: SourceHint,
151
- k: int,
152
- ) -> list[RetrievalResult]:
153
- routed = self._route(source_hint)
154
- batches = await asyncio.gather(
155
- *[r.retrieve(query, user_id, k) for _, r in routed],
156
- return_exceptions=True,
157
- )
158
-
159
- valid_lists: list[list[RetrievalResult]] = []
160
- per_retriever: dict[str, int | str] = {}
161
- for (name, _), batch in zip(routed, batches):
162
- if isinstance(batch, Exception):
163
- logger.error("retriever failed", retriever=name, error=str(batch))
164
- per_retriever[name] = "error"
165
- continue
166
- valid_lists.append(batch)
167
- per_retriever[name] = len(batch)
168
-
169
- results = _rrf_merge(valid_lists, top_k=k)
170
-
171
- logger.info(
172
- "router result",
173
- source_hint=source_hint,
174
- per_retriever=per_retriever,
175
- final_count=len(results),
176
- top_score=results[0].score if results else None,
177
- bottom_score=results[-1].score if results else None,
178
- )
179
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/{rag → retrieval}/base.py RENAMED
@@ -1,4 +1,4 @@
1
- """Shared contract for all retriever implementations."""
2
 
3
  from abc import ABC, abstractmethod
4
  from dataclasses import dataclass
 
1
+ """Shared types for the retrieval layer."""
2
 
3
  from abc import ABC, abstractmethod
4
  from dataclasses import dataclass
src/retrieval/document.py CHANGED
@@ -2,14 +2,161 @@
2
 
3
  For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector with
4
  collection `document_embeddings`. Methods: MMR, cosine, euclidean, etc.
5
-
6
- Receives the working implementation from the previous src/rag/retrievers/document.py
7
- during the cleanup phase; for now this is a placeholder.
8
  """
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- class DocumentRetriever:
12
- """Dense retrieval over PGVector chunks for unstructured sources."""
13
 
14
- async def retrieve(self, query: str, user_id: str, k: int = 5) -> list:
15
- raise NotImplementedError
 
2
 
3
  For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector with
4
  collection `document_embeddings`. Methods: MMR, cosine, euclidean, etc.
 
 
 
5
  """
6
 
7
+ import math
8
+
9
+ from langchain_postgres import PGVector
10
+ from langchain_postgres.vectorstores import DistanceStrategy
11
+ from langchain_openai import AzureOpenAIEmbeddings
12
+ from sqlalchemy import text
13
+
14
+ from src.config.settings import settings
15
+ from src.db.postgres.connection import _pgvector_engine
16
+ from src.db.postgres.vector_store import get_vector_store
17
+ from src.middlewares.logging import get_logger
18
+ from src.retrieval.base import BaseRetriever, RetrievalResult
19
+
20
+ logger = get_logger("document_retriever")
21
+
22
+ # Change this one line to switch retrieval method
23
+ # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
24
+ _RETRIEVAL_METHOD = "mmr"
25
+
26
+ _TABULAR_TYPES = {"csv", "xlsx"}
27
+ _FETCH_K = 20
28
+ _LAMBDA_MULT = 0.5
29
+ _COLLECTION_NAME = "document_embeddings"
30
+
31
+ _embeddings = AzureOpenAIEmbeddings(
32
+ azure_deployment=settings.azureai_deployment_name_embedding,
33
+ openai_api_version=settings.azureai_api_version_embedding,
34
+ azure_endpoint=settings.azureai_endpoint_url_embedding,
35
+ api_key=settings.azureai_api_key_embedding,
36
+ )
37
+
38
+ _euclidean_store = PGVector(
39
+ embeddings=_embeddings,
40
+ connection=_pgvector_engine,
41
+ collection_name=_COLLECTION_NAME,
42
+ distance_strategy=DistanceStrategy.EUCLIDEAN,
43
+ use_jsonb=True,
44
+ async_mode=True,
45
+ create_extension=False,
46
+ )
47
+
48
+ _ip_store = PGVector(
49
+ embeddings=_embeddings,
50
+ connection=_pgvector_engine,
51
+ collection_name=_COLLECTION_NAME,
52
+ distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
53
+ use_jsonb=True,
54
+ async_mode=True,
55
+ create_extension=False,
56
+ )
57
+
58
+ _MANHATTAN_SQL = text("""
59
+ SELECT
60
+ lpe.document,
61
+ lpe.cmetadata,
62
+ lpe.embedding <+> CAST(:embedding AS vector) AS distance
63
+ FROM langchain_pg_embedding lpe
64
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
65
+ WHERE lpc.name = :collection
66
+ AND lpe.cmetadata->>'user_id' = :user_id
67
+ AND lpe.cmetadata->>'source_type' = 'document'
68
+ ORDER BY distance ASC
69
+ LIMIT :k
70
+ """)
71
+
72
+
73
+ class DocumentRetriever(BaseRetriever):
74
+ def __init__(self) -> None:
75
+ self.vector_store = get_vector_store()
76
+
77
+ async def retrieve(
78
+ self, query: str, user_id: str, k: int = 5
79
+ ) -> list[RetrievalResult]:
80
+ filter_ = {"user_id": user_id, "source_type": "document"}
81
+ fetch_k = k + len(_TABULAR_TYPES)
82
+
83
+ if _RETRIEVAL_METHOD == "manhattan":
84
+ return await self._retrieve_manhattan(query, user_id, k, fetch_k)
85
+
86
+ if _RETRIEVAL_METHOD == "mmr":
87
+ docs = await self.vector_store.amax_marginal_relevance_search(
88
+ query=query,
89
+ k=fetch_k,
90
+ fetch_k=_FETCH_K,
91
+ lambda_mult=_LAMBDA_MULT,
92
+ filter=filter_,
93
+ )
94
+ cosine = await self.vector_store.asimilarity_search_with_score(
95
+ query=query, k=fetch_k, filter=filter_,
96
+ )
97
+ score_map = {doc.page_content: score for doc, score in cosine}
98
+ docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
99
+ elif _RETRIEVAL_METHOD == "euclidean":
100
+ docs_with_scores = await _euclidean_store.asimilarity_search_with_score(
101
+ query=query, k=fetch_k, filter=filter_,
102
+ )
103
+ elif _RETRIEVAL_METHOD == "inner_product":
104
+ docs_with_scores = await _ip_store.asimilarity_search_with_score(
105
+ query=query, k=fetch_k, filter=filter_,
106
+ )
107
+ else: # cosine
108
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
109
+ query=query, k=fetch_k, filter=filter_,
110
+ )
111
+
112
+ results = []
113
+ for doc, score in docs_with_scores:
114
+ file_type = doc.metadata.get("data", {}).get("file_type", "")
115
+ if file_type not in _TABULAR_TYPES:
116
+ results.append(RetrievalResult(
117
+ content=doc.page_content,
118
+ metadata=doc.metadata,
119
+ score=score,
120
+ source_type="document",
121
+ ))
122
+ if len(results) == k:
123
+ break
124
+
125
+ logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
126
+ return results
127
+
128
+ async def _retrieve_manhattan(
129
+ self, query: str, user_id: str, k: int, fetch_k: int
130
+ ) -> list[RetrievalResult]:
131
+ query_vector = await _embeddings.aembed_query(query)
132
+ if not all(math.isfinite(v) for v in query_vector):
133
+ raise ValueError("Embedding vector contains NaN or Infinity values.")
134
+ vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
135
+
136
+ async with _pgvector_engine.connect() as conn:
137
+ result = await conn.execute(_MANHATTAN_SQL, {
138
+ "embedding": vector_str,
139
+ "collection": _COLLECTION_NAME,
140
+ "user_id": user_id,
141
+ "k": fetch_k,
142
+ })
143
+ rows = result.fetchall()
144
+
145
+ results = []
146
+ for row in rows:
147
+ file_type = row.cmetadata.get("data", {}).get("file_type", "")
148
+ if file_type not in _TABULAR_TYPES:
149
+ results.append(RetrievalResult(
150
+ content=row.document,
151
+ metadata=row.cmetadata,
152
+ score=float(row.distance),
153
+ source_type="document",
154
+ ))
155
+ if len(results) == k:
156
+ break
157
+
158
+ logger.info("retrieved chunks", method="manhattan", count=len(results))
159
+ return results
160
 
 
 
161
 
162
+ document_retriever = DocumentRetriever()
 
src/retrieval/router.py CHANGED
@@ -1,11 +1,83 @@
1
- """Retrieval-side router.
2
 
3
- Currently dispatches only the `unstructured` route to DocumentRetriever.
4
- The `structured` route is owned by query/service.py not by retrieval.
5
- The `chat` route bypasses retrieval entirely.
 
 
 
 
6
  """
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class RetrievalRouter:
10
- async def dispatch(self, query: str, user_id: str, source_hint: str) -> list:
11
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrieval router — dispatches to DocumentRetriever for unstructured sources.
2
 
3
+ Routing rules:
4
+ - unstructured / document / both DocumentRetriever (PGVector, PDF/DOCX/TXT)
5
+ - structured / schema → empty list; handled by query/service.py
6
+ - chat → empty list; bypasses retrieval entirely
7
+
8
+ Exposes the same interface as the old src/rag/retriever.py so call sites in
9
+ chat.py require no changes beyond the import path.
10
  """
11
 
12
+ import hashlib
13
+ import json
14
+ from dataclasses import asdict
15
+
16
+ from sqlalchemy.ext.asyncio import AsyncSession
17
+
18
+ from src.db.redis.connection import get_redis
19
+ from src.middlewares.logging import get_logger
20
+ from src.retrieval.base import RetrievalResult
21
+ from src.retrieval.document import document_retriever
22
+
23
+ logger = get_logger("retrieval_router")
24
+
25
+ _CACHE_TTL = 3600
26
+ _CACHE_KEY_PREFIX = "retrieval"
27
+ _UNSTRUCTURED_HINTS = frozenset({"document", "unstructured", "both"})
28
+
29
 
30
  class RetrievalRouter:
31
+ async def retrieve(
32
+ self,
33
+ query: str,
34
+ user_id: str,
35
+ db: AsyncSession,
36
+ k: int = 5,
37
+ source_hint: str = "both",
38
+ ) -> list[RetrievalResult]:
39
+ if source_hint not in _UNSTRUCTURED_HINTS:
40
+ return []
41
+
42
+ redis = await get_redis()
43
+ query_hash = hashlib.md5(query.encode()).hexdigest()
44
+ cache_key = f"{_CACHE_KEY_PREFIX}:{user_id}:{source_hint}:{query_hash}:{k}"
45
+
46
+ cached = await redis.get(cache_key)
47
+ if cached:
48
+ try:
49
+ raw = json.loads(cached)
50
+ logger.info("returning cached retrieval results", source_hint=source_hint)
51
+ return [RetrievalResult(**r) for r in raw]
52
+ except Exception:
53
+ logger.warning("corrupted retrieval cache, fetching fresh")
54
+
55
+ try:
56
+ results = await document_retriever.retrieve(query, user_id, k)
57
+ except Exception as e:
58
+ logger.error("retrieval failed", error=str(e))
59
+ return []
60
+
61
+ if not results and source_hint == "both":
62
+ logger.warning("empty retrieval result for source_hint='both'")
63
+
64
+ await redis.setex(
65
+ cache_key,
66
+ _CACHE_TTL,
67
+ json.dumps([asdict(r) for r in results]),
68
+ )
69
+ return results
70
+
71
+ async def invalidate_cache(self, user_id: str) -> int:
72
+ """Delete all cached retrieval entries for a user. Call after upload/delete."""
73
+ redis = await get_redis()
74
+ pattern = f"{_CACHE_KEY_PREFIX}:{user_id}:*"
75
+ keys = [key async for key in redis.scan_iter(match=pattern)]
76
+ if not keys:
77
+ return 0
78
+ deleted = await redis.delete(*keys)
79
+ logger.info("retrieval cache invalidated", user_id=user_id, deleted=deleted)
80
+ return int(deleted)
81
+
82
+
83
+ retrieval_router = RetrievalRouter()
src/tools/__init__.py DELETED
File without changes
src/tools/search.py DELETED
@@ -1,46 +0,0 @@
1
- """Search tool for agent."""
2
-
3
- from langchain_core.tools import tool
4
- from src.rag.retriever import retriever
5
- from sqlalchemy.ext.asyncio import AsyncSession
6
- from src.middlewares.logging import get_logger
7
-
8
- logger = get_logger("search_tool")
9
-
10
-
11
- @tool
12
- async def search_documents(
13
- query: str,
14
- user_id: str,
15
- db: AsyncSession,
16
- num_results: int = 5
17
- ) -> str:
18
- """Search user's uploaded documents for relevant information.
19
-
20
- Args:
21
- query: The search query or question
22
- user_id: The user's ID
23
- db: Database session
24
- num_results: Number of results to return (default: 5)
25
-
26
- Returns:
27
- Relevant document excerpts with source and page information
28
- """
29
- try:
30
- results = await retriever.retrieve(query, user_id, db, num_results)
31
-
32
- if not results:
33
- return "No relevant information found in the documents."
34
-
35
- formatted_results = []
36
- for result in results:
37
- filename = result.metadata.get("filename", "Unknown")
38
- page = result.metadata.get("page_label")
39
- source_label = f"{filename}, p.{page}" if page else filename
40
- formatted_results.append(f"[Source: {source_label}]\n{result.content}\n")
41
-
42
- return "\n".join(formatted_results)
43
-
44
- except Exception as e:
45
- logger.error("Search failed", error=str(e))
46
- return "Sorry, I encountered an error while searching the documents."