sofhiaazzhr commited on
Commit
5670888
·
1 Parent(s): 12f8dea

[KM-556] rewire chat stream endpoint to Phase 2 IntentRouter + AnswerAgent + QueryService

Browse files
Files changed (2) hide show
  1. PROGRESS.md +2 -2
  2. src/api/v1/chat.py +62 -112
PROGRESS.md CHANGED
@@ -2,7 +2,7 @@
2
 
3
  Persistent tracker mirroring the 42-item ownership table in `REPO_CONTEXT.md` "Team — division of work". Update as PRs land. Future Claude Code sessions read this to know what's already done.
4
 
5
- **Last updated**: 2026-05-08 (item 41 done; item 16 done; item 31 done; item 35 done — upload endpoint wired to on_tabular_uploaded)
6
  **Current open PR**: none — all Phase 2 contracts shipped on `pr/1`. Cleanup PR pending (API rewiring + Phase 1 removal).
7
 
8
  ---
@@ -119,7 +119,7 @@ Persistent tracker mirroring the 42-item ownership table in `REPO_CONTEXT.md` "T
119
  |---|---|---|---|---|
120
  | 34 | DB client endpoints (`api/v1/db_client.py`) | DB | `[ ]` | Phase 1 endpoint exists — rewire `/ingest` to call `pipeline.triggers.on_db_registered`. Trigger is ready as of PR2a; deferred to a later PR until both teammates ack. |
121
  | 35 | Document/tabular upload endpoints (`api/v1/document.py`) | TAB | `[x]` | Rewired `/document/process` — after processing CSV/XLSX, calls `on_tabular_uploaded(document_id, user_id)`. Catalog ingestion failure is logged but does not fail the request (document already ingested to vector store). |
122
- | 36 | Chat stream endpoint (`api/v1/chat.py`) | B | `[ ]` | Phase 2 handler module ready (`agents/chat_handler.py`); rewiring of the actual `/chat/stream` endpoint deferred to cleanup PR to avoid breaking Phase 1 during the migration. |
123
  | 37 | Room / users endpoints (`api/v1/room.py`, `api/v1/users.py`) | B | `[ ]` | No catalog work; only touch if auth flow changes |
124
 
125
  ### Tests + eval
 
2
 
3
  Persistent tracker mirroring the 42-item ownership table in `REPO_CONTEXT.md` "Team — division of work". Update as PRs land. Future Claude Code sessions read this to know what's already done.
4
 
5
+ **Last updated**: 2026-05-08 (item 41 done; item 16 done; item 31 done; item 35 done; item 36 done chat endpoint rewired to Phase 2 QueryService)
6
  **Current open PR**: none — all Phase 2 contracts shipped on `pr/1`. Cleanup PR pending (API rewiring + Phase 1 removal).
7
 
8
  ---
 
119
  |---|---|---|---|---|
120
  | 34 | DB client endpoints (`api/v1/db_client.py`) | DB | `[ ]` | Phase 1 endpoint exists — rewire `/ingest` to call `pipeline.triggers.on_db_registered`. Trigger is ready as of PR2a; deferred to a later PR until both teammates ack. |
121
  | 35 | Document/tabular upload endpoints (`api/v1/document.py`) | TAB | `[x]` | Rewired `/document/process` — after processing CSV/XLSX, calls `on_tabular_uploaded(document_id, user_id)`. Catalog ingestion failure is logged but does not fail the request (document already ingested to vector store). |
122
+ | 36 | Chat stream endpoint (`api/v1/chat.py`) | B | `[x]` | Rewired `/chat/stream` replaced `query_executor.execute()` (Phase 1) with `CatalogReader + QueryService` (Phase 2). Kept Phase 1 structure: Redis cache, message persistence, fast intent, orchestrator, retriever, chatbot. Only query execution block swapped. |
123
  | 37 | Room / users endpoints (`api/v1/room.py`, `api/v1/users.py`) | B | `[ ]` | No catalog work; only touch if auth flow changes |
124
 
125
  ### Tests + eval
src/api/v1/chat.py CHANGED
@@ -1,17 +1,17 @@
1
  """Chat endpoint with streaming support."""
2
 
3
- import asyncio
4
  import uuid
5
  from fastapi import APIRouter, Depends, HTTPException
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
  from src.db.postgres.connection import get_db
8
  from src.db.postgres.models import ChatMessage, MessageSource
9
- from src.agents.orchestration import orchestrator
10
- from src.agents.chatbot import chatbot
11
  from src.retrieval.router import retrieval_router as retriever
12
  from src.retrieval.base import RetrievalResult
13
- from src.query.query_executor import query_executor
14
- from src.query.base import QueryResult
 
15
  from src.db.redis.connection import get_redis
16
  from src.config.settings import settings
17
  from src.middlewares.logging import get_logger, log_execution
@@ -26,17 +26,16 @@ _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"])
26
  _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"])
27
 
28
 
29
- def _fast_intent(message: str) -> Optional[dict]:
30
- """Bypass LLM orchestrator for obvious greetings and farewells."""
31
  lower = message.lower().strip().rstrip("!.,?")
32
  if lower in _GREETINGS:
33
- return {"intent": "greeting", "needs_search": False,
34
- "direct_response": "Hello! How can I assist you today?", "search_query": ""}
35
  if lower in _GOODBYES:
36
- return {"intent": "goodbye", "needs_search": False,
37
- "direct_response": "Goodbye! Have a great day!", "search_query": ""}
38
  return None
39
 
 
40
  logger = get_logger("chat_api")
41
 
42
  router = APIRouter(prefix="/api/v1", tags=["Chat"])
@@ -48,18 +47,6 @@ class ChatRequest(BaseModel):
48
  message: str
49
 
50
 
51
- def _format_context(results: List[RetrievalResult]) -> str:
52
- """Format retrieval results as context string for the LLM."""
53
- lines = []
54
- for result in results:
55
- data = result.metadata.get("data", {})
56
- filename = data.get("filename", "Unknown")
57
- page = data.get("page_label")
58
- source_label = f"{filename}, p.{page}" if page else filename
59
- lines.append(f"[Source: {source_label}]\n{result.content}\n")
60
- return "\n".join(lines)
61
-
62
-
63
  def _extract_sources(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
64
  """Extract deduplicated source references from retrieval results."""
65
  seen = set()
@@ -87,25 +74,22 @@ def _extract_sources(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
87
  "filename": data.get("table_name", "Unknown"),
88
  "page_label": data.get("column_name", "Unknown"),
89
  })
90
-
91
  logger.debug(f"Extracted sources: {sources}")
92
  return sources
93
 
94
 
95
- def _format_query_results(results: list[QueryResult]) -> str:
96
- if not results:
97
- return ""
98
- lines = []
99
  for r in results:
100
- name = r.metadata.get("client_name", r.source_id)
101
- lines.append(f"[Query result — {name}, tables: {r.table_or_file}]")
102
- lines.append(f"SQL: {r.metadata.get('sql', '')}")
103
- if r.columns and r.rows:
104
- lines.append(" | ".join(r.columns))
105
- for row in r.rows[:20]:
106
- lines.append(" | ".join(str(row.get(c, "")) for c in r.columns))
107
- lines.append(f"({r.row_count} rows total)\n")
108
- return "\n".join(lines)
109
 
110
 
111
  async def get_cached_response(redis, cache_key: str) -> Optional[str]:
@@ -168,8 +152,9 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
168
  3. done — signals end of stream
169
  """
170
  redis = await get_redis()
171
-
172
  cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}"
 
 
173
  cached = await get_cached_response(redis, cache_key)
174
  if cached:
175
  logger.info("Returning cached response")
@@ -183,91 +168,56 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
183
  return EventSourceResponse(stream_cached())
184
 
185
  try:
186
- # Step 1: Fast local intent check (skips LLM for greetings/farewells)
187
- intent_result = _fast_intent(request.message)
188
-
189
- context = ""
190
- sources: List[Dict[str, Any]] = []
191
-
192
- if intent_result is None:
193
- # Step 2: Launch retrieval and history loading in parallel, then run orchestrator.
194
- # k=5
195
- # tables — db_executor's FK expansion is one-hop and cannot bridge
196
- # 2-hop gaps (e.g. customers -> order_items -> products) on its own.
197
- retrieval_task = asyncio.create_task(
198
- retriever.retrieve(request.message, request.user_id, db, k=5)
199
- )
200
- history_task = asyncio.create_task(
201
- load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
202
- )
203
- history = await history_task # fast DB query (<100ms), done before orchestrator finishes
204
- intent_result = await orchestrator.analyze_message(request.message, history)
205
-
206
- search_query = intent_result.get("search_query", request.message) or request.message
207
- if not intent_result.get("needs_search"):
208
- retrieval_task.cancel()
209
- try:
210
- await retrieval_task
211
- except asyncio.CancelledError:
212
- pass
213
- raw_results = []
214
- else:
215
- logger.info(f"Searching for: {search_query}")
216
- if search_query != request.message:
217
- retrieval_task.cancel()
218
- try:
219
- await retrieval_task
220
- except asyncio.CancelledError:
221
- pass
222
- raw_results = await retriever.retrieve(
223
- query=search_query,
224
- user_id=request.user_id,
225
- db=db,
226
- k=5,
227
- source_hint=intent_result.get("source_hint", "both"),
228
- )
229
- else:
230
- raw_results = await retrieval_task
231
-
232
- context = _format_context(raw_results)
233
- sources = _extract_sources(raw_results)
234
-
235
- source_hint = intent_result.get("source_hint", "both")
236
- if source_hint in ("schema", "both"):
237
- # Use search_query (orchestrator's standalone rewrite) so follow-up
238
- # messages like "dive deeper" or "show me last year" resolve correctly.
239
- # For first-turn questions search_query == request.message, so no change.
240
- query_results = await query_executor.execute(
241
- results=raw_results,
242
- user_id=request.user_id,
243
- db=db,
244
- question=search_query,
245
- )
246
- query_context = _format_query_results(query_results)
247
- if query_context:
248
- context = query_context + "\n\n" + context
249
-
250
- # Step 3: Direct response for greetings / non-document intents
251
- if intent_result.get("direct_response"):
252
- response = intent_result["direct_response"]
253
- await cache_response(redis, cache_key, response)
254
- await save_messages(db, request.room_id, request.message, response, sources=[])
255
 
256
  async def stream_direct():
257
  yield {"event": "sources", "data": json.dumps([])}
258
- yield {"event": "message", "data": response}
259
 
260
  return EventSourceResponse(stream_direct())
261
 
262
- # Step 4: Stream answer token-by-token as LLM generates it
263
- # Load full history (10 msgs) for chatbot — richer context than the 6 used by orchestrator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  full_history = await load_history(db, request.room_id, limit=10)
265
- messages = full_history + [HumanMessage(content=request.message)]
266
 
 
267
  async def stream_response():
268
  full_response = ""
269
  yield {"event": "sources", "data": json.dumps(sources)}
270
- async for token in chatbot.astream_response(messages, context):
 
 
 
 
 
271
  full_response += token
272
  yield {"event": "chunk", "data": token}
273
  yield {"event": "done", "data": ""}
 
1
  """Chat endpoint with streaming support."""
2
 
 
3
  import uuid
4
  from fastapi import APIRouter, Depends, HTTPException
5
  from sqlalchemy.ext.asyncio import AsyncSession
6
  from src.db.postgres.connection import get_db
7
  from src.db.postgres.models import ChatMessage, MessageSource
8
+ from src.agents.intent_router import IntentRouter
9
+ from src.agents.answer_agent import AnswerAgent, DocumentChunk
10
  from src.retrieval.router import retrieval_router as retriever
11
  from src.retrieval.base import RetrievalResult
12
+ from src.catalog.reader import CatalogReader
13
+ from src.catalog.store import CatalogStore
14
+ from src.query.service import QueryService
15
  from src.db.redis.connection import get_redis
16
  from src.config.settings import settings
17
  from src.middlewares.logging import get_logger, log_execution
 
26
  _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"])
27
 
28
 
29
+ def _fast_intent(message: str) -> Optional[str]:
30
+ """Return a direct response string for obvious greetings/farewells, else None."""
31
  lower = message.lower().strip().rstrip("!.,?")
32
  if lower in _GREETINGS:
33
+ return "Hello! How can I assist you today?"
 
34
  if lower in _GOODBYES:
35
+ return "Goodbye! Have a great day!"
 
36
  return None
37
 
38
+
39
  logger = get_logger("chat_api")
40
 
41
  router = APIRouter(prefix="/api/v1", tags=["Chat"])
 
47
  message: str
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def _extract_sources(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
51
  """Extract deduplicated source references from retrieval results."""
52
  seen = set()
 
74
  "filename": data.get("table_name", "Unknown"),
75
  "page_label": data.get("column_name", "Unknown"),
76
  })
 
77
  logger.debug(f"Extracted sources: {sources}")
78
  return sources
79
 
80
 
81
+ def _to_document_chunks(results: List[RetrievalResult]) -> List[DocumentChunk]:
82
+ """Convert Phase 1 RetrievalResult list to Phase 2 DocumentChunk list."""
83
+ chunks = []
 
84
  for r in results:
85
+ data = r.metadata.get("data", {})
86
+ page = data.get("page_label")
87
+ chunks.append(DocumentChunk(
88
+ content=r.content,
89
+ filename=data.get("filename"),
90
+ page_label=str(page) if page is not None else None,
91
+ ))
92
+ return chunks
 
93
 
94
 
95
  async def get_cached_response(redis, cache_key: str) -> Optional[str]:
 
152
  3. done — signals end of stream
153
  """
154
  redis = await get_redis()
 
155
  cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}"
156
+
157
+ # Redis cache hit
158
  cached = await get_cached_response(redis, cache_key)
159
  if cached:
160
  logger.info("Returning cached response")
 
168
  return EventSourceResponse(stream_cached())
169
 
170
  try:
171
+ # Fast intent: greetings/farewells bypass LLM entirely
172
+ direct = _fast_intent(request.message)
173
+ if direct:
174
+ await cache_response(redis, cache_key, direct)
175
+ await save_messages(db, request.room_id, request.message, direct, sources=[])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  async def stream_direct():
178
  yield {"event": "sources", "data": json.dumps([])}
179
+ yield {"event": "message", "data": direct}
180
 
181
  return EventSourceResponse(stream_direct())
182
 
183
+ # Load history for intent routing
184
+ history = await load_history(db, request.room_id, limit=6)
185
+
186
+ # Phase 2: IntentRouter classifies message
187
+ decision = await IntentRouter().classify(request.message, history)
188
+ rewritten = decision.rewritten_query or request.message
189
+
190
+ query_result = None
191
+ chunks: List[DocumentChunk] | None = None
192
+ sources: List[Dict[str, Any]] = []
193
+
194
+ if decision.source_hint == "structured":
195
+ catalog = await CatalogReader(CatalogStore()).read(request.user_id, "structured")
196
+ query_result = await QueryService().run(request.user_id, rewritten, catalog)
197
+
198
+ elif decision.source_hint == "unstructured":
199
+ raw_results = await retriever.retrieve(
200
+ query=rewritten,
201
+ user_id=request.user_id,
202
+ db=db,
203
+ k=5,
204
+ )
205
+ chunks = _to_document_chunks(raw_results)
206
+ sources = _extract_sources(raw_results)
207
+
208
+ # Load full history for answer generation
209
  full_history = await load_history(db, request.room_id, limit=10)
 
210
 
211
+ # Phase 2: AnswerAgent streams answer tokens
212
  async def stream_response():
213
  full_response = ""
214
  yield {"event": "sources", "data": json.dumps(sources)}
215
+ async for token in AnswerAgent().astream(
216
+ request.message,
217
+ history=full_history,
218
+ query_result=query_result,
219
+ chunks=chunks,
220
+ ):
221
  full_response += token
222
  yield {"event": "chunk", "data": token}
223
  yield {"event": "done", "data": ""}