ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""ChatbotAgent — final answer formation. Phase 2 chatbot.
Receives one of:
- a `QueryResult` (structured query path),
- a list of document chunks (unstructured path), or
- nothing (chat-only path: greeting, farewell, meta question).
Streams the answer token-by-token so the chat handler can wrap each token
into an SSE event. Conversation history is supported.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain_openai import AzureChatOpenAI
from src.middlewares.logging import get_logger
from ..query.executor.base import QueryResult
logger = get_logger("chatbot")
_PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts"
_SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md"
_GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md"
@dataclass
class DocumentChunk:
"""One retrieved document chunk for the unstructured path."""
content: str
filename: str | None = None
page_label: str | None = None
def _load_system_prompt() -> str:
"""Compose system prompt = chatbot_system.md + guardrails.md.
Guardrails appended last so they take precedence in conflict (matches
the docstring at the top of guardrails.md).
"""
chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8")
guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8")
return f"{chatbot}\n\n{guardrails}"
def _format_query_result(qr: QueryResult) -> str:
"""Render a QueryResult as a compact context block for the LLM."""
source_label = qr.source_name or "(unknown source)"
table_label = qr.table_name or "(unknown table)"
if qr.error:
return (
f"[Query result — FAILED]\n"
f"source: {source_label}\n"
f"table: {table_label}\n"
f"error: {qr.error}"
)
lines: list[str] = [
"[Query result]",
f"source: {source_label}",
f"table: {table_label}",
f"backend: {qr.backend}",
f"row_count: {qr.row_count}"
+ (" (truncated)" if qr.truncated else ""),
f"elapsed_ms: {qr.elapsed_ms}",
]
if qr.rows:
# Cap rendering at 25 rows; the LLM doesn't need the full set
cap = min(len(qr.rows), 25)
columns = list(qr.rows[0].keys())
lines.append("columns: " + ", ".join(columns))
lines.append("rows:")
for row in qr.rows[:cap]:
lines.append(" " + ", ".join(f"{k}={row[k]!r}" for k in columns))
if cap < len(qr.rows):
lines.append(f" ... (+{len(qr.rows) - cap} more rows omitted from prompt)")
return "\n".join(lines)
def _format_document_chunks(chunks: list[DocumentChunk]) -> str:
if not chunks:
return ""
blocks: list[str] = []
for c in chunks:
label_parts = [p for p in (c.filename, c.page_label) if p]
label = ", ".join(label_parts) if label_parts else "Unknown source"
blocks.append(f"[Source: {label}]\n{c.content}")
return "\n\n".join(blocks)
def _build_context_block(
query_result: QueryResult | None,
chunks: list[DocumentChunk] | None,
) -> str:
parts: list[str] = []
if query_result is not None:
parts.append(_format_query_result(query_result))
if chunks:
parts.append(_format_document_chunks(chunks))
return "\n\n".join(parts) if parts else "(no data context — answer conversationally)"
def _build_default_chain() -> Runnable:
from src.config.settings import settings
llm = AzureChatOpenAI(
azure_deployment=settings.azureai_deployment_name_4o,
openai_api_version=settings.azureai_api_version_4o,
azure_endpoint=settings.azureai_endpoint_url_4o,
api_key=settings.azureai_api_key_4o,
temperature=0.3,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", _load_system_prompt()),
MessagesPlaceholder(variable_name="history", optional=True),
("human", "{message}"),
("system", "Data context for this turn:\n\n{context}"),
]
)
return prompt | llm | StrOutputParser()
class ChatbotAgent:
"""Formats and streams the final user-facing answer.
`chain` is injectable: tests pass a fake that yields canned tokens.
Default constructs the production Azure OpenAI streaming chain on
first use.
"""
def __init__(self, chain: Runnable | None = None) -> None:
self._chain = chain
def _ensure_chain(self) -> Runnable:
if self._chain is None:
self._chain = _build_default_chain()
return self._chain
async def astream(
self,
message: str,
history: list[BaseMessage] | None = None,
query_result: QueryResult | None = None,
chunks: list[DocumentChunk] | None = None,
) -> AsyncIterator[str]:
"""Stream tokens of the final answer.
Caller wraps each token into the SSE format. Empty `history` and
no context = pure chat reply.
"""
chain = self._ensure_chain()
payload: dict[str, Any] = {
"message": message,
"history": history or [],
"context": _build_context_block(query_result, chunks),
}
async for token in chain.astream(payload):
yield token