File size: 5,607 Bytes
6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c d973099 6bff5d9 d973099 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 027123c 6bff5d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """ChatbotAgent — final answer formation. Phase 2 chatbot.
Receives one of:
- a `QueryResult` (structured query path),
- a list of document chunks (unstructured path), or
- nothing (chat-only path: greeting, farewell, meta question).
Streams the answer token-by-token so the chat handler can wrap each token
into an SSE event. Conversation history is supported.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain_openai import AzureChatOpenAI
from src.middlewares.logging import get_logger
from ..query.executor.base import QueryResult
logger = get_logger("chatbot")
_PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts"
_SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md"
_GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md"
@dataclass
class DocumentChunk:
"""One retrieved document chunk for the unstructured path."""
content: str
filename: str | None = None
page_label: str | None = None
def _load_system_prompt() -> str:
"""Compose system prompt = chatbot_system.md + guardrails.md.
Guardrails appended last so they take precedence in conflict (matches
the docstring at the top of guardrails.md).
"""
chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8")
guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8")
return f"{chatbot}\n\n{guardrails}"
def _format_query_result(qr: QueryResult) -> str:
"""Render a QueryResult as a compact context block for the LLM."""
source_label = qr.source_name or "(unknown source)"
table_label = qr.table_name or "(unknown table)"
if qr.error:
return (
f"[Query result — FAILED]\n"
f"source: {source_label}\n"
f"table: {table_label}\n"
f"error: {qr.error}"
)
lines: list[str] = [
"[Query result]",
f"source: {source_label}",
f"table: {table_label}",
f"backend: {qr.backend}",
f"row_count: {qr.row_count}"
+ (" (truncated)" if qr.truncated else ""),
f"elapsed_ms: {qr.elapsed_ms}",
]
if qr.rows:
# Cap rendering at 25 rows; the LLM doesn't need the full set
cap = min(len(qr.rows), 25)
columns = list(qr.rows[0].keys())
lines.append("columns: " + ", ".join(columns))
lines.append("rows:")
for row in qr.rows[:cap]:
lines.append(" " + ", ".join(f"{k}={row[k]!r}" for k in columns))
if cap < len(qr.rows):
lines.append(f" ... (+{len(qr.rows) - cap} more rows omitted from prompt)")
return "\n".join(lines)
def _format_document_chunks(chunks: list[DocumentChunk]) -> str:
if not chunks:
return ""
blocks: list[str] = []
for c in chunks:
label_parts = [p for p in (c.filename, c.page_label) if p]
label = ", ".join(label_parts) if label_parts else "Unknown source"
blocks.append(f"[Source: {label}]\n{c.content}")
return "\n\n".join(blocks)
def _build_context_block(
query_result: QueryResult | None,
chunks: list[DocumentChunk] | None,
) -> str:
parts: list[str] = []
if query_result is not None:
parts.append(_format_query_result(query_result))
if chunks:
parts.append(_format_document_chunks(chunks))
return "\n\n".join(parts) if parts else "(no data context — answer conversationally)"
def _build_default_chain() -> Runnable:
from src.config.settings import settings
llm = AzureChatOpenAI(
azure_deployment=settings.azureai_deployment_name_4o,
openai_api_version=settings.azureai_api_version_4o,
azure_endpoint=settings.azureai_endpoint_url_4o,
api_key=settings.azureai_api_key_4o,
temperature=0.3,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", _load_system_prompt()),
MessagesPlaceholder(variable_name="history", optional=True),
("human", "{message}"),
("system", "Data context for this turn:\n\n{context}"),
]
)
return prompt | llm | StrOutputParser()
class ChatbotAgent:
"""Formats and streams the final user-facing answer.
`chain` is injectable: tests pass a fake that yields canned tokens.
Default constructs the production Azure OpenAI streaming chain on
first use.
"""
def __init__(self, chain: Runnable | None = None) -> None:
self._chain = chain
def _ensure_chain(self) -> Runnable:
if self._chain is None:
self._chain = _build_default_chain()
return self._chain
async def astream(
self,
message: str,
history: list[BaseMessage] | None = None,
query_result: QueryResult | None = None,
chunks: list[DocumentChunk] | None = None,
) -> AsyncIterator[str]:
"""Stream tokens of the final answer.
Caller wraps each token into the SSE format. Empty `history` and
no context = pure chat reply.
"""
chain = self._ensure_chain()
payload: dict[str, Any] = {
"message": message,
"history": history or [],
"context": _build_context_block(query_result, chunks),
}
async for token in chain.astream(payload):
yield token
|