File size: 5,607 Bytes
6bff5d9
027123c
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
027123c
6bff5d9
 
 
 
027123c
6bff5d9
 
027123c
 
 
d973099
6bff5d9
 
 
d973099
 
6bff5d9
 
 
027123c
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
027123c
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
027123c
6bff5d9
 
 
 
027123c
6bff5d9
 
027123c
6bff5d9
 
 
 
027123c
6bff5d9
027123c
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""ChatbotAgent — final answer formation. Phase 2 chatbot.

Receives one of:
  - a `QueryResult` (structured query path),
  - a list of document chunks (unstructured path), or
  - nothing (chat-only path: greeting, farewell, meta question).

Streams the answer token-by-token so the chat handler can wrap each token
into an SSE event. Conversation history is supported.
"""

from __future__ import annotations

from collections.abc import AsyncIterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain_openai import AzureChatOpenAI

from src.middlewares.logging import get_logger

from ..query.executor.base import QueryResult

logger = get_logger("chatbot")


_PROMPT_DIR = Path(__file__).resolve().parent.parent / "config" / "prompts"
_SYSTEM_PROMPT_PATH = _PROMPT_DIR / "chatbot_system.md"
_GUARDRAILS_PATH = _PROMPT_DIR / "guardrails.md"


@dataclass
class DocumentChunk:
    """One retrieved document chunk for the unstructured path."""

    content: str
    filename: str | None = None
    page_label: str | None = None


def _load_system_prompt() -> str:
    """Compose system prompt = chatbot_system.md + guardrails.md.

    Guardrails appended last so they take precedence in conflict (matches
    the docstring at the top of guardrails.md).
    """
    chatbot = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8")
    guardrails = _GUARDRAILS_PATH.read_text(encoding="utf-8")
    return f"{chatbot}\n\n{guardrails}"


def _format_query_result(qr: QueryResult) -> str:
    """Render a QueryResult as a compact context block for the LLM."""
    source_label = qr.source_name or "(unknown source)"
    table_label = qr.table_name or "(unknown table)"
    if qr.error:
        return (
            f"[Query result — FAILED]\n"
            f"source: {source_label}\n"
            f"table: {table_label}\n"
            f"error: {qr.error}"
        )
    lines: list[str] = [
        "[Query result]",
        f"source: {source_label}",
        f"table: {table_label}",
        f"backend: {qr.backend}",
        f"row_count: {qr.row_count}"
        + (" (truncated)" if qr.truncated else ""),
        f"elapsed_ms: {qr.elapsed_ms}",
    ]
    if qr.rows:
        # Cap rendering at 25 rows; the LLM doesn't need the full set
        cap = min(len(qr.rows), 25)
        columns = list(qr.rows[0].keys())
        lines.append("columns: " + ", ".join(columns))
        lines.append("rows:")
        for row in qr.rows[:cap]:
            lines.append("  " + ", ".join(f"{k}={row[k]!r}" for k in columns))
        if cap < len(qr.rows):
            lines.append(f"  ... (+{len(qr.rows) - cap} more rows omitted from prompt)")
    return "\n".join(lines)


def _format_document_chunks(chunks: list[DocumentChunk]) -> str:
    if not chunks:
        return ""
    blocks: list[str] = []
    for c in chunks:
        label_parts = [p for p in (c.filename, c.page_label) if p]
        label = ", ".join(label_parts) if label_parts else "Unknown source"
        blocks.append(f"[Source: {label}]\n{c.content}")
    return "\n\n".join(blocks)


def _build_context_block(
    query_result: QueryResult | None,
    chunks: list[DocumentChunk] | None,
) -> str:
    parts: list[str] = []
    if query_result is not None:
        parts.append(_format_query_result(query_result))
    if chunks:
        parts.append(_format_document_chunks(chunks))
    return "\n\n".join(parts) if parts else "(no data context — answer conversationally)"


def _build_default_chain() -> Runnable:
    from src.config.settings import settings

    llm = AzureChatOpenAI(
        azure_deployment=settings.azureai_deployment_name_4o,
        openai_api_version=settings.azureai_api_version_4o,
        azure_endpoint=settings.azureai_endpoint_url_4o,
        api_key=settings.azureai_api_key_4o,
        temperature=0.3,
    )
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", _load_system_prompt()),
            MessagesPlaceholder(variable_name="history", optional=True),
            ("human", "{message}"),
            ("system", "Data context for this turn:\n\n{context}"),
        ]
    )
    return prompt | llm | StrOutputParser()


class ChatbotAgent:
    """Formats and streams the final user-facing answer.

    `chain` is injectable: tests pass a fake that yields canned tokens.
    Default constructs the production Azure OpenAI streaming chain on
    first use.
    """

    def __init__(self, chain: Runnable | None = None) -> None:
        self._chain = chain

    def _ensure_chain(self) -> Runnable:
        if self._chain is None:
            self._chain = _build_default_chain()
        return self._chain

    async def astream(
        self,
        message: str,
        history: list[BaseMessage] | None = None,
        query_result: QueryResult | None = None,
        chunks: list[DocumentChunk] | None = None,
    ) -> AsyncIterator[str]:
        """Stream tokens of the final answer.

        Caller wraps each token into the SSE format. Empty `history` and
        no context = pure chat reply.
        """
        chain = self._ensure_chain()
        payload: dict[str, Any] = {
            "message": message,
            "history": history or [],
            "context": _build_context_block(query_result, chunks),
        }
        async for token in chain.astream(payload):
            yield token