XQ commited on
Commit
1441fa0
·
1 Parent(s): e128a20

Refactor to Plan-and-Execute architecture

Browse files
.github/README.md CHANGED
@@ -4,7 +4,7 @@
4
 
5
  [xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
6
 
7
- A document intelligence system covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a LangGraph ReAct agent (default) for queries that need multiple retrieval steps, and a pipeline for lightweight models without tool-calling support. Retrieval quality is evaluated with RAGAS.
8
 
9
  ## How it works
10
 
@@ -14,15 +14,18 @@ At query time both indexes are searched and their results merged with reciprocal
14
 
15
  **Two routing modes, switchable via `AGENT_MODE`:**
16
 
17
- - **ReAct Agent** (default): a reasoning loop where the LLM calls tools as many times as it needs before answering. Useful for multi-hop questions or comparisons across documents. Requires a model with tool-calling support.
18
 
19
  | Tool | Purpose |
20
  |------|---------|
21
- | `hybrid_search(query, top_k)` | Retrieve relevant passages |
 
 
 
22
  | `list_documents()` | See what's in the knowledge base |
23
  | `fetch_document(document_id)` | Read a full document |
24
 
25
- - **Pipeline** (`AGENT_MODE=pipeline`): a fixed LangGraph graph — language detection → optional translation → hybrid retrieval → reranking → generation. Works with lightweight local models that lack tool-calling support.
26
 
27
  ## Tech Stack
28
 
@@ -49,12 +52,12 @@ See `.env.example` for per-provider configuration.
49
 
50
  | Mode | `AGENT_MODE` | Notes |
51
  |------|-------------|-------|
52
- | ReAct | `react` (default) | Tool-calling loop, needs a model that supports tool use |
53
- | Pipeline | `pipeline` | Fixed graph, works with lightweight models that lack tool calling |
54
 
55
  Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`gemma4`, `llama3.1`, `qwen2.5`, `mistral-nemo`).
56
 
57
- ReAct with local Ollama (default):
58
 
59
  ```dotenv
60
  AGENT_MODE=react
@@ -70,7 +73,7 @@ LLM_PROVIDER=ollama
70
  OLLAMA_MODEL=gemma3
71
  ```
72
 
73
- ReAct with OpenAI:
74
 
75
  ```dotenv
76
  AGENT_MODE=react
@@ -149,8 +152,10 @@ src/
149
  agent/
150
  intent_classifier.py
151
  router.py # pipeline mode (AGENT_MODE=pipeline)
152
- tools.py # hybrid_search + ToolResultStore
153
- react_router.py # ReAct mode (AGENT_MODE=react)
 
 
154
  evaluation/
155
  evaluator.py # RAGAS metrics
156
  ui/
 
4
 
5
  [xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
6
 
7
+ A document intelligence system covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a Plan-and-Execute agent (default) with conversation memory for complex multi-step queries, and a pipeline for lightweight models without tool-calling support. Retrieval quality is evaluated with RAGAS.
8
 
9
  ## How it works
10
 
 
14
 
15
  **Two routing modes, switchable via `AGENT_MODE`:**
16
 
17
+ - **Plan-and-Execute Agent** (default): a structured multi-step pipeline a planner decomposes the query into steps, an executor runs each step via a ReAct sub-agent with tool access, and a synthesizer produces the final cited answer. Includes conversation memory for multi-turn follow-ups. Requires a model with tool-calling support.
18
 
19
  | Tool | Purpose |
20
  |------|---------|
21
+ | `hybrid_search(query, top_k)` | Retrieve relevant passages via hybrid search + reranking |
22
+ | `multi_query_search(question, top_k)` | Decompose complex questions into sub-queries, search each, merge results |
23
+ | `search_within_document(document_id, query, top_k)` | Find specific sections inside a known document |
24
+ | `summarize_document(document_id)` | Generate a structured summary of a document |
25
  | `list_documents()` | See what's in the knowledge base |
26
  | `fetch_document(document_id)` | Read a full document |
27
 
28
+ - **Pipeline** (`AGENT_MODE=pipeline`): a predefined LangGraph graph — language detection → optional translation → hybrid retrieval → reranking → generation, with a confidence-based retry loop. Works with lightweight local models that lack tool-calling support.
29
 
30
  ## Tech Stack
31
 
 
52
 
53
  | Mode | `AGENT_MODE` | Notes |
54
  |------|-------------|-------|
55
+ | Plan-and-Execute | `react` (default) | Structured multi-step agent with conversation memory |
56
+ | Pipeline | `pipeline` | Predefined graph, works with lightweight models that lack tool calling |
57
 
58
  Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`gemma4`, `llama3.1`, `qwen2.5`, `mistral-nemo`).
59
 
60
+ Plan-and-Execute with local Ollama (default):
61
 
62
  ```dotenv
63
  AGENT_MODE=react
 
73
  OLLAMA_MODEL=gemma3
74
  ```
75
 
76
+ Plan-and-Execute with OpenAI:
77
 
78
  ```dotenv
79
  AGENT_MODE=react
 
152
  agent/
153
  intent_classifier.py
154
  router.py # pipeline mode (AGENT_MODE=pipeline)
155
+ tools.py # 6 retrieval tools + ToolResultStore
156
+ react_router.py # legacy ReAct loop (superseded by plan_and_execute)
157
+ plan_and_execute.py # Plan-and-Execute agent (AGENT_MODE=react)
158
+ memory.py # conversation memory for multi-turn
159
  evaluation/
160
  evaluator.py # RAGAS metrics
161
  ui/
README.md CHANGED
@@ -12,7 +12,7 @@ noindex: true
12
 
13
  **Live Demo:** [xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
14
 
15
- A document intelligence system built on a RAG architecture, covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a pipeline for lightweight models, a LangGraph ReAct agent for queries that need multiple retrieval steps. Retrieval quality is evaluated with RAGAS.
16
 
17
  ## How it works
18
 
@@ -22,13 +22,16 @@ At query time both indexes are searched and their results merged with reciprocal
22
 
23
  **Two routing modes, switchable via `AGENT_MODE`:**
24
 
25
- - **Pipeline** (default): a fixed LangGraph DAG — language detection → optional translation → hybrid retrieval → reranking → generation. Works with lightweight local models like `gemma4`.
26
 
27
- - **ReAct Agent** (`AGENT_MODE=react`): replaces the DAG with a reasoning loop where the LLM calls tools as many times as it needs before answering. Useful for multi-hop questions or comparisons across documents. Requires a model with tool-calling support.
28
 
29
  | Tool | Purpose |
30
  |------|---------|
31
- | `hybrid_search(query, top_k)` | Retrieve relevant passages |
 
 
 
32
  | `list_documents()` | See what's in the knowledge base |
33
  | `fetch_document(document_id)` | Read a full document |
34
 
@@ -57,12 +60,12 @@ See `.env.example` for per-provider configuration.
57
 
58
  | Mode | `AGENT_MODE` | Notes |
59
  |------|-------------|-------|
60
- | Pipeline | `pipeline` (default) | Fixed DAG, works with `gemma4` |
61
- | ReAct | `react` | Tool-calling loop, needs a model that supports tool use |
62
 
63
- Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`llama3.1`, `qwen2.5`, `mistral-nemo`). The default `gemma4` does not support it — use `pipeline` mode with Ollama.
64
 
65
- ReAct with OpenAI:
66
 
67
  ```dotenv
68
  AGENT_MODE=react
@@ -76,7 +79,7 @@ Pipeline with local Ollama:
76
  ```dotenv
77
  AGENT_MODE=pipeline
78
  LLM_PROVIDER=ollama
79
- OLLAMA_MODEL=gemma4
80
  ```
81
 
82
  ## Quick Start
@@ -149,8 +152,10 @@ src/
149
  agent/
150
  intent_classifier.py
151
  router.py # pipeline mode (AGENT_MODE=pipeline)
152
- tools.py # hybrid_search + ToolResultStore
153
- react_router.py # ReAct mode (AGENT_MODE=react)
 
 
154
  evaluation/
155
  evaluator.py # RAGAS metrics
156
  ui/
 
12
 
13
  **Live Demo:** [xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
14
 
15
+ A document intelligence system built on a RAG architecture, covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a pipeline for lightweight models, and a Plan-and-Execute agent flow with conversation memory for complex multi-step queries. Retrieval quality is evaluated with RAGAS.
16
 
17
  ## How it works
18
 
 
22
 
23
  **Two routing modes, switchable via `AGENT_MODE`:**
24
 
25
+ - **Pipeline**: a predefined LangGraph graph — language detection → optional translation → hybrid retrieval → reranking → generation, with a confidence-based retry loop. Works with lightweight local models.
26
 
27
+ - **Plan-and-Execute Agent** (default, `AGENT_MODE=react`): a structured multi-step pipeline where a planner decomposes the query into steps, an executor runs each step via a ReAct sub-agent with tool access, and a synthesizer produces the final cited answer. Includes conversation memory for multi-turn follow-ups. Requires a model with tool-calling support.
28
 
29
  | Tool | Purpose |
30
  |------|---------|
31
+ | `hybrid_search(query, top_k)` | Retrieve relevant passages via hybrid search + reranking |
32
+ | `multi_query_search(question, top_k)` | Decompose complex questions into sub-queries, search each, merge results |
33
+ | `search_within_document(document_id, query, top_k)` | Find specific sections inside a known document |
34
+ | `summarize_document(document_id)` | Generate a structured summary of a document |
35
  | `list_documents()` | See what's in the knowledge base |
36
  | `fetch_document(document_id)` | Read a full document |
37
 
 
60
 
61
  | Mode | `AGENT_MODE` | Notes |
62
  |------|-------------|-------|
63
+ | Pipeline | `pipeline` | Predefined graph, works with lightweight models |
64
+ | Plan-and-Execute (default) | `react` | Structured multi-step agent with conversation memory |
65
 
66
+ Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`llama3.1`, `qwen2.5`, `mistral-nemo`).
67
 
68
+ Plan-and-Execute with OpenAI:
69
 
70
  ```dotenv
71
  AGENT_MODE=react
 
79
  ```dotenv
80
  AGENT_MODE=pipeline
81
  LLM_PROVIDER=ollama
82
+ OLLAMA_MODEL=gemma3
83
  ```
84
 
85
  ## Quick Start
 
152
  agent/
153
  intent_classifier.py
154
  router.py # pipeline mode (AGENT_MODE=pipeline)
155
+ tools.py # 6 retrieval tools + ToolResultStore
156
+ react_router.py # legacy ReAct loop (superseded by plan_and_execute)
157
+ plan_and_execute.py # Plan-and-Execute agent (AGENT_MODE=react)
158
+ memory.py # conversation memory for multi-turn
159
  evaluation/
160
  evaluator.py # RAGAS metrics
161
  ui/
src/agent/memory.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conversation memory for multi-turn interactions.
2
+
3
+ Stores message history and retrieved sources across turns so that:
4
+ - Follow-up questions can reference prior context ("what about the other one?")
5
+ - The planner/synthesizer can see what was already discussed
6
+ - Previously retrieved sources are available without re-searching
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass, field
11
+
12
+ from src.models import QueryResult
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ _MAX_TURNS = 20
17
+
18
+
19
+ @dataclass
20
+ class Turn:
21
+ """A single conversation turn.
22
+
23
+ Attributes:
24
+ query: The user's question.
25
+ answer: The assistant's response.
26
+ sources: Retrieved sources used to generate the answer.
27
+ """
28
+
29
+ query: str
30
+ answer: str
31
+ sources: list[QueryResult] = field(default_factory=list)
32
+
33
+
34
+ class ConversationMemory:
35
+ """Manages multi-turn conversation state.
36
+
37
+ Stores a rolling window of recent turns and provides formatted
38
+ context for the planner and synthesizer prompts.
39
+ """
40
+
41
+ def __init__(self, max_turns: int = _MAX_TURNS) -> None:
42
+ """Initialize conversation memory.
43
+
44
+ Args:
45
+ max_turns: Maximum number of turns to retain.
46
+ """
47
+ self._max_turns = max_turns
48
+ self._turns: list[Turn] = []
49
+
50
+ @property
51
+ def turns(self) -> list[Turn]:
52
+ """Return the list of conversation turns (read-only copy)."""
53
+ return list(self._turns)
54
+
55
+ @property
56
+ def is_empty(self) -> bool:
57
+ """Return True if no conversation history exists."""
58
+ return len(self._turns) == 0
59
+
60
+ def add_turn(self, query: str, answer: str, sources: list[QueryResult] | None = None) -> None:
61
+ """Record a completed conversation turn.
62
+
63
+ Args:
64
+ query: The user's question.
65
+ answer: The assistant's response.
66
+ sources: Retrieved sources (optional).
67
+ """
68
+ self._turns.append(Turn(query=query, answer=answer, sources=sources or []))
69
+ if len(self._turns) > self._max_turns:
70
+ removed = self._turns.pop(0)
71
+ logger.debug("Evicted oldest turn: %s", removed.query[:50])
72
+ logger.debug("Memory now has %d turns", len(self._turns))
73
+
74
+ def clear(self) -> None:
75
+ """Clear all conversation history."""
76
+ self._turns.clear()
77
+ logger.info("Conversation memory cleared")
78
+
79
+ def format_history(self, max_recent: int = 5) -> str:
80
+ """Format recent conversation history for inclusion in prompts.
81
+
82
+ Args:
83
+ max_recent: Maximum number of recent turns to include.
84
+
85
+ Returns:
86
+ Formatted string of recent Q&A pairs, or empty string if no history.
87
+ """
88
+ if not self._turns:
89
+ return ""
90
+
91
+ recent = self._turns[-max_recent:]
92
+ parts: list[str] = []
93
+ for i, turn in enumerate(recent, 1):
94
+ source_note = ""
95
+ if turn.sources:
96
+ doc_ids = sorted({s.chunk.document_id for s in turn.sources})
97
+ source_note = f" [sources: {', '.join(doc_ids)}]"
98
+ parts.append(
99
+ f"Turn {i}:\n"
100
+ f" User: {turn.query}\n"
101
+ f" Assistant: {turn.answer[:500]}{source_note}"
102
+ )
103
+ return "\n\n".join(parts)
104
+
105
+ def get_prior_sources(self) -> list[QueryResult]:
106
+ """Return all unique sources from prior turns, sorted by score.
107
+
108
+ Returns:
109
+ Deduplicated list of QueryResult from all past turns.
110
+ """
111
+ by_id: dict[str, QueryResult] = {}
112
+ for turn in self._turns:
113
+ for r in turn.sources:
114
+ cid = r.chunk.chunk_id
115
+ if cid not in by_id or r.score > by_id[cid].score:
116
+ by_id[cid] = r
117
+ return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
118
+
119
+ def last_query(self) -> str:
120
+ """Return the last user query, or empty string."""
121
+ return self._turns[-1].query if self._turns else ""
122
+
123
+ def last_sources(self) -> list[QueryResult]:
124
+ """Return sources from the most recent turn."""
125
+ return self._turns[-1].sources if self._turns else []
src/agent/plan_and_execute.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plan-and-Execute agent router using LangGraph.
2
+
3
+ Replaces the flat ReAct loop with a structured three-phase pipeline:
4
+
5
+ 1. **Planner** — analyses the user query and produces an ordered list of
6
+ steps (e.g. "search for exam rules", "search for grading policy",
7
+ "compare both").
8
+ 2. **Executor** — runs each step via a short ReAct sub-graph that has
9
+ access to all retrieval tools.
10
+ 3. **Synthesizer** — collects the results from all executed steps and
11
+ produces a final, cited answer.
12
+
13
+ The separation gives the pipeline *predictable structure* while still
14
+ allowing the executor to reason freely within each step.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ from collections.abc import Generator
20
+ from typing import TypedDict
21
+
22
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
23
+ from langchain_core.runnables import Runnable
24
+ from langgraph.graph import END, StateGraph
25
+ from langgraph.prebuilt import create_react_agent
26
+
27
+ from src.agent.memory import ConversationMemory
28
+ from src.agent.tools import ToolResultStore, make_retrieval_tools
29
+ from src.models import GenerationResponse, IntentType, PipelineDetails, QueryResult
30
+ from src.retrieval.hybrid import HybridRetriever
31
+ from src.retrieval.reranker import Reranker
32
+ from src.retrieval.vector_store import VectorStore
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _MAX_STEPS = 6
37
+
38
+ # ------------------------------------------------------------------
39
+ # Prompts
40
+ # ------------------------------------------------------------------
41
+
42
+ _PLANNER_PROMPT = (
43
+ "You are a planning assistant for the University of Copenhagen (KU) document system.\n\n"
44
+ "Given a user question, produce a JSON list of 1–4 steps needed to answer it.\n"
45
+ "Each step is an object with:\n"
46
+ ' - "action": one of "search", "search_within", "multi_search", '
47
+ '"summarize", "list_docs", "fetch_doc"\n'
48
+ ' - "detail": a short description of what to do (e.g. the search query, document ID)\n\n'
49
+ "Rules:\n"
50
+ "- For simple factual questions: 1 search step is enough.\n"
51
+ "- For comparison questions: use multi_search or separate search steps.\n"
52
+ "- For document overview requests: use summarize.\n"
53
+ "- Always end with the steps needed; do NOT include a final 'answer' step.\n\n"
54
+ "Reply with ONLY the JSON array, nothing else.\n\n"
55
+ "Examples:\n"
56
+ 'Question: "What is the exam policy?"\n'
57
+ '[{"action": "search", "detail": "KU eksamensregler"}]\n\n'
58
+ 'Question: "Compare vacation rules for academic vs administrative staff"\n'
59
+ '[{"action": "search", "detail": "ferieregler videnskabeligt personale"}, '
60
+ '{"action": "search", "detail": "ferieregler administrativt personale"}]\n\n'
61
+ 'Question: "Summarize the AI policy document"\n'
62
+ '[{"action": "summarize", "detail": "ku_ai_policy.pdf"}]\n\n'
63
+ "Now plan for this question:\n"
64
+ )
65
+
66
+ _EXECUTOR_SYSTEM = (
67
+ "You are executing ONE step of a plan to answer a user's question about "
68
+ "University of Copenhagen (KU) documents.\n\n"
69
+ "You have retrieval tools available. Execute the step described below, "
70
+ "then summarise what you found in 2-3 sentences. If you find nothing "
71
+ "relevant, say so clearly.\n\n"
72
+ "Do NOT produce a final answer — just report what you found for this step."
73
+ )
74
+
75
+ _SYNTHESIZER_PROMPT = (
76
+ "You are a helpful assistant for administrative staff at the University "
77
+ "of Copenhagen (KU).\n\n"
78
+ "Below are the results gathered from multiple research steps. "
79
+ "Synthesize them into a single coherent answer to the user's original question.\n\n"
80
+ "Guidelines:\n"
81
+ "- Cite document sources using [1], [2], etc.\n"
82
+ "- Answer in the same language as the user's question.\n"
83
+ "- Be concise but thorough.\n"
84
+ "- If some steps found no results, acknowledge gaps honestly.\n\n"
85
+ )
86
+
87
+
88
+ # ------------------------------------------------------------------
89
+ # Graph state
90
+ # ------------------------------------------------------------------
91
+
92
+ class PlanStep(TypedDict):
93
+ """A single step in the execution plan."""
94
+
95
+ action: str
96
+ detail: str
97
+
98
+
99
+ class PlanExecState(TypedDict):
100
+ """State for the Plan-and-Execute graph.
101
+
102
+ Attributes:
103
+ query: The user's original question.
104
+ top_k: Number of results per retrieval call.
105
+ plan: Ordered list of steps produced by the planner.
106
+ step_index: Index of the next step to execute.
107
+ step_results: List of (step_description, result_text) pairs.
108
+ answer: Final synthesised answer.
109
+ """
110
+
111
+ query: str
112
+ top_k: int
113
+ plan: list[PlanStep]
114
+ step_index: int
115
+ step_results: list[tuple[str, str]]
116
+ answer: str
117
+
118
+
119
+ # ------------------------------------------------------------------
120
+ # Router class
121
+ # ------------------------------------------------------------------
122
+
123
+ class PlanAndExecuteRouter:
124
+ """Routes queries through a Plan-and-Execute pipeline.
125
+
126
+ Graph topology::
127
+
128
+ plan → should_execute? ─┬─ yes → execute_step → should_execute?
129
+ └─ no → synthesize → END
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ llm: Runnable,
135
+ hybrid_retriever: HybridRetriever,
136
+ reranker: Reranker,
137
+ vector_store: VectorStore,
138
+ default_top_k: int = 5,
139
+ memory: ConversationMemory | None = None,
140
+ ) -> None:
141
+ """Initialise the Plan-and-Execute router.
142
+
143
+ Args:
144
+ llm: LLM with tool-calling support.
145
+ hybrid_retriever: HybridRetriever instance.
146
+ reranker: Reranker instance.
147
+ vector_store: VectorStore instance.
148
+ default_top_k: Default number of results per retrieval call.
149
+ memory: Optional ConversationMemory for multi-turn context.
150
+ When provided, prior conversation history is injected into
151
+ planner and synthesizer prompts, and each completed turn
152
+ is automatically recorded.
153
+ """
154
+ self._llm = llm
155
+ self._hybrid_retriever = hybrid_retriever
156
+ self._reranker = reranker
157
+ self._vector_store = vector_store
158
+ self._default_top_k = default_top_k
159
+ self._store = ToolResultStore()
160
+ self._memory = memory or ConversationMemory()
161
+
162
+ # ------------------------------------------------------------------
163
+ # Node functions
164
+ # ------------------------------------------------------------------
165
+
166
+ def _plan_node(self, state: PlanExecState) -> dict:
167
+ """Generate an execution plan from the user query."""
168
+ history = self._memory.format_history()
169
+ history_section = ""
170
+ if history:
171
+ history_section = (
172
+ f"Conversation history (for context on follow-up questions):\n"
173
+ f"{history}\n\n"
174
+ )
175
+ prompt = _PLANNER_PROMPT + history_section + f'Question: "{state["query"]}"'
176
+ raw = str(self._llm.invoke(prompt)).strip()
177
+ logger.info("Planner raw output: %s", raw)
178
+
179
+ plan = _parse_plan(raw)
180
+ logger.info("Plan: %d steps — %s", len(plan), plan)
181
+ return {"plan": plan, "step_index": 0, "step_results": []}
182
+
183
+ @staticmethod
184
+ def _should_execute(state: PlanExecState) -> str:
185
+ """Decide whether to execute the next step or synthesize."""
186
+ if state["step_index"] < len(state["plan"]) and state["step_index"] < _MAX_STEPS:
187
+ return "execute"
188
+ return "synthesize"
189
+
190
+ def _execute_step_node(self, state: PlanExecState) -> dict:
191
+ """Execute the current plan step using a ReAct sub-agent."""
192
+ idx = state["step_index"]
193
+ step = state["plan"][idx]
194
+ step_desc = f'{step["action"]}: {step["detail"]}'
195
+ logger.info("Executing step %d/%d: %s", idx + 1, len(state["plan"]), step_desc)
196
+
197
+ # Build a fresh tool set and sub-agent for this step
198
+ tools = make_retrieval_tools(
199
+ self._hybrid_retriever,
200
+ self._reranker,
201
+ self._vector_store,
202
+ self._store,
203
+ self._default_top_k,
204
+ llm_chain=self._llm,
205
+ )
206
+ sub_agent = create_react_agent(self._llm, tools)
207
+
208
+ step_prompt = (
209
+ f'Step to execute: {step_desc}\n\n'
210
+ f'Original user question (for context): {state["query"]}'
211
+ )
212
+
213
+ result = sub_agent.invoke({
214
+ "messages": [
215
+ SystemMessage(content=_EXECUTOR_SYSTEM),
216
+ HumanMessage(content=step_prompt),
217
+ ]
218
+ })
219
+
220
+ # Extract the sub-agent's final text answer
221
+ answer = _extract_last_ai_text(result.get("messages", []))
222
+ logger.info("Step %d result: %s", idx + 1, answer[:200])
223
+
224
+ new_results = list(state["step_results"]) + [(step_desc, answer)]
225
+ return {"step_index": idx + 1, "step_results": new_results}
226
+
227
+ def _synthesize_node(self, state: PlanExecState) -> dict:
228
+ """Synthesize a final answer from all step results."""
229
+ step_texts = []
230
+ for i, (desc, result) in enumerate(state["step_results"], 1):
231
+ step_texts.append(f"### Step {i}: {desc}\n{result}")
232
+ gathered = "\n\n".join(step_texts)
233
+
234
+ history = self._memory.format_history()
235
+ history_section = ""
236
+ if history:
237
+ history_section = (
238
+ f"Prior conversation:\n{history}\n\n"
239
+ )
240
+
241
+ prompt = (
242
+ f"{_SYNTHESIZER_PROMPT}"
243
+ f"{history_section}"
244
+ f"Original question: {state['query']}\n\n"
245
+ f"Research results:\n{gathered}\n\n"
246
+ f"Answer:"
247
+ )
248
+ answer = str(self._llm.invoke(prompt)).strip()
249
+ logger.info("Synthesized final answer (%d chars)", len(answer))
250
+ return {"answer": answer}
251
+
252
+ # ------------------------------------------------------------------
253
+ # Graph construction
254
+ # ------------------------------------------------------------------
255
+
256
+ def _build_graph(self) -> object:
257
+ """Build the Plan-and-Execute LangGraph.
258
+
259
+ Returns:
260
+ Compiled LangGraph.
261
+ """
262
+ graph: StateGraph = StateGraph(PlanExecState)
263
+
264
+ graph.add_node("plan", self._plan_node)
265
+ graph.add_node("execute_step", self._execute_step_node)
266
+ graph.add_node("synthesize", self._synthesize_node)
267
+
268
+ graph.set_entry_point("plan")
269
+ graph.add_conditional_edges(
270
+ "plan",
271
+ self._should_execute,
272
+ {"execute": "execute_step", "synthesize": "synthesize"},
273
+ )
274
+ graph.add_conditional_edges(
275
+ "execute_step",
276
+ self._should_execute,
277
+ {"execute": "execute_step", "synthesize": "synthesize"},
278
+ )
279
+ graph.add_edge("synthesize", END)
280
+
281
+ return graph.compile()
282
+
283
+ # ------------------------------------------------------------------
284
+ # Public interface (mirrors QueryRouter / ReActRouter)
285
+ # ------------------------------------------------------------------
286
+
287
+ def route(self, query: str, top_k: int) -> GenerationResponse:
288
+ """Route a query through the Plan-and-Execute pipeline.
289
+
290
+ Args:
291
+ query: The user's natural language query.
292
+ top_k: Number of top documents to retrieve per tool call.
293
+
294
+ Returns:
295
+ GenerationResponse with answer, sources, intent, and confidence.
296
+ """
297
+ logger.info("PlanExec routing query: %s", query)
298
+ self._store = ToolResultStore()
299
+
300
+ initial_state = PlanExecState(
301
+ query=query,
302
+ top_k=top_k,
303
+ plan=[],
304
+ step_index=0,
305
+ step_results=[],
306
+ answer="",
307
+ )
308
+
309
+ graph = self._build_graph()
310
+ final_state: PlanExecState = graph.invoke(initial_state)
311
+
312
+ sources = self._store.retrieved[:top_k]
313
+ confidence = max((r.score for r in sources), default=0.0)
314
+
315
+ plan_step_strs = [
316
+ f'{s["action"]}: {s["detail"]}' for s in final_state.get("plan", [])
317
+ ]
318
+ tool_call_strs = [f"{name}: {arg}" for name, arg in self._store.tool_calls]
319
+
320
+ response = GenerationResponse(
321
+ answer=final_state["answer"],
322
+ sources=sources,
323
+ intent=IntentType.RAG if sources else IntentType.FACTUAL,
324
+ confidence=confidence,
325
+ pipeline_details=PipelineDetails(
326
+ original_query=query,
327
+ retrieval_query=", ".join(
328
+ q for name, q in self._store.tool_calls if name == "hybrid_search"
329
+ ) or query,
330
+ dense_results=self._store.dense_results,
331
+ sparse_results=self._store.sparse_results,
332
+ fused_results=self._store.fused_results,
333
+ reranked_results=sources,
334
+ plan_steps=plan_step_strs,
335
+ tool_calls=tool_call_strs,
336
+ ),
337
+ )
338
+
339
+ self._memory.add_turn(query, response.answer, sources)
340
+ return response
341
+
342
+ def route_stream(self, query: str, top_k: int) -> Generator[dict, None, None]:
343
+ """Stream Plan-and-Execute events step by step.
344
+
345
+ Yields event dicts with step types:
346
+ - ``plan`` — plan was generated; carries ``steps``.
347
+ - ``execute_step`` — a step was executed; carries ``step_index``,
348
+ ``step_desc``, ``result_preview``.
349
+ - ``synthesize`` — final answer generated.
350
+ - ``done`` — final event with full result payload.
351
+
352
+ Args:
353
+ query: User query.
354
+ top_k: Number of results to retrieve per tool call.
355
+
356
+ Yields:
357
+ Step event dicts.
358
+ """
359
+ self._store = ToolResultStore()
360
+
361
+ initial_state = PlanExecState(
362
+ query=query,
363
+ top_k=top_k,
364
+ plan=[],
365
+ step_index=0,
366
+ step_results=[],
367
+ answer="",
368
+ )
369
+
370
+ graph = self._build_graph()
371
+ accumulated: dict = dict(initial_state)
372
+
373
+ for chunk in graph.stream(initial_state, stream_mode="updates"):
374
+ for node_name, update in chunk.items():
375
+ if update is None:
376
+ continue
377
+ accumulated.update(update)
378
+
379
+ if node_name == "plan":
380
+ yield {
381
+ "step": "plan",
382
+ "steps": [
383
+ f'{s["action"]}: {s["detail"]}'
384
+ for s in update.get("plan", [])
385
+ ],
386
+ }
387
+ elif node_name == "execute_step":
388
+ results = update.get("step_results", [])
389
+ if results:
390
+ last_desc, last_result = results[-1]
391
+ yield {
392
+ "step": "execute_step",
393
+ "step_index": update.get("step_index", 0),
394
+ "step_desc": last_desc,
395
+ "result_preview": last_result[:300],
396
+ }
397
+ elif node_name == "synthesize":
398
+ yield {"step": "synthesize"}
399
+
400
+ sources = self._store.retrieved[:top_k]
401
+ confidence = max((r.score for r in sources), default=0.0)
402
+ answer = accumulated.get("answer", "")
403
+
404
+ self._memory.add_turn(query, answer, sources)
405
+
406
+ yield {
407
+ "step": "done",
408
+ "result": {
409
+ "answer": answer,
410
+ "sources": [r.to_dict() for r in sources],
411
+ "intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
412
+ "confidence": confidence,
413
+ "pipeline_details": {
414
+ "original_query": query,
415
+ "retrieval_query": ", ".join(
416
+ q for name, q in self._store.tool_calls if name == "hybrid_search"
417
+ ) or query,
418
+ "detected_language": "",
419
+ "translated": False,
420
+ "dense_results": [r.to_dict(include_text=False) for r in self._store.dense_results],
421
+ "sparse_results": [r.to_dict(include_text=False) for r in self._store.sparse_results],
422
+ "fused_results": [r.to_dict(include_text=False) for r in self._store.fused_results],
423
+ "reranked_results": [r.to_dict(include_text=False) for r in sources],
424
+ "plan_steps": [
425
+ f'{s["action"]}: {s["detail"]}'
426
+ for s in accumulated.get("plan", [])
427
+ ],
428
+ "tool_calls": [f"{n}: {a}" for n, a in self._store.tool_calls],
429
+ },
430
+ },
431
+ }
432
+
433
+
434
+ # ------------------------------------------------------------------
435
+ # Helpers
436
+ # ------------------------------------------------------------------
437
+
438
+ def _parse_plan(raw: str) -> list[PlanStep]:
439
+ """Parse the planner's JSON output into a list of PlanStep dicts.
440
+
441
+ Robust against markdown fences, trailing text, and minor formatting issues.
442
+
443
+ Args:
444
+ raw: Raw LLM output expected to contain a JSON array.
445
+
446
+ Returns:
447
+ List of PlanStep dicts. Falls back to a single search step on failure.
448
+ """
449
+ # Strip markdown code fences if present
450
+ cleaned = raw.strip()
451
+ if cleaned.startswith("```"):
452
+ lines = cleaned.splitlines()
453
+ # Remove opening and closing fences
454
+ lines = [l for l in lines if not l.strip().startswith("```")]
455
+ cleaned = "\n".join(lines).strip()
456
+
457
+ try:
458
+ parsed = json.loads(cleaned)
459
+ except json.JSONDecodeError:
460
+ # Try to extract a JSON array from the text
461
+ start = cleaned.find("[")
462
+ end = cleaned.rfind("]")
463
+ if start != -1 and end != -1:
464
+ try:
465
+ parsed = json.loads(cleaned[start:end + 1])
466
+ except json.JSONDecodeError:
467
+ logger.warning("Failed to parse plan, falling back to single search")
468
+ return [PlanStep(action="search", detail=cleaned[:200])]
469
+ else:
470
+ logger.warning("No JSON array found in plan output, falling back")
471
+ return [PlanStep(action="search", detail=cleaned[:200])]
472
+
473
+ if not isinstance(parsed, list):
474
+ logger.warning("Plan is not a list, wrapping")
475
+ parsed = [parsed]
476
+
477
+ steps: list[PlanStep] = []
478
+ for item in parsed:
479
+ if isinstance(item, dict) and "action" in item and "detail" in item:
480
+ steps.append(PlanStep(action=str(item["action"]), detail=str(item["detail"])))
481
+ else:
482
+ logger.warning("Skipping malformed plan step: %s", item)
483
+
484
+ if not steps:
485
+ return [PlanStep(action="search", detail="general search")]
486
+
487
+ return steps
488
+
489
+
490
+ def _extract_last_ai_text(messages: list) -> str:
491
+ """Return the text content of the last non-tool-call AI message.
492
+
493
+ Args:
494
+ messages: List of LangChain message objects.
495
+
496
+ Returns:
497
+ The extracted text, or empty string if none found.
498
+ """
499
+ for msg in reversed(messages):
500
+ if (
501
+ isinstance(msg, AIMessage)
502
+ and msg.content
503
+ and not getattr(msg, "tool_calls", None)
504
+ ):
505
+ return str(msg.content)
506
+ return ""
src/agent/react_router.py CHANGED
@@ -26,14 +26,22 @@ logger = logging.getLogger(__name__)
26
 
27
  _SYSTEM_PROMPT = (
28
  "You are a helpful assistant for administrative staff at the University of Copenhagen (KU).\n\n"
29
- "You have access to a hybrid_search tool that searches KU policy documents stored in the "
30
- "knowledge base.\n\n"
 
 
 
 
 
31
  "Guidelines:\n"
32
- "- Always call hybrid_search before answering questions about KU rules, policies, exams, "
33
  "employment conditions, or administrative procedures.\n"
34
- "- If the first search does not return sufficient information, call hybrid_search again "
35
- "with a refined or more specific query.\n"
36
- "- For comparison questions, search for each item separately.\n"
 
 
 
37
  "- Cite the document sources ([1], [2], …) in your answer.\n"
38
  "- Answer in the same language as the user's question."
39
  )
@@ -83,6 +91,7 @@ class ReActRouter:
83
  self._vector_store,
84
  store,
85
  self._default_top_k,
 
86
  )
87
  return create_react_agent(self._llm, tools)
88
 
 
26
 
27
  _SYSTEM_PROMPT = (
28
  "You are a helpful assistant for administrative staff at the University of Copenhagen (KU).\n\n"
29
+ "You have access to several tools for searching KU policy documents:\n"
30
+ "- hybrid_search: General-purpose search across all documents.\n"
31
+ "- multi_query_search: For complex or comparison questions — decomposes into sub-queries.\n"
32
+ "- search_within_document: Pinpoint specific sections inside a known document.\n"
33
+ "- summarize_document: Generate an overview of an entire document.\n"
34
+ "- list_documents: See which documents are available.\n"
35
+ "- fetch_document: Get the full text of a specific document.\n\n"
36
  "Guidelines:\n"
37
+ "- Always search before answering questions about KU rules, policies, exams, "
38
  "employment conditions, or administrative procedures.\n"
39
+ "- Use multi_query_search for comparison questions or complex multi-part questions.\n"
40
+ "- Use search_within_document when you already know the relevant document and "
41
+ "need to find a specific clause or section.\n"
42
+ "- Use summarize_document when the user asks for an overview of a document.\n"
43
+ "- If the first search does not return sufficient information, try a different "
44
+ "tool or refine your query.\n"
45
  "- Cite the document sources ([1], [2], …) in your answer.\n"
46
  "- Answer in the same language as the user's question."
47
  )
 
91
  self._vector_store,
92
  store,
93
  self._default_top_k,
94
+ llm_chain=self._llm,
95
  )
96
  return create_react_agent(self._llm, tools)
97
 
src/agent/tools.py CHANGED
@@ -3,6 +3,7 @@
3
  import logging
4
  from dataclasses import dataclass, field
5
 
 
6
  from langchain_core.tools import tool
7
 
8
  from src.models import QueryResult
@@ -33,12 +34,54 @@ class ToolResultStore:
33
  fused_results: list[QueryResult] = field(default_factory=list)
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def make_retrieval_tools(
37
  hybrid_retriever: HybridRetriever,
38
  reranker: Reranker,
39
  vector_store: VectorStore,
40
  store: ToolResultStore,
41
  default_top_k: int = 5,
 
42
  ) -> list:
43
  """Create retrieval tools bound to the given components and result store.
44
 
@@ -52,11 +95,18 @@ def make_retrieval_tools(
52
  vector_store: VectorStore instance for document-level access.
53
  store: Shared ToolResultStore that captures structured results.
54
  default_top_k: Default number of results to return per call.
 
 
 
55
 
56
  Returns:
57
  List of LangChain tool callables ready for bind_tools / ToolNode.
58
  """
59
 
 
 
 
 
60
  @tool
61
  def hybrid_search(query: str, top_k: int = default_top_k) -> str:
62
  """Search the KU document knowledge base using hybrid retrieval.
@@ -84,36 +134,16 @@ def make_retrieval_tools(
84
  hybrid_result = hybrid_retriever.search_detailed(query, top_k=top_k)
85
  results = reranker.rerank(query, hybrid_result.fused_results, top_k=top_k)
86
 
87
- # Accumulate intermediate pipeline stages
88
- def _merge(existing_list: list[QueryResult], new_list: list[QueryResult]) -> list[QueryResult]:
89
- by_id = {r.chunk.chunk_id: r for r in existing_list}
90
- for r in new_list:
91
- cid = r.chunk.chunk_id
92
- if cid not in by_id or r.score > by_id[cid].score:
93
- by_id[cid] = r
94
- return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
95
 
96
- store.dense_results = _merge(store.dense_results, hybrid_result.dense_results)
97
- store.sparse_results = _merge(store.sparse_results, hybrid_result.sparse_results)
98
- store.fused_results = _merge(store.fused_results, hybrid_result.fused_results)
99
 
100
- # Accumulate reranked results across multiple calls (union by chunk_id, keep highest score)
101
- existing = {r.chunk.chunk_id: r for r in store.retrieved}
102
- for r in results:
103
- cid = r.chunk.chunk_id
104
- if cid not in existing or r.score > existing[cid].score:
105
- existing[cid] = r
106
- store.retrieved = sorted(existing.values(), key=lambda r: r.score, reverse=True)
107
-
108
- if not results:
109
- return "Ingen relevante dokumenter fundet. (No relevant documents found.)"
110
-
111
- parts: list[str] = []
112
- for i, r in enumerate(results, 1):
113
- parts.append(
114
- f"[{i}] {r.chunk.document_id} (relevance: {r.score:.3f})\n{r.chunk.text}"
115
- )
116
- return "\n\n---\n\n".join(parts)
117
 
118
  @tool
119
  def list_documents() -> str:
@@ -161,11 +191,8 @@ def make_retrieval_tools(
161
  f"(Document not found. Use list_documents to see available IDs.)"
162
  )
163
 
164
- # Sort chunks by chunk_index to preserve document order
165
  chunks.sort(key=lambda c: c.metadata.get("chunk_index", 0))
166
 
167
- # Register chunks as QueryResult so confidence and sources are surfaced in the UI.
168
- # Score 1.0 indicates a direct full-document fetch (no ranking involved).
169
  existing = {r.chunk.chunk_id: r for r in store.retrieved}
170
  for chunk in chunks:
171
  if chunk.chunk_id not in existing:
@@ -178,4 +205,171 @@ def make_retrieval_tools(
178
  f"{full_text}"
179
  )
180
 
181
- return [hybrid_search, list_documents, fetch_document]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
  from dataclasses import dataclass, field
5
 
6
+ from langchain_core.runnables import Runnable
7
  from langchain_core.tools import tool
8
 
9
  from src.models import QueryResult
 
34
  fused_results: list[QueryResult] = field(default_factory=list)
35
 
36
 
37
+ def _merge_results(existing: list[QueryResult], new: list[QueryResult]) -> list[QueryResult]:
38
+ """Merge two QueryResult lists by chunk_id, keeping the highest score.
39
+
40
+ Args:
41
+ existing: Previously accumulated results.
42
+ new: New results to merge in.
43
+
44
+ Returns:
45
+ Merged list sorted by descending score.
46
+ """
47
+ by_id = {r.chunk.chunk_id: r for r in existing}
48
+ for r in new:
49
+ cid = r.chunk.chunk_id
50
+ if cid not in by_id or r.score > by_id[cid].score:
51
+ by_id[cid] = r
52
+ return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
53
+
54
+
55
+ def _format_results(results: list[QueryResult]) -> str:
56
+ """Format a list of QueryResult into a readable string.
57
+
58
+ Args:
59
+ results: Ranked results to format.
60
+
61
+ Returns:
62
+ Formatted string with numbered entries, or a no-results message.
63
+ """
64
+ if not results:
65
+ return "Ingen relevante dokumenter fundet. (No relevant documents found.)"
66
+ parts: list[str] = []
67
+ for i, r in enumerate(results, 1):
68
+ page_info = ""
69
+ page = r.chunk.metadata.get("page_number")
70
+ if page is not None:
71
+ page_info = f" side {page}"
72
+ parts.append(
73
+ f"[{i}] {r.chunk.document_id}{page_info} (relevance: {r.score:.3f})\n{r.chunk.text}"
74
+ )
75
+ return "\n\n---\n\n".join(parts)
76
+
77
+
78
  def make_retrieval_tools(
79
  hybrid_retriever: HybridRetriever,
80
  reranker: Reranker,
81
  vector_store: VectorStore,
82
  store: ToolResultStore,
83
  default_top_k: int = 5,
84
+ llm_chain: Runnable | None = None,
85
  ) -> list:
86
  """Create retrieval tools bound to the given components and result store.
87
 
 
95
  vector_store: VectorStore instance for document-level access.
96
  store: Shared ToolResultStore that captures structured results.
97
  default_top_k: Default number of results to return per call.
98
+ llm_chain: Optional LLM chain for tools that need generation
99
+ (summarize_document, multi_query_search). When None, those
100
+ tools are excluded from the returned list.
101
 
102
  Returns:
103
  List of LangChain tool callables ready for bind_tools / ToolNode.
104
  """
105
 
106
+ # ------------------------------------------------------------------
107
+ # Core search tool
108
+ # ------------------------------------------------------------------
109
+
110
  @tool
111
  def hybrid_search(query: str, top_k: int = default_top_k) -> str:
112
  """Search the KU document knowledge base using hybrid retrieval.
 
134
  hybrid_result = hybrid_retriever.search_detailed(query, top_k=top_k)
135
  results = reranker.rerank(query, hybrid_result.fused_results, top_k=top_k)
136
 
137
+ store.dense_results = _merge_results(store.dense_results, hybrid_result.dense_results)
138
+ store.sparse_results = _merge_results(store.sparse_results, hybrid_result.sparse_results)
139
+ store.fused_results = _merge_results(store.fused_results, hybrid_result.fused_results)
140
+ store.retrieved = _merge_results(store.retrieved, results)
 
 
 
 
141
 
142
+ return _format_results(results)
 
 
143
 
144
+ # ------------------------------------------------------------------
145
+ # Document-level tools
146
+ # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  @tool
149
  def list_documents() -> str:
 
191
  f"(Document not found. Use list_documents to see available IDs.)"
192
  )
193
 
 
194
  chunks.sort(key=lambda c: c.metadata.get("chunk_index", 0))
195
 
 
 
196
  existing = {r.chunk.chunk_id: r for r in store.retrieved}
197
  for chunk in chunks:
198
  if chunk.chunk_id not in existing:
 
205
  f"{full_text}"
206
  )
207
 
208
+ # ------------------------------------------------------------------
209
+ # Targeted within-document search
210
+ # ------------------------------------------------------------------
211
+
212
+ @tool
213
+ def search_within_document(document_id: str, query: str, top_k: int = 3) -> str:
214
+ """Search for specific information within a single document.
215
+
216
+ Retrieves all chunks belonging to the document and uses the cross-encoder
217
+ reranker to find the most relevant passages for the query. Use this when
218
+ you already know which document to look in and need to pinpoint the exact
219
+ section (e.g. a specific clause, page, or paragraph).
220
+
221
+ Args:
222
+ document_id: The exact document ID to search within.
223
+ query: What to look for inside the document.
224
+ top_k: Number of top passages to return (1–10). Default is 3.
225
+
226
+ Returns:
227
+ The most relevant passages within the document, ranked by relevance.
228
+ """
229
+ logger.info(
230
+ "Tool search_within_document: doc=%r query=%r top_k=%d",
231
+ document_id, query, top_k,
232
+ )
233
+ store.tool_calls.append(("search_within_document", f"{document_id}: {query}"))
234
+
235
+ chunks = vector_store.get_chunks_by_document_id(document_id)
236
+ if not chunks:
237
+ return (
238
+ f"Dokumentet '{document_id}' blev ikke fundet i vidensbasen. "
239
+ f"(Document not found. Use list_documents to see available IDs.)"
240
+ )
241
+
242
+ # Wrap chunks as QueryResult so the reranker can score them
243
+ candidates = [
244
+ QueryResult(chunk=c, score=0.0, source="search_within_document")
245
+ for c in chunks
246
+ ]
247
+ results = reranker.rerank(query, candidates, top_k=top_k)
248
+
249
+ store.retrieved = _merge_results(store.retrieved, results)
250
+
251
+ return _format_results(results)
252
+
253
+ # ------------------------------------------------------------------
254
+ # LLM-powered tools (only available when llm_chain is provided)
255
+ # ------------------------------------------------------------------
256
+
257
+ tools: list = [hybrid_search, list_documents, fetch_document, search_within_document]
258
+
259
+ if llm_chain is not None:
260
+
261
+ @tool
262
+ def multi_query_search(question: str, top_k: int = default_top_k) -> str:
263
+ """Decompose a complex question into sub-queries and search each independently.
264
+
265
+ Use this tool instead of hybrid_search when the question involves
266
+ multiple aspects, comparisons, or requires information from different
267
+ topics. For example: "How do exam rules differ between bachelor and
268
+ master programmes?" would be split into separate searches for each
269
+ programme's exam rules, then merged.
270
+
271
+ Args:
272
+ question: The complex user question to decompose and search.
273
+ top_k: Number of results to return per sub-query (1–10). Default is 5.
274
+
275
+ Returns:
276
+ Combined results from all sub-queries, deduplicated and ranked.
277
+ """
278
+ logger.info("Tool multi_query_search: question=%r", question)
279
+ store.tool_calls.append(("multi_query_search", question))
280
+
281
+ # Step 1: Ask LLM to decompose the question
282
+ decompose_prompt = (
283
+ "You are a search query planner. Given a complex question, "
284
+ "decompose it into 2-4 simple, independent search queries that "
285
+ "together cover all aspects of the question. The queries should "
286
+ "be in Danish (since the document base is Danish).\n\n"
287
+ "Reply with ONLY the queries, one per line, nothing else.\n\n"
288
+ f"Question: {question}"
289
+ )
290
+ raw = str(llm_chain.invoke(decompose_prompt)).strip()
291
+ sub_queries = [q.strip().lstrip("0123456789.-) ") for q in raw.splitlines() if q.strip()]
292
+ if not sub_queries:
293
+ sub_queries = [question]
294
+
295
+ logger.info("Decomposed into %d sub-queries: %s", len(sub_queries), sub_queries)
296
+
297
+ # Step 2: Search each sub-query independently
298
+ all_results: list[QueryResult] = []
299
+ for sq in sub_queries:
300
+ hybrid_result = hybrid_retriever.search_detailed(sq, top_k=top_k)
301
+ reranked = reranker.rerank(sq, hybrid_result.fused_results, top_k=top_k)
302
+ all_results = _merge_results(all_results, reranked)
303
+
304
+ store.dense_results = _merge_results(store.dense_results, hybrid_result.dense_results)
305
+ store.sparse_results = _merge_results(store.sparse_results, hybrid_result.sparse_results)
306
+ store.fused_results = _merge_results(store.fused_results, hybrid_result.fused_results)
307
+
308
+ # Step 3: Keep top results across all sub-queries
309
+ final = all_results[:top_k]
310
+ store.retrieved = _merge_results(store.retrieved, final)
311
+
312
+ header = f"Søgning opdelt i {len(sub_queries)} delforespørgsler:\n"
313
+ header += "\n".join(f" • {sq}" for sq in sub_queries)
314
+ header += "\n\n"
315
+ return header + _format_results(final)
316
+
317
+ @tool
318
+ def summarize_document(document_id: str) -> str:
319
+ """Generate a structured summary of a document in the knowledge base.
320
+
321
+ Fetches the full document and uses the LLM to produce a concise summary
322
+ covering the main topics, key rules, and important details. Use this
323
+ when the user asks "what is this document about?" or wants an overview
324
+ before diving into specifics.
325
+
326
+ Args:
327
+ document_id: The exact document ID to summarize.
328
+
329
+ Returns:
330
+ A structured summary of the document, or an error if not found.
331
+ """
332
+ logger.info("Tool summarize_document: document_id=%r", document_id)
333
+ store.tool_calls.append(("summarize_document", document_id))
334
+
335
+ chunks = vector_store.get_chunks_by_document_id(document_id)
336
+ if not chunks:
337
+ return (
338
+ f"Dokumentet '{document_id}' blev ikke fundet i vidensbasen. "
339
+ f"(Document not found. Use list_documents to see available IDs.)"
340
+ )
341
+
342
+ chunks.sort(key=lambda c: c.metadata.get("chunk_index", 0))
343
+ full_text = "\n\n".join(c.text for c in chunks)
344
+
345
+ # Register chunks as sources
346
+ existing = {r.chunk.chunk_id: r for r in store.retrieved}
347
+ for chunk in chunks:
348
+ if chunk.chunk_id not in existing:
349
+ existing[chunk.chunk_id] = QueryResult(
350
+ chunk=chunk, score=1.0, source="summarize_document",
351
+ )
352
+ store.retrieved = sorted(existing.values(), key=lambda r: r.score, reverse=True)
353
+
354
+ # Truncate to avoid exceeding context limits
355
+ max_chars = 8000
356
+ if len(full_text) > max_chars:
357
+ full_text = full_text[:max_chars] + "\n\n[... teksten er forkortet ...]"
358
+
359
+ summary_prompt = (
360
+ "Produce a structured summary of the following document. "
361
+ "Include:\n"
362
+ "1. Document title/topic\n"
363
+ "2. Key points (3-7 bullet points)\n"
364
+ "3. Important rules, deadlines, or requirements mentioned\n"
365
+ "4. Who the document applies to\n\n"
366
+ "Write the summary in the same language as the document.\n\n"
367
+ f"Document ID: {document_id}\n\n"
368
+ f"Document text:\n{full_text}"
369
+ )
370
+ summary = str(llm_chain.invoke(summary_prompt)).strip()
371
+ return f"Resumé af {document_id}:\n\n{summary}"
372
+
373
+ tools.extend([multi_query_search, summarize_document])
374
+
375
+ return tools
src/api/main.py CHANGED
@@ -16,7 +16,8 @@ from src.retrieval.hybrid import HybridRetriever
16
  from src.retrieval.reranker import Reranker
17
  from src.agent.intent_classifier import IntentClassifier
18
  from src.agent.router import QueryRouter
19
- from src.agent.react_router import ReActRouter
 
20
  from src.ingestion.pipeline import IngestionPipeline
21
  from src.api.routes import router, set_dependencies
22
 
@@ -72,13 +73,14 @@ def create_app() -> FastAPI:
72
  reranker = Reranker(model=create_reranker(settings.reranker_model))
73
 
74
  if settings.agent_mode == "react":
75
- logger.info("Agent mode: ReAct (tool-calling loop)")
76
- query_router: QueryRouter | ReActRouter = ReActRouter(
77
  llm=llm,
78
  hybrid_retriever=hybrid_retriever,
79
  reranker=reranker,
80
  vector_store=vector_store,
81
  default_top_k=settings.top_k,
 
82
  )
83
  else:
84
  logger.info("Agent mode: pipeline (fixed DAG)")
 
16
  from src.retrieval.reranker import Reranker
17
  from src.agent.intent_classifier import IntentClassifier
18
  from src.agent.router import QueryRouter
19
+ from src.agent.plan_and_execute import PlanAndExecuteRouter
20
+ from src.agent.memory import ConversationMemory
21
  from src.ingestion.pipeline import IngestionPipeline
22
  from src.api.routes import router, set_dependencies
23
 
 
73
  reranker = Reranker(model=create_reranker(settings.reranker_model))
74
 
75
  if settings.agent_mode == "react":
76
+ logger.info("Agent mode: Plan-and-Execute (structured multi-step agent)")
77
+ query_router: QueryRouter | PlanAndExecuteRouter = PlanAndExecuteRouter(
78
  llm=llm,
79
  hybrid_retriever=hybrid_retriever,
80
  reranker=reranker,
81
  vector_store=vector_store,
82
  default_top_k=settings.top_k,
83
+ memory=ConversationMemory(),
84
  )
85
  else:
86
  logger.info("Agent mode: pipeline (fixed DAG)")
src/api/routes.py CHANGED
@@ -15,6 +15,7 @@ from pydantic import BaseModel
15
  if TYPE_CHECKING:
16
  from src.agent.router import QueryRouter
17
  from src.agent.react_router import ReActRouter
 
18
  from src.config import Settings
19
  from src.ingestion.pipeline import IngestionPipeline
20
  from src.retrieval.bm25_search import BM25Search
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
25
 
26
  router = APIRouter()
27
 
28
- _query_router: "QueryRouter | ReActRouter | None" = None
29
  _ingestion_pipeline: "IngestionPipeline | None" = None
30
  _embedder: "Embedder | None" = None
31
  _vector_store: "VectorStore | None" = None
@@ -34,7 +35,7 @@ _settings: "Settings | None" = None
34
 
35
 
36
  def set_dependencies(
37
- query_router: "QueryRouter | ReActRouter",
38
  ingestion_pipeline: "IngestionPipeline",
39
  embedder: "Embedder",
40
  vector_store: "VectorStore",
@@ -75,6 +76,7 @@ class PipelineResultItem(BaseModel):
75
  chunk_id: str
76
  score: float
77
  source: str
 
78
 
79
 
80
  class PipelineDetailsResponse(BaseModel):
@@ -88,6 +90,8 @@ class PipelineDetailsResponse(BaseModel):
88
  sparse_results: list[PipelineResultItem] = []
89
  fused_results: list[PipelineResultItem] = []
90
  reranked_results: list[PipelineResultItem] = []
 
 
91
 
92
 
93
  class SourceItem(BaseModel):
@@ -206,6 +210,8 @@ async def query_documents(request: QueryRequest) -> QueryResponse:
206
  sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results],
207
  fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results],
208
  reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results],
 
 
209
  )
210
 
211
  return QueryResponse(
 
15
  if TYPE_CHECKING:
16
  from src.agent.router import QueryRouter
17
  from src.agent.react_router import ReActRouter
18
+ from src.agent.plan_and_execute import PlanAndExecuteRouter
19
  from src.config import Settings
20
  from src.ingestion.pipeline import IngestionPipeline
21
  from src.retrieval.bm25_search import BM25Search
 
26
 
27
  router = APIRouter()
28
 
29
+ _query_router: "QueryRouter | ReActRouter | PlanAndExecuteRouter | None" = None
30
  _ingestion_pipeline: "IngestionPipeline | None" = None
31
  _embedder: "Embedder | None" = None
32
  _vector_store: "VectorStore | None" = None
 
35
 
36
 
37
  def set_dependencies(
38
+ query_router: "QueryRouter | ReActRouter | PlanAndExecuteRouter",
39
  ingestion_pipeline: "IngestionPipeline",
40
  embedder: "Embedder",
41
  vector_store: "VectorStore",
 
76
  chunk_id: str
77
  score: float
78
  source: str
79
+ metadata: dict[str, str | int] = {}
80
 
81
 
82
  class PipelineDetailsResponse(BaseModel):
 
90
  sparse_results: list[PipelineResultItem] = []
91
  fused_results: list[PipelineResultItem] = []
92
  reranked_results: list[PipelineResultItem] = []
93
+ plan_steps: list[str] = []
94
+ tool_calls: list[str] = []
95
 
96
 
97
  class SourceItem(BaseModel):
 
210
  sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results],
211
  fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results],
212
  reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results],
213
+ plan_steps=pd.plan_steps,
214
+ tool_calls=pd.tool_calls,
215
  )
216
 
217
  return QueryResponse(
src/models.py CHANGED
@@ -91,6 +91,8 @@ class PipelineDetails:
91
  sparse_results: Results from sparse (BM25) retrieval.
92
  fused_results: Results after reciprocal rank fusion.
93
  reranked_results: Results after cross-encoder reranking.
 
 
94
  """
95
 
96
  original_query: str = ""
@@ -101,6 +103,8 @@ class PipelineDetails:
101
  sparse_results: list[QueryResult] = field(default_factory=list)
102
  fused_results: list[QueryResult] = field(default_factory=list)
103
  reranked_results: list[QueryResult] = field(default_factory=list)
 
 
104
 
105
 
106
  @dataclass
 
91
  sparse_results: Results from sparse (BM25) retrieval.
92
  fused_results: Results after reciprocal rank fusion.
93
  reranked_results: Results after cross-encoder reranking.
94
+ plan_steps: Ordered descriptions of planned steps (Plan-and-Execute mode).
95
+ tool_calls: Log of tool invocations as "tool_name: argument" strings.
96
  """
97
 
98
  original_query: str = ""
 
103
  sparse_results: list[QueryResult] = field(default_factory=list)
104
  fused_results: list[QueryResult] = field(default_factory=list)
105
  reranked_results: list[QueryResult] = field(default_factory=list)
106
+ plan_steps: list[str] = field(default_factory=list)
107
+ tool_calls: list[str] = field(default_factory=list)
108
 
109
 
110
  @dataclass
src/ui/app.py CHANGED
@@ -53,7 +53,7 @@ TEXTS: dict[str, dict[str, str]] = {
53
  "- **LLM-integration** — provider-agnostisk, prompt-styret "
54
  "svargenerering\n"
55
  "- **Evaluering** — RAGAS-baseret kvalitetsmåling\n"
56
- "- **Agent Flows** — ReAct-loop med værktøjskald.\n"
57
  "- [**Kildedokumenter**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
58
  " — de dokumenter systemet er indekseret fra"
59
  ),
@@ -67,8 +67,8 @@ TEXTS: dict[str, dict[str, str]] = {
67
  "Et dokumentintelligens-system bygget på en RAG-arkitektur, dækkende file-indlæsning, semantisk chunking, "
68
  "hybrid søgning med reranking "
69
  "og LLM-genererede svar med kildehenvisninger. LLM-laget er provider-agnostisk. "
70
- "To tilstande: en LangGraph ReAct-agent (standard) til forespørgsler der kræver flere søgetrin, "
71
- "og en pipeline til lette modeller uden værktøjskald. Søgekvaliteten evalueres med RAGAS."
72
  ),
73
  "search_label": "Stil et spørgsmål om ... ",
74
  "search_placeholder": "F.eks.: Hvad er reglerne for behandling af personoplysninger?",
@@ -113,6 +113,8 @@ TEXTS: dict[str, dict[str, str]] = {
113
  "pipeline_rank": "#",
114
  "pipeline_no_results": "Ingen resultater",
115
  "pipeline_score_change": "Score-ændring",
 
 
116
  },
117
  "en": {
118
  "page_title": "Document Assistant",
@@ -131,7 +133,7 @@ TEXTS: dict[str, dict[str, str]] = {
131
  "- **LLM integration** — provider-agnostic, prompt-driven "
132
  "answer generation\n"
133
  "- **Evaluation** — RAGAS-based quality measurement\n"
134
- "- **Agent Flows** — ReAct loop with tool calling\n"
135
  "- [**Source documents**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
136
  " — the documents indexed into the knowledge base"
137
  ),
@@ -145,8 +147,8 @@ TEXTS: dict[str, dict[str, str]] = {
145
  "A document intelligence system built on a RAG architecture, covering file ingestion, semantic chunking, "
146
  "hybrid retrieval with reranking, "
147
  "and LLM-generated answers with source citations. The LLM layer is provider-agnostic. "
148
- "Two modes: a LangGraph ReAct agent (default) for queries that need multiple retrieval steps, "
149
- "and a pipeline for lightweight models without tool-calling support. "
150
  "Retrieval quality is evaluated with RAGAS."
151
  ),
152
  "search_label": "Ask a question ...",
@@ -192,6 +194,8 @@ TEXTS: dict[str, dict[str, str]] = {
192
  "pipeline_rank": "#",
193
  "pipeline_no_results": "No results",
194
  "pipeline_score_change": "Score change",
 
 
195
  },
196
  }
197
 
@@ -711,6 +715,32 @@ if search_clicked and question.strip():
711
  else (f"Reranked to **{_rc}** results · confidence **{_cf:.0%}**")
712
  )
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  elif _step == "tool_call":
715
  _tool_name = _event.get("tool", "")
716
  _tool_query = _event.get("query", "")
@@ -850,6 +880,21 @@ if search_clicked and question.strip():
850
  pd = data.get("pipeline_details", {})
851
  if pd:
852
  with st.expander(t["pipeline_heading"], expanded=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  # 1) Query translation (only show if translation actually happened)
854
  if pd.get("translated"):
855
  st.markdown(f'**{t["pipeline_translation"]}**')
 
53
  "- **LLM-integration** — provider-agnostisk, prompt-styret "
54
  "svargenerering\n"
55
  "- **Evaluering** — RAGAS-baseret kvalitetsmåling\n"
56
+ "- **Agent Flows** — LangGraph Plan-and-Execute med værktøjskald og samtalehukommelse\n"
57
  "- [**Kildedokumenter**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
58
  " — de dokumenter systemet er indekseret fra"
59
  ),
 
67
  "Et dokumentintelligens-system bygget på en RAG-arkitektur, dækkende file-indlæsning, semantisk chunking, "
68
  "hybrid søgning med reranking "
69
  "og LLM-genererede svar med kildehenvisninger. LLM-laget er provider-agnostisk. "
70
+ "To tilstande: en LangGraph Plan-and-Execute-agent (standard) med samtalehukommelse til komplekse forespørgsler, "
71
+ "og en foruddefineret pipeline til lette modeller. Søgekvaliteten evalueres med RAGAS."
72
  ),
73
  "search_label": "Stil et spørgsmål om ... ",
74
  "search_placeholder": "F.eks.: Hvad er reglerne for behandling af personoplysninger?",
 
113
  "pipeline_rank": "#",
114
  "pipeline_no_results": "Ingen resultater",
115
  "pipeline_score_change": "Score-ændring",
116
+ "pipeline_plan_steps": "Udførelsesplan",
117
+ "pipeline_tool_calls": "Værktøjskald",
118
  },
119
  "en": {
120
  "page_title": "Document Assistant",
 
133
  "- **LLM integration** — provider-agnostic, prompt-driven "
134
  "answer generation\n"
135
  "- **Evaluation** — RAGAS-based quality measurement\n"
136
+ "- **Agent Flows** — LangGraph Plan-and-Execute with tool calling and conversation memory\n"
137
  "- [**Source documents**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
138
  " — the documents indexed into the knowledge base"
139
  ),
 
147
  "A document intelligence system built on a RAG architecture, covering file ingestion, semantic chunking, "
148
  "hybrid retrieval with reranking, "
149
  "and LLM-generated answers with source citations. The LLM layer is provider-agnostic. "
150
+ "Two modes: a LangGraph Plan-and-Execute agent (default) with conversation memory for complex multi-step queries, "
151
+ "and a predefined pipeline for lightweight models. "
152
  "Retrieval quality is evaluated with RAGAS."
153
  ),
154
  "search_label": "Ask a question ...",
 
194
  "pipeline_rank": "#",
195
  "pipeline_no_results": "No results",
196
  "pipeline_score_change": "Score change",
197
+ "pipeline_plan_steps": "Execution Plan",
198
+ "pipeline_tool_calls": "Tool Calls",
199
  },
200
  }
201
 
 
715
  else (f"Reranked to **{_rc}** results · confidence **{_cf:.0%}**")
716
  )
717
 
718
+ elif _step == "plan":
719
+ _steps = _event.get("steps", [])
720
+ st.write(
721
+ (f"Plan oprettet med **{len(_steps)}** trin")
722
+ if lang == "da"
723
+ else (f"Plan created with **{len(_steps)}** steps")
724
+ )
725
+ for _ps in _steps:
726
+ st.write(f" - {_ps}")
727
+
728
+ elif _step == "execute_step":
729
+ _si = _event.get("step_index", 0)
730
+ _sd = _event.get("step_desc", "")
731
+ st.write(
732
+ (f"Trin {_si} udført: _{_sd}_")
733
+ if lang == "da"
734
+ else (f"Step {_si} executed: _{_sd}_")
735
+ )
736
+
737
+ elif _step == "synthesize":
738
+ st.write(
739
+ "Syntetiserer endeligt svar ..."
740
+ if lang == "da"
741
+ else "Synthesizing final answer ..."
742
+ )
743
+
744
  elif _step == "tool_call":
745
  _tool_name = _event.get("tool", "")
746
  _tool_query = _event.get("query", "")
 
880
  pd = data.get("pipeline_details", {})
881
  if pd:
882
  with st.expander(t["pipeline_heading"], expanded=False):
883
+ # 0) Plan steps and tool calls (Plan-and-Execute mode)
884
+ plan_steps = pd.get("plan_steps", [])
885
+ if plan_steps:
886
+ st.markdown(f'**{t["pipeline_plan_steps"]}**')
887
+ for i, step in enumerate(plan_steps, 1):
888
+ st.markdown(f"{i}. {step}")
889
+ st.markdown("---")
890
+
891
+ tool_calls = pd.get("tool_calls", [])
892
+ if tool_calls:
893
+ st.markdown(f'**{t["pipeline_tool_calls"]}**')
894
+ for tc in tool_calls:
895
+ st.markdown(f"- `{tc}`")
896
+ st.markdown("---")
897
+
898
  # 1) Query translation (only show if translation actually happened)
899
  if pd.get("translated"):
900
  st.markdown(f'**{t["pipeline_translation"]}**')
tests/test_memory.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for conversation memory."""
2
+
3
+ import pytest
4
+
5
+ from src.agent.memory import ConversationMemory, Turn
6
+ from src.models import DocumentChunk, QueryResult
7
+
8
+
9
+ # ---------------------------------------------------------------------------
10
+ # Helpers
11
+ # ---------------------------------------------------------------------------
12
+
13
+ def _qr(chunk_id: str = "c1", doc_id: str = "doc.pdf", score: float = 0.8) -> QueryResult:
14
+ chunk = DocumentChunk(
15
+ chunk_id=chunk_id, document_id=doc_id, text="text",
16
+ metadata={"page_number": 1},
17
+ )
18
+ return QueryResult(chunk=chunk, score=score, source="test")
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Basic operations
23
+ # ---------------------------------------------------------------------------
24
+
25
+ class TestConversationMemory:
26
+ def test_initially_empty(self) -> None:
27
+ mem = ConversationMemory()
28
+ assert mem.is_empty
29
+ assert mem.turns == []
30
+ assert mem.last_query() == ""
31
+ assert mem.last_sources() == []
32
+
33
+ def test_add_turn(self) -> None:
34
+ mem = ConversationMemory()
35
+ mem.add_turn("What is X?", "X is Y.", [_qr()])
36
+ assert not mem.is_empty
37
+ assert len(mem.turns) == 1
38
+ assert mem.last_query() == "What is X?"
39
+
40
+ def test_multiple_turns(self) -> None:
41
+ mem = ConversationMemory()
42
+ mem.add_turn("Q1", "A1")
43
+ mem.add_turn("Q2", "A2")
44
+ assert len(mem.turns) == 2
45
+ assert mem.last_query() == "Q2"
46
+
47
+ def test_clear(self) -> None:
48
+ mem = ConversationMemory()
49
+ mem.add_turn("Q1", "A1")
50
+ mem.clear()
51
+ assert mem.is_empty
52
+
53
+ def test_turns_returns_copy(self) -> None:
54
+ mem = ConversationMemory()
55
+ mem.add_turn("Q1", "A1")
56
+ turns = mem.turns
57
+ turns.append(Turn(query="fake", answer="fake"))
58
+ assert len(mem.turns) == 1 # original unaffected
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Eviction
63
+ # ---------------------------------------------------------------------------
64
+
65
+ class TestEviction:
66
+ def test_max_turns_eviction(self) -> None:
67
+ mem = ConversationMemory(max_turns=3)
68
+ for i in range(5):
69
+ mem.add_turn(f"Q{i}", f"A{i}")
70
+ assert len(mem.turns) == 3
71
+ # Oldest should be Q2 (Q0 and Q1 evicted)
72
+ assert mem.turns[0].query == "Q2"
73
+
74
+ def test_max_turns_one(self) -> None:
75
+ mem = ConversationMemory(max_turns=1)
76
+ mem.add_turn("Q1", "A1")
77
+ mem.add_turn("Q2", "A2")
78
+ assert len(mem.turns) == 1
79
+ assert mem.turns[0].query == "Q2"
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # format_history
84
+ # ---------------------------------------------------------------------------
85
+
86
+ class TestFormatHistory:
87
+ def test_empty_history(self) -> None:
88
+ mem = ConversationMemory()
89
+ assert mem.format_history() == ""
90
+
91
+ def test_includes_query_and_answer(self) -> None:
92
+ mem = ConversationMemory()
93
+ mem.add_turn("What is X?", "X is a policy.")
94
+ text = mem.format_history()
95
+ assert "What is X?" in text
96
+ assert "X is a policy." in text
97
+
98
+ def test_includes_source_doc_ids(self) -> None:
99
+ mem = ConversationMemory()
100
+ sources = [_qr(doc_id="policy.pdf"), _qr(chunk_id="c2", doc_id="rules.pdf")]
101
+ mem.add_turn("Q", "A", sources)
102
+ text = mem.format_history()
103
+ assert "policy.pdf" in text
104
+ assert "rules.pdf" in text
105
+
106
+ def test_max_recent_limits_output(self) -> None:
107
+ mem = ConversationMemory()
108
+ for i in range(10):
109
+ mem.add_turn(f"Q{i}", f"A{i}")
110
+ text = mem.format_history(max_recent=2)
111
+ assert "Q8" in text
112
+ assert "Q9" in text
113
+ assert "Q0" not in text
114
+
115
+ def test_long_answer_truncated(self) -> None:
116
+ mem = ConversationMemory()
117
+ mem.add_turn("Q", "x" * 1000)
118
+ text = mem.format_history()
119
+ # Answer should be truncated to 500 chars
120
+ assert len(text) < 1000
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # get_prior_sources
125
+ # ---------------------------------------------------------------------------
126
+
127
+ class TestGetPriorSources:
128
+ def test_empty_returns_empty(self) -> None:
129
+ mem = ConversationMemory()
130
+ assert mem.get_prior_sources() == []
131
+
132
+ def test_collects_across_turns(self) -> None:
133
+ mem = ConversationMemory()
134
+ mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.8)])
135
+ mem.add_turn("Q2", "A2", [_qr(chunk_id="c2", score=0.9)])
136
+ sources = mem.get_prior_sources()
137
+ assert len(sources) == 2
138
+ # Sorted by score descending
139
+ assert sources[0].score == 0.9
140
+
141
+ def test_deduplicates_by_chunk_id(self) -> None:
142
+ mem = ConversationMemory()
143
+ mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.5)])
144
+ mem.add_turn("Q2", "A2", [_qr(chunk_id="c1", score=0.9)])
145
+ sources = mem.get_prior_sources()
146
+ assert len(sources) == 1
147
+ assert sources[0].score == 0.9 # keeps higher score
148
+
149
+ def test_no_sources_turns(self) -> None:
150
+ mem = ConversationMemory()
151
+ mem.add_turn("Q1", "A1") # no sources
152
+ assert mem.get_prior_sources() == []
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Integration: memory in PlanAndExecuteRouter
157
+ # ---------------------------------------------------------------------------
158
+
159
+ class TestMemoryIntegration:
160
+ def test_route_records_turn(self) -> None:
161
+ """After route(), the conversation turn should be recorded in memory."""
162
+ from unittest.mock import MagicMock, patch
163
+ from langchain_core.messages import AIMessage
164
+ from src.agent.plan_and_execute import PlanAndExecuteRouter
165
+
166
+ llm = MagicMock()
167
+ retriever = MagicMock()
168
+ reranker = MagicMock()
169
+ vector_store = MagicMock()
170
+ memory = ConversationMemory()
171
+
172
+ plan_json = '[{"action": "search", "detail": "test"}]'
173
+ llm.invoke.side_effect = [plan_json, "The answer."]
174
+
175
+ mock_agent = MagicMock()
176
+ mock_agent.invoke.return_value = {"messages": [AIMessage(content="Found info.")]}
177
+
178
+ router = PlanAndExecuteRouter(
179
+ llm, retriever, reranker, vector_store, memory=memory,
180
+ )
181
+
182
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
183
+ router.route("test question", top_k=5)
184
+
185
+ assert not memory.is_empty
186
+ assert memory.last_query() == "test question"
187
+ assert memory.turns[0].answer == "The answer."
188
+
189
+ def test_history_injected_into_planner(self) -> None:
190
+ """On a follow-up query, conversation history should appear in the planner prompt."""
191
+ from unittest.mock import MagicMock, patch
192
+ from langchain_core.messages import AIMessage
193
+ from src.agent.plan_and_execute import PlanAndExecuteRouter
194
+
195
+ llm = MagicMock()
196
+ memory = ConversationMemory()
197
+ memory.add_turn("What is the exam policy?", "The exam policy says...")
198
+
199
+ plan_json = '[{"action": "search", "detail": "follow-up"}]'
200
+ llm.invoke.side_effect = [plan_json, "Follow-up answer."]
201
+
202
+ mock_agent = MagicMock()
203
+ mock_agent.invoke.return_value = {"messages": [AIMessage(content="More info.")]}
204
+
205
+ router = PlanAndExecuteRouter(
206
+ llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
207
+ )
208
+
209
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
210
+ router.route("What about the grading?", top_k=5)
211
+
212
+ # The first LLM call is the planner — check it includes history
213
+ planner_prompt = llm.invoke.call_args_list[0][0][0]
214
+ assert "exam policy" in planner_prompt
215
+ assert "Conversation history" in planner_prompt
216
+
217
+ def test_multi_turn_accumulates(self) -> None:
218
+ """Multiple route() calls should accumulate turns in memory."""
219
+ from unittest.mock import MagicMock, patch
220
+ from langchain_core.messages import AIMessage
221
+ from src.agent.plan_and_execute import PlanAndExecuteRouter
222
+
223
+ llm = MagicMock()
224
+ memory = ConversationMemory()
225
+
226
+ mock_agent = MagicMock()
227
+ mock_agent.invoke.return_value = {"messages": [AIMessage(content="info")]}
228
+
229
+ router = PlanAndExecuteRouter(
230
+ llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
231
+ )
232
+
233
+ for i in range(3):
234
+ plan_json = f'[{{"action": "search", "detail": "q{i}"}}]'
235
+ llm.invoke.side_effect = [plan_json, f"Answer {i}"]
236
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
237
+ router.route(f"Question {i}", top_k=5)
238
+
239
+ assert len(memory.turns) == 3
tests/test_plan_and_execute.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the Plan-and-Execute agent router."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+ import json
5
+
6
+ import pytest
7
+
8
+ from src.agent.plan_and_execute import (
9
+ PlanAndExecuteRouter,
10
+ PlanExecState,
11
+ PlanStep,
12
+ _extract_last_ai_text,
13
+ _parse_plan,
14
+ )
15
+ from src.models import (
16
+ DocumentChunk,
17
+ GenerationResponse,
18
+ IntentType,
19
+ QueryResult,
20
+ )
21
+ from src.retrieval.hybrid import HybridSearchResult
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Helpers
26
+ # ---------------------------------------------------------------------------
27
+
28
+ def _chunk(chunk_id: str = "c1", text: str = "text") -> DocumentChunk:
29
+ return DocumentChunk(
30
+ chunk_id=chunk_id, document_id="doc.pdf", text=text,
31
+ metadata={"page_number": 1, "chunk_index": 0},
32
+ )
33
+
34
+
35
+ def _qr(chunk_id: str = "c1", score: float = 0.8, text: str = "text") -> QueryResult:
36
+ return QueryResult(chunk=_chunk(chunk_id=chunk_id, text=text), score=score, source="hybrid")
37
+
38
+
39
+ def _hybrid_result(results: list[QueryResult]) -> HybridSearchResult:
40
+ return HybridSearchResult(
41
+ dense_results=results, sparse_results=results, fused_results=results,
42
+ )
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # _parse_plan
47
+ # ---------------------------------------------------------------------------
48
+
49
+ class TestParsePlan:
50
+ def test_valid_json(self) -> None:
51
+ raw = '[{"action": "search", "detail": "exam rules"}]'
52
+ steps = _parse_plan(raw)
53
+ assert len(steps) == 1
54
+ assert steps[0]["action"] == "search"
55
+ assert steps[0]["detail"] == "exam rules"
56
+
57
+ def test_multiple_steps(self) -> None:
58
+ raw = json.dumps([
59
+ {"action": "search", "detail": "policy A"},
60
+ {"action": "search", "detail": "policy B"},
61
+ {"action": "summarize", "detail": "doc.pdf"},
62
+ ])
63
+ steps = _parse_plan(raw)
64
+ assert len(steps) == 3
65
+
66
+ def test_markdown_fenced(self) -> None:
67
+ raw = '```json\n[{"action": "search", "detail": "test"}]\n```'
68
+ steps = _parse_plan(raw)
69
+ assert len(steps) == 1
70
+ assert steps[0]["action"] == "search"
71
+
72
+ def test_json_with_surrounding_text(self) -> None:
73
+ raw = 'Here is the plan:\n[{"action": "search", "detail": "x"}]\nDone.'
74
+ steps = _parse_plan(raw)
75
+ assert len(steps) == 1
76
+
77
+ def test_invalid_json_falls_back(self) -> None:
78
+ raw = "this is not json at all"
79
+ steps = _parse_plan(raw)
80
+ assert len(steps) == 1
81
+ assert steps[0]["action"] == "search"
82
+
83
+ def test_empty_array_falls_back(self) -> None:
84
+ raw = "[]"
85
+ steps = _parse_plan(raw)
86
+ assert len(steps) == 1 # fallback to single search
87
+
88
+ def test_malformed_items_skipped(self) -> None:
89
+ raw = json.dumps([
90
+ {"action": "search", "detail": "good"},
91
+ {"bad": "step"},
92
+ {"action": "search", "detail": "also good"},
93
+ ])
94
+ steps = _parse_plan(raw)
95
+ assert len(steps) == 2
96
+
97
+ def test_non_list_wrapped(self) -> None:
98
+ raw = '{"action": "search", "detail": "test"}'
99
+ steps = _parse_plan(raw)
100
+ assert len(steps) == 1
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # _extract_last_ai_text
105
+ # ---------------------------------------------------------------------------
106
+
107
+ class TestExtractLastAIText:
108
+ def test_returns_last_ai_message(self) -> None:
109
+ from langchain_core.messages import AIMessage, HumanMessage
110
+ messages = [
111
+ HumanMessage(content="question"),
112
+ AIMessage(content="first"),
113
+ AIMessage(content="second"),
114
+ ]
115
+ assert _extract_last_ai_text(messages) == "second"
116
+
117
+ def test_skips_tool_calls(self) -> None:
118
+ from langchain_core.messages import AIMessage
119
+ msg_with_tools = AIMessage(content="calling tool", tool_calls=[{"name": "t", "args": {}, "id": "1"}])
120
+ msg_final = AIMessage(content="the answer")
121
+ assert _extract_last_ai_text([msg_with_tools, msg_final]) == "the answer"
122
+
123
+ def test_empty_messages(self) -> None:
124
+ assert _extract_last_ai_text([]) == ""
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # PlanAndExecuteRouter — plan node
129
+ # ---------------------------------------------------------------------------
130
+
131
+ class TestPlanNode:
132
+ def test_plan_node_generates_steps(self) -> None:
133
+ llm = MagicMock()
134
+ llm.invoke.return_value = '[{"action": "search", "detail": "KU regler"}]'
135
+
136
+ router = PlanAndExecuteRouter(
137
+ llm=llm,
138
+ hybrid_retriever=MagicMock(),
139
+ reranker=MagicMock(),
140
+ vector_store=MagicMock(),
141
+ )
142
+
143
+ state = PlanExecState(
144
+ query="What are the rules?",
145
+ top_k=5, plan=[], step_index=0, step_results=[], answer="",
146
+ )
147
+ result = router._plan_node(state)
148
+ assert len(result["plan"]) == 1
149
+ assert result["plan"][0]["action"] == "search"
150
+ assert result["step_index"] == 0
151
+
152
+ def test_plan_node_handles_bad_llm_output(self) -> None:
153
+ llm = MagicMock()
154
+ llm.invoke.return_value = "I cannot produce JSON"
155
+
156
+ router = PlanAndExecuteRouter(
157
+ llm=llm,
158
+ hybrid_retriever=MagicMock(),
159
+ reranker=MagicMock(),
160
+ vector_store=MagicMock(),
161
+ )
162
+
163
+ state = PlanExecState(
164
+ query="test", top_k=5, plan=[], step_index=0, step_results=[], answer="",
165
+ )
166
+ result = router._plan_node(state)
167
+ assert len(result["plan"]) >= 1 # fallback plan
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # PlanAndExecuteRouter — should_execute
172
+ # ---------------------------------------------------------------------------
173
+
174
+ class TestShouldExecute:
175
+ def test_more_steps_returns_execute(self) -> None:
176
+ state = PlanExecState(
177
+ query="q", top_k=5,
178
+ plan=[PlanStep(action="search", detail="x")],
179
+ step_index=0, step_results=[], answer="",
180
+ )
181
+ assert PlanAndExecuteRouter._should_execute(state) == "execute"
182
+
183
+ def test_all_steps_done_returns_synthesize(self) -> None:
184
+ state = PlanExecState(
185
+ query="q", top_k=5,
186
+ plan=[PlanStep(action="search", detail="x")],
187
+ step_index=1, step_results=[], answer="",
188
+ )
189
+ assert PlanAndExecuteRouter._should_execute(state) == "synthesize"
190
+
191
+ def test_empty_plan_returns_synthesize(self) -> None:
192
+ state = PlanExecState(
193
+ query="q", top_k=5, plan=[], step_index=0, step_results=[], answer="",
194
+ )
195
+ assert PlanAndExecuteRouter._should_execute(state) == "synthesize"
196
+
197
+ def test_max_steps_cap(self) -> None:
198
+ """Step index at _MAX_STEPS should stop execution."""
199
+ state = PlanExecState(
200
+ query="q", top_k=5,
201
+ plan=[PlanStep(action="search", detail=f"q{i}") for i in range(10)],
202
+ step_index=6, # == _MAX_STEPS
203
+ step_results=[], answer="",
204
+ )
205
+ assert PlanAndExecuteRouter._should_execute(state) == "synthesize"
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # PlanAndExecuteRouter — synthesize node
210
+ # ---------------------------------------------------------------------------
211
+
212
+ class TestSynthesizeNode:
213
+ def test_synthesize_combines_results(self) -> None:
214
+ llm = MagicMock()
215
+ llm.invoke.return_value = "Combined answer about exams."
216
+
217
+ router = PlanAndExecuteRouter(
218
+ llm=llm,
219
+ hybrid_retriever=MagicMock(),
220
+ reranker=MagicMock(),
221
+ vector_store=MagicMock(),
222
+ )
223
+
224
+ state = PlanExecState(
225
+ query="exam rules",
226
+ top_k=5, plan=[],
227
+ step_index=2,
228
+ step_results=[
229
+ ("search: exam bachelor", "Found bachelor exam rules..."),
230
+ ("search: exam master", "Found master exam rules..."),
231
+ ],
232
+ answer="",
233
+ )
234
+ result = router._synthesize_node(state)
235
+ assert result["answer"] == "Combined answer about exams."
236
+
237
+ # Verify prompt includes both step results
238
+ prompt = llm.invoke.call_args[0][0]
239
+ assert "bachelor exam rules" in prompt
240
+ assert "master exam rules" in prompt
241
+
242
+
243
+ # ---------------------------------------------------------------------------
244
+ # PlanAndExecuteRouter — full route (integration with mocks)
245
+ # ---------------------------------------------------------------------------
246
+
247
+ class TestFullRoute:
248
+ def test_route_produces_response(self) -> None:
249
+ """Full route with mocked LLM and retrieval components."""
250
+ llm = MagicMock()
251
+ retriever = MagicMock()
252
+ reranker = MagicMock()
253
+ vector_store = MagicMock()
254
+
255
+ # Plan: single search step
256
+ plan_json = '[{"action": "search", "detail": "test query"}]'
257
+ # Sub-agent answer after executing step
258
+ from langchain_core.messages import AIMessage
259
+ sub_agent_result = {"messages": [AIMessage(content="Found relevant info about test.")]}
260
+ # Final synthesis
261
+ final_answer = "The test policy states..."
262
+
263
+ # LLM calls: plan, executor system/tools, synthesis
264
+ # We mock the LLM and also mock the sub-agent creation
265
+ llm.invoke.side_effect = [plan_json, final_answer]
266
+
267
+ results = [_qr(chunk_id="c1", score=0.9, text="test policy")]
268
+ retriever.search_detailed.return_value = _hybrid_result(results)
269
+ reranker.rerank.return_value = results
270
+ vector_store.list_document_ids.return_value = ["doc.pdf"]
271
+
272
+ router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
273
+
274
+ # Patch create_react_agent to return a mock that returns our sub_agent_result
275
+ mock_agent = MagicMock()
276
+ mock_agent.invoke.return_value = sub_agent_result
277
+
278
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
279
+ response = router.route("test question", top_k=5)
280
+
281
+ assert isinstance(response, GenerationResponse)
282
+ assert response.answer == "The test policy states..."
283
+
284
+ def test_route_with_no_results(self) -> None:
285
+ """Route when retrieval finds nothing."""
286
+ llm = MagicMock()
287
+ retriever = MagicMock()
288
+ reranker = MagicMock()
289
+ vector_store = MagicMock()
290
+
291
+ plan_json = '[{"action": "search", "detail": "nonexistent"}]'
292
+ from langchain_core.messages import AIMessage
293
+ sub_agent_result = {"messages": [AIMessage(content="No relevant documents found.")]}
294
+
295
+ llm.invoke.side_effect = [plan_json, "I could not find information."]
296
+
297
+ mock_agent = MagicMock()
298
+ mock_agent.invoke.return_value = sub_agent_result
299
+
300
+ router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
301
+
302
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
303
+ response = router.route("nonexistent topic", top_k=5)
304
+
305
+ assert response.intent == IntentType.FACTUAL
306
+ assert response.confidence == 0.0
307
+
308
+ def test_route_multi_step(self) -> None:
309
+ """Route with a multi-step plan."""
310
+ llm = MagicMock()
311
+ retriever = MagicMock()
312
+ reranker = MagicMock()
313
+ vector_store = MagicMock()
314
+
315
+ plan_json = json.dumps([
316
+ {"action": "search", "detail": "policy A"},
317
+ {"action": "search", "detail": "policy B"},
318
+ ])
319
+ from langchain_core.messages import AIMessage
320
+ sub_result_1 = {"messages": [AIMessage(content="Policy A info")]}
321
+ sub_result_2 = {"messages": [AIMessage(content="Policy B info")]}
322
+
323
+ llm.invoke.side_effect = [plan_json, "Comparison of A and B."]
324
+
325
+ mock_agent = MagicMock()
326
+ mock_agent.invoke.side_effect = [sub_result_1, sub_result_2]
327
+
328
+ router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
329
+
330
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
331
+ response = router.route("Compare A and B", top_k=5)
332
+
333
+ assert response.answer == "Comparison of A and B."
334
+ # Sub-agent should have been called twice
335
+ assert mock_agent.invoke.call_count == 2
336
+
337
+
338
+ # ---------------------------------------------------------------------------
339
+ # PlanAndExecuteRouter — route_stream
340
+ # ---------------------------------------------------------------------------
341
+
342
+ class TestRouteStream:
343
+ def test_stream_yields_plan_execute_synthesize_done(self) -> None:
344
+ """Streaming should yield events in order: plan, execute, synthesize, done."""
345
+ llm = MagicMock()
346
+ retriever = MagicMock()
347
+ reranker = MagicMock()
348
+ vector_store = MagicMock()
349
+
350
+ plan_json = '[{"action": "search", "detail": "test"}]'
351
+ from langchain_core.messages import AIMessage
352
+ sub_agent_result = {"messages": [AIMessage(content="Found info.")]}
353
+
354
+ llm.invoke.side_effect = [plan_json, "Final answer."]
355
+
356
+ mock_agent = MagicMock()
357
+ mock_agent.invoke.return_value = sub_agent_result
358
+
359
+ router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
360
+
361
+ with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
362
+ events = list(router.route_stream("test", top_k=5))
363
+
364
+ step_names = [e["step"] for e in events]
365
+ assert "plan" in step_names
366
+ assert "done" in step_names
367
+ # done event has the result
368
+ done_event = [e for e in events if e["step"] == "done"][0]
369
+ assert "result" in done_event
370
+ assert done_event["result"]["answer"] == "Final answer."
tests/test_tools.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for agent tools (hybrid_search, list_documents, fetch_document,
2
+ search_within_document, multi_query_search, summarize_document)."""
3
+
4
+ from unittest.mock import MagicMock
5
+
6
+ import pytest
7
+
8
+ from src.agent.tools import ToolResultStore, make_retrieval_tools, _merge_results, _format_results
9
+ from src.models import DocumentChunk, QueryResult
10
+ from src.retrieval.hybrid import HybridSearchResult
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Helpers
15
+ # ---------------------------------------------------------------------------
16
+
17
+ def _chunk(chunk_id: str = "c1", document_id: str = "doc.pdf", text: str = "text",
18
+ page_number: int = 1, chunk_index: int = 0) -> DocumentChunk:
19
+ return DocumentChunk(
20
+ chunk_id=chunk_id,
21
+ document_id=document_id,
22
+ text=text,
23
+ metadata={"page_number": page_number, "chunk_index": chunk_index},
24
+ )
25
+
26
+
27
+ def _qr(chunk_id: str = "c1", document_id: str = "doc.pdf", text: str = "text",
28
+ score: float = 0.8, source: str = "hybrid", page_number: int = 1) -> QueryResult:
29
+ return QueryResult(
30
+ chunk=_chunk(chunk_id=chunk_id, document_id=document_id, text=text, page_number=page_number),
31
+ score=score,
32
+ source=source,
33
+ )
34
+
35
+
36
+ def _hybrid_result(results: list[QueryResult]) -> HybridSearchResult:
37
+ return HybridSearchResult(
38
+ dense_results=results,
39
+ sparse_results=results,
40
+ fused_results=results,
41
+ )
42
+
43
+
44
+ @pytest.fixture
45
+ def components():
46
+ """Create mock retriever, reranker, vector_store, and store."""
47
+ retriever = MagicMock()
48
+ reranker = MagicMock()
49
+ vector_store = MagicMock()
50
+ store = ToolResultStore()
51
+ return retriever, reranker, vector_store, store
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Unit tests for helper functions
56
+ # ---------------------------------------------------------------------------
57
+
58
+ class TestMergeResults:
59
+ def test_merge_empty(self) -> None:
60
+ assert _merge_results([], []) == []
61
+
62
+ def test_merge_keeps_higher_score(self) -> None:
63
+ old = [_qr(chunk_id="c1", score=0.5)]
64
+ new = [_qr(chunk_id="c1", score=0.9)]
65
+ merged = _merge_results(old, new)
66
+ assert len(merged) == 1
67
+ assert merged[0].score == 0.9
68
+
69
+ def test_merge_keeps_old_if_higher(self) -> None:
70
+ old = [_qr(chunk_id="c1", score=0.9)]
71
+ new = [_qr(chunk_id="c1", score=0.5)]
72
+ merged = _merge_results(old, new)
73
+ assert merged[0].score == 0.9
74
+
75
+ def test_merge_combines_different_ids(self) -> None:
76
+ old = [_qr(chunk_id="c1", score=0.5)]
77
+ new = [_qr(chunk_id="c2", score=0.9)]
78
+ merged = _merge_results(old, new)
79
+ assert len(merged) == 2
80
+ assert merged[0].chunk.chunk_id == "c2" # higher score first
81
+
82
+ def test_merge_sorted_descending(self) -> None:
83
+ results = [_qr(chunk_id=f"c{i}", score=s) for i, s in enumerate([0.3, 0.9, 0.6])]
84
+ merged = _merge_results([], results)
85
+ scores = [r.score for r in merged]
86
+ assert scores == sorted(scores, reverse=True)
87
+
88
+
89
+ class TestFormatResults:
90
+ def test_empty_returns_no_results_message(self) -> None:
91
+ result = _format_results([])
92
+ assert "Ingen relevante" in result
93
+
94
+ def test_includes_document_id_and_score(self) -> None:
95
+ results = [_qr(document_id="policy.pdf", score=0.85)]
96
+ text = _format_results(results)
97
+ assert "policy.pdf" in text
98
+ assert "0.850" in text
99
+
100
+ def test_includes_page_number(self) -> None:
101
+ results = [_qr(page_number=5)]
102
+ text = _format_results(results)
103
+ assert "side 5" in text
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # hybrid_search
108
+ # ---------------------------------------------------------------------------
109
+
110
+ class TestHybridSearch:
111
+ def test_returns_formatted_results(self, components) -> None:
112
+ retriever, reranker, vector_store, store = components
113
+ results = [_qr(document_id="a.pdf", score=0.9, text="answer")]
114
+ retriever.search_detailed.return_value = _hybrid_result(results)
115
+ reranker.rerank.return_value = results
116
+
117
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
118
+ hybrid_search = tools[0]
119
+ output = hybrid_search.invoke({"query": "test", "top_k": 5})
120
+
121
+ assert "a.pdf" in output
122
+ assert "answer" in output
123
+ retriever.search_detailed.assert_called_once_with("test", top_k=5)
124
+
125
+ def test_accumulates_in_store(self, components) -> None:
126
+ retriever, reranker, vector_store, store = components
127
+ results = [_qr(chunk_id="c1", score=0.8)]
128
+ retriever.search_detailed.return_value = _hybrid_result(results)
129
+ reranker.rerank.return_value = results
130
+
131
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
132
+ tools[0].invoke({"query": "q1"})
133
+
134
+ assert len(store.retrieved) == 1
135
+ assert store.retrieved[0].chunk.chunk_id == "c1"
136
+ assert len(store.tool_calls) == 1
137
+ assert store.tool_calls[0] == ("hybrid_search", "q1")
138
+
139
+ def test_no_results(self, components) -> None:
140
+ retriever, reranker, vector_store, store = components
141
+ retriever.search_detailed.return_value = _hybrid_result([])
142
+ reranker.rerank.return_value = []
143
+
144
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
145
+ output = tools[0].invoke({"query": "nothing"})
146
+ assert "Ingen relevante" in output
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # list_documents
151
+ # ---------------------------------------------------------------------------
152
+
153
+ class TestListDocuments:
154
+ def test_returns_document_list(self, components) -> None:
155
+ retriever, reranker, vector_store, store = components
156
+ vector_store.list_document_ids.return_value = ["a.pdf", "b.pdf"]
157
+
158
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
159
+ list_docs = tools[1]
160
+ output = list_docs.invoke({})
161
+
162
+ assert "a.pdf" in output
163
+ assert "b.pdf" in output
164
+ assert "2 i alt" in output
165
+
166
+ def test_empty_knowledge_base(self, components) -> None:
167
+ retriever, reranker, vector_store, store = components
168
+ vector_store.list_document_ids.return_value = []
169
+
170
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
171
+ output = tools[1].invoke({})
172
+ assert "empty" in output.lower() or "Ingen" in output
173
+
174
+
175
+ # ---------------------------------------------------------------------------
176
+ # fetch_document
177
+ # ---------------------------------------------------------------------------
178
+
179
+ class TestFetchDocument:
180
+ def test_returns_full_text(self, components) -> None:
181
+ retriever, reranker, vector_store, store = components
182
+ chunks = [_chunk(chunk_id="c1", text="page1"), _chunk(chunk_id="c2", text="page2")]
183
+ vector_store.get_chunks_by_document_id.return_value = chunks
184
+
185
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
186
+ fetch = tools[2]
187
+ output = fetch.invoke({"document_id": "doc.pdf"})
188
+
189
+ assert "page1" in output
190
+ assert "page2" in output
191
+ assert len(store.retrieved) == 2
192
+
193
+ def test_document_not_found(self, components) -> None:
194
+ retriever, reranker, vector_store, store = components
195
+ vector_store.get_chunks_by_document_id.return_value = []
196
+
197
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
198
+ output = tools[2].invoke({"document_id": "missing.pdf"})
199
+ assert "ikke fundet" in output
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # search_within_document
204
+ # ---------------------------------------------------------------------------
205
+
206
+ class TestSearchWithinDocument:
207
+ def test_reranks_document_chunks(self, components) -> None:
208
+ retriever, reranker, vector_store, store = components
209
+ chunks = [
210
+ _chunk(chunk_id="c1", text="irrelevant"),
211
+ _chunk(chunk_id="c2", text="relevant answer"),
212
+ ]
213
+ vector_store.get_chunks_by_document_id.return_value = chunks
214
+ reranker.rerank.return_value = [_qr(chunk_id="c2", text="relevant answer", score=0.95)]
215
+
216
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
217
+ search_within = tools[3]
218
+ output = search_within.invoke({"document_id": "doc.pdf", "query": "answer"})
219
+
220
+ assert "relevant answer" in output
221
+ assert "0.950" in output
222
+ reranker.rerank.assert_called_once()
223
+ # Verify it passed all chunks to reranker
224
+ candidates = reranker.rerank.call_args[0][1]
225
+ assert len(candidates) == 2
226
+
227
+ def test_document_not_found(self, components) -> None:
228
+ retriever, reranker, vector_store, store = components
229
+ vector_store.get_chunks_by_document_id.return_value = []
230
+
231
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
232
+ output = tools[3].invoke({"document_id": "missing.pdf", "query": "test"})
233
+ assert "ikke fundet" in output
234
+
235
+ def test_accumulates_in_store(self, components) -> None:
236
+ retriever, reranker, vector_store, store = components
237
+ chunks = [_chunk(chunk_id="c1")]
238
+ vector_store.get_chunks_by_document_id.return_value = chunks
239
+ reranker.rerank.return_value = [_qr(chunk_id="c1", score=0.7)]
240
+
241
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store)
242
+ tools[3].invoke({"document_id": "doc.pdf", "query": "q"})
243
+
244
+ assert len(store.retrieved) == 1
245
+ assert store.tool_calls[-1][0] == "search_within_document"
246
+
247
+
248
+ # ---------------------------------------------------------------------------
249
+ # multi_query_search (requires llm_chain)
250
+ # ---------------------------------------------------------------------------
251
+
252
+ class TestMultiQuerySearch:
253
+ def test_decomposes_and_searches(self, components) -> None:
254
+ retriever, reranker, vector_store, store = components
255
+ llm_chain = MagicMock()
256
+
257
+ # LLM returns 2 sub-queries
258
+ llm_chain.invoke.return_value = "eksamenregler bachelor\neksamensregler kandidat"
259
+
260
+ results_a = [_qr(chunk_id="c1", score=0.9, text="bachelor exam")]
261
+ results_b = [_qr(chunk_id="c2", score=0.85, text="master exam")]
262
+ retriever.search_detailed.side_effect = [
263
+ _hybrid_result(results_a),
264
+ _hybrid_result(results_b),
265
+ ]
266
+ reranker.rerank.side_effect = [results_a, results_b]
267
+
268
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
269
+ multi_search = tools[4]
270
+ output = multi_search.invoke({"question": "Compare exam rules"})
271
+
272
+ assert "delforespørgsler" in output
273
+ assert retriever.search_detailed.call_count == 2
274
+ assert reranker.rerank.call_count == 2
275
+ assert len(store.retrieved) == 2
276
+
277
+ def test_fallback_when_decompose_fails(self, components) -> None:
278
+ retriever, reranker, vector_store, store = components
279
+ llm_chain = MagicMock()
280
+
281
+ # LLM returns empty/garbage
282
+ llm_chain.invoke.return_value = ""
283
+
284
+ results = [_qr(chunk_id="c1", score=0.8)]
285
+ retriever.search_detailed.return_value = _hybrid_result(results)
286
+ reranker.rerank.return_value = results
287
+
288
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
289
+ output = tools[4].invoke({"question": "original question"})
290
+
291
+ # Should fall back to the original question as single query
292
+ assert retriever.search_detailed.call_count == 1
293
+ assert "0.800" in output
294
+
295
+ def test_not_available_without_llm(self, components) -> None:
296
+ retriever, reranker, vector_store, store = components
297
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=None)
298
+ tool_names = [t.name for t in tools]
299
+ assert "multi_query_search" not in tool_names
300
+ assert "summarize_document" not in tool_names
301
+
302
+ def test_deduplicates_across_sub_queries(self, components) -> None:
303
+ retriever, reranker, vector_store, store = components
304
+ llm_chain = MagicMock()
305
+ llm_chain.invoke.return_value = "query1\nquery2"
306
+
307
+ # Both sub-queries return the same chunk
308
+ same_result = [_qr(chunk_id="c1", score=0.8)]
309
+ retriever.search_detailed.return_value = _hybrid_result(same_result)
310
+ reranker.rerank.return_value = same_result
311
+
312
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
313
+ tools[4].invoke({"question": "test"})
314
+
315
+ # Should be deduplicated to 1
316
+ assert len(store.retrieved) == 1
317
+
318
+
319
+ # ---------------------------------------------------------------------------
320
+ # summarize_document (requires llm_chain)
321
+ # ---------------------------------------------------------------------------
322
+
323
+ class TestSummarizeDocument:
324
+ def test_generates_summary(self, components) -> None:
325
+ retriever, reranker, vector_store, store = components
326
+ llm_chain = MagicMock()
327
+ llm_chain.invoke.return_value = "This document covers exam policies."
328
+
329
+ chunks = [_chunk(chunk_id="c1", text="Exam rules...")]
330
+ vector_store.get_chunks_by_document_id.return_value = chunks
331
+
332
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
333
+ summarize = tools[5]
334
+ output = summarize.invoke({"document_id": "exam.pdf"})
335
+
336
+ assert "Resumé af exam.pdf" in output
337
+ assert "exam policies" in output
338
+ llm_chain.invoke.assert_called_once()
339
+ # Verify the prompt includes the document text
340
+ prompt = llm_chain.invoke.call_args[0][0]
341
+ assert "Exam rules" in prompt
342
+
343
+ def test_document_not_found(self, components) -> None:
344
+ retriever, reranker, vector_store, store = components
345
+ llm_chain = MagicMock()
346
+ vector_store.get_chunks_by_document_id.return_value = []
347
+
348
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
349
+ output = tools[5].invoke({"document_id": "missing.pdf"})
350
+ assert "ikke fundet" in output
351
+ llm_chain.invoke.assert_not_called()
352
+
353
+ def test_truncates_long_documents(self, components) -> None:
354
+ retriever, reranker, vector_store, store = components
355
+ llm_chain = MagicMock()
356
+ llm_chain.invoke.return_value = "summary"
357
+
358
+ # Create a document longer than 8000 chars
359
+ long_text = "x" * 10000
360
+ chunks = [_chunk(chunk_id="c1", text=long_text)]
361
+ vector_store.get_chunks_by_document_id.return_value = chunks
362
+
363
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
364
+ tools[5].invoke({"document_id": "long.pdf"})
365
+
366
+ prompt = llm_chain.invoke.call_args[0][0]
367
+ assert "forkortet" in prompt
368
+
369
+ def test_registers_chunks_as_sources(self, components) -> None:
370
+ retriever, reranker, vector_store, store = components
371
+ llm_chain = MagicMock()
372
+ llm_chain.invoke.return_value = "summary"
373
+
374
+ chunks = [_chunk(chunk_id="c1"), _chunk(chunk_id="c2")]
375
+ vector_store.get_chunks_by_document_id.return_value = chunks
376
+
377
+ tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
378
+ tools[5].invoke({"document_id": "doc.pdf"})
379
+
380
+ assert len(store.retrieved) == 2
381
+ assert store.tool_calls[-1] == ("summarize_document", "doc.pdf")