Spaces:
Sleeping
Sleeping
Commit ·
4eb7b2b
1
Parent(s): a892f49
Bỏ router, để orchestrator quyết định khi nào trả lời directly
Browse files
graph.py
CHANGED
|
@@ -1,42 +1,33 @@
|
|
| 1 |
from langgraph.graph import END, START, StateGraph
|
| 2 |
|
| 3 |
from src.nodes import (
|
|
|
|
| 4 |
final_response_node,
|
| 5 |
-
llm_response_node,
|
| 6 |
orchestrator_node,
|
| 7 |
-
route_decision,
|
| 8 |
-
router_node,
|
| 9 |
should_call_tool,
|
| 10 |
tool_node,
|
| 11 |
)
|
| 12 |
-
from src.state import
|
| 13 |
|
| 14 |
|
| 15 |
def build_graph():
|
| 16 |
builder = StateGraph(AgentState)
|
| 17 |
|
| 18 |
-
builder.add_node("
|
| 19 |
-
builder.add_node("
|
| 20 |
-
builder.add_node("
|
| 21 |
-
builder.add_node("
|
| 22 |
-
builder.add_node("final_response", final_response_node)
|
| 23 |
|
| 24 |
-
builder.add_edge(START, "
|
| 25 |
-
|
| 26 |
-
builder.add_conditional_edges(
|
| 27 |
-
"router", route_decision,
|
| 28 |
-
{"respond": "llm_response", "orchestrator": "orchestrator"},
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
builder.add_edge("llm_response", END)
|
| 32 |
|
| 33 |
builder.add_conditional_edges(
|
| 34 |
"orchestrator", should_call_tool,
|
| 35 |
-
{"tool_node": "tool_node", "
|
| 36 |
)
|
| 37 |
|
| 38 |
-
builder.add_edge("tool_node",
|
| 39 |
-
builder.add_edge("final_response",
|
|
|
|
| 40 |
|
| 41 |
return builder.compile()
|
| 42 |
|
|
@@ -61,7 +52,6 @@ def print_result(result: AgentState, label: str = "") -> None:
|
|
| 61 |
SEP = "═" * 60
|
| 62 |
if label:
|
| 63 |
print(f"\n{SEP}\n {label}\n{SEP}")
|
| 64 |
-
print(f" 📌 query_type : {result.get('query_type')}")
|
| 65 |
print(f" 📌 iters : {result.get('iters')} / {result.get('max_iters')}")
|
| 66 |
print(f" 📌 # messages : {len(result.get('messages', []))}")
|
| 67 |
print(f"{'─' * 60}")
|
|
|
|
| 1 |
from langgraph.graph import END, START, StateGraph
|
| 2 |
|
| 3 |
from src.nodes import (
|
| 4 |
+
direct_response_node,
|
| 5 |
final_response_node,
|
|
|
|
| 6 |
orchestrator_node,
|
|
|
|
|
|
|
| 7 |
should_call_tool,
|
| 8 |
tool_node,
|
| 9 |
)
|
| 10 |
+
from src.state import AgentState
|
| 11 |
|
| 12 |
|
| 13 |
def build_graph():
|
| 14 |
builder = StateGraph(AgentState)
|
| 15 |
|
| 16 |
+
builder.add_node("orchestrator", orchestrator_node)
|
| 17 |
+
builder.add_node("tool_node", tool_node)
|
| 18 |
+
builder.add_node("final_response", final_response_node)
|
| 19 |
+
builder.add_node("direct_response", direct_response_node)
|
|
|
|
| 20 |
|
| 21 |
+
builder.add_edge(START, "orchestrator")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
builder.add_conditional_edges(
|
| 24 |
"orchestrator", should_call_tool,
|
| 25 |
+
{"tool_node": "tool_node", "direct_response": "direct_response"},
|
| 26 |
)
|
| 27 |
|
| 28 |
+
builder.add_edge("tool_node", "final_response")
|
| 29 |
+
builder.add_edge("final_response", END)
|
| 30 |
+
builder.add_edge("direct_response", END)
|
| 31 |
|
| 32 |
return builder.compile()
|
| 33 |
|
|
|
|
| 52 |
SEP = "═" * 60
|
| 53 |
if label:
|
| 54 |
print(f"\n{SEP}\n {label}\n{SEP}")
|
|
|
|
| 55 |
print(f" 📌 iters : {result.get('iters')} / {result.get('max_iters')}")
|
| 56 |
print(f" 📌 # messages : {len(result.get('messages', []))}")
|
| 57 |
print(f"{'─' * 60}")
|
nodes.py
CHANGED
|
@@ -7,12 +7,15 @@ from typing import Any, Literal
|
|
| 7 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
| 8 |
|
| 9 |
from src.llm import llm, multimodal_llm
|
|
|
|
| 10 |
from src.prompts import (
|
|
|
|
| 11 |
final_response_prompt,
|
| 12 |
orchestrator_prompt,
|
| 13 |
respond_prompt,
|
| 14 |
router_prompt,
|
| 15 |
)
|
|
|
|
| 16 |
from src.state import MAX_ITERS, AgentState, QueryType
|
| 17 |
from src.tools.chat_tools import TOOL_MAP, TOOLS
|
| 18 |
|
|
@@ -114,16 +117,41 @@ def orchestrator_node(state: AgentState) -> AgentState:
|
|
| 114 |
return {**state, "messages": [response], "iters": iters + 1}
|
| 115 |
|
| 116 |
|
| 117 |
-
def should_call_tool(state: AgentState) -> Literal["tool_node", "
|
| 118 |
-
"""Conditional edge: orchestrator → tool_node |
|
| 119 |
last_msg = state["messages"][-1] if state["messages"] else None
|
| 120 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 121 |
return "tool_node"
|
| 122 |
-
return "
|
| 123 |
|
| 124 |
|
| 125 |
# ════════════════════════════════════════════════════════════════════
|
| 126 |
-
# NODE 4 —
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# ════════════════════════════════════════════════════════════════════
|
| 128 |
def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
|
| 129 |
fn = TOOL_MAP.get(tool_name)
|
|
|
|
| 7 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
| 8 |
|
| 9 |
from src.llm import llm, multimodal_llm
|
| 10 |
+
from src.pdf_processing import format_chat_history
|
| 11 |
from src.prompts import (
|
| 12 |
+
direct_response_prompt,
|
| 13 |
final_response_prompt,
|
| 14 |
orchestrator_prompt,
|
| 15 |
respond_prompt,
|
| 16 |
router_prompt,
|
| 17 |
)
|
| 18 |
+
from src.redis_client import redis_client
|
| 19 |
from src.state import MAX_ITERS, AgentState, QueryType
|
| 20 |
from src.tools.chat_tools import TOOL_MAP, TOOLS
|
| 21 |
|
|
|
|
| 117 |
return {**state, "messages": [response], "iters": iters + 1}
|
| 118 |
|
| 119 |
|
| 120 |
+
def should_call_tool(state: AgentState) -> Literal["tool_node", "direct_response"]:
|
| 121 |
+
"""Conditional edge: orchestrator → tool_node | direct_response."""
|
| 122 |
last_msg = state["messages"][-1] if state["messages"] else None
|
| 123 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
| 124 |
return "tool_node"
|
| 125 |
+
return "direct_response"
|
| 126 |
|
| 127 |
|
| 128 |
# ════════════════════════════════════════════════════════════════════
|
| 129 |
+
# NODE 4 — DirectResponseNode
|
| 130 |
+
# ════════════════════════════════════════════════════════════════════
|
| 131 |
+
def direct_response_node(state: AgentState) -> AgentState:
|
| 132 |
+
"""Trả lời trực tiếp bằng LLM, kèm lịch sử chat từ Redis làm context."""
|
| 133 |
+
logger.info("[DirectResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])
|
| 134 |
+
|
| 135 |
+
raw_history = redis_client.get_chat_history(state["conversation_id"])
|
| 136 |
+
chat_history = format_chat_history(raw_history) if raw_history else "(Chưa có lịch sử trò chuyện)"
|
| 137 |
+
|
| 138 |
+
msgs = direct_response_prompt.format_messages(
|
| 139 |
+
sender_id=state["sender_id"],
|
| 140 |
+
raw_query=state["raw_query"],
|
| 141 |
+
chat_history=chat_history,
|
| 142 |
+
)
|
| 143 |
+
if cp := state.get("custom_prompt"):
|
| 144 |
+
msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
|
| 145 |
+
|
| 146 |
+
response = llm.invoke(msgs)
|
| 147 |
+
answer = response.content
|
| 148 |
+
|
| 149 |
+
logger.info("[DirectResponseNode] Hoàn thành (%d ký tự)", len(answer))
|
| 150 |
+
return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ════════════════════════════════════════════════════════════════════
|
| 154 |
+
# NODE 5 — ToolNode (custom)
|
| 155 |
# ════════════════════════════════════════════════════════════════════
|
| 156 |
def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
|
| 157 |
fn = TOOL_MAP.get(tool_name)
|