from __future__ import annotations from typing import Annotated, Any, Dict, List, Literal, TypedDict from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition from orchestrator.factories import get_llm from orchestrator.graph_agent import graph_answer from orchestrator.settings import Settings from orchestrator.sql_agent import sql_answer from orchestrator.tools import make_web_wiki_arxiv_tools Route = Literal["sql", "graph", "tools", "general"] class RouterState(TypedDict, total=False): messages: Annotated[list[BaseMessage], add_messages] route: Route debug: Dict[str, Any] def _safe_text(x: Any) -> str: if x is None: return "" return x if isinstance(x, str) else str(x) def _last_user_text(messages: list[BaseMessage]) -> str: for m in reversed(messages): if isinstance(m, HumanMessage): return _safe_text(m.content).strip() return "" def _messages_to_transcript(messages: list[BaseMessage], max_turns: int = 8) -> str: """ Build a lightweight transcript from the last N Human/AI messages. We intentionally skip tool messages to keep prompts stable. """ kept: List[BaseMessage] = [] for m in reversed(messages): if isinstance(m, (HumanMessage, AIMessage)): kept.append(m) if len(kept) >= max_turns * 2: # ~turns * 2 messages break kept.reverse() lines: List[str] = [] for m in kept: if isinstance(m, HumanMessage): lines.append(f"User: {_safe_text(m.content)}") elif isinstance(m, AIMessage): lines.append(f"Assistant: {_safe_text(m.content)}") return "\n".join(lines).strip() def _merge_debug(state: RouterState, **kv: Any) -> Dict[str, Any]: dbg = dict(state.get("debug") or {}) for k, v in kv.items(): if v is not None: dbg[k] = v return dbg def _extract_tool_names(messages: list[BaseMessage]) -> List[str]: """ Extract tool names from AIMessage.tool_calls across LangChain variants. """ names: List[str] = [] for m in messages: if isinstance(m, AIMessage): tool_calls = getattr(m, "tool_calls", None) or [] for tc in tool_calls: # tc may be dict-like or object-like if isinstance(tc, dict): n = tc.get("name") else: n = getattr(tc, "name", None) if n: names.append(str(n)) # de-dupe, preserve order out: List[str] = [] for n in names: if n not in out: out.append(n) return out def _rewrite_to_standalone(llm, messages: list[BaseMessage]) -> str: """ If the user asks a follow-up like "show them", rewrite into a standalone question. """ question = _last_user_text(messages) if not question: return "" # If there's only one user message total, no rewrite needed. num_user_msgs = sum(1 for m in messages if isinstance(m, HumanMessage)) if num_user_msgs <= 1: return question transcript = _messages_to_transcript(messages, max_turns=8) prompt = ( "Rewrite the user's latest question into a standalone question.\n" "Do NOT answer the question.\n\n" "Conversation:\n" f"{transcript}\n\n" "Latest user question:\n" f"{question}\n\n" "Standalone question:" ) msg = llm.invoke( [ SystemMessage(content="You rewrite follow-up questions into standalone questions."), HumanMessage(content=prompt), ] ) rewritten = _safe_text(getattr(msg, "content", "")).strip() return rewritten or question def build_tools_agent_graph(settings: Settings): tools = make_web_wiki_arxiv_tools( wiki_chars=settings.wiki_doc_content_chars_max, ) llm = get_llm(settings, temperature=0).bind_tools(tools) def assistant(state: RouterState): msg = llm.invoke(state["messages"]) return {"messages": [msg]} g = StateGraph(RouterState) g.add_node("assistant", assistant) g.add_node("tools", ToolNode(tools)) g.add_edge(START, "assistant") g.add_conditional_edges("assistant", tools_condition) g.add_edge("tools", "assistant") return g.compile() def build_router_graph(settings: Settings): tools_graph = build_tools_agent_graph(settings) llm_router = get_llm(settings, temperature=0) route_prompt = ( "You are a router for a multi-agent system.\n" "Choose exactly ONE route label from: sql, graph, tools, general.\n\n" "Routing rules:\n" "- sql: querying a relational database (tables/rows, SQL, students DB, counts, filters).\n" "- graph: querying a Neo4j graph database (nodes/relationships, Cypher).\n" "- tools: needs external knowledge / searching (Wikipedia/arXiv/web) or tool use.\n" "- general: conceptual explanation or chat that doesn't need tools/DB queries.\n\n" "Return ONLY the label.\n" ) def router(state: RouterState): msgs = state.get("messages", []) q = _last_user_text(msgs) transcript = _messages_to_transcript(msgs, max_turns=8) payload = ( "Conversation transcript:\n" f"{transcript}\n\n" "Latest user question:\n" f"{q}" ) msg = llm_router.invoke( [SystemMessage(content=route_prompt), HumanMessage(content=payload)] ) label = _safe_text(msg.content).strip().lower() if label not in ("sql", "graph", "tools", "general"): label = "general" dbg = _merge_debug(state, router_label=label, router_raw=msg.content, routed_to=label) return {"route": label, "debug": dbg} def sql_node(state: RouterState): standalone = _rewrite_to_standalone(llm_router, state["messages"]) out = sql_answer(settings, standalone) dbg = _merge_debug(state, routed_to="sql", sql=out, standalone_question=standalone) return {"route": "sql", "messages": [AIMessage(content=str(out["answer"]))], "debug": dbg} def graph_node(state: RouterState): standalone = _rewrite_to_standalone(llm_router, state["messages"]) out = graph_answer(settings, standalone) dbg = _merge_debug(state, routed_to="graph", graph=out.get("debug", {}), standalone_question=standalone) return {"route": "graph", "messages": [AIMessage(content=str(out["answer"]))], "debug": dbg} def tools_node(state: RouterState): out_state = tools_graph.invoke({"messages": state["messages"]}) out_msgs = out_state.get("messages", []) tools_used = _extract_tool_names(out_msgs) dbg = _merge_debug( state, routed_to="tools", tools_used=tools_used, tools_graph={"messages_len": len(out_msgs)}, ) return {"route": "tools", "messages": out_msgs, "debug": dbg} def general_node(state: RouterState): # Use the conversation itself (not just last message) convo = [m for m in state["messages"] if isinstance(m, (HumanMessage, AIMessage))] msg = llm_router.invoke([SystemMessage(content="You are a helpful assistant.")] + convo) dbg = _merge_debug(state, routed_to="general") return {"route": "general", "messages": [AIMessage(content=_safe_text(msg.content))], "debug": dbg} g = StateGraph(RouterState) g.add_node("router", router) g.add_node("sql", sql_node) g.add_node("graph", graph_node) g.add_node("tools", tools_node) g.add_node("general", general_node) g.add_edge(START, "router") g.add_conditional_edges( "router", lambda s: s["route"], {"sql": "sql", "graph": "graph", "tools": "tools", "general": "general"}, ) g.add_edge("sql", END) g.add_edge("graph", END) g.add_edge("tools", END) g.add_edge("general", END) return g.compile()