from __future__ import annotations import json import logging from typing import Any, Literal from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from src.conversation_memory import get_context from src.llm import llm, multimodal_llm from src.prompts import ( direct_response_prompt, final_response_prompt, orchestrator_prompt, respond_prompt, router_prompt, ) from src.state import MAX_ITERS, AgentState, QueryType from src.tools.chat_tools import TOOL_MAP, TOOLS logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", ) logger = logging.getLogger("agent_router") _llm_with_tools = llm.bind_tools(TOOLS) VALID_LABELS: set[str] = {"respond", "summary"} DEFAULT_LABEL: QueryType = "respond" # ════════════════════════════════════════════════════════════════════ # NODE 1 — RouterNode # ════════════════════════════════════════════════════════════════════ def router_node(state: AgentState) -> AgentState: """Validate state, gọi LLM để phân loại query → query_type.""" if not state.get("raw_query"): raise ValueError("AgentState thiếu trường 'raw_query'.") if not state.get("sender_id"): raise ValueError("AgentState thiếu trường 'sender_id'.") if not state.get("conversation_id"): raise ValueError("AgentState thiếu trường 'conversation_id'.") logger.info("[RouterNode] Phân loại query từ %s", state["sender_id"]) chain = router_prompt | llm response = chain.invoke({ "sender_id": state["sender_id"], "time": state.get("time", ""), "raw_query": state["raw_query"], }) raw_label = response.content.strip().lower() query_type: QueryType = raw_label if raw_label in VALID_LABELS else DEFAULT_LABEL # type: ignore if raw_label not in VALID_LABELS: logger.warning( "[RouterNode] Nhãn không hợp lệ '%s', fallback → '%s'", raw_label, DEFAULT_LABEL, ) logger.info("[RouterNode] query_type = %s", query_type) return {**state, "query_type": query_type, "messages": [HumanMessage(content=state["raw_query"])]} def route_decision(state: AgentState) -> Literal["respond", "orchestrator"]: """Conditional edge: router → respond | orchestrator.""" return "respond" if state.get("query_type") == "respond" else "orchestrator" # ════════════════════════════════════════════════════════════════════ # NODE 2 — LLMResponseNode (luồng respond) # ════════════════════════════════════════════════════════════════════ def llm_response_node(state: AgentState) -> AgentState: """Trả lời trực tiếp câu hỏi không cần tra cứu lịch sử.""" logger.info("[LLMResponseNode] Trả lời trực tiếp cho %s", state["sender_id"]) msgs = respond_prompt.format_messages(sender_id=state["sender_id"], raw_query=state["raw_query"]) if cp := state.get("custom_prompt"): msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}") response = llm.invoke(msgs) answer = response.content logger.info("[LLMResponseNode] Đã sinh câu trả lời (%d ký tự)", len(answer)) return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} # ════════════════════════════════════════════════════════════════════ # NODE 3 — OrchestratorNode # ════════════════════════════════════════════════════════════════════ def orchestrator_node(state: AgentState) -> AgentState: """Gọi LLM (đã bind tool) để sinh tool_call. Dừng nếu đạt max_iters.""" iters = state.get("iters", 0) max_iters = state.get("max_iters", MAX_ITERS) logger.info("[OrchestratorNode] iters=%d / max_iters=%d", iters, max_iters) if iters >= max_iters: logger.info("[OrchestratorNode] Đã đạt max_iters, bỏ qua gọi tool.") return {**state, "iters": iters} prompt_value = orchestrator_prompt.format_messages( sender_id=state["sender_id"], conversation_id=state.get("conversation_id", ""), time=state.get("time", ""), raw_query=state["raw_query"], ) response: AIMessage = _llm_with_tools.invoke(prompt_value) # type: ignore logger.info( "[OrchestratorNode] tool_calls=%s", [tc["name"] for tc in (response.tool_calls or [])], ) return {**state, "messages": [response], "iters": iters + 1} def should_call_tool(state: AgentState) -> Literal["tool_node", "direct_response"]: """Conditional edge: orchestrator → tool_node | direct_response.""" last_msg = state["messages"][-1] if state["messages"] else None if isinstance(last_msg, AIMessage) and last_msg.tool_calls: return "tool_node" return "direct_response" # ════════════════════════════════════════════════════════════════════ # NODE 4 — DirectResponseNode # ════════════════════════════════════════════════════════════════════ def direct_response_node(state: AgentState) -> AgentState: """Trả lời trực tiếp bằng LLM, kèm lịch sử chat từ Redis làm context.""" logger.info("[DirectResponseNode] Trả lời trực tiếp cho %s", state["sender_id"]) chat_history = get_context(state["conversation_id"]) msgs = direct_response_prompt.format_messages( sender_id=state["sender_id"], raw_query=state["raw_query"], chat_history=chat_history, ) if cp := state.get("custom_prompt"): msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}") response = llm.invoke(msgs) answer = response.content logger.info("[DirectResponseNode] Hoàn thành (%d ký tự)", len(answer)) return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} # ════════════════════════════════════════════════════════════════════ # NODE 5 — ToolNode (custom) # ════════════════════════════════════════════════════════════════════ def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str: fn = TOOL_MAP.get(tool_name) if fn is None: return f"[Lỗi] Tool '{tool_name}' không tồn tại trong registry." try: result = fn.invoke(tool_args) return result if isinstance(result, str) else json.dumps(result, ensure_ascii=False) except Exception as e: logger.exception("[ToolNode] Tool '%s' gặp lỗi", tool_name) return f"[Lỗi] Tool '{tool_name}': {e}" def tool_node(state: AgentState) -> AgentState: """Thực thi tất cả tool_calls trong AIMessage cuối, sinh ToolMessage(s).""" last_msg = state["messages"][-1] if not isinstance(last_msg, AIMessage) or not last_msg.tool_calls: logger.warning("[ToolNode] Không tìm thấy tool_call.") return state tool_messages: list[ToolMessage] = [] for tc in last_msg.tool_calls: name, args, cid = tc["name"], tc.get("args", {}), tc["id"] logger.info("[ToolNode] Thực thi tool='%s' args=%s", name, args) result = _run_tool(name, args) logger.info("[ToolNode] Tool='%s' → %d ký tự", name, len(result)) tool_messages.append(ToolMessage(content=result, tool_call_id=cid, name=name)) return {**state, "messages": tool_messages} # ════════════════════════════════════════════════════════════════════ # NODE 5 — FinalResponseNode # ════════════════════════════════════════════════════════════════════ def _extract_tool_results(state: AgentState) -> str: parts = [f"[{m.name}]\n{m.content}" for m in state["messages"] if isinstance(m, ToolMessage)] return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)" # ════════════════════════════════════════════════════════════════════ # NODE 6 — ImageResponseNode # ════════════════════════════════════════════════════════════════════ def image_response_node(state: AgentState) -> AgentState: """Nhận HumanMessage chứa ảnh + text, gọi multimodal LLM sinh câu trả lời.""" logger.info("[ImageResponseNode] Xử lý ảnh cho %s", state["sender_id"]) msgs = list(state["messages"]) if cp := state.get("custom_prompt"): msgs.insert(0, SystemMessage(content=cp)) response = multimodal_llm.invoke(msgs) answer = response.content logger.info("[ImageResponseNode] Hoàn thành (%d ký tự)", len(answer)) return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} # ════════════════════════════════════════════════════════════════════ # NODE 5 — FinalResponseNode # ════════════════════════════════════════════════════════════════════ def final_response_node(state: AgentState) -> AgentState: """Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng.""" logger.info("[FinalResponseNode] Tổng hợp câu trả lời...") msgs = final_response_prompt.format_messages( sender_id=state["sender_id"], raw_query=state["raw_query"], tool_results=_extract_tool_results(state), ) if cp := state.get("custom_prompt"): msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}") response = llm.invoke(msgs) answer = response.content logger.info("[FinalResponseNode] Hoàn thành (%d ký tự)", len(answer)) return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer}