"""ChatbotAgent — final answer formation. Phase 2 chatbot. Receives one of: - a `QueryResult` (structured query path), - a list of document chunks (unstructured path), or - nothing (chat-only path: greeting, farewell, meta question). Streams the answer token-by-token so the chat handler can wrap each token into an SSE event. Conversation history is supported. """ from __future__ import annotations from collections.abc import AsyncIterator from dataclasses import dataclass from pathlib import Path from typing import Any from langchain_core.messages import BaseMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable from langchain_openai import AzureChatOpenAI from src.middlewares.logging import get_logger from ..query.executor.base import QueryResult logger = get_logger("chatbot") _PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts" _SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md" _GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md" @dataclass class DocumentChunk: """One retrieved document chunk for the unstructured path.""" content: str filename: str | None = None page_label: str | None = None def _load_system_prompt() -> str: """Compose system prompt = chatbot_system.md + guardrails.md. Guardrails appended last so they take precedence in conflict (matches the docstring at the top of guardrails.md). """ chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8") guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8") return f"{chatbot}\n\n{guardrails}" def _format_query_result(qr: QueryResult) -> str: """Render a QueryResult as a compact context block for the LLM.""" source_label = qr.source_name or "(unknown source)" table_label = qr.table_name or "(unknown table)" if qr.error: return ( f"[Query result — FAILED]\n" f"source: {source_label}\n" f"table: {table_label}\n" f"error: {qr.error}" ) lines: list[str] = [ "[Query result]", f"source: {source_label}", f"table: {table_label}", f"backend: {qr.backend}", f"row_count: {qr.row_count}" + (" (truncated)" if qr.truncated else ""), f"elapsed_ms: {qr.elapsed_ms}", ] if qr.rows: # Cap rendering at 25 rows; the LLM doesn't need the full set cap = min(len(qr.rows), 25) columns = list(qr.rows[0].keys()) lines.append("columns: " + ", ".join(columns)) lines.append("rows:") for row in qr.rows[:cap]: lines.append(" " + ", ".join(f"{k}={row[k]!r}" for k in columns)) if cap < len(qr.rows): lines.append(f" ... (+{len(qr.rows) - cap} more rows omitted from prompt)") return "\n".join(lines) def _format_document_chunks(chunks: list[DocumentChunk]) -> str: if not chunks: return "" blocks: list[str] = [] for c in chunks: label_parts = [p for p in (c.filename, c.page_label) if p] label = ", ".join(label_parts) if label_parts else "Unknown source" blocks.append(f"[Source: {label}]\n{c.content}") return "\n\n".join(blocks) def _build_context_block( query_result: QueryResult | None, chunks: list[DocumentChunk] | None, ) -> str: parts: list[str] = [] if query_result is not None: parts.append(_format_query_result(query_result)) if chunks: parts.append(_format_document_chunks(chunks)) return "\n\n".join(parts) if parts else "(no data context — answer conversationally)" def _build_default_chain() -> Runnable: from src.config.settings import settings llm = AzureChatOpenAI( azure_deployment=settings.azureai_deployment_name_4o, openai_api_version=settings.azureai_api_version_4o, azure_endpoint=settings.azureai_endpoint_url_4o, api_key=settings.azureai_api_key_4o, temperature=0.3, ) prompt = ChatPromptTemplate.from_messages( [ ("system", _load_system_prompt()), MessagesPlaceholder(variable_name="history", optional=True), ("human", "{message}"), ("system", "Data context for this turn:\n\n{context}"), ] ) return prompt | llm | StrOutputParser() class ChatbotAgent: """Formats and streams the final user-facing answer. `chain` is injectable: tests pass a fake that yields canned tokens. Default constructs the production Azure OpenAI streaming chain on first use. """ def __init__(self, chain: Runnable | None = None) -> None: self._chain = chain def _ensure_chain(self) -> Runnable: if self._chain is None: self._chain = _build_default_chain() return self._chain async def astream( self, message: str, history: list[BaseMessage] | None = None, query_result: QueryResult | None = None, chunks: list[DocumentChunk] | None = None, ) -> AsyncIterator[str]: """Stream tokens of the final answer. Caller wraps each token into the SSE format. Empty `history` and no context = pure chat reply. """ chain = self._ensure_chain() payload: dict[str, Any] = { "message": message, "history": history or [], "context": _build_context_block(query_result, chunks), } async for token in chain.astream(payload): yield token