092_agent_api / nodes.py
anhkhoiphan's picture
Tích hợp summary memory vào luồng hiện tại
b9b60b0
from __future__ import annotations
import json
import logging
from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from src.conversation_memory import get_context
from src.llm import llm, multimodal_llm
from src.prompts import (
direct_response_prompt,
final_response_prompt,
orchestrator_prompt,
respond_prompt,
router_prompt,
)
from src.state import MAX_ITERS, AgentState, QueryType
from src.tools.chat_tools import TOOL_MAP, TOOLS
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
)
logger = logging.getLogger("agent_router")
_llm_with_tools = llm.bind_tools(TOOLS)
VALID_LABELS: set[str] = {"respond", "summary"}
DEFAULT_LABEL: QueryType = "respond"
# ════════════════════════════════════════════════════════════════════
# NODE 1 — RouterNode
# ════════════════════════════════════════════════════════════════════
def router_node(state: AgentState) -> AgentState:
"""Validate state, gọi LLM để phân loại query → query_type."""
if not state.get("raw_query"):
raise ValueError("AgentState thiếu trường 'raw_query'.")
if not state.get("sender_id"):
raise ValueError("AgentState thiếu trường 'sender_id'.")
if not state.get("conversation_id"):
raise ValueError("AgentState thiếu trường 'conversation_id'.")
logger.info("[RouterNode] Phân loại query từ %s", state["sender_id"])
chain = router_prompt | llm
response = chain.invoke({
"sender_id": state["sender_id"],
"time": state.get("time", ""),
"raw_query": state["raw_query"],
})
raw_label = response.content.strip().lower()
query_type: QueryType = raw_label if raw_label in VALID_LABELS else DEFAULT_LABEL # type: ignore
if raw_label not in VALID_LABELS:
logger.warning(
"[RouterNode] Nhãn không hợp lệ '%s', fallback → '%s'",
raw_label, DEFAULT_LABEL,
)
logger.info("[RouterNode] query_type = %s", query_type)
return {**state, "query_type": query_type, "messages": [HumanMessage(content=state["raw_query"])]}
def route_decision(state: AgentState) -> Literal["respond", "orchestrator"]:
"""Conditional edge: router → respond | orchestrator."""
return "respond" if state.get("query_type") == "respond" else "orchestrator"
# ════════════════════════════════════════════════════════════════════
# NODE 2 — LLMResponseNode (luồng respond)
# ════════════════════════════════════════════════════════════════════
def llm_response_node(state: AgentState) -> AgentState:
"""Trả lời trực tiếp câu hỏi không cần tra cứu lịch sử."""
logger.info("[LLMResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])
msgs = respond_prompt.format_messages(sender_id=state["sender_id"], raw_query=state["raw_query"])
if cp := state.get("custom_prompt"):
msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
response = llm.invoke(msgs)
answer = response.content
logger.info("[LLMResponseNode] Đã sinh câu trả lời (%d ký tự)", len(answer))
return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
# ════════════════════════════════════════════════════════════════════
# NODE 3 — OrchestratorNode
# ════════════════════════════════════════════════════════════════════
def orchestrator_node(state: AgentState) -> AgentState:
"""Gọi LLM (đã bind tool) để sinh tool_call. Dừng nếu đạt max_iters."""
iters = state.get("iters", 0)
max_iters = state.get("max_iters", MAX_ITERS)
logger.info("[OrchestratorNode] iters=%d / max_iters=%d", iters, max_iters)
if iters >= max_iters:
logger.info("[OrchestratorNode] Đã đạt max_iters, bỏ qua gọi tool.")
return {**state, "iters": iters}
prompt_value = orchestrator_prompt.format_messages(
sender_id=state["sender_id"],
conversation_id=state.get("conversation_id", ""),
time=state.get("time", ""),
raw_query=state["raw_query"],
)
response: AIMessage = _llm_with_tools.invoke(prompt_value) # type: ignore
logger.info(
"[OrchestratorNode] tool_calls=%s",
[tc["name"] for tc in (response.tool_calls or [])],
)
return {**state, "messages": [response], "iters": iters + 1}
def should_call_tool(state: AgentState) -> Literal["tool_node", "direct_response"]:
"""Conditional edge: orchestrator → tool_node | direct_response."""
last_msg = state["messages"][-1] if state["messages"] else None
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
return "tool_node"
return "direct_response"
# ════════════════════════════════════════════════════════════════════
# NODE 4 — DirectResponseNode
# ════════════════════════════════════════════════════════════════════
def direct_response_node(state: AgentState) -> AgentState:
"""Trả lời trực tiếp bằng LLM, kèm lịch sử chat từ Redis làm context."""
logger.info("[DirectResponseNode] Trả lời trực tiếp cho %s", state["sender_id"])
chat_history = get_context(state["conversation_id"])
msgs = direct_response_prompt.format_messages(
sender_id=state["sender_id"],
raw_query=state["raw_query"],
chat_history=chat_history,
)
if cp := state.get("custom_prompt"):
msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
response = llm.invoke(msgs)
answer = response.content
logger.info("[DirectResponseNode] Hoàn thành (%d ký tự)", len(answer))
return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
# ════════════════════════════════════════════════════════════════════
# NODE 5 — ToolNode (custom)
# ════════════════════════════════════════════════════════════════════
def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str:
fn = TOOL_MAP.get(tool_name)
if fn is None:
return f"[Lỗi] Tool '{tool_name}' không tồn tại trong registry."
try:
result = fn.invoke(tool_args)
return result if isinstance(result, str) else json.dumps(result, ensure_ascii=False)
except Exception as e:
logger.exception("[ToolNode] Tool '%s' gặp lỗi", tool_name)
return f"[Lỗi] Tool '{tool_name}': {e}"
def tool_node(state: AgentState) -> AgentState:
"""Thực thi tất cả tool_calls trong AIMessage cuối, sinh ToolMessage(s)."""
last_msg = state["messages"][-1]
if not isinstance(last_msg, AIMessage) or not last_msg.tool_calls:
logger.warning("[ToolNode] Không tìm thấy tool_call.")
return state
tool_messages: list[ToolMessage] = []
for tc in last_msg.tool_calls:
name, args, cid = tc["name"], tc.get("args", {}), tc["id"]
logger.info("[ToolNode] Thực thi tool='%s' args=%s", name, args)
result = _run_tool(name, args)
logger.info("[ToolNode] Tool='%s' → %d ký tự", name, len(result))
tool_messages.append(ToolMessage(content=result, tool_call_id=cid, name=name))
return {**state, "messages": tool_messages}
# ════════════════════════════════════════════════════════════════════
# NODE 5 — FinalResponseNode
# ════════════════════════════════════════════════════════════════════
def _extract_tool_results(state: AgentState) -> str:
parts = [f"[{m.name}]\n{m.content}" for m in state["messages"] if isinstance(m, ToolMessage)]
return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)"
# ════════════════════════════════════════════════════════════════════
# NODE 6 — ImageResponseNode
# ════════════════════════════════════════════════════════════════════
def image_response_node(state: AgentState) -> AgentState:
"""Nhận HumanMessage chứa ảnh + text, gọi multimodal LLM sinh câu trả lời."""
logger.info("[ImageResponseNode] Xử lý ảnh cho %s", state["sender_id"])
msgs = list(state["messages"])
if cp := state.get("custom_prompt"):
msgs.insert(0, SystemMessage(content=cp))
response = multimodal_llm.invoke(msgs)
answer = response.content
logger.info("[ImageResponseNode] Hoàn thành (%d ký tự)", len(answer))
return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}
# ════════════════════════════════════════════════════════════════════
# NODE 5 — FinalResponseNode
# ════════════════════════════════════════════════════════════════════
def final_response_node(state: AgentState) -> AgentState:
"""Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng."""
logger.info("[FinalResponseNode] Tổng hợp câu trả lời...")
msgs = final_response_prompt.format_messages(
sender_id=state["sender_id"],
raw_query=state["raw_query"],
tool_results=_extract_tool_results(state),
)
if cp := state.get("custom_prompt"):
msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}")
response = llm.invoke(msgs)
answer = response.content
logger.info("[FinalResponseNode] Hoàn thành (%d ký tự)", len(answer))
return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}