feat/Planner Agent

#2
by rhbt6767 - opened
.gitignore CHANGED
@@ -37,8 +37,6 @@ playground_chat.py
37
  playground_flush_cache.py
38
  playground_create_user.py
39
  API_CONTRACT.md
40
- API_CONTRACT_AGENT.md
41
- API_CONTRACT_AGENT_ACTIVE.md
42
  context_engineering/
43
  sample_file/
44
  test_tesseract.py
@@ -48,4 +46,3 @@ software/
48
 
49
  tests/
50
  .claude/
51
- migratego/
 
37
  playground_flush_cache.py
38
  playground_create_user.py
39
  API_CONTRACT.md
 
 
40
  context_engineering/
41
  sample_file/
42
  test_tesseract.py
 
46
 
47
  tests/
48
  .claude/
 
.vscode/launch.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "DataEyond: FastAPI (debug)",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "module": "uvicorn",
12
+ "args": [
13
+ "main:app",
14
+ "--host", "0.0.0.0",
15
+ "--port", "7860",
16
+ "--reload"
17
+ ],
18
+ "jinja": true,
19
+ "justMyCode": true,
20
+ "envFile": "${workspaceFolder}/.env",
21
+ "console": "integratedTerminal",
22
+ "cwd": "${workspaceFolder}"
23
+ }
24
+ ]
25
+ }
main.py CHANGED
@@ -14,7 +14,6 @@ from src.api.v1.users import router as users_router
14
  from src.api.v1.db_client import router as db_client_router
15
  from src.api.v1.data_catalog import router as data_catalog_router
16
  from src.db.postgres.init_db import init_db
17
- import os
18
  import uvicorn
19
 
20
  # Configure logging
@@ -25,11 +24,8 @@ logger = get_logger("main")
25
  @asynccontextmanager
26
  async def lifespan(app: FastAPI):
27
  logger.info("Starting application...")
28
- if os.getenv("SKIP_INIT_DB", "false").lower() != "true":
29
- await init_db()
30
- logger.info("Database initialized")
31
- else:
32
- logger.info("Skipping database initialization (SKIP_INIT_DB=true)")
33
  yield
34
 
35
 
 
14
  from src.api.v1.db_client import router as db_client_router
15
  from src.api.v1.data_catalog import router as data_catalog_router
16
  from src.db.postgres.init_db import init_db
 
17
  import uvicorn
18
 
19
  # Configure logging
 
24
  @asynccontextmanager
25
  async def lifespan(app: FastAPI):
26
  logger.info("Starting application...")
27
+ await init_db()
28
+ logger.info("Database initialized")
 
 
 
29
  yield
30
 
31
 
pyproject.toml CHANGED
@@ -77,6 +77,7 @@ dependencies = [
77
  "cachetools==5.5.0",
78
  "apscheduler==3.10.4",
79
  "jsonpatch>=1.33",
 
80
  "psycopg2>=2.9.11",
81
  # --- SQL parsing / guardrails ---
82
  "sqlglot>=25.0.0",
@@ -120,7 +121,8 @@ ignore = [
120
  ]
121
 
122
  [tool.ruff.lint.per-file-ignores]
123
- "tests/**" = ["S101", "S105", "S106"]
 
124
 
125
  [tool.mypy]
126
  python_version = "3.12"
 
77
  "cachetools==5.5.0",
78
  "apscheduler==3.10.4",
79
  "jsonpatch>=1.33",
80
+ "pymongo>=4.14.0",
81
  "psycopg2>=2.9.11",
82
  # --- SQL parsing / guardrails ---
83
  "sqlglot>=25.0.0",
 
121
  ]
122
 
123
  [tool.ruff.lint.per-file-ignores]
124
+ # S608 in tests is a false positive — tests assert literal SQL strings as fixtures.
125
+ "tests/**" = ["S101", "S105", "S106", "S608"]
126
 
127
  [tool.mypy]
128
  python_version = "3.12"
src/agents/chat_handler.py CHANGED
@@ -170,7 +170,6 @@ class ChatHandler:
170
  sources = _build_sources(
171
  decision.source_hint, user_id, query_result, raw_chunks
172
  )
173
- logger.info("built sources", source_hint=decision.source_hint, sources_count=len(sources), raw_chunks_count=len(raw_chunks) if raw_chunks else 0)
174
  yield {"event": "sources", "data": json.dumps(sources)}
175
 
176
  # ---- 3. Stream answer ----------------------------------------
 
170
  sources = _build_sources(
171
  decision.source_hint, user_id, query_result, raw_chunks
172
  )
 
173
  yield {"event": "sources", "data": json.dumps(sources)}
174
 
175
  # ---- 3. Stream answer ----------------------------------------
src/api/v1/chat.py CHANGED
@@ -42,19 +42,15 @@ class ChatRequest(BaseModel):
42
  message: str
43
 
44
 
45
- async def get_cached_response(redis, cache_key: str) -> Optional[dict]:
46
  cached = await redis.get(cache_key)
47
  if cached:
48
- data = json.loads(cached)
49
- if isinstance(data, dict) and "response" in data:
50
- return data
51
- # legacy: plain string cached before this change
52
- return {"response": data, "sources": []}
53
  return None
54
 
55
 
56
- async def cache_response(redis, cache_key: str, response: str, sources: list):
57
- await redis.setex(cache_key, 86400, json.dumps({"response": response, "sources": sources}))
58
 
59
 
60
  async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
@@ -95,34 +91,6 @@ async def save_messages(
95
  await db.commit()
96
 
97
 
98
- @router.delete("/chat/cache")
99
- async def clear_chat_cache(room_id: str, message: str):
100
- """Delete the Redis cache entry for a specific room + message pair."""
101
- redis = await get_redis()
102
- cache_key = f"{settings.redis_prefix}chat:{room_id}:{message}"
103
- deleted = await redis.delete(cache_key)
104
- return {"deleted": deleted > 0, "cache_key": cache_key}
105
-
106
-
107
- @router.delete("/chat/cache/room/{room_id}")
108
- async def clear_room_cache(room_id: str):
109
- """Delete all Redis cache entries for a room."""
110
- redis = await get_redis()
111
- pattern = f"{settings.redis_prefix}chat:{room_id}:*"
112
- keys = await redis.keys(pattern)
113
- if keys:
114
- await redis.delete(*keys)
115
- return {"deleted_count": len(keys), "room_id": room_id}
116
-
117
-
118
- @router.delete("/retrieval/cache/{user_id}")
119
- async def clear_retrieval_cache(user_id: str):
120
- """Delete all cached retrieval results for a user. Call this after uploading/processing new documents."""
121
- from src.retrieval.router import retrieval_router
122
- deleted = await retrieval_router.invalidate_cache(user_id)
123
- return {"deleted_count": deleted, "user_id": user_id}
124
-
125
-
126
  @router.post("/chat/stream")
127
  @log_execution(logger)
128
  async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
@@ -139,17 +107,13 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
139
 
140
  # Redis cache hit
141
  cached = await get_cached_response(redis, cache_key)
142
- logger.info("cache check", cache_key=cache_key, cache_hit=cached is not None)
143
  if cached:
144
  logger.info("Returning cached response")
145
- cached_text = cached["response"]
146
- cached_sources = cached["sources"]
147
- await save_messages(db, request.room_id, request.message, cached_text, sources=cached_sources)
148
 
149
  async def stream_cached():
150
- yield {"event": "sources", "data": json.dumps(cached_sources)}
151
- for i in range(0, len(cached_text), 50):
152
- yield {"event": "chunk", "data": cached_text[i:i + 50]}
153
  yield {"event": "done", "data": ""}
154
 
155
  return EventSourceResponse(stream_cached())
@@ -158,7 +122,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
158
  # Fast intent: greetings/farewells bypass LLM entirely
159
  direct = _fast_intent(request.message)
160
  if direct:
161
- await cache_response(redis, cache_key, direct, sources=[])
162
  await save_messages(db, request.room_id, request.message, direct, sources=[])
163
 
164
  async def stream_direct():
@@ -172,7 +136,6 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
172
  handler = ChatHandler()
173
 
174
  async def stream_response():
175
- logger.info("stream_response started", room_id=request.room_id, user_id=request.user_id)
176
  full_response = ""
177
  sources: List[Dict[str, Any]] = []
178
  async for event in handler.handle(request.message, request.user_id, history):
@@ -186,12 +149,8 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
186
  full_response += event["data"]
187
  yield event
188
  elif event["event"] == "done":
189
- await cache_response(redis, cache_key, full_response, sources=sources)
190
- logger.info("saving messages", sources_count=len(sources), sources=sources)
191
- try:
192
- await save_messages(db, request.room_id, request.message, full_response, sources=sources)
193
- except Exception as e:
194
- logger.error("save_messages failed", room_id=request.room_id, error=str(e))
195
  yield event
196
  elif event["event"] == "error":
197
  yield event
 
42
  message: str
43
 
44
 
45
+ async def get_cached_response(redis, cache_key: str) -> Optional[str]:
46
  cached = await redis.get(cache_key)
47
  if cached:
48
+ return json.loads(cached)
 
 
 
 
49
  return None
50
 
51
 
52
+ async def cache_response(redis, cache_key: str, response: str):
53
+ await redis.setex(cache_key, 86400, json.dumps(response))
54
 
55
 
56
  async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
 
91
  await db.commit()
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @router.post("/chat/stream")
95
  @log_execution(logger)
96
  async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
 
107
 
108
  # Redis cache hit
109
  cached = await get_cached_response(redis, cache_key)
 
110
  if cached:
111
  logger.info("Returning cached response")
 
 
 
112
 
113
  async def stream_cached():
114
+ yield {"event": "sources", "data": json.dumps([])}
115
+ for i in range(0, len(cached), 50):
116
+ yield {"event": "chunk", "data": cached[i:i + 50]}
117
  yield {"event": "done", "data": ""}
118
 
119
  return EventSourceResponse(stream_cached())
 
122
  # Fast intent: greetings/farewells bypass LLM entirely
123
  direct = _fast_intent(request.message)
124
  if direct:
125
+ await cache_response(redis, cache_key, direct)
126
  await save_messages(db, request.room_id, request.message, direct, sources=[])
127
 
128
  async def stream_direct():
 
136
  handler = ChatHandler()
137
 
138
  async def stream_response():
 
139
  full_response = ""
140
  sources: List[Dict[str, Any]] = []
141
  async for event in handler.handle(request.message, request.user_id, history):
 
149
  full_response += event["data"]
150
  yield event
151
  elif event["event"] == "done":
152
+ await cache_response(redis, cache_key, full_response)
153
+ await save_messages(db, request.room_id, request.message, full_response, sources=sources)
 
 
 
 
154
  yield event
155
  elif event["event"] == "error":
156
  yield event
src/api/v1/document.py CHANGED
@@ -114,8 +114,5 @@ async def process_document(
114
  except Exception as e:
115
  logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e))
116
 
117
- from src.retrieval.router import retrieval_router
118
- await retrieval_router.invalidate_cache(user_id)
119
-
120
  return {"status": "success", "message": "Document processed successfully", "data": data}
121
 
 
114
  except Exception as e:
115
  logger.error("catalog ingestion failed after process", document_id=document_id, error=str(e))
116
 
 
 
 
117
  return {"status": "success", "message": "Document processed successfully", "data": data}
118
 
src/config/prompts/intent_router.md CHANGED
@@ -7,16 +7,16 @@ Return three fields:
7
  - **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself.
8
  - **`source_hint`** — one of:
9
  - `chat` — no data lookup needed (greetings, farewells, generic small talk).
10
- - `unstructured` — the user is asking about a topic, concept, feature, or factual knowledge that may exist in uploaded documents (PDF / DOCX / TXT). The user does not need to explicitly mention a document.
11
  - `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources.
12
  - **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null.
13
 
14
  ## Routing rules
15
 
16
- 1. If the message is ONLY a pure greeting / farewell / thanks / "how are you" / "what can you do" / compliment with no factual question → `chat` + `needs_search=false`.
17
- 2. If the message asks a data question answerable from a database or tabular file (counts, sums, top-N, filters, comparisons, trends, sheet rows, table columns) → `structured` + `needs_search=true`.
18
- 3. If the message asks about a topic, concept, feature, explanation, summary, or factual knowledge even without explicitly mentioning a document route to `unstructured` + `needs_search=true`. The user may have uploaded relevant documents covering that topic.
19
- 4. If ambiguous between structured and unstructured prefer `unstructured`. Only prefer `structured` if there are clear signals of tabular/numeric data questions.
20
  5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate.
21
 
22
  ## Rewriting follow-ups
@@ -53,22 +53,6 @@ User: "Top 5 customers by revenue this year"
53
  → needs_search=true, source_hint="structured",
54
  rewritten_query="Top 5 customers by revenue this year"
55
 
56
- User: "apa key feature dari iot connectivity?"
57
- → needs_search=true, source_hint="unstructured",
58
- rewritten_query="What are the key features of IoT connectivity?"
59
-
60
- User: "jelaskan tentang machine learning"
61
- → needs_search=true, source_hint="unstructured",
62
- rewritten_query="Explain machine learning"
63
-
64
- User: "bagaimana cara kerja neural network?"
65
- → needs_search=true, source_hint="unstructured",
66
- rewritten_query="How does a neural network work?"
67
-
68
- User: "what is the main purpose of this system?"
69
- → needs_search=true, source_hint="unstructured",
70
- rewritten_query="What is the main purpose of this system?"
71
-
72
  History: assistant: "Pro Plan Annual led at $487,200 in April."
73
  User: "And in March?"
74
  → needs_search=true, source_hint="structured",
@@ -77,6 +61,6 @@ User: "And in March?"
77
 
78
  ## Constraints
79
 
80
- - Do not invent data. If the question is factual or knowledge-based (not clearly tabular), route to `unstructured` and let the retriever decide. Only route to `structured` if the question clearly involves counts, sums, filters, or trends from tabular sources.
81
  - Do not refuse — refusal happens later in guardrails. Just classify.
82
  - One JSON object as output; no prose, no markdown.
 
7
  - **`needs_search`** — `true` if we must look at the user's data to answer; `false` for greetings, farewells, off-topic chitchat, or meta questions about the assistant itself.
8
  - **`source_hint`** — one of:
9
  - `chat` — no data lookup needed (greetings, farewells, generic small talk).
10
+ - `unstructured` — the user is asking about the **content** of an uploaded document (PDF / DOCX / TXT).
11
  - `structured` — the user is asking a **data question** answerable from a database or a tabular file (CSV / XLSX / Parquet). This includes counts, sums, top-N, filters, comparisons, trends, joins across registered structured sources.
12
  - **`rewritten_query`** — a **standalone** version of the user's question that incorporates necessary context from history. If the original message is already standalone, return it unchanged. If `needs_search` is `false`, leave this empty/null.
13
 
14
  ## Routing rules
15
 
16
+ 1. If the message is a pure greeting / farewell / thanks / "how are you" / "what can you do" → `chat` + `needs_search=false`.
17
+ 2. If the message references content that lives in a registered DB or uploaded tabular file (sales numbers, customer counts, order trends, sheet rows, table columns) → `structured` + `needs_search=true`.
18
+ 3. If the message asks about prose content (a section of a PDF, what a memo says, a quote from a document) `unstructured` + `needs_search=true`.
19
+ 4. If the message is ambiguous between structured and unstructured, prefer `structured` the planner can fall back if the catalog has nothing relevant.
20
  5. Cross-source comparison ("compare DB sales to the customers.csv file") → `structured`. The planner sees both source types in one prompt and can correlate.
21
 
22
  ## Rewriting follow-ups
 
53
  → needs_search=true, source_hint="structured",
54
  rewritten_query="Top 5 customers by revenue this year"
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  History: assistant: "Pro Plan Annual led at $487,200 in April."
57
  User: "And in March?"
58
  → needs_search=true, source_hint="structured",
 
61
 
62
  ## Constraints
63
 
64
+ - Do not invent data. If you don't know whether a topic exists in the user's data, route to `structured` and let the planner decide.
65
  - Do not refuse — refusal happens later in guardrails. Just classify.
66
  - One JSON object as output; no prose, no markdown.
src/config/settings.py CHANGED
@@ -1,7 +1,7 @@
1
  """Centralized configuration management using pydantic-settings."""
2
 
3
- # import os
4
- # from typing import Optional
5
  from pydantic import Field
6
  from pydantic_settings import BaseSettings, SettingsConfigDict
7
 
@@ -51,8 +51,8 @@ class Settings(BaseSettings):
51
  LANGFUSE_HOST: str
52
 
53
  # MongoDB (for users - existing)
54
- # emarcal_mongo_endpoint_url: str = Field(alias="emarcal__mongo__endpoint__url", default="")
55
- # emarcal_buma_mongo_dbname: str = Field(alias="emarcal__buma__mongo__dbname", default="")
56
 
57
  # JWT (for users - existing)
58
  emarcal_jwt_secret_key: str = Field(alias="emarcal__jwt__secret_key", default="")
 
1
  """Centralized configuration management using pydantic-settings."""
2
 
3
+ import os
4
+ from typing import Optional
5
  from pydantic import Field
6
  from pydantic_settings import BaseSettings, SettingsConfigDict
7
 
 
51
  LANGFUSE_HOST: str
52
 
53
  # MongoDB (for users - existing)
54
+ emarcal_mongo_endpoint_url: str = Field(alias="emarcal__mongo__endpoint__url", default="")
55
+ emarcal_buma_mongo_dbname: str = Field(alias="emarcal__buma__mongo__dbname", default="")
56
 
57
  # JWT (for users - existing)
58
  emarcal_jwt_secret_key: str = Field(alias="emarcal__jwt__secret_key", default="")
src/db/postgres/vector_store.py CHANGED
@@ -19,7 +19,7 @@ embeddings = AzureOpenAIEmbeddings(
19
  vector_store = PGVector(
20
  embeddings=embeddings,
21
  connection=_pgvector_engine,
22
- collection_name="documents",
23
  use_jsonb=True,
24
  async_mode=True,
25
  create_extension=False, # Extension pre-created in init_db.py (avoids multi-statement asyncpg bug)
 
19
  vector_store = PGVector(
20
  embeddings=embeddings,
21
  connection=_pgvector_engine,
22
+ collection_name="document_embeddings",
23
  use_jsonb=True,
24
  async_mode=True,
25
  create_extension=False, # Extension pre-created in init_db.py (avoids multi-statement asyncpg bug)
src/knowledge/processing_service.py CHANGED
@@ -59,7 +59,6 @@ class KnowledgeProcessingService:
59
  "filename": db_doc.filename,
60
  "file_type": db_doc.file_type,
61
  "chunk_index": i,
62
- "page_label": None,
63
  },
64
  }
65
  )
 
59
  "filename": db_doc.filename,
60
  "file_type": db_doc.file_type,
61
  "chunk_index": i,
 
62
  },
63
  }
64
  )
src/middlewares/logging.py CHANGED
@@ -9,7 +9,7 @@ import time
9
 
10
  def configure_logging():
11
  """Configure structured logging."""
12
- logging.basicConfig(level=logging.INFO)
13
  logging.getLogger("tabular_executor").setLevel(logging.INFO)
14
  structlog.configure(
15
  processors=[
 
9
 
10
  def configure_logging():
11
  """Configure structured logging."""
12
+ logging.basicConfig(level=logging.WARNING)
13
  logging.getLogger("tabular_executor").setLevel(logging.INFO)
14
  structlog.configure(
15
  processors=[
src/retrieval/document.py CHANGED
@@ -1,44 +1,68 @@
1
- """DocumentRetriever — dense similarity over prose chunks.
2
 
3
- For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector via
4
- raw SQL to avoid LangChain ORM / asyncpg type-mapping issues (id UUID vs
5
- String mismatch, jsonb_path_match asyncpg binding quirks).
6
- Collection `document_embeddings`. Methods: cosine | manhattan.
7
  """
8
 
9
  import functools
10
  import math
11
 
 
 
12
  from langchain_openai import AzureOpenAIEmbeddings
13
  from sqlalchemy import text
14
 
15
  from src.config.settings import settings
16
  from src.db.postgres.connection import _pgvector_engine
 
17
  from src.middlewares.logging import get_logger
18
  from src.retrieval.base import BaseRetriever, RetrievalResult
19
 
20
  logger = get_logger("document_retriever")
21
 
22
  # Change this one line to switch retrieval method
23
- # Options: "cosine" | "manhattan"
24
- _RETRIEVAL_METHOD = "cosine"
25
 
26
  _TABULAR_TYPES = {"csv", "xlsx"}
27
- _COLLECTION_NAME = "documents"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- _COSINE_SQL = text("""
30
- SELECT
31
- lpe.document,
32
- lpe.cmetadata,
33
- lpe.embedding <=> CAST(:embedding AS vector) AS distance
34
- FROM langchain_pg_embedding lpe
35
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
36
- WHERE lpc.name = :collection
37
- AND lpe.cmetadata->>'user_id' = :user_id
38
- AND lpe.cmetadata->>'source_type' = 'document'
39
- ORDER BY distance ASC
40
- LIMIT :k
41
- """)
42
 
43
  _MANHATTAN_SQL = text("""
44
  SELECT
@@ -55,32 +79,71 @@ _MANHATTAN_SQL = text("""
55
  """)
56
 
57
 
58
- @functools.cache
59
- def _get_embeddings() -> AzureOpenAIEmbeddings:
60
- return AzureOpenAIEmbeddings(
61
- azure_deployment=settings.azureai_deployment_name_embedding,
62
- openai_api_version=settings.azureai_api_version_embedding,
63
- azure_endpoint=settings.azureai_endpoint_url_embedding,
64
- api_key=settings.azureai_api_key_embedding,
65
- )
66
-
67
-
68
  class DocumentRetriever(BaseRetriever):
 
 
 
69
  async def retrieve(
70
  self, query: str, user_id: str, k: int = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ) -> list[RetrievalResult]:
72
  query_vector = await _get_embeddings().aembed_query(query)
73
  if not all(math.isfinite(v) for v in query_vector):
74
  raise ValueError("Embedding vector contains NaN or Infinity values.")
75
  vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
76
- fetch_k = k + len(_TABULAR_TYPES)
77
-
78
- sql = _COSINE_SQL if _RETRIEVAL_METHOD == "cosine" else _MANHATTAN_SQL
79
-
80
- logger.info("retrieve called", user_id=user_id, collection=_COLLECTION_NAME, fetch_k=fetch_k)
81
 
82
  async with _pgvector_engine.connect() as conn:
83
- result = await conn.execute(sql, {
84
  "embedding": vector_str,
85
  "collection": _COLLECTION_NAME,
86
  "user_id": user_id,
@@ -88,8 +151,6 @@ class DocumentRetriever(BaseRetriever):
88
  })
89
  rows = result.fetchall()
90
 
91
- logger.info("raw rows from db", row_count=len(rows))
92
-
93
  results = []
94
  for row in rows:
95
  file_type = row.cmetadata.get("data", {}).get("file_type", "")
@@ -103,7 +164,7 @@ class DocumentRetriever(BaseRetriever):
103
  if len(results) == k:
104
  break
105
 
106
- logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
107
  return results
108
 
109
 
 
1
+ """DocumentRetriever — dense similarity over prose chunks (Cu).
2
 
3
+ For unstructured sources only (PDF / DOCX / TXT). Backed by PGVector with
4
+ collection `document_embeddings`. Methods: MMR, cosine, euclidean, etc.
 
 
5
  """
6
 
7
  import functools
8
  import math
9
 
10
+ from langchain_postgres import PGVector
11
+ from langchain_postgres.vectorstores import DistanceStrategy
12
  from langchain_openai import AzureOpenAIEmbeddings
13
  from sqlalchemy import text
14
 
15
  from src.config.settings import settings
16
  from src.db.postgres.connection import _pgvector_engine
17
+ from src.db.postgres.vector_store import get_vector_store
18
  from src.middlewares.logging import get_logger
19
  from src.retrieval.base import BaseRetriever, RetrievalResult
20
 
21
  logger = get_logger("document_retriever")
22
 
23
  # Change this one line to switch retrieval method
24
+ # Options: "mmr" | "cosine" | "euclidean" | "inner_product" | "manhattan"
25
+ _RETRIEVAL_METHOD = "mmr"
26
 
27
  _TABULAR_TYPES = {"csv", "xlsx"}
28
+ _FETCH_K = 20
29
+ _LAMBDA_MULT = 0.5
30
+ _COLLECTION_NAME = "document_embeddings"
31
+
32
+ @functools.cache
33
+ def _get_embeddings() -> AzureOpenAIEmbeddings:
34
+ return AzureOpenAIEmbeddings(
35
+ azure_deployment=settings.azureai_deployment_name_embedding,
36
+ openai_api_version=settings.azureai_api_version_embedding,
37
+ azure_endpoint=settings.azureai_endpoint_url_embedding,
38
+ api_key=settings.azureai_api_key_embedding,
39
+ )
40
+
41
+
42
+ @functools.cache
43
+ def _get_euclidean_store() -> PGVector:
44
+ return PGVector(
45
+ embeddings=_get_embeddings(),
46
+ connection=_pgvector_engine,
47
+ collection_name=_COLLECTION_NAME,
48
+ distance_strategy=DistanceStrategy.EUCLIDEAN,
49
+ use_jsonb=True,
50
+ async_mode=True,
51
+ create_extension=False,
52
+ )
53
 
54
+
55
+ @functools.cache
56
+ def _get_ip_store() -> PGVector:
57
+ return PGVector(
58
+ embeddings=_get_embeddings(),
59
+ connection=_pgvector_engine,
60
+ collection_name=_COLLECTION_NAME,
61
+ distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
62
+ use_jsonb=True,
63
+ async_mode=True,
64
+ create_extension=False,
65
+ )
 
66
 
67
  _MANHATTAN_SQL = text("""
68
  SELECT
 
79
  """)
80
 
81
 
 
 
 
 
 
 
 
 
 
 
82
  class DocumentRetriever(BaseRetriever):
83
+ def __init__(self) -> None:
84
+ self.vector_store = get_vector_store()
85
+
86
  async def retrieve(
87
  self, query: str, user_id: str, k: int = 5
88
+ ) -> list[RetrievalResult]:
89
+ filter_ = {"user_id": user_id, "source_type": "document"}
90
+ fetch_k = k + len(_TABULAR_TYPES)
91
+
92
+ if _RETRIEVAL_METHOD == "manhattan":
93
+ return await self._retrieve_manhattan(query, user_id, k, fetch_k)
94
+
95
+ if _RETRIEVAL_METHOD == "mmr":
96
+ docs = await self.vector_store.amax_marginal_relevance_search(
97
+ query=query,
98
+ k=fetch_k,
99
+ fetch_k=_FETCH_K,
100
+ lambda_mult=_LAMBDA_MULT,
101
+ filter=filter_,
102
+ )
103
+ cosine = await self.vector_store.asimilarity_search_with_score(
104
+ query=query, k=fetch_k, filter=filter_,
105
+ )
106
+ score_map = {doc.page_content: score for doc, score in cosine}
107
+ docs_with_scores = [(doc, score_map.get(doc.page_content, 0.0)) for doc in docs]
108
+ elif _RETRIEVAL_METHOD == "euclidean":
109
+ docs_with_scores = await _get_euclidean_store().asimilarity_search_with_score(
110
+ query=query, k=fetch_k, filter=filter_,
111
+ )
112
+ elif _RETRIEVAL_METHOD == "inner_product":
113
+ docs_with_scores = await _get_ip_store().asimilarity_search_with_score(
114
+ query=query, k=fetch_k, filter=filter_,
115
+ )
116
+ else: # cosine
117
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
118
+ query=query, k=fetch_k, filter=filter_,
119
+ )
120
+
121
+ results = []
122
+ for doc, score in docs_with_scores:
123
+ file_type = doc.metadata.get("data", {}).get("file_type", "")
124
+ if file_type not in _TABULAR_TYPES:
125
+ results.append(RetrievalResult(
126
+ content=doc.page_content,
127
+ metadata=doc.metadata,
128
+ score=score,
129
+ source_type="document",
130
+ ))
131
+ if len(results) == k:
132
+ break
133
+
134
+ logger.info("retrieved chunks", method=_RETRIEVAL_METHOD, count=len(results))
135
+ return results
136
+
137
+ async def _retrieve_manhattan(
138
+ self, query: str, user_id: str, k: int, fetch_k: int
139
  ) -> list[RetrievalResult]:
140
  query_vector = await _get_embeddings().aembed_query(query)
141
  if not all(math.isfinite(v) for v in query_vector):
142
  raise ValueError("Embedding vector contains NaN or Infinity values.")
143
  vector_str = "[" + ",".join(str(v) for v in query_vector) + "]"
 
 
 
 
 
144
 
145
  async with _pgvector_engine.connect() as conn:
146
+ result = await conn.execute(_MANHATTAN_SQL, {
147
  "embedding": vector_str,
148
  "collection": _COLLECTION_NAME,
149
  "user_id": user_id,
 
151
  })
152
  rows = result.fetchall()
153
 
 
 
154
  results = []
155
  for row in rows:
156
  file_type = row.cmetadata.get("data", {}).get("file_type", "")
 
164
  if len(results) == k:
165
  break
166
 
167
+ logger.info("retrieved chunks", method="manhattan", count=len(results))
168
  return results
169
 
170
 
uv.lock CHANGED
@@ -1,5 +1,5 @@
1
  version = 1
2
- revision = 2
3
  requires-python = "==3.12.*"
4
  resolution-markers = [
5
  "python_full_version >= '3.12.4'",
@@ -50,6 +50,7 @@ dependencies = [
50
  { name = "pyarrow" },
51
  { name = "pydantic" },
52
  { name = "pydantic-settings" },
 
53
  { name = "pymssql" },
54
  { name = "pymysql" },
55
  { name = "pypdf" },
@@ -137,6 +138,7 @@ requires-dist = [
137
  { name = "pyarrow", specifier = ">=24.0.0" },
138
  { name = "pydantic", specifier = "==2.10.3" },
139
  { name = "pydantic-settings", specifier = "==2.7.0" },
 
140
  { name = "pymssql", specifier = ">=2.3.0" },
141
  { name = "pymysql", specifier = ">=1.1.1" },
142
  { name = "pypdf", specifier = "==5.1.0" },
@@ -2558,6 +2560,27 @@ crypto = [
2558
  { name = "cryptography" },
2559
  ]
2560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2561
  [[package]]
2562
  name = "pymssql"
2563
  version = "2.3.13"
 
1
  version = 1
2
+ revision = 3
3
  requires-python = "==3.12.*"
4
  resolution-markers = [
5
  "python_full_version >= '3.12.4'",
 
50
  { name = "pyarrow" },
51
  { name = "pydantic" },
52
  { name = "pydantic-settings" },
53
+ { name = "pymongo" },
54
  { name = "pymssql" },
55
  { name = "pymysql" },
56
  { name = "pypdf" },
 
138
  { name = "pyarrow", specifier = ">=24.0.0" },
139
  { name = "pydantic", specifier = "==2.10.3" },
140
  { name = "pydantic-settings", specifier = "==2.7.0" },
141
+ { name = "pymongo", specifier = ">=4.14.0" },
142
  { name = "pymssql", specifier = ">=2.3.0" },
143
  { name = "pymysql", specifier = ">=1.1.1" },
144
  { name = "pypdf", specifier = "==5.1.0" },
 
2560
  { name = "cryptography" },
2561
  ]
2562
 
2563
+ [[package]]
2564
+ name = "pymongo"
2565
+ version = "4.16.0"
2566
+ source = { registry = "https://pypi.org/simple" }
2567
+ dependencies = [
2568
+ { name = "dnspython" },
2569
+ ]
2570
+ sdist = { url = "https://files.pythonhosted.org/packages/65/9c/a4895c4b785fc9865a84a56e14b5bd21ca75aadc3dab79c14187cdca189b/pymongo-4.16.0.tar.gz", hash = "sha256:8ba8405065f6e258a6f872fe62d797a28f383a12178c7153c01ed04e845c600c", size = 2495323, upload-time = "2026-01-07T18:05:48.107Z" }
2571
+ wheels = [
2572
+ { url = "https://files.pythonhosted.org/packages/6a/03/6dd7c53cbde98de469a3e6fb893af896dca644c476beb0f0c6342bcc368b/pymongo-4.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bd4911c40a43a821dfd93038ac824b756b6e703e26e951718522d29f6eb166a8", size = 917619, upload-time = "2026-01-07T18:04:19.173Z" },
2573
+ { url = "https://files.pythonhosted.org/packages/73/e1/328915f2734ea1f355dc9b0e98505ff670f5fab8be5e951d6ed70971c6aa/pymongo-4.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25a6b03a68f9907ea6ec8bc7cf4c58a1b51a18e23394f962a6402f8e46d41211", size = 917364, upload-time = "2026-01-07T18:04:20.861Z" },
2574
+ { url = "https://files.pythonhosted.org/packages/41/fe/4769874dd9812a1bc2880a9785e61eba5340da966af888dd430392790ae0/pymongo-4.16.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:91ac0cb0fe2bf17616c2039dac88d7c9a5088f5cb5829b27c9d250e053664d31", size = 1686901, upload-time = "2026-01-07T18:04:22.219Z" },
2575
+ { url = "https://files.pythonhosted.org/packages/fa/8d/15707b9669fdc517bbc552ac60da7124dafe7ac1552819b51e97ed4038b4/pymongo-4.16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf0ec79e8ca7077f455d14d915d629385153b6a11abc0b93283ed73a8013e376", size = 1723034, upload-time = "2026-01-07T18:04:24.055Z" },
2576
+ { url = "https://files.pythonhosted.org/packages/5b/af/3d5d16ff11d447d40c1472da1b366a31c7380d7ea2922a449c7f7f495567/pymongo-4.16.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2d0082631a7510318befc2b4fdab140481eb4b9dd62d9245e042157085da2a70", size = 1797161, upload-time = "2026-01-07T18:04:25.964Z" },
2577
+ { url = "https://files.pythonhosted.org/packages/fb/04/725ab8664eeec73ec125b5a873448d80f5d8cf2750aaaf804cbc538a50a5/pymongo-4.16.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:85dc2f3444c346ea019a371e321ac868a4fab513b7a55fe368f0cc78de8177cc", size = 1780938, upload-time = "2026-01-07T18:04:28.745Z" },
2578
+ { url = "https://files.pythonhosted.org/packages/22/50/dd7e9095e1ca35f93c3c844c92eb6eb0bc491caeb2c9bff3b32fe3c9b18f/pymongo-4.16.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dabbf3c14de75a20cc3c30bf0c6527157224a93dfb605838eabb1a2ee3be008d", size = 1714342, upload-time = "2026-01-07T18:04:30.331Z" },
2579
+ { url = "https://files.pythonhosted.org/packages/03/c9/542776987d5c31ae8e93e92680ea2b6e5a2295f398b25756234cabf38a39/pymongo-4.16.0-cp312-cp312-win32.whl", hash = "sha256:60307bb91e0ab44e560fe3a211087748b2b5f3e31f403baf41f5b7b0a70bd104", size = 887868, upload-time = "2026-01-07T18:04:32.124Z" },
2580
+ { url = "https://files.pythonhosted.org/packages/2e/d4/b4045a7ccc5680fb496d01edf749c7a9367cc8762fbdf7516cf807ef679b/pymongo-4.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:f513b2c6c0d5c491f478422f6b5b5c27ac1af06a54c93ef8631806f7231bd92e", size = 907554, upload-time = "2026-01-07T18:04:33.685Z" },
2581
+ { url = "https://files.pythonhosted.org/packages/60/4c/33f75713d50d5247f2258405142c0318ff32c6f8976171c4fcae87a9dbdf/pymongo-4.16.0-cp312-cp312-win_arm64.whl", hash = "sha256:dfc320f08ea9a7ec5b2403dc4e8150636f0d6150f4b9792faaae539c88e7db3b", size = 892971, upload-time = "2026-01-07T18:04:35.594Z" },
2582
+ ]
2583
+
2584
  [[package]]
2585
  name = "pymssql"
2586
  version = "2.3.13"