File size: 11,675 Bytes
d7c0645
 
 
 
 
 
b784540
d7c0645
b9b60b0
ba9644b
d7c0645
4eb7b2b
d7c0645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac6bb5
 
d7c0645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b784540
 
 
 
d7c0645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac6bb5
d7c0645
 
 
 
 
 
 
 
 
 
 
 
 
4eb7b2b
 
d7c0645
 
 
4eb7b2b
d7c0645
 
 
4eb7b2b
 
 
 
 
 
b9b60b0
4eb7b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c0645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba9644b
 
 
 
 
 
 
b784540
 
 
 
ba9644b
 
 
 
 
 
 
 
 
d7c0645
 
 
 
b784540
 
 
 
 
 
 
 
 
d7c0645
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from __future__ import annotations

import json
import logging
from typing import Any, Literal

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

from src.conversation_memory import get_context
from src.llm import llm, multimodal_llm
from src.prompts import (
    direct_response_prompt,
    final_response_prompt,
    orchestrator_prompt,
    respond_prompt,
    router_prompt,
)
from src.state import MAX_ITERS, AgentState, QueryType
from src.tools.chat_tools import TOOL_MAP, TOOLS

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
)
logger = logging.getLogger("agent_router")

_llm_with_tools = llm.bind_tools(TOOLS)

VALID_LABELS:  set[str]  = {"respond", "summary"}
DEFAULT_LABEL: QueryType = "respond"


# ════════════════════════════════════════════════════════════════════
# NODE 1 — RouterNode
# ════════════════════════════════════════════════════════════════════
def router_node(state: AgentState) -> AgentState:
    """Validate state, gọi LLM để phân loại query → query_type."""
    if not state.get("raw_query"):
        raise ValueError("AgentState thiếu trường 'raw_query'.")
    if not state.get("sender_id"):
        raise ValueError("AgentState thiếu trường 'sender_id'.")
    if not state.get("conversation_id"):
        raise ValueError("AgentState thiếu trường 'conversation_id'.")

    logger.info("[RouterNode] Phân loại query từ %s", state["sender_id"])

    chain     = router_prompt | llm
    response  = chain.invoke({
        "sender_id": state["sender_id"],
        "time":      state.get("time", ""),
        "raw_query": state["raw_query"],
    })

    raw_label  = response.content.strip().lower()
    query_type: QueryType = raw_label if raw_label in VALID_LABELS else DEFAULT_LABEL  # type: ignore

    if raw_label not in VALID_LABELS:
        logger.warning(
            "[RouterNode] Nhãn không hợp lệ '%s', fallback → '%s'",
            raw_label, DEFAULT_LABEL,
        )

    logger.info("[RouterNode] query_type = %s", query_type)
    return {**state, "query_type": query_type, "messages": [HumanMessage(content=state["raw_query"])]}


def route_decision(state: AgentState) -> Literal["respond", "orchestrator"]:
    """Conditional edge: router → respond | orchestrator."""
    return "respond" if state.get("query_type") == "respond" else "orchestrator"


# ════════════════════════════════════════════════════════════════════
# NODE 2 — LLMResponseNode (luồng respond)
# ════════════════════════════════════════════════════════════════════
def llm_response_node(state: AgentState) -> AgentState:
    """Trả lời trực tiếp câu hỏi không cần tra cứu lịch sử."""
    logger.info("[LLMResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])

    msgs = respond_prompt.format_messages(sender_id=state["sender_id"], raw_query=state["raw_query"])
    if cp := state.get("custom_prompt"):
        msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
    response = llm.invoke(msgs)
    answer   = response.content

    logger.info("[LLMResponseNode] Đã sinh câu trả lời (%d ký tự)", len(answer))
    return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}


# ════════════════════════════════════════════════════════════════════
# NODE 3 — OrchestratorNode
# ════════════════════════════════════════════════════════════════════
def orchestrator_node(state: AgentState) -> AgentState:
    """Gọi LLM (đã bind tool) để sinh tool_call. Dừng nếu đạt max_iters."""
    iters     = state.get("iters", 0)
    max_iters = state.get("max_iters", MAX_ITERS)

    logger.info("[OrchestratorNode] iters=%d / max_iters=%d", iters, max_iters)

    if iters >= max_iters:
        logger.info("[OrchestratorNode] Đã đạt max_iters, bỏ qua gọi tool.")
        return {**state, "iters": iters}

    prompt_value = orchestrator_prompt.format_messages(
        sender_id=state["sender_id"],
        conversation_id=state.get("conversation_id", ""),
        time=state.get("time", ""),
        raw_query=state["raw_query"],
    )

    response: AIMessage = _llm_with_tools.invoke(prompt_value)  # type: ignore
    logger.info(
        "[OrchestratorNode] tool_calls=%s",
        [tc["name"] for tc in (response.tool_calls or [])],
    )

    return {**state, "messages": [response], "iters": iters + 1}


def should_call_tool(state: AgentState) -> Literal["tool_node", "direct_response"]:
    """Conditional edge: orchestrator → tool_node | direct_response."""
    last_msg = state["messages"][-1] if state["messages"] else None
    if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
        return "tool_node"
    return "direct_response"


# ════════════════════════════════════════════════════════════════════
# NODE 4 — DirectResponseNode
# ════════════════════════════════════════════════════════════════════
def direct_response_node(state: AgentState) -> AgentState:
    """Trả lời trực tiếp bằng LLM, kèm lịch sử chat từ Redis làm context."""
    logger.info("[DirectResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])

    chat_history = get_context(state["conversation_id"])

    msgs = direct_response_prompt.format_messages(
        sender_id=state["sender_id"],
        raw_query=state["raw_query"],
        chat_history=chat_history,
    )
    if cp := state.get("custom_prompt"):
        msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")

    response = llm.invoke(msgs)
    answer   = response.content

    logger.info("[DirectResponseNode] Hoàn thành (%d ký tự)", len(answer))
    return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}


# ════════════════════════════════════════════════════════════════════
# NODE 5 — ToolNode (custom)
# ════════════════════════════════════════════════════════════════════
def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
    fn = TOOL_MAP.get(tool_name)
    if fn is None:
        return f"[Lỗi] Tool '{tool_name}' không tồn tại trong registry."
    try:
        result = fn.invoke(tool_args)
        return result if isinstance(result, str) else json.dumps(result, ensure_ascii=False)
    except Exception as e:
        logger.exception("[ToolNode] Tool '%s' gặp lỗi", tool_name)
        return f"[Lỗi] Tool '{tool_name}': {e}"


def tool_node(state: AgentState) -> AgentState:
    """Thực thi tất cả tool_calls trong AIMessage cuối, sinh ToolMessage(s)."""
    last_msg = state["messages"][-1]

    if not isinstance(last_msg, AIMessage) or not last_msg.tool_calls:
        logger.warning("[ToolNode] Không tìm thấy tool_call.")
        return state

    tool_messages: list[ToolMessage] = []
    for tc in last_msg.tool_calls:
        name, args, cid = tc["name"], tc.get("args", {}), tc["id"]
        logger.info("[ToolNode] Thực thi tool='%s' args=%s", name, args)
        result = _run_tool(name, args)
        logger.info("[ToolNode] Tool='%s' → %d ký tự", name, len(result))
        tool_messages.append(ToolMessage(content=result, tool_call_id=cid, name=name))

    return {**state, "messages": tool_messages}


# ════════════════════════════════════════════════════════════════════
# NODE 5 — FinalResponseNode
# ════════════════════════════════════════════════════════════════════
def _extract_tool_results(state: AgentState) -> str:
    parts = [f"[{m.name}]\n{m.content}" for m in state["messages"] if isinstance(m, ToolMessage)]
    return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)"


# ════════════════════════════════════════════════════════════════════
# NODE 6 — ImageResponseNode
# ════════════════════════════════════════════════════════════════════
def image_response_node(state: AgentState) -> AgentState:
    """Nhận HumanMessage chứa ảnh + text, gọi multimodal LLM sinh câu trả lời."""
    logger.info("[ImageResponseNode] Xử lý ảnh cho %s", state["sender_id"])

    msgs = list(state["messages"])
    if cp := state.get("custom_prompt"):
        msgs.insert(0, SystemMessage(content=cp))
    response = multimodal_llm.invoke(msgs)
    answer   = response.content

    logger.info("[ImageResponseNode] Hoàn thành (%d ký tự)", len(answer))
    return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}


# ════════════════════════════════════════════════════════════════════
# NODE 5 — FinalResponseNode
# ════════════════════════════════════════════════════════════════════
def final_response_node(state: AgentState) -> AgentState:
    """Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng."""
    logger.info("[FinalResponseNode] Tổng hợp câu trả lời...")

    msgs = final_response_prompt.format_messages(
        sender_id=state["sender_id"],
        raw_query=state["raw_query"],
        tool_results=_extract_tool_results(state),
    )
    if cp := state.get("custom_prompt"):
        msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
    response = llm.invoke(msgs)
    answer   = response.content

    logger.info("[FinalResponseNode] Hoàn thành (%d ký tự)", len(answer))
    return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}