anhkhoiphan commited on
Commit
4eb7b2b
·
1 Parent(s): a892f49

Bỏ router, để orchestrator quyết định khi nào trả lời directly

Browse files
Files changed (2) hide show
  1. graph.py +11 -21
  2. nodes.py +32 -4
graph.py CHANGED
@@ -1,42 +1,33 @@
1
  from langgraph.graph import END, START, StateGraph
2
 
3
  from src.nodes import (
 
4
  final_response_node,
5
- llm_response_node,
6
  orchestrator_node,
7
- route_decision,
8
- router_node,
9
  should_call_tool,
10
  tool_node,
11
  )
12
- from src.state import MAX_ITERS, AgentState
13
 
14
 
15
  def build_graph():
16
  builder = StateGraph(AgentState)
17
 
18
- builder.add_node("router", router_node)
19
- builder.add_node("llm_response", llm_response_node)
20
- builder.add_node("orchestrator", orchestrator_node)
21
- builder.add_node("tool_node", tool_node)
22
- builder.add_node("final_response", final_response_node)
23
 
24
- builder.add_edge(START, "router")
25
-
26
- builder.add_conditional_edges(
27
- "router", route_decision,
28
- {"respond": "llm_response", "orchestrator": "orchestrator"},
29
- )
30
-
31
- builder.add_edge("llm_response", END)
32
 
33
  builder.add_conditional_edges(
34
  "orchestrator", should_call_tool,
35
- {"tool_node": "tool_node", "final_response": "final_response"},
36
  )
37
 
38
- builder.add_edge("tool_node", "final_response")
39
- builder.add_edge("final_response", END)
 
40
 
41
  return builder.compile()
42
 
@@ -61,7 +52,6 @@ def print_result(result: AgentState, label: str = "") -> None:
61
  SEP = "═" * 60
62
  if label:
63
  print(f"\n{SEP}\n {label}\n{SEP}")
64
- print(f" 📌 query_type : {result.get('query_type')}")
65
  print(f" 📌 iters : {result.get('iters')} / {result.get('max_iters')}")
66
  print(f" 📌 # messages : {len(result.get('messages', []))}")
67
  print(f"{'─' * 60}")
 
1
  from langgraph.graph import END, START, StateGraph
2
 
3
  from src.nodes import (
4
+ direct_response_node,
5
  final_response_node,
 
6
  orchestrator_node,
 
 
7
  should_call_tool,
8
  tool_node,
9
  )
10
+ from src.state import AgentState
11
 
12
 
13
  def build_graph():
14
  builder = StateGraph(AgentState)
15
 
16
+ builder.add_node("orchestrator", orchestrator_node)
17
+ builder.add_node("tool_node", tool_node)
18
+ builder.add_node("final_response", final_response_node)
19
+ builder.add_node("direct_response", direct_response_node)
 
20
 
21
+ builder.add_edge(START, "orchestrator")
 
 
 
 
 
 
 
22
 
23
  builder.add_conditional_edges(
24
  "orchestrator", should_call_tool,
25
+ {"tool_node": "tool_node", "direct_response": "direct_response"},
26
  )
27
 
28
+ builder.add_edge("tool_node", "final_response")
29
+ builder.add_edge("final_response", END)
30
+ builder.add_edge("direct_response", END)
31
 
32
  return builder.compile()
33
 
 
52
  SEP = "═" * 60
53
  if label:
54
  print(f"\n{SEP}\n {label}\n{SEP}")
 
55
  print(f" 📌 iters : {result.get('iters')} / {result.get('max_iters')}")
56
  print(f" 📌 # messages : {len(result.get('messages', []))}")
57
  print(f"{'─' * 60}")
nodes.py CHANGED
@@ -7,12 +7,15 @@ from typing import Any, Literal
7
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
8
 
9
  from src.llm import llm, multimodal_llm
 
10
  from src.prompts import (
 
11
  final_response_prompt,
12
  orchestrator_prompt,
13
  respond_prompt,
14
  router_prompt,
15
  )
 
16
  from src.state import MAX_ITERS, AgentState, QueryType
17
  from src.tools.chat_tools import TOOL_MAP, TOOLS
18
 
@@ -114,16 +117,41 @@ def orchestrator_node(state: AgentState) -> AgentState:
114
  return {**state, "messages": [response], "iters": iters + 1}
115
 
116
 
117
- def should_call_tool(state: AgentState) -> Literal["tool_node", "final_response"]:
118
- """Conditional edge: orchestrator → tool_node | final_response."""
119
  last_msg = state["messages"][-1] if state["messages"] else None
120
  if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
121
  return "tool_node"
122
- return "final_response"
123
 
124
 
125
  # ════════════════════════════════════════════════════════════════════
126
- # NODE 4 — ToolNode (custom)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # ════════════════════════════════════════════════════════════════════
128
  def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
129
  fn = TOOL_MAP.get(tool_name)
 
7
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
8
 
9
  from src.llm import llm, multimodal_llm
10
+ from src.pdf_processing import format_chat_history
11
  from src.prompts import (
12
+ direct_response_prompt,
13
  final_response_prompt,
14
  orchestrator_prompt,
15
  respond_prompt,
16
  router_prompt,
17
  )
18
+ from src.redis_client import redis_client
19
  from src.state import MAX_ITERS, AgentState, QueryType
20
  from src.tools.chat_tools import TOOL_MAP, TOOLS
21
 
 
117
  return {**state, "messages": [response], "iters": iters + 1}
118
 
119
 
120
+ def should_call_tool(state: AgentState) -> Literal["tool_node", "direct_response"]:
121
+ """Conditional edge: orchestrator → tool_node | direct_response."""
122
  last_msg = state["messages"][-1] if state["messages"] else None
123
  if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
124
  return "tool_node"
125
+ return "direct_response"
126
 
127
 
128
  # ════════════════════════════════════════════════════════════════════
129
+ # NODE 4 — DirectResponseNode
130
+ # ════════════════════════════════════════════════════════════════════
131
+ def direct_response_node(state: AgentState) -> AgentState:
132
+ """Trả lời trực tiếp bằng LLM, kèm lịch sử chat từ Redis làm context."""
133
+ logger.info("[DirectResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])
134
+
135
+ raw_history = redis_client.get_chat_history(state["conversation_id"])
136
+ chat_history = format_chat_history(raw_history) if raw_history else "(Chưa có lịch sử trò chuyện)"
137
+
138
+ msgs = direct_response_prompt.format_messages(
139
+ sender_id=state["sender_id"],
140
+ raw_query=state["raw_query"],
141
+ chat_history=chat_history,
142
+ )
143
+ if cp := state.get("custom_prompt"):
144
+ msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
145
+
146
+ response = llm.invoke(msgs)
147
+ answer = response.content
148
+
149
+ logger.info("[DirectResponseNode] Hoàn thành (%d ký tự)", len(answer))
150
+ return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
151
+
152
+
153
+ # ════════════════════════════════════════════════════════════════════
154
+ # NODE 5 — ToolNode (custom)
155
  # ════════════════════════════════════════════════════════════════════
156
  def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
157
  fn = TOOL_MAP.get(tool_name)