Rifqi Hafizuddin
[KM-564] fix source, now shows name instead of id. added diff retrieval vs catalog
96598f8 | """ChatbotAgent — final answer formation. Phase 2 chatbot. | |
| Receives one of: | |
| - a `QueryResult` (structured query path), | |
| - a list of document chunks (unstructured path), or | |
| - nothing (chat-only path: greeting, farewell, meta question). | |
| Streams the answer token-by-token so the chat handler can wrap each token | |
| into an SSE event. Conversation history is supported. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import AsyncIterator | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import Runnable | |
| from langchain_openai import AzureChatOpenAI | |
| from src.middlewares.logging import get_logger | |
| from ..query.executor.base import QueryResult | |
| logger = get_logger("chatbot") | |
| _PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts" | |
| _SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md" | |
| _GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md" | |
| class DocumentChunk: | |
| """One retrieved document chunk for the unstructured path.""" | |
| content: str | |
| filename: str | None = None | |
| page_label: str | None = None | |
| def _load_system_prompt() -> str: | |
| """Compose system prompt = chatbot_system.md + guardrails.md. | |
| Guardrails appended last so they take precedence in conflict (matches | |
| the docstring at the top of guardrails.md). | |
| """ | |
| chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8") | |
| guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8") | |
| return f"{chatbot}\n\n{guardrails}" | |
| def _format_query_result(qr: QueryResult) -> str: | |
| """Render a QueryResult as a compact context block for the LLM.""" | |
| source_label = qr.source_name or "(unknown source)" | |
| table_label = qr.table_name or "(unknown table)" | |
| if qr.error: | |
| return ( | |
| f"[Query result — FAILED]\n" | |
| f"source: {source_label}\n" | |
| f"table: {table_label}\n" | |
| f"error: {qr.error}" | |
| ) | |
| lines: list[str] = [ | |
| "[Query result]", | |
| f"source: {source_label}", | |
| f"table: {table_label}", | |
| f"backend: {qr.backend}", | |
| f"row_count: {qr.row_count}" | |
| + (" (truncated)" if qr.truncated else ""), | |
| f"elapsed_ms: {qr.elapsed_ms}", | |
| ] | |
| if qr.rows: | |
| # Cap rendering at 25 rows; the LLM doesn't need the full set | |
| cap = min(len(qr.rows), 25) | |
| columns = list(qr.rows[0].keys()) | |
| lines.append("columns: " + ", ".join(columns)) | |
| lines.append("rows:") | |
| for row in qr.rows[:cap]: | |
| lines.append(" " + ", ".join(f"{k}={row[k]!r}" for k in columns)) | |
| if cap < len(qr.rows): | |
| lines.append(f" ... (+{len(qr.rows) - cap} more rows omitted from prompt)") | |
| return "\n".join(lines) | |
| def _format_document_chunks(chunks: list[DocumentChunk]) -> str: | |
| if not chunks: | |
| return "" | |
| blocks: list[str] = [] | |
| for c in chunks: | |
| label_parts = [p for p in (c.filename, c.page_label) if p] | |
| label = ", ".join(label_parts) if label_parts else "Unknown source" | |
| blocks.append(f"[Source: {label}]\n{c.content}") | |
| return "\n\n".join(blocks) | |
| def _build_context_block( | |
| query_result: QueryResult | None, | |
| chunks: list[DocumentChunk] | None, | |
| ) -> str: | |
| parts: list[str] = [] | |
| if query_result is not None: | |
| parts.append(_format_query_result(query_result)) | |
| if chunks: | |
| parts.append(_format_document_chunks(chunks)) | |
| return "\n\n".join(parts) if parts else "(no data context — answer conversationally)" | |
| def _build_default_chain() -> Runnable: | |
| from src.config.settings import settings | |
| llm = AzureChatOpenAI( | |
| azure_deployment=settings.azureai_deployment_name_4o, | |
| openai_api_version=settings.azureai_api_version_4o, | |
| azure_endpoint=settings.azureai_endpoint_url_4o, | |
| api_key=settings.azureai_api_key_4o, | |
| temperature=0.3, | |
| ) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", _load_system_prompt()), | |
| MessagesPlaceholder(variable_name="history", optional=True), | |
| ("human", "{message}"), | |
| ("system", "Data context for this turn:\n\n{context}"), | |
| ] | |
| ) | |
| return prompt | llm | StrOutputParser() | |
| class ChatbotAgent: | |
| """Formats and streams the final user-facing answer. | |
| `chain` is injectable: tests pass a fake that yields canned tokens. | |
| Default constructs the production Azure OpenAI streaming chain on | |
| first use. | |
| """ | |
| def __init__(self, chain: Runnable | None = None) -> None: | |
| self._chain = chain | |
| def _ensure_chain(self) -> Runnable: | |
| if self._chain is None: | |
| self._chain = _build_default_chain() | |
| return self._chain | |
| async def astream( | |
| self, | |
| message: str, | |
| history: list[BaseMessage] | None = None, | |
| query_result: QueryResult | None = None, | |
| chunks: list[DocumentChunk] | None = None, | |
| ) -> AsyncIterator[str]: | |
| """Stream tokens of the final answer. | |
| Caller wraps each token into the SSE format. Empty `history` and | |
| no context = pure chat reply. | |
| """ | |
| chain = self._ensure_chain() | |
| payload: dict[str, Any] = { | |
| "message": message, | |
| "history": history or [], | |
| "context": _build_context_block(query_result, chunks), | |
| } | |
| async for token in chain.astream(payload): | |
| yield token | |