anhkhoiphan commited on
Commit
d7c0645
·
1 Parent(s): bcaa321

Định nghĩa các nodes trong graph

Browse files
Files changed (1) hide show
  1. nodes.py +175 -0
nodes.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from typing import Any, Literal
6
+
7
+ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
8
+
9
+ from src.llm import llm
10
+ from src.prompts import (
11
+ final_response_prompt,
12
+ orchestrator_prompt,
13
+ respond_prompt,
14
+ router_prompt,
15
+ )
16
+ from src.state import MAX_ITERS, AgentState, QueryType
17
+ from src.tools.chat_tools import TOOL_MAP, TOOLS
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
22
+ )
23
+ logger = logging.getLogger("agent_router")
24
+
25
+ _llm_with_tools = llm.bind_tools(TOOLS)
26
+
27
+ VALID_LABELS: set[str] = {"respond", "summary"}
28
+ DEFAULT_LABEL: QueryType = "respond"
29
+
30
+
31
+ # ════════════════════════════════════════════════════════════════════
32
+ # NODE 1 — RouterNode
33
+ # ════════════════════════════════════════════════════════════════════
34
+ def router_node(state: AgentState) -> AgentState:
35
+ """Validate state, gọi LLM để phân loại query → query_type."""
36
+ if not state.get("raw_query"):
37
+ raise ValueError("AgentState thiếu trường 'raw_query'.")
38
+ if not state.get("sender_id"):
39
+ raise ValueError("AgentState thiếu trường 'sender_id'.")
40
+
41
+ logger.info("[RouterNode] Phân loại query từ %s", state["sender_id"])
42
+
43
+ chain = router_prompt | llm
44
+ response = chain.invoke({
45
+ "sender_id": state["sender_id"],
46
+ "time": state.get("time", ""),
47
+ "raw_query": state["raw_query"],
48
+ })
49
+
50
+ raw_label = response.content.strip().lower()
51
+ query_type: QueryType = raw_label if raw_label in VALID_LABELS else DEFAULT_LABEL # type: ignore
52
+
53
+ if raw_label not in VALID_LABELS:
54
+ logger.warning(
55
+ "[RouterNode] Nhãn không hợp lệ '%s', fallback → '%s'",
56
+ raw_label, DEFAULT_LABEL,
57
+ )
58
+
59
+ logger.info("[RouterNode] query_type = %s", query_type)
60
+ return {**state, "query_type": query_type, "messages": [HumanMessage(content=state["raw_query"])]}
61
+
62
+
63
+ def route_decision(state: AgentState) -> Literal["respond", "orchestrator"]:
64
+ """Conditional edge: router → respond | orchestrator."""
65
+ return "respond" if state.get("query_type") == "respond" else "orchestrator"
66
+
67
+
68
+ # ════════════════════════════════════════════════════════════════════
69
+ # NODE 2 — LLMResponseNode (luồng respond)
70
+ # ════════════════════════════════════════════════════════════════════
71
+ def llm_response_node(state: AgentState) -> AgentState:
72
+ """Trả lời trực tiếp câu hỏi không cần tra cứu lịch sử."""
73
+ logger.info("[LLMResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])
74
+
75
+ chain = respond_prompt | llm
76
+ response = chain.invoke({"sender_id": state["sender_id"], "raw_query": state["raw_query"]})
77
+ answer = response.content
78
+
79
+ logger.info("[LLMResponseNode] Đã sinh câu trả lời (%d ký tự)", len(answer))
80
+ return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
81
+
82
+
83
+ # ════════════════════════════════════════════════════════════════════
84
+ # NODE 3 — OrchestratorNode
85
+ # ════════════════════════════════════════════════════════════════════
86
+ def orchestrator_node(state: AgentState) -> AgentState:
87
+ """Gọi LLM (đã bind tool) để sinh tool_call. Dừng nếu đạt max_iters."""
88
+ iters = state.get("iters", 0)
89
+ max_iters = state.get("max_iters", MAX_ITERS)
90
+
91
+ logger.info("[OrchestratorNode] iters=%d / max_iters=%d", iters, max_iters)
92
+
93
+ if iters >= max_iters:
94
+ logger.info("[OrchestratorNode] Đã đạt max_iters, bỏ qua gọi tool.")
95
+ return {**state, "iters": iters}
96
+
97
+ prompt_value = orchestrator_prompt.format_messages(
98
+ sender_id=state["sender_id"],
99
+ time=state.get("time", ""),
100
+ raw_query=state["raw_query"],
101
+ )
102
+
103
+ response: AIMessage = _llm_with_tools.invoke(prompt_value) # type: ignore
104
+ logger.info(
105
+ "[OrchestratorNode] tool_calls=%s",
106
+ [tc["name"] for tc in (response.tool_calls or [])],
107
+ )
108
+
109
+ return {**state, "messages": [response], "iters": iters + 1}
110
+
111
+
112
+ def should_call_tool(state: AgentState) -> Literal["tool_node", "final_response"]:
113
+ """Conditional edge: orchestrator → tool_node | final_response."""
114
+ last_msg = state["messages"][-1] if state["messages"] else None
115
+ if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
116
+ return "tool_node"
117
+ return "final_response"
118
+
119
+
120
+ # ════════════════════════════════════════════════════════════════════
121
+ # NODE 4 — ToolNode (custom)
122
+ # ════════════════════════════════════════════════════════════════════
123
+ def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
124
+ fn = TOOL_MAP.get(tool_name)
125
+ if fn is None:
126
+ return f"[Lỗi] Tool '{tool_name}' không tồn tại trong registry."
127
+ try:
128
+ result = fn.invoke(tool_args)
129
+ return result if isinstance(result, str) else json.dumps(result, ensure_ascii=False)
130
+ except Exception as e:
131
+ logger.exception("[ToolNode] Tool '%s' gặp lỗi", tool_name)
132
+ return f"[Lỗi] Tool '{tool_name}': {e}"
133
+
134
+
135
+ def tool_node(state: AgentState) -> AgentState:
136
+ """Thực thi tất cả tool_calls trong AIMessage cuối, sinh ToolMessage(s)."""
137
+ last_msg = state["messages"][-1]
138
+
139
+ if not isinstance(last_msg, AIMessage) or not last_msg.tool_calls:
140
+ logger.warning("[ToolNode] Không tìm thấy tool_call.")
141
+ return state
142
+
143
+ tool_messages: list[ToolMessage] = []
144
+ for tc in last_msg.tool_calls:
145
+ name, args, cid = tc["name"], tc.get("args", {}), tc["id"]
146
+ logger.info("[ToolNode] Thực thi tool='%s' args=%s", name, args)
147
+ result = _run_tool(name, args)
148
+ logger.info("[ToolNode] Tool='%s' → %d ký tự", name, len(result))
149
+ tool_messages.append(ToolMessage(content=result, tool_call_id=cid, name=name))
150
+
151
+ return {**state, "messages": tool_messages}
152
+
153
+
154
+ # ════════════════════════════════════════════════════════════════════
155
+ # NODE 5 — FinalResponseNode
156
+ # ════════════════════════════════════════════════════════════════════
157
+ def _extract_tool_results(state: AgentState) -> str:
158
+ parts = [f"[{m.name}]\n{m.content}" for m in state["messages"] if isinstance(m, ToolMessage)]
159
+ return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)"
160
+
161
+
162
+ def final_response_node(state: AgentState) -> AgentState:
163
+ """Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng."""
164
+ logger.info("[FinalResponseNode] Tổng hợp câu trả lời...")
165
+
166
+ chain = final_response_prompt | llm
167
+ response = chain.invoke({
168
+ "sender_id": state["sender_id"],
169
+ "raw_query": state["raw_query"],
170
+ "tool_results": _extract_tool_results(state),
171
+ })
172
+ answer = response.content
173
+
174
+ logger.info("[FinalResponseNode] Hoàn thành (%d ký tự)", len(answer))
175
+ return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}