Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import logging | |
| from typing import Any, Literal | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage | |
| from src.conversation_memory import get_context | |
| from src.llm import llm, multimodal_llm | |
| from src.prompts import ( | |
| direct_response_prompt, | |
| final_response_prompt, | |
| orchestrator_prompt, | |
| respond_prompt, | |
| router_prompt, | |
| ) | |
| from src.state import MAX_ITERS, AgentState, QueryType | |
| from src.tools.chat_tools import TOOL_MAP, TOOLS | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", | |
| ) | |
| logger = logging.getLogger("agent_router") | |
| _llm_with_tools = llm.bind_tools(TOOLS) | |
| VALID_LABELS: set[str] = {"respond", "summary"} | |
| DEFAULT_LABEL: QueryType = "respond" | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 1 — RouterNode | |
| # ════════════════════════════════════════════════════════════════════ | |
| def router_node(state: AgentState) -> AgentState: | |
| """Validate state, gọi LLM để phân loại query → query_type.""" | |
| if not state.get("raw_query"): | |
| raise ValueError("AgentState thiếu trường 'raw_query'.") | |
| if not state.get("sender_id"): | |
| raise ValueError("AgentState thiếu trường 'sender_id'.") | |
| if not state.get("conversation_id"): | |
| raise ValueError("AgentState thiếu trường 'conversation_id'.") | |
| logger.info("[RouterNode] Phân loại query từ %s", state["sender_id"]) | |
| chain = router_prompt | llm | |
| response = chain.invoke({ | |
| "sender_id": state["sender_id"], | |
| "time": state.get("time", ""), | |
| "raw_query": state["raw_query"], | |
| }) | |
| raw_label = response.content.strip().lower() | |
| query_type: QueryType = raw_label if raw_label in VALID_LABELS else DEFAULT_LABEL # type: ignore | |
| if raw_label not in VALID_LABELS: | |
| logger.warning( | |
| "[RouterNode] Nhãn không hợp lệ '%s', fallback → '%s'", | |
| raw_label, DEFAULT_LABEL, | |
| ) | |
| logger.info("[RouterNode] query_type = %s", query_type) | |
| return {**state, "query_type": query_type, "messages": [HumanMessage(content=state["raw_query"])]} | |
| def route_decision(state: AgentState) -> Literal["respond", "orchestrator"]: | |
| """Conditional edge: router → respond | orchestrator.""" | |
| return "respond" if state.get("query_type") == "respond" else "orchestrator" | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 2 — LLMResponseNode (luồng respond) | |
| # ════════════════════════════════════════════════════════════════════ | |
| def llm_response_node(state: AgentState) -> AgentState: | |
| """Trả lời trực tiếp câu hỏi không cần tra cứu lịch sử.""" | |
| logger.info("[LLMResponseNode] Trả lời trực tiếp cho %s", state["sender_id"]) | |
| msgs = respond_prompt.format_messages(sender_id=state["sender_id"], raw_query=state["raw_query"]) | |
| if cp := state.get("custom_prompt"): | |
| msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}") | |
| response = llm.invoke(msgs) | |
| answer = response.content | |
| logger.info("[LLMResponseNode] Đã sinh câu trả lời (%d ký tự)", len(answer)) | |
| return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 3 — OrchestratorNode | |
| # ════════════════════════════════════════════════════════════════════ | |
| def orchestrator_node(state: AgentState) -> AgentState: | |
| """Gọi LLM (đã bind tool) để sinh tool_call. Dừng nếu đạt max_iters.""" | |
| iters = state.get("iters", 0) | |
| max_iters = state.get("max_iters", MAX_ITERS) | |
| logger.info("[OrchestratorNode] iters=%d / max_iters=%d", iters, max_iters) | |
| if iters >= max_iters: | |
| logger.info("[OrchestratorNode] Đã đạt max_iters, bỏ qua gọi tool.") | |
| return {**state, "iters": iters} | |
| prompt_value = orchestrator_prompt.format_messages( | |
| sender_id=state["sender_id"], | |
| conversation_id=state.get("conversation_id", ""), | |
| time=state.get("time", ""), | |
| raw_query=state["raw_query"], | |
| ) | |
| response: AIMessage = _llm_with_tools.invoke(prompt_value) # type: ignore | |
| logger.info( | |
| "[OrchestratorNode] tool_calls=%s", | |
| [tc["name"] for tc in (response.tool_calls or [])], | |
| ) | |
| return {**state, "messages": [response], "iters": iters + 1} | |
| def should_call_tool(state: AgentState) -> Literal["tool_node", "direct_response"]: | |
| """Conditional edge: orchestrator → tool_node | direct_response.""" | |
| last_msg = state["messages"][-1] if state["messages"] else None | |
| if isinstance(last_msg, AIMessage) and last_msg.tool_calls: | |
| return "tool_node" | |
| return "direct_response" | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 4 — DirectResponseNode | |
| # ════════════════════════════════════════════════════════════════════ | |
| def direct_response_node(state: AgentState) -> AgentState: | |
| """Trả lời trực tiếp bằng LLM, kèm lịch sử chat từ Redis làm context.""" | |
| logger.info("[DirectResponseNode] Trả lời trực tiếp cho %s", state["sender_id"]) | |
| chat_history = get_context(state["conversation_id"]) | |
| msgs = direct_response_prompt.format_messages( | |
| sender_id=state["sender_id"], | |
| raw_query=state["raw_query"], | |
| chat_history=chat_history, | |
| ) | |
| if cp := state.get("custom_prompt"): | |
| msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}") | |
| response = llm.invoke(msgs) | |
| answer = response.content | |
| logger.info("[DirectResponseNode] Hoàn thành (%d ký tự)", len(answer)) | |
| return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 5 — ToolNode (custom) | |
| # ════════════════════════════════════════════════════════════════════ | |
| def _run_tool(tool_name: str, tool_args: dict[str, Any]) -> str: | |
| fn = TOOL_MAP.get(tool_name) | |
| if fn is None: | |
| return f"[Lỗi] Tool '{tool_name}' không tồn tại trong registry." | |
| try: | |
| result = fn.invoke(tool_args) | |
| return result if isinstance(result, str) else json.dumps(result, ensure_ascii=False) | |
| except Exception as e: | |
| logger.exception("[ToolNode] Tool '%s' gặp lỗi", tool_name) | |
| return f"[Lỗi] Tool '{tool_name}': {e}" | |
| def tool_node(state: AgentState) -> AgentState: | |
| """Thực thi tất cả tool_calls trong AIMessage cuối, sinh ToolMessage(s).""" | |
| last_msg = state["messages"][-1] | |
| if not isinstance(last_msg, AIMessage) or not last_msg.tool_calls: | |
| logger.warning("[ToolNode] Không tìm thấy tool_call.") | |
| return state | |
| tool_messages: list[ToolMessage] = [] | |
| for tc in last_msg.tool_calls: | |
| name, args, cid = tc["name"], tc.get("args", {}), tc["id"] | |
| logger.info("[ToolNode] Thực thi tool='%s' args=%s", name, args) | |
| result = _run_tool(name, args) | |
| logger.info("[ToolNode] Tool='%s' → %d ký tự", name, len(result)) | |
| tool_messages.append(ToolMessage(content=result, tool_call_id=cid, name=name)) | |
| return {**state, "messages": tool_messages} | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 5 — FinalResponseNode | |
| # ════════════════════════════════════════════════════════════════════ | |
| def _extract_tool_results(state: AgentState) -> str: | |
| parts = [f"[{m.name}]\n{m.content}" for m in state["messages"] if isinstance(m, ToolMessage)] | |
| return "\n\n".join(parts) if parts else "(Không có kết quả từ tool)" | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 6 — ImageResponseNode | |
| # ════════════════════════════════════════════════════════════════════ | |
| def image_response_node(state: AgentState) -> AgentState: | |
| """Nhận HumanMessage chứa ảnh + text, gọi multimodal LLM sinh câu trả lời.""" | |
| logger.info("[ImageResponseNode] Xử lý ảnh cho %s", state["sender_id"]) | |
| msgs = list(state["messages"]) | |
| if cp := state.get("custom_prompt"): | |
| msgs.insert(0, SystemMessage(content=cp)) | |
| response = multimodal_llm.invoke(msgs) | |
| answer = response.content | |
| logger.info("[ImageResponseNode] Hoàn thành (%d ký tự)", len(answer)) | |
| return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} | |
| # ════════════════════════════════════════════════════════════════════ | |
| # NODE 5 — FinalResponseNode | |
| # ════════════════════════════════════════════════════════════════════ | |
| def final_response_node(state: AgentState) -> AgentState: | |
| """Tổng hợp ToolMessage(s) và sinh câu trả lời cuối cùng.""" | |
| logger.info("[FinalResponseNode] Tổng hợp câu trả lời...") | |
| msgs = final_response_prompt.format_messages( | |
| sender_id=state["sender_id"], | |
| raw_query=state["raw_query"], | |
| tool_results=_extract_tool_results(state), | |
| ) | |
| if cp := state.get("custom_prompt"): | |
| msgs[0] = SystemMessage(content=msgs[0].content + f"\n\n═══ YÊU CẦU BỔ SUNG ═══\n{cp}") | |
| response = llm.invoke(msgs) | |
| answer = response.content | |
| logger.info("[FinalResponseNode] Hoàn thành (%d ký tự)", len(answer)) | |
| return {**state, "messages": [AIMessage(content=answer)], "final_answer": answer} | |