"""Plan-and-Execute agent router using LangGraph.
Replaces the flat ReAct loop with a structured three-phase pipeline:
1. **Planner** — analyses the user query and produces an ordered list of
steps (e.g. "search for exam rules", "search for grading policy",
"compare both").
2. **Executor** — runs each step via a short ReAct sub-graph that has
access to all retrieval tools.
3. **Synthesizer** — collects the results from all executed steps and
produces a final, cited answer.
The separation gives the pipeline *predictable structure* while still
allowing the executor to reason freely within each step.
"""
import json
import logging
import re
from collections.abc import Generator
from typing import TypedDict
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import create_react_agent
from src.agent.memory import ConversationMemory
from src.agent.prompts import get_prompt
from src.agent.token_budget import measure as _measure_tokens
from src.agent.tools import ToolResultStore, detect_document_languages, make_retrieval_tools
from src.models import GenerationResponse, IntentType, PipelineDetails
from src.retrieval.hybrid import HybridRetriever
from src.retrieval.reranker import Reranker
from src.retrieval.vector_store import VectorStore
logger = logging.getLogger(__name__)
_MAX_STEPS = 6
# ------------------------------------------------------------------
# Prompts (loaded from src/agent/prompts/*.yaml)
# ------------------------------------------------------------------
_PLANNER_PROMPT = get_prompt("planner").template
_EXECUTOR_SYSTEM = get_prompt("executor_system").template
_SYNTHESIZER_PROMPT = get_prompt("synthesizer").template
# ------------------------------------------------------------------
# Graph state
# ------------------------------------------------------------------
class PlanStep(TypedDict):
"""A single step in the execution plan."""
action: str
detail: str
class PlanExecState(TypedDict):
"""State for the Plan-and-Execute graph.
Attributes:
query: The user's original question.
top_k: Number of results per retrieval call.
plan: Ordered list of steps produced by the planner.
step_index: Index of the next step to execute.
step_results: List of (step_description, result_text) pairs.
answer: Final synthesised answer.
"""
query: str
top_k: int
plan: list[PlanStep]
step_index: int
step_results: list[tuple[str, str]]
answer: str
# ------------------------------------------------------------------
# Router class
# ------------------------------------------------------------------
class PlanAndExecuteRouter:
"""Routes queries through a Plan-and-Execute pipeline.
Graph topology::
plan → should_execute? ─┬─ yes → execute_step → should_execute?
└─ no → synthesize → END
"""
def __init__(
self,
llm: Runnable,
hybrid_retriever: HybridRetriever,
reranker: Reranker,
vector_store: VectorStore,
default_top_k: int = 5,
memory: ConversationMemory | None = None,
document_languages: list[str] | None = None,
token_budget_enabled: bool = False,
) -> None:
"""Initialise the Plan-and-Execute router.
Args:
llm: LLM with tool-calling support.
hybrid_retriever: HybridRetriever instance.
reranker: Reranker instance.
vector_store: VectorStore instance.
default_top_k: Default number of results per retrieval call.
memory: Optional ConversationMemory for multi-turn context.
When provided, prior conversation history is injected into
planner and synthesizer prompts, and each completed turn
is automatically recorded.
document_languages: Optional pre-detected list of corpus
languages. When omitted, the router lazily detects them
from the vector store on first use via the LLM.
"""
self._llm = llm
self._hybrid_retriever = hybrid_retriever
self._reranker = reranker
self._vector_store = vector_store
self._default_top_k = default_top_k
self._memory = memory or ConversationMemory()
self._document_languages: list[str] | None = (
list(document_languages) if document_languages else None
)
self._token_budget_enabled = token_budget_enabled
def _ensure_document_languages(self) -> list[str]:
"""Lazily detect and cache the document corpus languages via the LLM.
Returns:
List of detected language names (e.g. ``["Danish"]`` or
``["Danish", "English"]``). Empty list when the corpus is empty
or no readable text could be sampled.
"""
if self._document_languages is not None:
return self._document_languages
self._document_languages = detect_document_languages(self._vector_store, self._llm)
if self._document_languages:
logger.info("Detected document corpus languages: %s", self._document_languages)
return self._document_languages
# ------------------------------------------------------------------
# Node functions
# ------------------------------------------------------------------
def _plan_node(self, state: PlanExecState) -> dict:
"""Generate an execution plan from the user query."""
history = self._memory.format_history()
history_section = ""
if history:
history_section = (
f"Conversation history (for context on follow-up questions):\n"
f"{history}\n\n"
)
prompt = _PLANNER_PROMPT + history_section + f'Question: "{state["query"]}"'
_measure_tokens("planner", prompt, enabled=self._token_budget_enabled)
raw = _extract_content(self._llm.invoke(prompt))
logger.info("Planner raw output: %s", raw)
plan = _parse_plan(raw)
logger.info("Plan: %d steps — %s", len(plan), plan)
return {"plan": plan, "step_index": 0, "step_results": []}
@staticmethod
def _should_execute(state: PlanExecState) -> str:
"""Decide whether to execute the next step or synthesize."""
if state["step_index"] < len(state["plan"]) and state["step_index"] < _MAX_STEPS:
return "execute"
return "synthesize"
def _make_execute_step_node(self, store: ToolResultStore):
"""Create an execute_step node closure bound to a request-scoped store.
Args:
store: ToolResultStore for this specific request.
Returns:
Node function for LangGraph.
"""
def _execute_step_node(state: PlanExecState) -> dict:
idx = state["step_index"]
step = state["plan"][idx]
step_desc = f'{step["action"]}: {step["detail"]}'
logger.info("Executing step %d/%d: %s", idx + 1, len(state["plan"]), step_desc)
tools = make_retrieval_tools(
self._hybrid_retriever,
self._reranker,
self._vector_store,
store,
self._default_top_k,
llm_chain=self._llm,
document_languages=self._ensure_document_languages(),
)
sub_agent = create_react_agent(self._llm, tools)
step_prompt = (
f'Step to execute: {step_desc}\n\n'
f'Original user question (for context): {state["query"]}'
)
result = sub_agent.invoke({
"messages": [
SystemMessage(content=_EXECUTOR_SYSTEM),
HumanMessage(content=step_prompt),
]
})
answer = _extract_last_ai_text(result.get("messages", []))
logger.info("Step %d result: %s", idx + 1, answer[:200])
new_results = list(state["step_results"]) + [(step_desc, answer)]
return {"step_index": idx + 1, "step_results": new_results}
return _execute_step_node
def _synthesize_node(self, state: PlanExecState) -> dict:
"""Synthesize a final answer from all step results."""
step_texts = []
for i, (desc, result) in enumerate(state["step_results"], 1):
step_texts.append(f"### Step {i}: {desc}\n{result}")
gathered = "\n\n".join(step_texts)
history = self._memory.format_history()
history_section = ""
if history:
history_section = (
f"Prior conversation:\n{history}\n\n"
)
prompt = (
f"{_SYNTHESIZER_PROMPT}"
f"{history_section}"
f"Original question: {state['query']}\n\n"
f"Research results:\n{gathered}\n\n"
f"Answer:"
)
_measure_tokens("synthesizer", prompt, enabled=self._token_budget_enabled)
answer = _extract_content(self._llm.invoke(prompt))
logger.info("Synthesized final answer (%d chars)", len(answer))
return {"answer": answer}
# ------------------------------------------------------------------
# Graph construction
# ------------------------------------------------------------------
def _build_graph(self, store: ToolResultStore) -> object:
"""Build the Plan-and-Execute LangGraph.
Args:
store: Request-scoped ToolResultStore for this invocation.
Returns:
Compiled LangGraph.
"""
graph: StateGraph = StateGraph(PlanExecState)
graph.add_node("plan", self._plan_node)
graph.add_node("execute_step", self._make_execute_step_node(store))
graph.add_node("synthesize", self._synthesize_node)
graph.set_entry_point("plan")
graph.add_conditional_edges(
"plan",
self._should_execute,
{"execute": "execute_step", "synthesize": "synthesize"},
)
graph.add_conditional_edges(
"execute_step",
self._should_execute,
{"execute": "execute_step", "synthesize": "synthesize"},
)
graph.add_edge("synthesize", END)
return graph.compile()
# ------------------------------------------------------------------
# Public interface (mirrors QueryRouter)
# ------------------------------------------------------------------
def route(
self,
query: str,
top_k: int,
memory: ConversationMemory | None = None,
) -> GenerationResponse:
"""Route a query through the Plan-and-Execute pipeline.
Args:
query: The user's natural language query.
top_k: Number of top documents to retrieve per tool call.
memory: Optional per-session memory override. When provided
this memory is used instead of the router's default memory.
Returns:
GenerationResponse with answer, sources, intent, and confidence.
"""
original_memory = self._memory
if memory is not None:
self._memory = memory
try:
logger.info("PlanExec routing query: %s", query)
store = ToolResultStore()
initial_state = PlanExecState(
query=query,
top_k=top_k,
plan=[],
step_index=0,
step_results=[],
answer="",
)
graph = self._build_graph(store)
final_state: PlanExecState = graph.invoke(initial_state)
sources = store.retrieved[:top_k]
confidence = max((r.score for r in sources), default=0.0)
plan_step_strs = [
f'{s["action"]}: {s["detail"]}' for s in final_state.get("plan", [])
]
tool_call_strs = [f"{name}: {arg}" for name, arg in store.tool_calls]
response = GenerationResponse(
answer=final_state["answer"],
sources=sources,
intent=IntentType.RAG if sources else IntentType.FACTUAL,
confidence=confidence,
pipeline_details=PipelineDetails(
original_query=query,
retrieval_query=", ".join(
q for name, q in store.tool_calls if name == "hybrid_search"
) or query,
dense_results=store.dense_results,
sparse_results=store.sparse_results,
fused_results=store.fused_results,
reranked_results=sources,
plan_steps=plan_step_strs,
tool_calls=tool_call_strs,
),
)
self._memory.add_turn(query, response.answer, sources)
return response
finally:
self._memory = original_memory
def route_stream(
self,
query: str,
top_k: int,
memory: ConversationMemory | None = None,
) -> Generator[dict, None, None]:
"""Stream Plan-and-Execute events step by step.
Yields event dicts with step types:
- ``plan`` — plan was generated; carries ``steps``.
- ``execute_step`` — a step was executed; carries ``step_index``,
``step_desc``, ``result_preview``.
- ``synthesize`` — final answer generated.
- ``done`` — final event with full result payload.
Args:
query: User query.
top_k: Number of results to retrieve per tool call.
memory: Optional per-session memory override.
Yields:
Step event dicts.
"""
original_memory = self._memory
if memory is not None:
self._memory = memory
try:
yield from self._route_stream_inner(query, top_k)
finally:
self._memory = original_memory
def _route_stream_inner(self, query: str, top_k: int) -> Generator[dict, None, None]:
"""Internal streaming implementation."""
store = ToolResultStore()
initial_state = PlanExecState(
query=query,
top_k=top_k,
plan=[],
step_index=0,
step_results=[],
answer="",
)
graph = self._build_graph(store)
accumulated: dict = dict(initial_state)
for chunk in graph.stream(initial_state, stream_mode="updates"):
for node_name, update in chunk.items():
if update is None:
continue
accumulated.update(update)
if node_name == "plan":
yield {
"step": "plan",
"steps": [
f'{s["action"]}: {s["detail"]}'
for s in update.get("plan", [])
],
}
elif node_name == "execute_step":
results = update.get("step_results", [])
if results:
last_desc, last_result = results[-1]
yield {
"step": "execute_step",
"step_index": update.get("step_index", 0),
"step_desc": last_desc,
"result_preview": last_result[:300],
}
elif node_name == "synthesize":
yield {"step": "synthesize"}
sources = store.retrieved[:top_k]
confidence = max((r.score for r in sources), default=0.0)
answer = accumulated.get("answer", "")
self._memory.add_turn(query, answer, sources)
yield {
"step": "done",
"result": {
"answer": answer,
"sources": [r.to_dict() for r in sources],
"intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
"confidence": confidence,
"pipeline_details": {
"original_query": query,
"retrieval_query": ", ".join(
q for name, q in store.tool_calls if name == "hybrid_search"
) or query,
"detected_language": "",
"translated": False,
"dense_results": [r.to_dict(include_text=False) for r in store.dense_results],
"sparse_results": [r.to_dict(include_text=False) for r in store.sparse_results],
"fused_results": [r.to_dict(include_text=False) for r in store.fused_results],
"reranked_results": [r.to_dict(include_text=False) for r in sources],
"plan_steps": [
f'{s["action"]}: {s["detail"]}'
for s in accumulated.get("plan", [])
],
"tool_calls": [f"{n}: {a}" for n, a in store.tool_calls],
},
},
}
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
_THINK_CLOSED_RE = re.compile(r".*?\s*", re.DOTALL)
_THINK_UNCLOSED_RE = re.compile(r".*", re.DOTALL)
def _strip_think(text: str) -> str:
"""Remove ```` blocks — both closed and unclosed.
Some models (Qwen3) always emit ``...``; others may
leave the tag unclosed. This handles both cases.
"""
text = _THINK_CLOSED_RE.sub("", text)
text = _THINK_UNCLOSED_RE.sub("", text)
return text.strip()
def _extract_content(result: object) -> str:
"""Extract plain text from an LLM invoke result.
Handles:
- AIMessage with ``content: str``
- AIMessage with ``content: list[str | dict]`` (some providers)
- Plain strings (e.g. from StrOutputParser or test mocks)
Args:
result: Return value of ``llm.invoke()`` or ``chain.invoke()``.
Returns:
Cleaned text with ```` blocks removed.
"""
if hasattr(result, "content"):
content = result.content
else:
content = result
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and "text" in block:
parts.append(block["text"])
text = "\n".join(parts)
else:
text = str(content)
return _strip_think(text)
def _parse_plan(raw: str) -> list[PlanStep]:
"""Parse the planner's JSON output into a list of PlanStep dicts.
Robust against markdown fences, trailing text, and minor formatting issues.
Args:
raw: Raw LLM output expected to contain a JSON array.
Returns:
List of PlanStep dicts. Falls back to a single search step on failure.
"""
# Strip markdown code fences if present
cleaned = raw.strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
# Remove opening and closing fences
lines = [line for line in lines if not line.strip().startswith("```")]
cleaned = "\n".join(lines).strip()
try:
parsed = json.loads(cleaned)
except json.JSONDecodeError:
# Try to extract a JSON array from the text
start = cleaned.find("[")
end = cleaned.rfind("]")
if start != -1 and end != -1:
try:
parsed = json.loads(cleaned[start:end + 1])
except json.JSONDecodeError:
logger.warning("Failed to parse plan, falling back to single search")
return [PlanStep(action="search", detail=cleaned[:200])]
else:
logger.warning("No JSON array found in plan output, falling back")
return [PlanStep(action="search", detail=cleaned[:200])]
if not isinstance(parsed, list):
logger.warning("Plan is not a list, wrapping")
parsed = [parsed]
steps: list[PlanStep] = []
for item in parsed:
if isinstance(item, dict) and "action" in item and "detail" in item:
steps.append(PlanStep(action=str(item["action"]), detail=str(item["detail"])))
else:
logger.warning("Skipping malformed plan step: %s", item)
if not steps:
return [PlanStep(action="search", detail="general search")]
return steps
def _extract_last_ai_text(messages: list) -> str:
"""Return the text content of the last non-tool-call AI message.
Args:
messages: List of LangChain message objects.
Returns:
The extracted text, or empty string if none found.
"""
for msg in reversed(messages):
if (
isinstance(msg, AIMessage)
and msg.content
and not getattr(msg, "tool_calls", None)
):
return _extract_content(msg)
return ""