| """ChatbotAgent — final answer formation. Phase 2 chatbot. |
| |
| Receives one of: |
| - a `QueryResult` (structured query path), |
| - a list of document chunks (unstructured path), or |
| - nothing (chat-only path: greeting, farewell, meta question). |
| |
| Streams the answer token-by-token so the chat handler can wrap each token |
| into an SSE event. Conversation history is supported. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import AsyncIterator |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| from langchain_core.messages import BaseMessage |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.runnables import Runnable |
| from langchain_openai import AzureChatOpenAI |
|
|
| from src.middlewares.logging import get_logger |
|
|
| from ..query.executor.base import QueryResult |
|
|
| logger = get_logger("chatbot") |
|
|
|
|
| _PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts" |
| _SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md" |
| _GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md" |
|
|
|
|
| @dataclass |
| class DocumentChunk: |
| """One retrieved document chunk for the unstructured path.""" |
|
|
| content: str |
| filename: str | None = None |
| page_label: str | None = None |
|
|
|
|
| def _load_system_prompt() -> str: |
| """Compose system prompt = chatbot_system.md + guardrails.md. |
| |
| Guardrails appended last so they take precedence in conflict (matches |
| the docstring at the top of guardrails.md). |
| """ |
| chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8") |
| guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8") |
| return f"{chatbot}\n\n{guardrails}" |
|
|
|
|
| def _format_query_result(qr: QueryResult) -> str: |
| """Render a QueryResult as a compact context block for the LLM.""" |
| source_label = qr.source_name or "(unknown source)" |
| table_label = qr.table_name or "(unknown table)" |
| if qr.error: |
| return ( |
| f"[Query result — FAILED]\n" |
| f"source: {source_label}\n" |
| f"table: {table_label}\n" |
| f"error: {qr.error}" |
| ) |
| lines: list[str] = [ |
| "[Query result]", |
| f"source: {source_label}", |
| f"table: {table_label}", |
| f"backend: {qr.backend}", |
| f"row_count: {qr.row_count}" |
| + (" (truncated)" if qr.truncated else ""), |
| f"elapsed_ms: {qr.elapsed_ms}", |
| ] |
| if qr.rows: |
| |
| cap = min(len(qr.rows), 25) |
| columns = list(qr.rows[0].keys()) |
| lines.append("columns: " + ", ".join(columns)) |
| lines.append("rows:") |
| for row in qr.rows[:cap]: |
| lines.append(" " + ", ".join(f"{k}={row[k]!r}" for k in columns)) |
| if cap < len(qr.rows): |
| lines.append(f" ... (+{len(qr.rows) - cap} more rows omitted from prompt)") |
| return "\n".join(lines) |
|
|
|
|
| def _format_document_chunks(chunks: list[DocumentChunk]) -> str: |
| if not chunks: |
| return "" |
| blocks: list[str] = [] |
| for c in chunks: |
| label_parts = [p for p in (c.filename, c.page_label) if p] |
| label = ", ".join(label_parts) if label_parts else "Unknown source" |
| blocks.append(f"[Source: {label}]\n{c.content}") |
| return "\n\n".join(blocks) |
|
|
|
|
| def _build_context_block( |
| query_result: QueryResult | None, |
| chunks: list[DocumentChunk] | None, |
| ) -> str: |
| parts: list[str] = [] |
| if query_result is not None: |
| parts.append(_format_query_result(query_result)) |
| if chunks: |
| parts.append(_format_document_chunks(chunks)) |
| return "\n\n".join(parts) if parts else "(no data context — answer conversationally)" |
|
|
|
|
| def _build_default_chain() -> Runnable: |
| from src.config.settings import settings |
|
|
| llm = AzureChatOpenAI( |
| azure_deployment=settings.azureai_deployment_name_4o, |
| openai_api_version=settings.azureai_api_version_4o, |
| azure_endpoint=settings.azureai_endpoint_url_4o, |
| api_key=settings.azureai_api_key_4o, |
| temperature=0.3, |
| ) |
| prompt = ChatPromptTemplate.from_messages( |
| [ |
| ("system", _load_system_prompt()), |
| MessagesPlaceholder(variable_name="history", optional=True), |
| ("human", "{message}"), |
| ("system", "Data context for this turn:\n\n{context}"), |
| ] |
| ) |
| return prompt | llm | StrOutputParser() |
|
|
|
|
| class ChatbotAgent: |
| """Formats and streams the final user-facing answer. |
| |
| `chain` is injectable: tests pass a fake that yields canned tokens. |
| Default constructs the production Azure OpenAI streaming chain on |
| first use. |
| """ |
|
|
| def __init__(self, chain: Runnable | None = None) -> None: |
| self._chain = chain |
|
|
| def _ensure_chain(self) -> Runnable: |
| if self._chain is None: |
| self._chain = _build_default_chain() |
| return self._chain |
|
|
| async def astream( |
| self, |
| message: str, |
| history: list[BaseMessage] | None = None, |
| query_result: QueryResult | None = None, |
| chunks: list[DocumentChunk] | None = None, |
| ) -> AsyncIterator[str]: |
| """Stream tokens of the final answer. |
| |
| Caller wraps each token into the SSE format. Empty `history` and |
| no context = pure chat reply. |
| """ |
| chain = self._ensure_chain() |
| payload: dict[str, Any] = { |
| "message": message, |
| "history": history or [], |
| "context": _build_context_block(query_result, chunks), |
| } |
| async for token in chain.astream(payload): |
| yield token |
|
|