"""Plan-and-Execute agent router using LangGraph. Replaces the flat ReAct loop with a structured three-phase pipeline: 1. **Planner** — analyses the user query and produces an ordered list of steps (e.g. "search for exam rules", "search for grading policy", "compare both"). 2. **Executor** — runs each step via a short ReAct sub-graph that has access to all retrieval tools. 3. **Synthesizer** — collects the results from all executed steps and produces a final, cited answer. The separation gives the pipeline *predictable structure* while still allowing the executor to reason freely within each step. """ import json import logging import re from collections.abc import Generator from typing import TypedDict from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.runnables import Runnable from langgraph.graph import END, StateGraph from langgraph.prebuilt import create_react_agent from src.agent.memory import ConversationMemory from src.agent.prompts import get_prompt from src.agent.token_budget import measure as _measure_tokens from src.agent.tools import ToolResultStore, detect_document_languages, make_retrieval_tools from src.models import GenerationResponse, IntentType, PipelineDetails from src.retrieval.hybrid import HybridRetriever from src.retrieval.reranker import Reranker from src.retrieval.vector_store import VectorStore logger = logging.getLogger(__name__) _MAX_STEPS = 6 # ------------------------------------------------------------------ # Prompts (loaded from src/agent/prompts/*.yaml) # ------------------------------------------------------------------ _PLANNER_PROMPT = get_prompt("planner").template _EXECUTOR_SYSTEM = get_prompt("executor_system").template _SYNTHESIZER_PROMPT = get_prompt("synthesizer").template # ------------------------------------------------------------------ # Graph state # ------------------------------------------------------------------ class PlanStep(TypedDict): """A single step in the execution plan.""" action: str detail: str class PlanExecState(TypedDict): """State for the Plan-and-Execute graph. Attributes: query: The user's original question. top_k: Number of results per retrieval call. plan: Ordered list of steps produced by the planner. step_index: Index of the next step to execute. step_results: List of (step_description, result_text) pairs. answer: Final synthesised answer. """ query: str top_k: int plan: list[PlanStep] step_index: int step_results: list[tuple[str, str]] answer: str # ------------------------------------------------------------------ # Router class # ------------------------------------------------------------------ class PlanAndExecuteRouter: """Routes queries through a Plan-and-Execute pipeline. Graph topology:: plan → should_execute? ─┬─ yes → execute_step → should_execute? └─ no → synthesize → END """ def __init__( self, llm: Runnable, hybrid_retriever: HybridRetriever, reranker: Reranker, vector_store: VectorStore, default_top_k: int = 5, memory: ConversationMemory | None = None, document_languages: list[str] | None = None, token_budget_enabled: bool = False, ) -> None: """Initialise the Plan-and-Execute router. Args: llm: LLM with tool-calling support. hybrid_retriever: HybridRetriever instance. reranker: Reranker instance. vector_store: VectorStore instance. default_top_k: Default number of results per retrieval call. memory: Optional ConversationMemory for multi-turn context. When provided, prior conversation history is injected into planner and synthesizer prompts, and each completed turn is automatically recorded. document_languages: Optional pre-detected list of corpus languages. When omitted, the router lazily detects them from the vector store on first use via the LLM. """ self._llm = llm self._hybrid_retriever = hybrid_retriever self._reranker = reranker self._vector_store = vector_store self._default_top_k = default_top_k self._memory = memory or ConversationMemory() self._document_languages: list[str] | None = ( list(document_languages) if document_languages else None ) self._token_budget_enabled = token_budget_enabled def _ensure_document_languages(self) -> list[str]: """Lazily detect and cache the document corpus languages via the LLM. Returns: List of detected language names (e.g. ``["Danish"]`` or ``["Danish", "English"]``). Empty list when the corpus is empty or no readable text could be sampled. """ if self._document_languages is not None: return self._document_languages self._document_languages = detect_document_languages(self._vector_store, self._llm) if self._document_languages: logger.info("Detected document corpus languages: %s", self._document_languages) return self._document_languages # ------------------------------------------------------------------ # Node functions # ------------------------------------------------------------------ def _plan_node(self, state: PlanExecState) -> dict: """Generate an execution plan from the user query.""" history = self._memory.format_history() history_section = "" if history: history_section = ( f"Conversation history (for context on follow-up questions):\n" f"{history}\n\n" ) prompt = _PLANNER_PROMPT + history_section + f'Question: "{state["query"]}"' _measure_tokens("planner", prompt, enabled=self._token_budget_enabled) raw = _extract_content(self._llm.invoke(prompt)) logger.info("Planner raw output: %s", raw) plan = _parse_plan(raw) logger.info("Plan: %d steps — %s", len(plan), plan) return {"plan": plan, "step_index": 0, "step_results": []} @staticmethod def _should_execute(state: PlanExecState) -> str: """Decide whether to execute the next step or synthesize.""" if state["step_index"] < len(state["plan"]) and state["step_index"] < _MAX_STEPS: return "execute" return "synthesize" def _make_execute_step_node(self, store: ToolResultStore): """Create an execute_step node closure bound to a request-scoped store. Args: store: ToolResultStore for this specific request. Returns: Node function for LangGraph. """ def _execute_step_node(state: PlanExecState) -> dict: idx = state["step_index"] step = state["plan"][idx] step_desc = f'{step["action"]}: {step["detail"]}' logger.info("Executing step %d/%d: %s", idx + 1, len(state["plan"]), step_desc) tools = make_retrieval_tools( self._hybrid_retriever, self._reranker, self._vector_store, store, self._default_top_k, llm_chain=self._llm, document_languages=self._ensure_document_languages(), ) sub_agent = create_react_agent(self._llm, tools) step_prompt = ( f'Step to execute: {step_desc}\n\n' f'Original user question (for context): {state["query"]}' ) result = sub_agent.invoke({ "messages": [ SystemMessage(content=_EXECUTOR_SYSTEM), HumanMessage(content=step_prompt), ] }) answer = _extract_last_ai_text(result.get("messages", [])) logger.info("Step %d result: %s", idx + 1, answer[:200]) new_results = list(state["step_results"]) + [(step_desc, answer)] return {"step_index": idx + 1, "step_results": new_results} return _execute_step_node def _synthesize_node(self, state: PlanExecState) -> dict: """Synthesize a final answer from all step results.""" step_texts = [] for i, (desc, result) in enumerate(state["step_results"], 1): step_texts.append(f"### Step {i}: {desc}\n{result}") gathered = "\n\n".join(step_texts) history = self._memory.format_history() history_section = "" if history: history_section = ( f"Prior conversation:\n{history}\n\n" ) prompt = ( f"{_SYNTHESIZER_PROMPT}" f"{history_section}" f"Original question: {state['query']}\n\n" f"Research results:\n{gathered}\n\n" f"Answer:" ) _measure_tokens("synthesizer", prompt, enabled=self._token_budget_enabled) answer = _extract_content(self._llm.invoke(prompt)) logger.info("Synthesized final answer (%d chars)", len(answer)) return {"answer": answer} # ------------------------------------------------------------------ # Graph construction # ------------------------------------------------------------------ def _build_graph(self, store: ToolResultStore) -> object: """Build the Plan-and-Execute LangGraph. Args: store: Request-scoped ToolResultStore for this invocation. Returns: Compiled LangGraph. """ graph: StateGraph = StateGraph(PlanExecState) graph.add_node("plan", self._plan_node) graph.add_node("execute_step", self._make_execute_step_node(store)) graph.add_node("synthesize", self._synthesize_node) graph.set_entry_point("plan") graph.add_conditional_edges( "plan", self._should_execute, {"execute": "execute_step", "synthesize": "synthesize"}, ) graph.add_conditional_edges( "execute_step", self._should_execute, {"execute": "execute_step", "synthesize": "synthesize"}, ) graph.add_edge("synthesize", END) return graph.compile() # ------------------------------------------------------------------ # Public interface (mirrors QueryRouter) # ------------------------------------------------------------------ def route( self, query: str, top_k: int, memory: ConversationMemory | None = None, ) -> GenerationResponse: """Route a query through the Plan-and-Execute pipeline. Args: query: The user's natural language query. top_k: Number of top documents to retrieve per tool call. memory: Optional per-session memory override. When provided this memory is used instead of the router's default memory. Returns: GenerationResponse with answer, sources, intent, and confidence. """ original_memory = self._memory if memory is not None: self._memory = memory try: logger.info("PlanExec routing query: %s", query) store = ToolResultStore() initial_state = PlanExecState( query=query, top_k=top_k, plan=[], step_index=0, step_results=[], answer="", ) graph = self._build_graph(store) final_state: PlanExecState = graph.invoke(initial_state) sources = store.retrieved[:top_k] confidence = max((r.score for r in sources), default=0.0) plan_step_strs = [ f'{s["action"]}: {s["detail"]}' for s in final_state.get("plan", []) ] tool_call_strs = [f"{name}: {arg}" for name, arg in store.tool_calls] response = GenerationResponse( answer=final_state["answer"], sources=sources, intent=IntentType.RAG if sources else IntentType.FACTUAL, confidence=confidence, pipeline_details=PipelineDetails( original_query=query, retrieval_query=", ".join( q for name, q in store.tool_calls if name == "hybrid_search" ) or query, dense_results=store.dense_results, sparse_results=store.sparse_results, fused_results=store.fused_results, reranked_results=sources, plan_steps=plan_step_strs, tool_calls=tool_call_strs, ), ) self._memory.add_turn(query, response.answer, sources) return response finally: self._memory = original_memory def route_stream( self, query: str, top_k: int, memory: ConversationMemory | None = None, ) -> Generator[dict, None, None]: """Stream Plan-and-Execute events step by step. Yields event dicts with step types: - ``plan`` — plan was generated; carries ``steps``. - ``execute_step`` — a step was executed; carries ``step_index``, ``step_desc``, ``result_preview``. - ``synthesize`` — final answer generated. - ``done`` — final event with full result payload. Args: query: User query. top_k: Number of results to retrieve per tool call. memory: Optional per-session memory override. Yields: Step event dicts. """ original_memory = self._memory if memory is not None: self._memory = memory try: yield from self._route_stream_inner(query, top_k) finally: self._memory = original_memory def _route_stream_inner(self, query: str, top_k: int) -> Generator[dict, None, None]: """Internal streaming implementation.""" store = ToolResultStore() initial_state = PlanExecState( query=query, top_k=top_k, plan=[], step_index=0, step_results=[], answer="", ) graph = self._build_graph(store) accumulated: dict = dict(initial_state) for chunk in graph.stream(initial_state, stream_mode="updates"): for node_name, update in chunk.items(): if update is None: continue accumulated.update(update) if node_name == "plan": yield { "step": "plan", "steps": [ f'{s["action"]}: {s["detail"]}' for s in update.get("plan", []) ], } elif node_name == "execute_step": results = update.get("step_results", []) if results: last_desc, last_result = results[-1] yield { "step": "execute_step", "step_index": update.get("step_index", 0), "step_desc": last_desc, "result_preview": last_result[:300], } elif node_name == "synthesize": yield {"step": "synthesize"} sources = store.retrieved[:top_k] confidence = max((r.score for r in sources), default=0.0) answer = accumulated.get("answer", "") self._memory.add_turn(query, answer, sources) yield { "step": "done", "result": { "answer": answer, "sources": [r.to_dict() for r in sources], "intent": (IntentType.RAG if sources else IntentType.FACTUAL).value, "confidence": confidence, "pipeline_details": { "original_query": query, "retrieval_query": ", ".join( q for name, q in store.tool_calls if name == "hybrid_search" ) or query, "detected_language": "", "translated": False, "dense_results": [r.to_dict(include_text=False) for r in store.dense_results], "sparse_results": [r.to_dict(include_text=False) for r in store.sparse_results], "fused_results": [r.to_dict(include_text=False) for r in store.fused_results], "reranked_results": [r.to_dict(include_text=False) for r in sources], "plan_steps": [ f'{s["action"]}: {s["detail"]}' for s in accumulated.get("plan", []) ], "tool_calls": [f"{n}: {a}" for n, a in store.tool_calls], }, }, } # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ _THINK_CLOSED_RE = re.compile(r".*?\s*", re.DOTALL) _THINK_UNCLOSED_RE = re.compile(r".*", re.DOTALL) def _strip_think(text: str) -> str: """Remove ```` blocks — both closed and unclosed. Some models (Qwen3) always emit ``...``; others may leave the tag unclosed. This handles both cases. """ text = _THINK_CLOSED_RE.sub("", text) text = _THINK_UNCLOSED_RE.sub("", text) return text.strip() def _extract_content(result: object) -> str: """Extract plain text from an LLM invoke result. Handles: - AIMessage with ``content: str`` - AIMessage with ``content: list[str | dict]`` (some providers) - Plain strings (e.g. from StrOutputParser or test mocks) Args: result: Return value of ``llm.invoke()`` or ``chain.invoke()``. Returns: Cleaned text with ```` blocks removed. """ if hasattr(result, "content"): content = result.content else: content = result if isinstance(content, list): parts: list[str] = [] for block in content: if isinstance(block, str): parts.append(block) elif isinstance(block, dict) and "text" in block: parts.append(block["text"]) text = "\n".join(parts) else: text = str(content) return _strip_think(text) def _parse_plan(raw: str) -> list[PlanStep]: """Parse the planner's JSON output into a list of PlanStep dicts. Robust against markdown fences, trailing text, and minor formatting issues. Args: raw: Raw LLM output expected to contain a JSON array. Returns: List of PlanStep dicts. Falls back to a single search step on failure. """ # Strip markdown code fences if present cleaned = raw.strip() if cleaned.startswith("```"): lines = cleaned.splitlines() # Remove opening and closing fences lines = [line for line in lines if not line.strip().startswith("```")] cleaned = "\n".join(lines).strip() try: parsed = json.loads(cleaned) except json.JSONDecodeError: # Try to extract a JSON array from the text start = cleaned.find("[") end = cleaned.rfind("]") if start != -1 and end != -1: try: parsed = json.loads(cleaned[start:end + 1]) except json.JSONDecodeError: logger.warning("Failed to parse plan, falling back to single search") return [PlanStep(action="search", detail=cleaned[:200])] else: logger.warning("No JSON array found in plan output, falling back") return [PlanStep(action="search", detail=cleaned[:200])] if not isinstance(parsed, list): logger.warning("Plan is not a list, wrapping") parsed = [parsed] steps: list[PlanStep] = [] for item in parsed: if isinstance(item, dict) and "action" in item and "detail" in item: steps.append(PlanStep(action=str(item["action"]), detail=str(item["detail"]))) else: logger.warning("Skipping malformed plan step: %s", item) if not steps: return [PlanStep(action="search", detail="general search")] return steps def _extract_last_ai_text(messages: list) -> str: """Return the text content of the last non-tool-call AI message. Args: messages: List of LangChain message objects. Returns: The extracted text, or empty string if none found. """ for msg in reversed(messages): if ( isinstance(msg, AIMessage) and msg.content and not getattr(msg, "tool_calls", None) ): return _extract_content(msg) return ""