Spaces:
Sleeping
Sleeping
XQ commited on
Commit ·
1441fa0
1
Parent(s): e128a20
Refactor to Plan-and-Execute architecture
Browse files- .github/README.md +15 -10
- README.md +16 -11
- src/agent/memory.py +125 -0
- src/agent/plan_and_execute.py +506 -0
- src/agent/react_router.py +15 -6
- src/agent/tools.py +226 -32
- src/api/main.py +5 -3
- src/api/routes.py +8 -2
- src/models.py +4 -0
- src/ui/app.py +51 -6
- tests/test_memory.py +239 -0
- tests/test_plan_and_execute.py +370 -0
- tests/test_tools.py +381 -0
.github/README.md
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
[xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
|
| 6 |
|
| 7 |
-
A document intelligence system covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a
|
| 8 |
|
| 9 |
## How it works
|
| 10 |
|
|
@@ -14,15 +14,18 @@ At query time both indexes are searched and their results merged with reciprocal
|
|
| 14 |
|
| 15 |
**Two routing modes, switchable via `AGENT_MODE`:**
|
| 16 |
|
| 17 |
-
- **
|
| 18 |
|
| 19 |
| Tool | Purpose |
|
| 20 |
|------|---------|
|
| 21 |
-
| `hybrid_search(query, top_k)` | Retrieve relevant passages |
|
|
|
|
|
|
|
|
|
|
| 22 |
| `list_documents()` | See what's in the knowledge base |
|
| 23 |
| `fetch_document(document_id)` | Read a full document |
|
| 24 |
|
| 25 |
-
- **Pipeline** (`AGENT_MODE=pipeline`): a
|
| 26 |
|
| 27 |
## Tech Stack
|
| 28 |
|
|
@@ -49,12 +52,12 @@ See `.env.example` for per-provider configuration.
|
|
| 49 |
|
| 50 |
| Mode | `AGENT_MODE` | Notes |
|
| 51 |
|------|-------------|-------|
|
| 52 |
-
|
|
| 53 |
-
| Pipeline | `pipeline` |
|
| 54 |
|
| 55 |
Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`gemma4`, `llama3.1`, `qwen2.5`, `mistral-nemo`).
|
| 56 |
|
| 57 |
-
|
| 58 |
|
| 59 |
```dotenv
|
| 60 |
AGENT_MODE=react
|
|
@@ -70,7 +73,7 @@ LLM_PROVIDER=ollama
|
|
| 70 |
OLLAMA_MODEL=gemma3
|
| 71 |
```
|
| 72 |
|
| 73 |
-
|
| 74 |
|
| 75 |
```dotenv
|
| 76 |
AGENT_MODE=react
|
|
@@ -149,8 +152,10 @@ src/
|
|
| 149 |
agent/
|
| 150 |
intent_classifier.py
|
| 151 |
router.py # pipeline mode (AGENT_MODE=pipeline)
|
| 152 |
-
tools.py #
|
| 153 |
-
react_router.py # ReAct
|
|
|
|
|
|
|
| 154 |
evaluation/
|
| 155 |
evaluator.py # RAGAS metrics
|
| 156 |
ui/
|
|
|
|
| 4 |
|
| 5 |
[xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
|
| 6 |
|
| 7 |
+
A document intelligence system covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a Plan-and-Execute agent (default) with conversation memory for complex multi-step queries, and a pipeline for lightweight models without tool-calling support. Retrieval quality is evaluated with RAGAS.
|
| 8 |
|
| 9 |
## How it works
|
| 10 |
|
|
|
|
| 14 |
|
| 15 |
**Two routing modes, switchable via `AGENT_MODE`:**
|
| 16 |
|
| 17 |
+
- **Plan-and-Execute Agent** (default): a structured multi-step pipeline — a planner decomposes the query into steps, an executor runs each step via a ReAct sub-agent with tool access, and a synthesizer produces the final cited answer. Includes conversation memory for multi-turn follow-ups. Requires a model with tool-calling support.
|
| 18 |
|
| 19 |
| Tool | Purpose |
|
| 20 |
|------|---------|
|
| 21 |
+
| `hybrid_search(query, top_k)` | Retrieve relevant passages via hybrid search + reranking |
|
| 22 |
+
| `multi_query_search(question, top_k)` | Decompose complex questions into sub-queries, search each, merge results |
|
| 23 |
+
| `search_within_document(document_id, query, top_k)` | Find specific sections inside a known document |
|
| 24 |
+
| `summarize_document(document_id)` | Generate a structured summary of a document |
|
| 25 |
| `list_documents()` | See what's in the knowledge base |
|
| 26 |
| `fetch_document(document_id)` | Read a full document |
|
| 27 |
|
| 28 |
+
- **Pipeline** (`AGENT_MODE=pipeline`): a predefined LangGraph graph — language detection → optional translation → hybrid retrieval → reranking → generation, with a confidence-based retry loop. Works with lightweight local models that lack tool-calling support.
|
| 29 |
|
| 30 |
## Tech Stack
|
| 31 |
|
|
|
|
| 52 |
|
| 53 |
| Mode | `AGENT_MODE` | Notes |
|
| 54 |
|------|-------------|-------|
|
| 55 |
+
| Plan-and-Execute | `react` (default) | Structured multi-step agent with conversation memory |
|
| 56 |
+
| Pipeline | `pipeline` | Predefined graph, works with lightweight models that lack tool calling |
|
| 57 |
|
| 58 |
Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`gemma4`, `llama3.1`, `qwen2.5`, `mistral-nemo`).
|
| 59 |
|
| 60 |
+
Plan-and-Execute with local Ollama (default):
|
| 61 |
|
| 62 |
```dotenv
|
| 63 |
AGENT_MODE=react
|
|
|
|
| 73 |
OLLAMA_MODEL=gemma3
|
| 74 |
```
|
| 75 |
|
| 76 |
+
Plan-and-Execute with OpenAI:
|
| 77 |
|
| 78 |
```dotenv
|
| 79 |
AGENT_MODE=react
|
|
|
|
| 152 |
agent/
|
| 153 |
intent_classifier.py
|
| 154 |
router.py # pipeline mode (AGENT_MODE=pipeline)
|
| 155 |
+
tools.py # 6 retrieval tools + ToolResultStore
|
| 156 |
+
react_router.py # legacy ReAct loop (superseded by plan_and_execute)
|
| 157 |
+
plan_and_execute.py # Plan-and-Execute agent (AGENT_MODE=react)
|
| 158 |
+
memory.py # conversation memory for multi-turn
|
| 159 |
evaluation/
|
| 160 |
evaluator.py # RAGAS metrics
|
| 161 |
ui/
|
README.md
CHANGED
|
@@ -12,7 +12,7 @@ noindex: true
|
|
| 12 |
|
| 13 |
**Live Demo:** [xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
|
| 14 |
|
| 15 |
-
A document intelligence system built on a RAG architecture, covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a pipeline for lightweight models, a
|
| 16 |
|
| 17 |
## How it works
|
| 18 |
|
|
@@ -22,13 +22,16 @@ At query time both indexes are searched and their results merged with reciprocal
|
|
| 22 |
|
| 23 |
**Two routing modes, switchable via `AGENT_MODE`:**
|
| 24 |
|
| 25 |
-
- **Pipeline**
|
| 26 |
|
| 27 |
-
- **
|
| 28 |
|
| 29 |
| Tool | Purpose |
|
| 30 |
|------|---------|
|
| 31 |
-
| `hybrid_search(query, top_k)` | Retrieve relevant passages |
|
|
|
|
|
|
|
|
|
|
| 32 |
| `list_documents()` | See what's in the knowledge base |
|
| 33 |
| `fetch_document(document_id)` | Read a full document |
|
| 34 |
|
|
@@ -57,12 +60,12 @@ See `.env.example` for per-provider configuration.
|
|
| 57 |
|
| 58 |
| Mode | `AGENT_MODE` | Notes |
|
| 59 |
|------|-------------|-------|
|
| 60 |
-
| Pipeline | `pipeline`
|
| 61 |
-
|
|
| 62 |
|
| 63 |
-
Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`llama3.1`, `qwen2.5`, `mistral-nemo`).
|
| 64 |
|
| 65 |
-
|
| 66 |
|
| 67 |
```dotenv
|
| 68 |
AGENT_MODE=react
|
|
@@ -76,7 +79,7 @@ Pipeline with local Ollama:
|
|
| 76 |
```dotenv
|
| 77 |
AGENT_MODE=pipeline
|
| 78 |
LLM_PROVIDER=ollama
|
| 79 |
-
OLLAMA_MODEL=
|
| 80 |
```
|
| 81 |
|
| 82 |
## Quick Start
|
|
@@ -149,8 +152,10 @@ src/
|
|
| 149 |
agent/
|
| 150 |
intent_classifier.py
|
| 151 |
router.py # pipeline mode (AGENT_MODE=pipeline)
|
| 152 |
-
tools.py #
|
| 153 |
-
react_router.py # ReAct
|
|
|
|
|
|
|
| 154 |
evaluation/
|
| 155 |
evaluator.py # RAGAS metrics
|
| 156 |
ui/
|
|
|
|
| 12 |
|
| 13 |
**Live Demo:** [xq-dokumentassistent.hf.space](https://xq-dokumentassistent.hf.space) — hosted on Hugging Face Spaces
|
| 14 |
|
| 15 |
+
A document intelligence system built on a RAG architecture, covering PDF ingestion, semantic chunking, hybrid retrieval with reranking, and LLM-generated answers with source citations. The LLM layer is provider-agnostic. Two modes: a pipeline for lightweight models, and a Plan-and-Execute agent flow with conversation memory for complex multi-step queries. Retrieval quality is evaluated with RAGAS.
|
| 16 |
|
| 17 |
## How it works
|
| 18 |
|
|
|
|
| 22 |
|
| 23 |
**Two routing modes, switchable via `AGENT_MODE`:**
|
| 24 |
|
| 25 |
+
- **Pipeline**: a predefined LangGraph graph — language detection → optional translation → hybrid retrieval → reranking → generation, with a confidence-based retry loop. Works with lightweight local models.
|
| 26 |
|
| 27 |
+
- **Plan-and-Execute Agent** (default, `AGENT_MODE=react`): a structured multi-step pipeline where a planner decomposes the query into steps, an executor runs each step via a ReAct sub-agent with tool access, and a synthesizer produces the final cited answer. Includes conversation memory for multi-turn follow-ups. Requires a model with tool-calling support.
|
| 28 |
|
| 29 |
| Tool | Purpose |
|
| 30 |
|------|---------|
|
| 31 |
+
| `hybrid_search(query, top_k)` | Retrieve relevant passages via hybrid search + reranking |
|
| 32 |
+
| `multi_query_search(question, top_k)` | Decompose complex questions into sub-queries, search each, merge results |
|
| 33 |
+
| `search_within_document(document_id, query, top_k)` | Find specific sections inside a known document |
|
| 34 |
+
| `summarize_document(document_id)` | Generate a structured summary of a document |
|
| 35 |
| `list_documents()` | See what's in the knowledge base |
|
| 36 |
| `fetch_document(document_id)` | Read a full document |
|
| 37 |
|
|
|
|
| 60 |
|
| 61 |
| Mode | `AGENT_MODE` | Notes |
|
| 62 |
|------|-------------|-------|
|
| 63 |
+
| Pipeline | `pipeline` | Predefined graph, works with lightweight models |
|
| 64 |
+
| Plan-and-Execute (default) | `react` | Structured multi-step agent with conversation memory |
|
| 65 |
|
| 66 |
+
Tool-calling is supported by OpenAI, Anthropic, Google GenAI, Azure OpenAI, Groq, and some Ollama models (`llama3.1`, `qwen2.5`, `mistral-nemo`).
|
| 67 |
|
| 68 |
+
Plan-and-Execute with OpenAI:
|
| 69 |
|
| 70 |
```dotenv
|
| 71 |
AGENT_MODE=react
|
|
|
|
| 79 |
```dotenv
|
| 80 |
AGENT_MODE=pipeline
|
| 81 |
LLM_PROVIDER=ollama
|
| 82 |
+
OLLAMA_MODEL=gemma3
|
| 83 |
```
|
| 84 |
|
| 85 |
## Quick Start
|
|
|
|
| 152 |
agent/
|
| 153 |
intent_classifier.py
|
| 154 |
router.py # pipeline mode (AGENT_MODE=pipeline)
|
| 155 |
+
tools.py # 6 retrieval tools + ToolResultStore
|
| 156 |
+
react_router.py # legacy ReAct loop (superseded by plan_and_execute)
|
| 157 |
+
plan_and_execute.py # Plan-and-Execute agent (AGENT_MODE=react)
|
| 158 |
+
memory.py # conversation memory for multi-turn
|
| 159 |
evaluation/
|
| 160 |
evaluator.py # RAGAS metrics
|
| 161 |
ui/
|
src/agent/memory.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conversation memory for multi-turn interactions.
|
| 2 |
+
|
| 3 |
+
Stores message history and retrieved sources across turns so that:
|
| 4 |
+
- Follow-up questions can reference prior context ("what about the other one?")
|
| 5 |
+
- The planner/synthesizer can see what was already discussed
|
| 6 |
+
- Previously retrieved sources are available without re-searching
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
from src.models import QueryResult
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
_MAX_TURNS = 20
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Turn:
|
| 21 |
+
"""A single conversation turn.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
query: The user's question.
|
| 25 |
+
answer: The assistant's response.
|
| 26 |
+
sources: Retrieved sources used to generate the answer.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
query: str
|
| 30 |
+
answer: str
|
| 31 |
+
sources: list[QueryResult] = field(default_factory=list)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ConversationMemory:
|
| 35 |
+
"""Manages multi-turn conversation state.
|
| 36 |
+
|
| 37 |
+
Stores a rolling window of recent turns and provides formatted
|
| 38 |
+
context for the planner and synthesizer prompts.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, max_turns: int = _MAX_TURNS) -> None:
|
| 42 |
+
"""Initialize conversation memory.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
max_turns: Maximum number of turns to retain.
|
| 46 |
+
"""
|
| 47 |
+
self._max_turns = max_turns
|
| 48 |
+
self._turns: list[Turn] = []
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def turns(self) -> list[Turn]:
|
| 52 |
+
"""Return the list of conversation turns (read-only copy)."""
|
| 53 |
+
return list(self._turns)
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def is_empty(self) -> bool:
|
| 57 |
+
"""Return True if no conversation history exists."""
|
| 58 |
+
return len(self._turns) == 0
|
| 59 |
+
|
| 60 |
+
def add_turn(self, query: str, answer: str, sources: list[QueryResult] | None = None) -> None:
|
| 61 |
+
"""Record a completed conversation turn.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
query: The user's question.
|
| 65 |
+
answer: The assistant's response.
|
| 66 |
+
sources: Retrieved sources (optional).
|
| 67 |
+
"""
|
| 68 |
+
self._turns.append(Turn(query=query, answer=answer, sources=sources or []))
|
| 69 |
+
if len(self._turns) > self._max_turns:
|
| 70 |
+
removed = self._turns.pop(0)
|
| 71 |
+
logger.debug("Evicted oldest turn: %s", removed.query[:50])
|
| 72 |
+
logger.debug("Memory now has %d turns", len(self._turns))
|
| 73 |
+
|
| 74 |
+
def clear(self) -> None:
|
| 75 |
+
"""Clear all conversation history."""
|
| 76 |
+
self._turns.clear()
|
| 77 |
+
logger.info("Conversation memory cleared")
|
| 78 |
+
|
| 79 |
+
def format_history(self, max_recent: int = 5) -> str:
|
| 80 |
+
"""Format recent conversation history for inclusion in prompts.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
max_recent: Maximum number of recent turns to include.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Formatted string of recent Q&A pairs, or empty string if no history.
|
| 87 |
+
"""
|
| 88 |
+
if not self._turns:
|
| 89 |
+
return ""
|
| 90 |
+
|
| 91 |
+
recent = self._turns[-max_recent:]
|
| 92 |
+
parts: list[str] = []
|
| 93 |
+
for i, turn in enumerate(recent, 1):
|
| 94 |
+
source_note = ""
|
| 95 |
+
if turn.sources:
|
| 96 |
+
doc_ids = sorted({s.chunk.document_id for s in turn.sources})
|
| 97 |
+
source_note = f" [sources: {', '.join(doc_ids)}]"
|
| 98 |
+
parts.append(
|
| 99 |
+
f"Turn {i}:\n"
|
| 100 |
+
f" User: {turn.query}\n"
|
| 101 |
+
f" Assistant: {turn.answer[:500]}{source_note}"
|
| 102 |
+
)
|
| 103 |
+
return "\n\n".join(parts)
|
| 104 |
+
|
| 105 |
+
def get_prior_sources(self) -> list[QueryResult]:
|
| 106 |
+
"""Return all unique sources from prior turns, sorted by score.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Deduplicated list of QueryResult from all past turns.
|
| 110 |
+
"""
|
| 111 |
+
by_id: dict[str, QueryResult] = {}
|
| 112 |
+
for turn in self._turns:
|
| 113 |
+
for r in turn.sources:
|
| 114 |
+
cid = r.chunk.chunk_id
|
| 115 |
+
if cid not in by_id or r.score > by_id[cid].score:
|
| 116 |
+
by_id[cid] = r
|
| 117 |
+
return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
|
| 118 |
+
|
| 119 |
+
def last_query(self) -> str:
|
| 120 |
+
"""Return the last user query, or empty string."""
|
| 121 |
+
return self._turns[-1].query if self._turns else ""
|
| 122 |
+
|
| 123 |
+
def last_sources(self) -> list[QueryResult]:
|
| 124 |
+
"""Return sources from the most recent turn."""
|
| 125 |
+
return self._turns[-1].sources if self._turns else []
|
src/agent/plan_and_execute.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plan-and-Execute agent router using LangGraph.
|
| 2 |
+
|
| 3 |
+
Replaces the flat ReAct loop with a structured three-phase pipeline:
|
| 4 |
+
|
| 5 |
+
1. **Planner** — analyses the user query and produces an ordered list of
|
| 6 |
+
steps (e.g. "search for exam rules", "search for grading policy",
|
| 7 |
+
"compare both").
|
| 8 |
+
2. **Executor** — runs each step via a short ReAct sub-graph that has
|
| 9 |
+
access to all retrieval tools.
|
| 10 |
+
3. **Synthesizer** — collects the results from all executed steps and
|
| 11 |
+
produces a final, cited answer.
|
| 12 |
+
|
| 13 |
+
The separation gives the pipeline *predictable structure* while still
|
| 14 |
+
allowing the executor to reason freely within each step.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
from collections.abc import Generator
|
| 20 |
+
from typing import TypedDict
|
| 21 |
+
|
| 22 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
| 23 |
+
from langchain_core.runnables import Runnable
|
| 24 |
+
from langgraph.graph import END, StateGraph
|
| 25 |
+
from langgraph.prebuilt import create_react_agent
|
| 26 |
+
|
| 27 |
+
from src.agent.memory import ConversationMemory
|
| 28 |
+
from src.agent.tools import ToolResultStore, make_retrieval_tools
|
| 29 |
+
from src.models import GenerationResponse, IntentType, PipelineDetails, QueryResult
|
| 30 |
+
from src.retrieval.hybrid import HybridRetriever
|
| 31 |
+
from src.retrieval.reranker import Reranker
|
| 32 |
+
from src.retrieval.vector_store import VectorStore
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
_MAX_STEPS = 6
|
| 37 |
+
|
| 38 |
+
# ------------------------------------------------------------------
|
| 39 |
+
# Prompts
|
| 40 |
+
# ------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
_PLANNER_PROMPT = (
|
| 43 |
+
"You are a planning assistant for the University of Copenhagen (KU) document system.\n\n"
|
| 44 |
+
"Given a user question, produce a JSON list of 1–4 steps needed to answer it.\n"
|
| 45 |
+
"Each step is an object with:\n"
|
| 46 |
+
' - "action": one of "search", "search_within", "multi_search", '
|
| 47 |
+
'"summarize", "list_docs", "fetch_doc"\n'
|
| 48 |
+
' - "detail": a short description of what to do (e.g. the search query, document ID)\n\n'
|
| 49 |
+
"Rules:\n"
|
| 50 |
+
"- For simple factual questions: 1 search step is enough.\n"
|
| 51 |
+
"- For comparison questions: use multi_search or separate search steps.\n"
|
| 52 |
+
"- For document overview requests: use summarize.\n"
|
| 53 |
+
"- Always end with the steps needed; do NOT include a final 'answer' step.\n\n"
|
| 54 |
+
"Reply with ONLY the JSON array, nothing else.\n\n"
|
| 55 |
+
"Examples:\n"
|
| 56 |
+
'Question: "What is the exam policy?"\n'
|
| 57 |
+
'[{"action": "search", "detail": "KU eksamensregler"}]\n\n'
|
| 58 |
+
'Question: "Compare vacation rules for academic vs administrative staff"\n'
|
| 59 |
+
'[{"action": "search", "detail": "ferieregler videnskabeligt personale"}, '
|
| 60 |
+
'{"action": "search", "detail": "ferieregler administrativt personale"}]\n\n'
|
| 61 |
+
'Question: "Summarize the AI policy document"\n'
|
| 62 |
+
'[{"action": "summarize", "detail": "ku_ai_policy.pdf"}]\n\n'
|
| 63 |
+
"Now plan for this question:\n"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
_EXECUTOR_SYSTEM = (
|
| 67 |
+
"You are executing ONE step of a plan to answer a user's question about "
|
| 68 |
+
"University of Copenhagen (KU) documents.\n\n"
|
| 69 |
+
"You have retrieval tools available. Execute the step described below, "
|
| 70 |
+
"then summarise what you found in 2-3 sentences. If you find nothing "
|
| 71 |
+
"relevant, say so clearly.\n\n"
|
| 72 |
+
"Do NOT produce a final answer — just report what you found for this step."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
_SYNTHESIZER_PROMPT = (
|
| 76 |
+
"You are a helpful assistant for administrative staff at the University "
|
| 77 |
+
"of Copenhagen (KU).\n\n"
|
| 78 |
+
"Below are the results gathered from multiple research steps. "
|
| 79 |
+
"Synthesize them into a single coherent answer to the user's original question.\n\n"
|
| 80 |
+
"Guidelines:\n"
|
| 81 |
+
"- Cite document sources using [1], [2], etc.\n"
|
| 82 |
+
"- Answer in the same language as the user's question.\n"
|
| 83 |
+
"- Be concise but thorough.\n"
|
| 84 |
+
"- If some steps found no results, acknowledge gaps honestly.\n\n"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ------------------------------------------------------------------
|
| 89 |
+
# Graph state
|
| 90 |
+
# ------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
class PlanStep(TypedDict):
|
| 93 |
+
"""A single step in the execution plan."""
|
| 94 |
+
|
| 95 |
+
action: str
|
| 96 |
+
detail: str
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class PlanExecState(TypedDict):
|
| 100 |
+
"""State for the Plan-and-Execute graph.
|
| 101 |
+
|
| 102 |
+
Attributes:
|
| 103 |
+
query: The user's original question.
|
| 104 |
+
top_k: Number of results per retrieval call.
|
| 105 |
+
plan: Ordered list of steps produced by the planner.
|
| 106 |
+
step_index: Index of the next step to execute.
|
| 107 |
+
step_results: List of (step_description, result_text) pairs.
|
| 108 |
+
answer: Final synthesised answer.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
query: str
|
| 112 |
+
top_k: int
|
| 113 |
+
plan: list[PlanStep]
|
| 114 |
+
step_index: int
|
| 115 |
+
step_results: list[tuple[str, str]]
|
| 116 |
+
answer: str
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ------------------------------------------------------------------
|
| 120 |
+
# Router class
|
| 121 |
+
# ------------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
class PlanAndExecuteRouter:
|
| 124 |
+
"""Routes queries through a Plan-and-Execute pipeline.
|
| 125 |
+
|
| 126 |
+
Graph topology::
|
| 127 |
+
|
| 128 |
+
plan → should_execute? ─┬─ yes → execute_step → should_execute?
|
| 129 |
+
└─ no → synthesize → END
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
llm: Runnable,
|
| 135 |
+
hybrid_retriever: HybridRetriever,
|
| 136 |
+
reranker: Reranker,
|
| 137 |
+
vector_store: VectorStore,
|
| 138 |
+
default_top_k: int = 5,
|
| 139 |
+
memory: ConversationMemory | None = None,
|
| 140 |
+
) -> None:
|
| 141 |
+
"""Initialise the Plan-and-Execute router.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
llm: LLM with tool-calling support.
|
| 145 |
+
hybrid_retriever: HybridRetriever instance.
|
| 146 |
+
reranker: Reranker instance.
|
| 147 |
+
vector_store: VectorStore instance.
|
| 148 |
+
default_top_k: Default number of results per retrieval call.
|
| 149 |
+
memory: Optional ConversationMemory for multi-turn context.
|
| 150 |
+
When provided, prior conversation history is injected into
|
| 151 |
+
planner and synthesizer prompts, and each completed turn
|
| 152 |
+
is automatically recorded.
|
| 153 |
+
"""
|
| 154 |
+
self._llm = llm
|
| 155 |
+
self._hybrid_retriever = hybrid_retriever
|
| 156 |
+
self._reranker = reranker
|
| 157 |
+
self._vector_store = vector_store
|
| 158 |
+
self._default_top_k = default_top_k
|
| 159 |
+
self._store = ToolResultStore()
|
| 160 |
+
self._memory = memory or ConversationMemory()
|
| 161 |
+
|
| 162 |
+
# ------------------------------------------------------------------
|
| 163 |
+
# Node functions
|
| 164 |
+
# ------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def _plan_node(self, state: PlanExecState) -> dict:
|
| 167 |
+
"""Generate an execution plan from the user query."""
|
| 168 |
+
history = self._memory.format_history()
|
| 169 |
+
history_section = ""
|
| 170 |
+
if history:
|
| 171 |
+
history_section = (
|
| 172 |
+
f"Conversation history (for context on follow-up questions):\n"
|
| 173 |
+
f"{history}\n\n"
|
| 174 |
+
)
|
| 175 |
+
prompt = _PLANNER_PROMPT + history_section + f'Question: "{state["query"]}"'
|
| 176 |
+
raw = str(self._llm.invoke(prompt)).strip()
|
| 177 |
+
logger.info("Planner raw output: %s", raw)
|
| 178 |
+
|
| 179 |
+
plan = _parse_plan(raw)
|
| 180 |
+
logger.info("Plan: %d steps — %s", len(plan), plan)
|
| 181 |
+
return {"plan": plan, "step_index": 0, "step_results": []}
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _should_execute(state: PlanExecState) -> str:
|
| 185 |
+
"""Decide whether to execute the next step or synthesize."""
|
| 186 |
+
if state["step_index"] < len(state["plan"]) and state["step_index"] < _MAX_STEPS:
|
| 187 |
+
return "execute"
|
| 188 |
+
return "synthesize"
|
| 189 |
+
|
| 190 |
+
def _execute_step_node(self, state: PlanExecState) -> dict:
|
| 191 |
+
"""Execute the current plan step using a ReAct sub-agent."""
|
| 192 |
+
idx = state["step_index"]
|
| 193 |
+
step = state["plan"][idx]
|
| 194 |
+
step_desc = f'{step["action"]}: {step["detail"]}'
|
| 195 |
+
logger.info("Executing step %d/%d: %s", idx + 1, len(state["plan"]), step_desc)
|
| 196 |
+
|
| 197 |
+
# Build a fresh tool set and sub-agent for this step
|
| 198 |
+
tools = make_retrieval_tools(
|
| 199 |
+
self._hybrid_retriever,
|
| 200 |
+
self._reranker,
|
| 201 |
+
self._vector_store,
|
| 202 |
+
self._store,
|
| 203 |
+
self._default_top_k,
|
| 204 |
+
llm_chain=self._llm,
|
| 205 |
+
)
|
| 206 |
+
sub_agent = create_react_agent(self._llm, tools)
|
| 207 |
+
|
| 208 |
+
step_prompt = (
|
| 209 |
+
f'Step to execute: {step_desc}\n\n'
|
| 210 |
+
f'Original user question (for context): {state["query"]}'
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
result = sub_agent.invoke({
|
| 214 |
+
"messages": [
|
| 215 |
+
SystemMessage(content=_EXECUTOR_SYSTEM),
|
| 216 |
+
HumanMessage(content=step_prompt),
|
| 217 |
+
]
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
# Extract the sub-agent's final text answer
|
| 221 |
+
answer = _extract_last_ai_text(result.get("messages", []))
|
| 222 |
+
logger.info("Step %d result: %s", idx + 1, answer[:200])
|
| 223 |
+
|
| 224 |
+
new_results = list(state["step_results"]) + [(step_desc, answer)]
|
| 225 |
+
return {"step_index": idx + 1, "step_results": new_results}
|
| 226 |
+
|
| 227 |
+
def _synthesize_node(self, state: PlanExecState) -> dict:
|
| 228 |
+
"""Synthesize a final answer from all step results."""
|
| 229 |
+
step_texts = []
|
| 230 |
+
for i, (desc, result) in enumerate(state["step_results"], 1):
|
| 231 |
+
step_texts.append(f"### Step {i}: {desc}\n{result}")
|
| 232 |
+
gathered = "\n\n".join(step_texts)
|
| 233 |
+
|
| 234 |
+
history = self._memory.format_history()
|
| 235 |
+
history_section = ""
|
| 236 |
+
if history:
|
| 237 |
+
history_section = (
|
| 238 |
+
f"Prior conversation:\n{history}\n\n"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
prompt = (
|
| 242 |
+
f"{_SYNTHESIZER_PROMPT}"
|
| 243 |
+
f"{history_section}"
|
| 244 |
+
f"Original question: {state['query']}\n\n"
|
| 245 |
+
f"Research results:\n{gathered}\n\n"
|
| 246 |
+
f"Answer:"
|
| 247 |
+
)
|
| 248 |
+
answer = str(self._llm.invoke(prompt)).strip()
|
| 249 |
+
logger.info("Synthesized final answer (%d chars)", len(answer))
|
| 250 |
+
return {"answer": answer}
|
| 251 |
+
|
| 252 |
+
# ------------------------------------------------------------------
|
| 253 |
+
# Graph construction
|
| 254 |
+
# ------------------------------------------------------------------
|
| 255 |
+
|
| 256 |
+
def _build_graph(self) -> object:
|
| 257 |
+
"""Build the Plan-and-Execute LangGraph.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Compiled LangGraph.
|
| 261 |
+
"""
|
| 262 |
+
graph: StateGraph = StateGraph(PlanExecState)
|
| 263 |
+
|
| 264 |
+
graph.add_node("plan", self._plan_node)
|
| 265 |
+
graph.add_node("execute_step", self._execute_step_node)
|
| 266 |
+
graph.add_node("synthesize", self._synthesize_node)
|
| 267 |
+
|
| 268 |
+
graph.set_entry_point("plan")
|
| 269 |
+
graph.add_conditional_edges(
|
| 270 |
+
"plan",
|
| 271 |
+
self._should_execute,
|
| 272 |
+
{"execute": "execute_step", "synthesize": "synthesize"},
|
| 273 |
+
)
|
| 274 |
+
graph.add_conditional_edges(
|
| 275 |
+
"execute_step",
|
| 276 |
+
self._should_execute,
|
| 277 |
+
{"execute": "execute_step", "synthesize": "synthesize"},
|
| 278 |
+
)
|
| 279 |
+
graph.add_edge("synthesize", END)
|
| 280 |
+
|
| 281 |
+
return graph.compile()
|
| 282 |
+
|
| 283 |
+
# ------------------------------------------------------------------
|
| 284 |
+
# Public interface (mirrors QueryRouter / ReActRouter)
|
| 285 |
+
# ------------------------------------------------------------------
|
| 286 |
+
|
| 287 |
+
def route(self, query: str, top_k: int) -> GenerationResponse:
|
| 288 |
+
"""Route a query through the Plan-and-Execute pipeline.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
query: The user's natural language query.
|
| 292 |
+
top_k: Number of top documents to retrieve per tool call.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
GenerationResponse with answer, sources, intent, and confidence.
|
| 296 |
+
"""
|
| 297 |
+
logger.info("PlanExec routing query: %s", query)
|
| 298 |
+
self._store = ToolResultStore()
|
| 299 |
+
|
| 300 |
+
initial_state = PlanExecState(
|
| 301 |
+
query=query,
|
| 302 |
+
top_k=top_k,
|
| 303 |
+
plan=[],
|
| 304 |
+
step_index=0,
|
| 305 |
+
step_results=[],
|
| 306 |
+
answer="",
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
graph = self._build_graph()
|
| 310 |
+
final_state: PlanExecState = graph.invoke(initial_state)
|
| 311 |
+
|
| 312 |
+
sources = self._store.retrieved[:top_k]
|
| 313 |
+
confidence = max((r.score for r in sources), default=0.0)
|
| 314 |
+
|
| 315 |
+
plan_step_strs = [
|
| 316 |
+
f'{s["action"]}: {s["detail"]}' for s in final_state.get("plan", [])
|
| 317 |
+
]
|
| 318 |
+
tool_call_strs = [f"{name}: {arg}" for name, arg in self._store.tool_calls]
|
| 319 |
+
|
| 320 |
+
response = GenerationResponse(
|
| 321 |
+
answer=final_state["answer"],
|
| 322 |
+
sources=sources,
|
| 323 |
+
intent=IntentType.RAG if sources else IntentType.FACTUAL,
|
| 324 |
+
confidence=confidence,
|
| 325 |
+
pipeline_details=PipelineDetails(
|
| 326 |
+
original_query=query,
|
| 327 |
+
retrieval_query=", ".join(
|
| 328 |
+
q for name, q in self._store.tool_calls if name == "hybrid_search"
|
| 329 |
+
) or query,
|
| 330 |
+
dense_results=self._store.dense_results,
|
| 331 |
+
sparse_results=self._store.sparse_results,
|
| 332 |
+
fused_results=self._store.fused_results,
|
| 333 |
+
reranked_results=sources,
|
| 334 |
+
plan_steps=plan_step_strs,
|
| 335 |
+
tool_calls=tool_call_strs,
|
| 336 |
+
),
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self._memory.add_turn(query, response.answer, sources)
|
| 340 |
+
return response
|
| 341 |
+
|
| 342 |
+
def route_stream(self, query: str, top_k: int) -> Generator[dict, None, None]:
|
| 343 |
+
"""Stream Plan-and-Execute events step by step.
|
| 344 |
+
|
| 345 |
+
Yields event dicts with step types:
|
| 346 |
+
- ``plan`` — plan was generated; carries ``steps``.
|
| 347 |
+
- ``execute_step`` — a step was executed; carries ``step_index``,
|
| 348 |
+
``step_desc``, ``result_preview``.
|
| 349 |
+
- ``synthesize`` — final answer generated.
|
| 350 |
+
- ``done`` — final event with full result payload.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
query: User query.
|
| 354 |
+
top_k: Number of results to retrieve per tool call.
|
| 355 |
+
|
| 356 |
+
Yields:
|
| 357 |
+
Step event dicts.
|
| 358 |
+
"""
|
| 359 |
+
self._store = ToolResultStore()
|
| 360 |
+
|
| 361 |
+
initial_state = PlanExecState(
|
| 362 |
+
query=query,
|
| 363 |
+
top_k=top_k,
|
| 364 |
+
plan=[],
|
| 365 |
+
step_index=0,
|
| 366 |
+
step_results=[],
|
| 367 |
+
answer="",
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
graph = self._build_graph()
|
| 371 |
+
accumulated: dict = dict(initial_state)
|
| 372 |
+
|
| 373 |
+
for chunk in graph.stream(initial_state, stream_mode="updates"):
|
| 374 |
+
for node_name, update in chunk.items():
|
| 375 |
+
if update is None:
|
| 376 |
+
continue
|
| 377 |
+
accumulated.update(update)
|
| 378 |
+
|
| 379 |
+
if node_name == "plan":
|
| 380 |
+
yield {
|
| 381 |
+
"step": "plan",
|
| 382 |
+
"steps": [
|
| 383 |
+
f'{s["action"]}: {s["detail"]}'
|
| 384 |
+
for s in update.get("plan", [])
|
| 385 |
+
],
|
| 386 |
+
}
|
| 387 |
+
elif node_name == "execute_step":
|
| 388 |
+
results = update.get("step_results", [])
|
| 389 |
+
if results:
|
| 390 |
+
last_desc, last_result = results[-1]
|
| 391 |
+
yield {
|
| 392 |
+
"step": "execute_step",
|
| 393 |
+
"step_index": update.get("step_index", 0),
|
| 394 |
+
"step_desc": last_desc,
|
| 395 |
+
"result_preview": last_result[:300],
|
| 396 |
+
}
|
| 397 |
+
elif node_name == "synthesize":
|
| 398 |
+
yield {"step": "synthesize"}
|
| 399 |
+
|
| 400 |
+
sources = self._store.retrieved[:top_k]
|
| 401 |
+
confidence = max((r.score for r in sources), default=0.0)
|
| 402 |
+
answer = accumulated.get("answer", "")
|
| 403 |
+
|
| 404 |
+
self._memory.add_turn(query, answer, sources)
|
| 405 |
+
|
| 406 |
+
yield {
|
| 407 |
+
"step": "done",
|
| 408 |
+
"result": {
|
| 409 |
+
"answer": answer,
|
| 410 |
+
"sources": [r.to_dict() for r in sources],
|
| 411 |
+
"intent": (IntentType.RAG if sources else IntentType.FACTUAL).value,
|
| 412 |
+
"confidence": confidence,
|
| 413 |
+
"pipeline_details": {
|
| 414 |
+
"original_query": query,
|
| 415 |
+
"retrieval_query": ", ".join(
|
| 416 |
+
q for name, q in self._store.tool_calls if name == "hybrid_search"
|
| 417 |
+
) or query,
|
| 418 |
+
"detected_language": "",
|
| 419 |
+
"translated": False,
|
| 420 |
+
"dense_results": [r.to_dict(include_text=False) for r in self._store.dense_results],
|
| 421 |
+
"sparse_results": [r.to_dict(include_text=False) for r in self._store.sparse_results],
|
| 422 |
+
"fused_results": [r.to_dict(include_text=False) for r in self._store.fused_results],
|
| 423 |
+
"reranked_results": [r.to_dict(include_text=False) for r in sources],
|
| 424 |
+
"plan_steps": [
|
| 425 |
+
f'{s["action"]}: {s["detail"]}'
|
| 426 |
+
for s in accumulated.get("plan", [])
|
| 427 |
+
],
|
| 428 |
+
"tool_calls": [f"{n}: {a}" for n, a in self._store.tool_calls],
|
| 429 |
+
},
|
| 430 |
+
},
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# ------------------------------------------------------------------
|
| 435 |
+
# Helpers
|
| 436 |
+
# ------------------------------------------------------------------
|
| 437 |
+
|
| 438 |
+
def _parse_plan(raw: str) -> list[PlanStep]:
|
| 439 |
+
"""Parse the planner's JSON output into a list of PlanStep dicts.
|
| 440 |
+
|
| 441 |
+
Robust against markdown fences, trailing text, and minor formatting issues.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
raw: Raw LLM output expected to contain a JSON array.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
List of PlanStep dicts. Falls back to a single search step on failure.
|
| 448 |
+
"""
|
| 449 |
+
# Strip markdown code fences if present
|
| 450 |
+
cleaned = raw.strip()
|
| 451 |
+
if cleaned.startswith("```"):
|
| 452 |
+
lines = cleaned.splitlines()
|
| 453 |
+
# Remove opening and closing fences
|
| 454 |
+
lines = [l for l in lines if not l.strip().startswith("```")]
|
| 455 |
+
cleaned = "\n".join(lines).strip()
|
| 456 |
+
|
| 457 |
+
try:
|
| 458 |
+
parsed = json.loads(cleaned)
|
| 459 |
+
except json.JSONDecodeError:
|
| 460 |
+
# Try to extract a JSON array from the text
|
| 461 |
+
start = cleaned.find("[")
|
| 462 |
+
end = cleaned.rfind("]")
|
| 463 |
+
if start != -1 and end != -1:
|
| 464 |
+
try:
|
| 465 |
+
parsed = json.loads(cleaned[start:end + 1])
|
| 466 |
+
except json.JSONDecodeError:
|
| 467 |
+
logger.warning("Failed to parse plan, falling back to single search")
|
| 468 |
+
return [PlanStep(action="search", detail=cleaned[:200])]
|
| 469 |
+
else:
|
| 470 |
+
logger.warning("No JSON array found in plan output, falling back")
|
| 471 |
+
return [PlanStep(action="search", detail=cleaned[:200])]
|
| 472 |
+
|
| 473 |
+
if not isinstance(parsed, list):
|
| 474 |
+
logger.warning("Plan is not a list, wrapping")
|
| 475 |
+
parsed = [parsed]
|
| 476 |
+
|
| 477 |
+
steps: list[PlanStep] = []
|
| 478 |
+
for item in parsed:
|
| 479 |
+
if isinstance(item, dict) and "action" in item and "detail" in item:
|
| 480 |
+
steps.append(PlanStep(action=str(item["action"]), detail=str(item["detail"])))
|
| 481 |
+
else:
|
| 482 |
+
logger.warning("Skipping malformed plan step: %s", item)
|
| 483 |
+
|
| 484 |
+
if not steps:
|
| 485 |
+
return [PlanStep(action="search", detail="general search")]
|
| 486 |
+
|
| 487 |
+
return steps
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _extract_last_ai_text(messages: list) -> str:
|
| 491 |
+
"""Return the text content of the last non-tool-call AI message.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
messages: List of LangChain message objects.
|
| 495 |
+
|
| 496 |
+
Returns:
|
| 497 |
+
The extracted text, or empty string if none found.
|
| 498 |
+
"""
|
| 499 |
+
for msg in reversed(messages):
|
| 500 |
+
if (
|
| 501 |
+
isinstance(msg, AIMessage)
|
| 502 |
+
and msg.content
|
| 503 |
+
and not getattr(msg, "tool_calls", None)
|
| 504 |
+
):
|
| 505 |
+
return str(msg.content)
|
| 506 |
+
return ""
|
src/agent/react_router.py
CHANGED
|
@@ -26,14 +26,22 @@ logger = logging.getLogger(__name__)
|
|
| 26 |
|
| 27 |
_SYSTEM_PROMPT = (
|
| 28 |
"You are a helpful assistant for administrative staff at the University of Copenhagen (KU).\n\n"
|
| 29 |
-
"You have access to
|
| 30 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"Guidelines:\n"
|
| 32 |
-
"- Always
|
| 33 |
"employment conditions, or administrative procedures.\n"
|
| 34 |
-
"-
|
| 35 |
-
"
|
| 36 |
-
"
|
|
|
|
|
|
|
|
|
|
| 37 |
"- Cite the document sources ([1], [2], …) in your answer.\n"
|
| 38 |
"- Answer in the same language as the user's question."
|
| 39 |
)
|
|
@@ -83,6 +91,7 @@ class ReActRouter:
|
|
| 83 |
self._vector_store,
|
| 84 |
store,
|
| 85 |
self._default_top_k,
|
|
|
|
| 86 |
)
|
| 87 |
return create_react_agent(self._llm, tools)
|
| 88 |
|
|
|
|
| 26 |
|
| 27 |
_SYSTEM_PROMPT = (
|
| 28 |
"You are a helpful assistant for administrative staff at the University of Copenhagen (KU).\n\n"
|
| 29 |
+
"You have access to several tools for searching KU policy documents:\n"
|
| 30 |
+
"- hybrid_search: General-purpose search across all documents.\n"
|
| 31 |
+
"- multi_query_search: For complex or comparison questions — decomposes into sub-queries.\n"
|
| 32 |
+
"- search_within_document: Pinpoint specific sections inside a known document.\n"
|
| 33 |
+
"- summarize_document: Generate an overview of an entire document.\n"
|
| 34 |
+
"- list_documents: See which documents are available.\n"
|
| 35 |
+
"- fetch_document: Get the full text of a specific document.\n\n"
|
| 36 |
"Guidelines:\n"
|
| 37 |
+
"- Always search before answering questions about KU rules, policies, exams, "
|
| 38 |
"employment conditions, or administrative procedures.\n"
|
| 39 |
+
"- Use multi_query_search for comparison questions or complex multi-part questions.\n"
|
| 40 |
+
"- Use search_within_document when you already know the relevant document and "
|
| 41 |
+
"need to find a specific clause or section.\n"
|
| 42 |
+
"- Use summarize_document when the user asks for an overview of a document.\n"
|
| 43 |
+
"- If the first search does not return sufficient information, try a different "
|
| 44 |
+
"tool or refine your query.\n"
|
| 45 |
"- Cite the document sources ([1], [2], …) in your answer.\n"
|
| 46 |
"- Answer in the same language as the user's question."
|
| 47 |
)
|
|
|
|
| 91 |
self._vector_store,
|
| 92 |
store,
|
| 93 |
self._default_top_k,
|
| 94 |
+
llm_chain=self._llm,
|
| 95 |
)
|
| 96 |
return create_react_agent(self._llm, tools)
|
| 97 |
|
src/agent/tools.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import logging
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
|
|
|
|
| 6 |
from langchain_core.tools import tool
|
| 7 |
|
| 8 |
from src.models import QueryResult
|
|
@@ -33,12 +34,54 @@ class ToolResultStore:
|
|
| 33 |
fused_results: list[QueryResult] = field(default_factory=list)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def make_retrieval_tools(
|
| 37 |
hybrid_retriever: HybridRetriever,
|
| 38 |
reranker: Reranker,
|
| 39 |
vector_store: VectorStore,
|
| 40 |
store: ToolResultStore,
|
| 41 |
default_top_k: int = 5,
|
|
|
|
| 42 |
) -> list:
|
| 43 |
"""Create retrieval tools bound to the given components and result store.
|
| 44 |
|
|
@@ -52,11 +95,18 @@ def make_retrieval_tools(
|
|
| 52 |
vector_store: VectorStore instance for document-level access.
|
| 53 |
store: Shared ToolResultStore that captures structured results.
|
| 54 |
default_top_k: Default number of results to return per call.
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
Returns:
|
| 57 |
List of LangChain tool callables ready for bind_tools / ToolNode.
|
| 58 |
"""
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
@tool
|
| 61 |
def hybrid_search(query: str, top_k: int = default_top_k) -> str:
|
| 62 |
"""Search the KU document knowledge base using hybrid retrieval.
|
|
@@ -84,36 +134,16 @@ def make_retrieval_tools(
|
|
| 84 |
hybrid_result = hybrid_retriever.search_detailed(query, top_k=top_k)
|
| 85 |
results = reranker.rerank(query, hybrid_result.fused_results, top_k=top_k)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
cid = r.chunk.chunk_id
|
| 92 |
-
if cid not in by_id or r.score > by_id[cid].score:
|
| 93 |
-
by_id[cid] = r
|
| 94 |
-
return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
store.sparse_results = _merge(store.sparse_results, hybrid_result.sparse_results)
|
| 98 |
-
store.fused_results = _merge(store.fused_results, hybrid_result.fused_results)
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
cid = r.chunk.chunk_id
|
| 104 |
-
if cid not in existing or r.score > existing[cid].score:
|
| 105 |
-
existing[cid] = r
|
| 106 |
-
store.retrieved = sorted(existing.values(), key=lambda r: r.score, reverse=True)
|
| 107 |
-
|
| 108 |
-
if not results:
|
| 109 |
-
return "Ingen relevante dokumenter fundet. (No relevant documents found.)"
|
| 110 |
-
|
| 111 |
-
parts: list[str] = []
|
| 112 |
-
for i, r in enumerate(results, 1):
|
| 113 |
-
parts.append(
|
| 114 |
-
f"[{i}] {r.chunk.document_id} (relevance: {r.score:.3f})\n{r.chunk.text}"
|
| 115 |
-
)
|
| 116 |
-
return "\n\n---\n\n".join(parts)
|
| 117 |
|
| 118 |
@tool
|
| 119 |
def list_documents() -> str:
|
|
@@ -161,11 +191,8 @@ def make_retrieval_tools(
|
|
| 161 |
f"(Document not found. Use list_documents to see available IDs.)"
|
| 162 |
)
|
| 163 |
|
| 164 |
-
# Sort chunks by chunk_index to preserve document order
|
| 165 |
chunks.sort(key=lambda c: c.metadata.get("chunk_index", 0))
|
| 166 |
|
| 167 |
-
# Register chunks as QueryResult so confidence and sources are surfaced in the UI.
|
| 168 |
-
# Score 1.0 indicates a direct full-document fetch (no ranking involved).
|
| 169 |
existing = {r.chunk.chunk_id: r for r in store.retrieved}
|
| 170 |
for chunk in chunks:
|
| 171 |
if chunk.chunk_id not in existing:
|
|
@@ -178,4 +205,171 @@ def make_retrieval_tools(
|
|
| 178 |
f"{full_text}"
|
| 179 |
)
|
| 180 |
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
|
| 6 |
+
from langchain_core.runnables import Runnable
|
| 7 |
from langchain_core.tools import tool
|
| 8 |
|
| 9 |
from src.models import QueryResult
|
|
|
|
| 34 |
fused_results: list[QueryResult] = field(default_factory=list)
|
| 35 |
|
| 36 |
|
| 37 |
+
def _merge_results(existing: list[QueryResult], new: list[QueryResult]) -> list[QueryResult]:
|
| 38 |
+
"""Merge two QueryResult lists by chunk_id, keeping the highest score.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
existing: Previously accumulated results.
|
| 42 |
+
new: New results to merge in.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Merged list sorted by descending score.
|
| 46 |
+
"""
|
| 47 |
+
by_id = {r.chunk.chunk_id: r for r in existing}
|
| 48 |
+
for r in new:
|
| 49 |
+
cid = r.chunk.chunk_id
|
| 50 |
+
if cid not in by_id or r.score > by_id[cid].score:
|
| 51 |
+
by_id[cid] = r
|
| 52 |
+
return sorted(by_id.values(), key=lambda r: r.score, reverse=True)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _format_results(results: list[QueryResult]) -> str:
|
| 56 |
+
"""Format a list of QueryResult into a readable string.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
results: Ranked results to format.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Formatted string with numbered entries, or a no-results message.
|
| 63 |
+
"""
|
| 64 |
+
if not results:
|
| 65 |
+
return "Ingen relevante dokumenter fundet. (No relevant documents found.)"
|
| 66 |
+
parts: list[str] = []
|
| 67 |
+
for i, r in enumerate(results, 1):
|
| 68 |
+
page_info = ""
|
| 69 |
+
page = r.chunk.metadata.get("page_number")
|
| 70 |
+
if page is not None:
|
| 71 |
+
page_info = f" side {page}"
|
| 72 |
+
parts.append(
|
| 73 |
+
f"[{i}] {r.chunk.document_id}{page_info} (relevance: {r.score:.3f})\n{r.chunk.text}"
|
| 74 |
+
)
|
| 75 |
+
return "\n\n---\n\n".join(parts)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
def make_retrieval_tools(
|
| 79 |
hybrid_retriever: HybridRetriever,
|
| 80 |
reranker: Reranker,
|
| 81 |
vector_store: VectorStore,
|
| 82 |
store: ToolResultStore,
|
| 83 |
default_top_k: int = 5,
|
| 84 |
+
llm_chain: Runnable | None = None,
|
| 85 |
) -> list:
|
| 86 |
"""Create retrieval tools bound to the given components and result store.
|
| 87 |
|
|
|
|
| 95 |
vector_store: VectorStore instance for document-level access.
|
| 96 |
store: Shared ToolResultStore that captures structured results.
|
| 97 |
default_top_k: Default number of results to return per call.
|
| 98 |
+
llm_chain: Optional LLM chain for tools that need generation
|
| 99 |
+
(summarize_document, multi_query_search). When None, those
|
| 100 |
+
tools are excluded from the returned list.
|
| 101 |
|
| 102 |
Returns:
|
| 103 |
List of LangChain tool callables ready for bind_tools / ToolNode.
|
| 104 |
"""
|
| 105 |
|
| 106 |
+
# ------------------------------------------------------------------
|
| 107 |
+
# Core search tool
|
| 108 |
+
# ------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
@tool
|
| 111 |
def hybrid_search(query: str, top_k: int = default_top_k) -> str:
|
| 112 |
"""Search the KU document knowledge base using hybrid retrieval.
|
|
|
|
| 134 |
hybrid_result = hybrid_retriever.search_detailed(query, top_k=top_k)
|
| 135 |
results = reranker.rerank(query, hybrid_result.fused_results, top_k=top_k)
|
| 136 |
|
| 137 |
+
store.dense_results = _merge_results(store.dense_results, hybrid_result.dense_results)
|
| 138 |
+
store.sparse_results = _merge_results(store.sparse_results, hybrid_result.sparse_results)
|
| 139 |
+
store.fused_results = _merge_results(store.fused_results, hybrid_result.fused_results)
|
| 140 |
+
store.retrieved = _merge_results(store.retrieved, results)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
return _format_results(results)
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
# ------------------------------------------------------------------
|
| 145 |
+
# Document-level tools
|
| 146 |
+
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
@tool
|
| 149 |
def list_documents() -> str:
|
|
|
|
| 191 |
f"(Document not found. Use list_documents to see available IDs.)"
|
| 192 |
)
|
| 193 |
|
|
|
|
| 194 |
chunks.sort(key=lambda c: c.metadata.get("chunk_index", 0))
|
| 195 |
|
|
|
|
|
|
|
| 196 |
existing = {r.chunk.chunk_id: r for r in store.retrieved}
|
| 197 |
for chunk in chunks:
|
| 198 |
if chunk.chunk_id not in existing:
|
|
|
|
| 205 |
f"{full_text}"
|
| 206 |
)
|
| 207 |
|
| 208 |
+
# ------------------------------------------------------------------
|
| 209 |
+
# Targeted within-document search
|
| 210 |
+
# ------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
@tool
|
| 213 |
+
def search_within_document(document_id: str, query: str, top_k: int = 3) -> str:
|
| 214 |
+
"""Search for specific information within a single document.
|
| 215 |
+
|
| 216 |
+
Retrieves all chunks belonging to the document and uses the cross-encoder
|
| 217 |
+
reranker to find the most relevant passages for the query. Use this when
|
| 218 |
+
you already know which document to look in and need to pinpoint the exact
|
| 219 |
+
section (e.g. a specific clause, page, or paragraph).
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
document_id: The exact document ID to search within.
|
| 223 |
+
query: What to look for inside the document.
|
| 224 |
+
top_k: Number of top passages to return (1–10). Default is 3.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
The most relevant passages within the document, ranked by relevance.
|
| 228 |
+
"""
|
| 229 |
+
logger.info(
|
| 230 |
+
"Tool search_within_document: doc=%r query=%r top_k=%d",
|
| 231 |
+
document_id, query, top_k,
|
| 232 |
+
)
|
| 233 |
+
store.tool_calls.append(("search_within_document", f"{document_id}: {query}"))
|
| 234 |
+
|
| 235 |
+
chunks = vector_store.get_chunks_by_document_id(document_id)
|
| 236 |
+
if not chunks:
|
| 237 |
+
return (
|
| 238 |
+
f"Dokumentet '{document_id}' blev ikke fundet i vidensbasen. "
|
| 239 |
+
f"(Document not found. Use list_documents to see available IDs.)"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Wrap chunks as QueryResult so the reranker can score them
|
| 243 |
+
candidates = [
|
| 244 |
+
QueryResult(chunk=c, score=0.0, source="search_within_document")
|
| 245 |
+
for c in chunks
|
| 246 |
+
]
|
| 247 |
+
results = reranker.rerank(query, candidates, top_k=top_k)
|
| 248 |
+
|
| 249 |
+
store.retrieved = _merge_results(store.retrieved, results)
|
| 250 |
+
|
| 251 |
+
return _format_results(results)
|
| 252 |
+
|
| 253 |
+
# ------------------------------------------------------------------
|
| 254 |
+
# LLM-powered tools (only available when llm_chain is provided)
|
| 255 |
+
# ------------------------------------------------------------------
|
| 256 |
+
|
| 257 |
+
tools: list = [hybrid_search, list_documents, fetch_document, search_within_document]
|
| 258 |
+
|
| 259 |
+
if llm_chain is not None:
|
| 260 |
+
|
| 261 |
+
@tool
|
| 262 |
+
def multi_query_search(question: str, top_k: int = default_top_k) -> str:
|
| 263 |
+
"""Decompose a complex question into sub-queries and search each independently.
|
| 264 |
+
|
| 265 |
+
Use this tool instead of hybrid_search when the question involves
|
| 266 |
+
multiple aspects, comparisons, or requires information from different
|
| 267 |
+
topics. For example: "How do exam rules differ between bachelor and
|
| 268 |
+
master programmes?" would be split into separate searches for each
|
| 269 |
+
programme's exam rules, then merged.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
question: The complex user question to decompose and search.
|
| 273 |
+
top_k: Number of results to return per sub-query (1–10). Default is 5.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Combined results from all sub-queries, deduplicated and ranked.
|
| 277 |
+
"""
|
| 278 |
+
logger.info("Tool multi_query_search: question=%r", question)
|
| 279 |
+
store.tool_calls.append(("multi_query_search", question))
|
| 280 |
+
|
| 281 |
+
# Step 1: Ask LLM to decompose the question
|
| 282 |
+
decompose_prompt = (
|
| 283 |
+
"You are a search query planner. Given a complex question, "
|
| 284 |
+
"decompose it into 2-4 simple, independent search queries that "
|
| 285 |
+
"together cover all aspects of the question. The queries should "
|
| 286 |
+
"be in Danish (since the document base is Danish).\n\n"
|
| 287 |
+
"Reply with ONLY the queries, one per line, nothing else.\n\n"
|
| 288 |
+
f"Question: {question}"
|
| 289 |
+
)
|
| 290 |
+
raw = str(llm_chain.invoke(decompose_prompt)).strip()
|
| 291 |
+
sub_queries = [q.strip().lstrip("0123456789.-) ") for q in raw.splitlines() if q.strip()]
|
| 292 |
+
if not sub_queries:
|
| 293 |
+
sub_queries = [question]
|
| 294 |
+
|
| 295 |
+
logger.info("Decomposed into %d sub-queries: %s", len(sub_queries), sub_queries)
|
| 296 |
+
|
| 297 |
+
# Step 2: Search each sub-query independently
|
| 298 |
+
all_results: list[QueryResult] = []
|
| 299 |
+
for sq in sub_queries:
|
| 300 |
+
hybrid_result = hybrid_retriever.search_detailed(sq, top_k=top_k)
|
| 301 |
+
reranked = reranker.rerank(sq, hybrid_result.fused_results, top_k=top_k)
|
| 302 |
+
all_results = _merge_results(all_results, reranked)
|
| 303 |
+
|
| 304 |
+
store.dense_results = _merge_results(store.dense_results, hybrid_result.dense_results)
|
| 305 |
+
store.sparse_results = _merge_results(store.sparse_results, hybrid_result.sparse_results)
|
| 306 |
+
store.fused_results = _merge_results(store.fused_results, hybrid_result.fused_results)
|
| 307 |
+
|
| 308 |
+
# Step 3: Keep top results across all sub-queries
|
| 309 |
+
final = all_results[:top_k]
|
| 310 |
+
store.retrieved = _merge_results(store.retrieved, final)
|
| 311 |
+
|
| 312 |
+
header = f"Søgning opdelt i {len(sub_queries)} delforespørgsler:\n"
|
| 313 |
+
header += "\n".join(f" • {sq}" for sq in sub_queries)
|
| 314 |
+
header += "\n\n"
|
| 315 |
+
return header + _format_results(final)
|
| 316 |
+
|
| 317 |
+
@tool
|
| 318 |
+
def summarize_document(document_id: str) -> str:
|
| 319 |
+
"""Generate a structured summary of a document in the knowledge base.
|
| 320 |
+
|
| 321 |
+
Fetches the full document and uses the LLM to produce a concise summary
|
| 322 |
+
covering the main topics, key rules, and important details. Use this
|
| 323 |
+
when the user asks "what is this document about?" or wants an overview
|
| 324 |
+
before diving into specifics.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
document_id: The exact document ID to summarize.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
A structured summary of the document, or an error if not found.
|
| 331 |
+
"""
|
| 332 |
+
logger.info("Tool summarize_document: document_id=%r", document_id)
|
| 333 |
+
store.tool_calls.append(("summarize_document", document_id))
|
| 334 |
+
|
| 335 |
+
chunks = vector_store.get_chunks_by_document_id(document_id)
|
| 336 |
+
if not chunks:
|
| 337 |
+
return (
|
| 338 |
+
f"Dokumentet '{document_id}' blev ikke fundet i vidensbasen. "
|
| 339 |
+
f"(Document not found. Use list_documents to see available IDs.)"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
chunks.sort(key=lambda c: c.metadata.get("chunk_index", 0))
|
| 343 |
+
full_text = "\n\n".join(c.text for c in chunks)
|
| 344 |
+
|
| 345 |
+
# Register chunks as sources
|
| 346 |
+
existing = {r.chunk.chunk_id: r for r in store.retrieved}
|
| 347 |
+
for chunk in chunks:
|
| 348 |
+
if chunk.chunk_id not in existing:
|
| 349 |
+
existing[chunk.chunk_id] = QueryResult(
|
| 350 |
+
chunk=chunk, score=1.0, source="summarize_document",
|
| 351 |
+
)
|
| 352 |
+
store.retrieved = sorted(existing.values(), key=lambda r: r.score, reverse=True)
|
| 353 |
+
|
| 354 |
+
# Truncate to avoid exceeding context limits
|
| 355 |
+
max_chars = 8000
|
| 356 |
+
if len(full_text) > max_chars:
|
| 357 |
+
full_text = full_text[:max_chars] + "\n\n[... teksten er forkortet ...]"
|
| 358 |
+
|
| 359 |
+
summary_prompt = (
|
| 360 |
+
"Produce a structured summary of the following document. "
|
| 361 |
+
"Include:\n"
|
| 362 |
+
"1. Document title/topic\n"
|
| 363 |
+
"2. Key points (3-7 bullet points)\n"
|
| 364 |
+
"3. Important rules, deadlines, or requirements mentioned\n"
|
| 365 |
+
"4. Who the document applies to\n\n"
|
| 366 |
+
"Write the summary in the same language as the document.\n\n"
|
| 367 |
+
f"Document ID: {document_id}\n\n"
|
| 368 |
+
f"Document text:\n{full_text}"
|
| 369 |
+
)
|
| 370 |
+
summary = str(llm_chain.invoke(summary_prompt)).strip()
|
| 371 |
+
return f"Resumé af {document_id}:\n\n{summary}"
|
| 372 |
+
|
| 373 |
+
tools.extend([multi_query_search, summarize_document])
|
| 374 |
+
|
| 375 |
+
return tools
|
src/api/main.py
CHANGED
|
@@ -16,7 +16,8 @@ from src.retrieval.hybrid import HybridRetriever
|
|
| 16 |
from src.retrieval.reranker import Reranker
|
| 17 |
from src.agent.intent_classifier import IntentClassifier
|
| 18 |
from src.agent.router import QueryRouter
|
| 19 |
-
from src.agent.
|
|
|
|
| 20 |
from src.ingestion.pipeline import IngestionPipeline
|
| 21 |
from src.api.routes import router, set_dependencies
|
| 22 |
|
|
@@ -72,13 +73,14 @@ def create_app() -> FastAPI:
|
|
| 72 |
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 73 |
|
| 74 |
if settings.agent_mode == "react":
|
| 75 |
-
logger.info("Agent mode:
|
| 76 |
-
query_router: QueryRouter |
|
| 77 |
llm=llm,
|
| 78 |
hybrid_retriever=hybrid_retriever,
|
| 79 |
reranker=reranker,
|
| 80 |
vector_store=vector_store,
|
| 81 |
default_top_k=settings.top_k,
|
|
|
|
| 82 |
)
|
| 83 |
else:
|
| 84 |
logger.info("Agent mode: pipeline (fixed DAG)")
|
|
|
|
| 16 |
from src.retrieval.reranker import Reranker
|
| 17 |
from src.agent.intent_classifier import IntentClassifier
|
| 18 |
from src.agent.router import QueryRouter
|
| 19 |
+
from src.agent.plan_and_execute import PlanAndExecuteRouter
|
| 20 |
+
from src.agent.memory import ConversationMemory
|
| 21 |
from src.ingestion.pipeline import IngestionPipeline
|
| 22 |
from src.api.routes import router, set_dependencies
|
| 23 |
|
|
|
|
| 73 |
reranker = Reranker(model=create_reranker(settings.reranker_model))
|
| 74 |
|
| 75 |
if settings.agent_mode == "react":
|
| 76 |
+
logger.info("Agent mode: Plan-and-Execute (structured multi-step agent)")
|
| 77 |
+
query_router: QueryRouter | PlanAndExecuteRouter = PlanAndExecuteRouter(
|
| 78 |
llm=llm,
|
| 79 |
hybrid_retriever=hybrid_retriever,
|
| 80 |
reranker=reranker,
|
| 81 |
vector_store=vector_store,
|
| 82 |
default_top_k=settings.top_k,
|
| 83 |
+
memory=ConversationMemory(),
|
| 84 |
)
|
| 85 |
else:
|
| 86 |
logger.info("Agent mode: pipeline (fixed DAG)")
|
src/api/routes.py
CHANGED
|
@@ -15,6 +15,7 @@ from pydantic import BaseModel
|
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
from src.agent.router import QueryRouter
|
| 17 |
from src.agent.react_router import ReActRouter
|
|
|
|
| 18 |
from src.config import Settings
|
| 19 |
from src.ingestion.pipeline import IngestionPipeline
|
| 20 |
from src.retrieval.bm25_search import BM25Search
|
|
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|
| 25 |
|
| 26 |
router = APIRouter()
|
| 27 |
|
| 28 |
-
_query_router: "QueryRouter | ReActRouter | None" = None
|
| 29 |
_ingestion_pipeline: "IngestionPipeline | None" = None
|
| 30 |
_embedder: "Embedder | None" = None
|
| 31 |
_vector_store: "VectorStore | None" = None
|
|
@@ -34,7 +35,7 @@ _settings: "Settings | None" = None
|
|
| 34 |
|
| 35 |
|
| 36 |
def set_dependencies(
|
| 37 |
-
query_router: "QueryRouter | ReActRouter",
|
| 38 |
ingestion_pipeline: "IngestionPipeline",
|
| 39 |
embedder: "Embedder",
|
| 40 |
vector_store: "VectorStore",
|
|
@@ -75,6 +76,7 @@ class PipelineResultItem(BaseModel):
|
|
| 75 |
chunk_id: str
|
| 76 |
score: float
|
| 77 |
source: str
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
class PipelineDetailsResponse(BaseModel):
|
|
@@ -88,6 +90,8 @@ class PipelineDetailsResponse(BaseModel):
|
|
| 88 |
sparse_results: list[PipelineResultItem] = []
|
| 89 |
fused_results: list[PipelineResultItem] = []
|
| 90 |
reranked_results: list[PipelineResultItem] = []
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
class SourceItem(BaseModel):
|
|
@@ -206,6 +210,8 @@ async def query_documents(request: QueryRequest) -> QueryResponse:
|
|
| 206 |
sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results],
|
| 207 |
fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results],
|
| 208 |
reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results],
|
|
|
|
|
|
|
| 209 |
)
|
| 210 |
|
| 211 |
return QueryResponse(
|
|
|
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
from src.agent.router import QueryRouter
|
| 17 |
from src.agent.react_router import ReActRouter
|
| 18 |
+
from src.agent.plan_and_execute import PlanAndExecuteRouter
|
| 19 |
from src.config import Settings
|
| 20 |
from src.ingestion.pipeline import IngestionPipeline
|
| 21 |
from src.retrieval.bm25_search import BM25Search
|
|
|
|
| 26 |
|
| 27 |
router = APIRouter()
|
| 28 |
|
| 29 |
+
_query_router: "QueryRouter | ReActRouter | PlanAndExecuteRouter | None" = None
|
| 30 |
_ingestion_pipeline: "IngestionPipeline | None" = None
|
| 31 |
_embedder: "Embedder | None" = None
|
| 32 |
_vector_store: "VectorStore | None" = None
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def set_dependencies(
|
| 38 |
+
query_router: "QueryRouter | ReActRouter | PlanAndExecuteRouter",
|
| 39 |
ingestion_pipeline: "IngestionPipeline",
|
| 40 |
embedder: "Embedder",
|
| 41 |
vector_store: "VectorStore",
|
|
|
|
| 76 |
chunk_id: str
|
| 77 |
score: float
|
| 78 |
source: str
|
| 79 |
+
metadata: dict[str, str | int] = {}
|
| 80 |
|
| 81 |
|
| 82 |
class PipelineDetailsResponse(BaseModel):
|
|
|
|
| 90 |
sparse_results: list[PipelineResultItem] = []
|
| 91 |
fused_results: list[PipelineResultItem] = []
|
| 92 |
reranked_results: list[PipelineResultItem] = []
|
| 93 |
+
plan_steps: list[str] = []
|
| 94 |
+
tool_calls: list[str] = []
|
| 95 |
|
| 96 |
|
| 97 |
class SourceItem(BaseModel):
|
|
|
|
| 210 |
sparse_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.sparse_results],
|
| 211 |
fused_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.fused_results],
|
| 212 |
reranked_results=[PipelineResultItem(**r.to_dict(include_text=False)) for r in pd.reranked_results],
|
| 213 |
+
plan_steps=pd.plan_steps,
|
| 214 |
+
tool_calls=pd.tool_calls,
|
| 215 |
)
|
| 216 |
|
| 217 |
return QueryResponse(
|
src/models.py
CHANGED
|
@@ -91,6 +91,8 @@ class PipelineDetails:
|
|
| 91 |
sparse_results: Results from sparse (BM25) retrieval.
|
| 92 |
fused_results: Results after reciprocal rank fusion.
|
| 93 |
reranked_results: Results after cross-encoder reranking.
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
|
| 96 |
original_query: str = ""
|
|
@@ -101,6 +103,8 @@ class PipelineDetails:
|
|
| 101 |
sparse_results: list[QueryResult] = field(default_factory=list)
|
| 102 |
fused_results: list[QueryResult] = field(default_factory=list)
|
| 103 |
reranked_results: list[QueryResult] = field(default_factory=list)
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
@dataclass
|
|
|
|
| 91 |
sparse_results: Results from sparse (BM25) retrieval.
|
| 92 |
fused_results: Results after reciprocal rank fusion.
|
| 93 |
reranked_results: Results after cross-encoder reranking.
|
| 94 |
+
plan_steps: Ordered descriptions of planned steps (Plan-and-Execute mode).
|
| 95 |
+
tool_calls: Log of tool invocations as "tool_name: argument" strings.
|
| 96 |
"""
|
| 97 |
|
| 98 |
original_query: str = ""
|
|
|
|
| 103 |
sparse_results: list[QueryResult] = field(default_factory=list)
|
| 104 |
fused_results: list[QueryResult] = field(default_factory=list)
|
| 105 |
reranked_results: list[QueryResult] = field(default_factory=list)
|
| 106 |
+
plan_steps: list[str] = field(default_factory=list)
|
| 107 |
+
tool_calls: list[str] = field(default_factory=list)
|
| 108 |
|
| 109 |
|
| 110 |
@dataclass
|
src/ui/app.py
CHANGED
|
@@ -53,7 +53,7 @@ TEXTS: dict[str, dict[str, str]] = {
|
|
| 53 |
"- **LLM-integration** — provider-agnostisk, prompt-styret "
|
| 54 |
"svargenerering\n"
|
| 55 |
"- **Evaluering** — RAGAS-baseret kvalitetsmåling\n"
|
| 56 |
-
"- **Agent Flows** —
|
| 57 |
"- [**Kildedokumenter**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
|
| 58 |
" — de dokumenter systemet er indekseret fra"
|
| 59 |
),
|
|
@@ -67,8 +67,8 @@ TEXTS: dict[str, dict[str, str]] = {
|
|
| 67 |
"Et dokumentintelligens-system bygget på en RAG-arkitektur, dækkende file-indlæsning, semantisk chunking, "
|
| 68 |
"hybrid søgning med reranking "
|
| 69 |
"og LLM-genererede svar med kildehenvisninger. LLM-laget er provider-agnostisk. "
|
| 70 |
-
"To tilstande: en LangGraph
|
| 71 |
-
"og en pipeline til lette modeller
|
| 72 |
),
|
| 73 |
"search_label": "Stil et spørgsmål om ... ",
|
| 74 |
"search_placeholder": "F.eks.: Hvad er reglerne for behandling af personoplysninger?",
|
|
@@ -113,6 +113,8 @@ TEXTS: dict[str, dict[str, str]] = {
|
|
| 113 |
"pipeline_rank": "#",
|
| 114 |
"pipeline_no_results": "Ingen resultater",
|
| 115 |
"pipeline_score_change": "Score-ændring",
|
|
|
|
|
|
|
| 116 |
},
|
| 117 |
"en": {
|
| 118 |
"page_title": "Document Assistant",
|
|
@@ -131,7 +133,7 @@ TEXTS: dict[str, dict[str, str]] = {
|
|
| 131 |
"- **LLM integration** — provider-agnostic, prompt-driven "
|
| 132 |
"answer generation\n"
|
| 133 |
"- **Evaluation** — RAGAS-based quality measurement\n"
|
| 134 |
-
"- **Agent Flows** —
|
| 135 |
"- [**Source documents**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
|
| 136 |
" — the documents indexed into the knowledge base"
|
| 137 |
),
|
|
@@ -145,8 +147,8 @@ TEXTS: dict[str, dict[str, str]] = {
|
|
| 145 |
"A document intelligence system built on a RAG architecture, covering file ingestion, semantic chunking, "
|
| 146 |
"hybrid retrieval with reranking, "
|
| 147 |
"and LLM-generated answers with source citations. The LLM layer is provider-agnostic. "
|
| 148 |
-
"Two modes: a LangGraph
|
| 149 |
-
"and a pipeline for lightweight models
|
| 150 |
"Retrieval quality is evaluated with RAGAS."
|
| 151 |
),
|
| 152 |
"search_label": "Ask a question ...",
|
|
@@ -192,6 +194,8 @@ TEXTS: dict[str, dict[str, str]] = {
|
|
| 192 |
"pipeline_rank": "#",
|
| 193 |
"pipeline_no_results": "No results",
|
| 194 |
"pipeline_score_change": "Score change",
|
|
|
|
|
|
|
| 195 |
},
|
| 196 |
}
|
| 197 |
|
|
@@ -711,6 +715,32 @@ if search_clicked and question.strip():
|
|
| 711 |
else (f"Reranked to **{_rc}** results · confidence **{_cf:.0%}**")
|
| 712 |
)
|
| 713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
elif _step == "tool_call":
|
| 715 |
_tool_name = _event.get("tool", "")
|
| 716 |
_tool_query = _event.get("query", "")
|
|
@@ -850,6 +880,21 @@ if search_clicked and question.strip():
|
|
| 850 |
pd = data.get("pipeline_details", {})
|
| 851 |
if pd:
|
| 852 |
with st.expander(t["pipeline_heading"], expanded=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
# 1) Query translation (only show if translation actually happened)
|
| 854 |
if pd.get("translated"):
|
| 855 |
st.markdown(f'**{t["pipeline_translation"]}**')
|
|
|
|
| 53 |
"- **LLM-integration** — provider-agnostisk, prompt-styret "
|
| 54 |
"svargenerering\n"
|
| 55 |
"- **Evaluering** — RAGAS-baseret kvalitetsmåling\n"
|
| 56 |
+
"- **Agent Flows** — LangGraph Plan-and-Execute med værktøjskald og samtalehukommelse\n"
|
| 57 |
"- [**Kildedokumenter**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
|
| 58 |
" — de dokumenter systemet er indekseret fra"
|
| 59 |
),
|
|
|
|
| 67 |
"Et dokumentintelligens-system bygget på en RAG-arkitektur, dækkende file-indlæsning, semantisk chunking, "
|
| 68 |
"hybrid søgning med reranking "
|
| 69 |
"og LLM-genererede svar med kildehenvisninger. LLM-laget er provider-agnostisk. "
|
| 70 |
+
"To tilstande: en LangGraph Plan-and-Execute-agent (standard) med samtalehukommelse til komplekse forespørgsler, "
|
| 71 |
+
"og en foruddefineret pipeline til lette modeller. Søgekvaliteten evalueres med RAGAS."
|
| 72 |
),
|
| 73 |
"search_label": "Stil et spørgsmål om ... ",
|
| 74 |
"search_placeholder": "F.eks.: Hvad er reglerne for behandling af personoplysninger?",
|
|
|
|
| 113 |
"pipeline_rank": "#",
|
| 114 |
"pipeline_no_results": "Ingen resultater",
|
| 115 |
"pipeline_score_change": "Score-ændring",
|
| 116 |
+
"pipeline_plan_steps": "Udførelsesplan",
|
| 117 |
+
"pipeline_tool_calls": "Værktøjskald",
|
| 118 |
},
|
| 119 |
"en": {
|
| 120 |
"page_title": "Document Assistant",
|
|
|
|
| 133 |
"- **LLM integration** — provider-agnostic, prompt-driven "
|
| 134 |
"answer generation\n"
|
| 135 |
"- **Evaluation** — RAGAS-based quality measurement\n"
|
| 136 |
+
"- **Agent Flows** — LangGraph Plan-and-Execute with tool calling and conversation memory\n"
|
| 137 |
"- [**Source documents**](https://github.com/Xiiqiing/Dokumentassistent/tree/main/docs)"
|
| 138 |
" — the documents indexed into the knowledge base"
|
| 139 |
),
|
|
|
|
| 147 |
"A document intelligence system built on a RAG architecture, covering file ingestion, semantic chunking, "
|
| 148 |
"hybrid retrieval with reranking, "
|
| 149 |
"and LLM-generated answers with source citations. The LLM layer is provider-agnostic. "
|
| 150 |
+
"Two modes: a LangGraph Plan-and-Execute agent (default) with conversation memory for complex multi-step queries, "
|
| 151 |
+
"and a predefined pipeline for lightweight models. "
|
| 152 |
"Retrieval quality is evaluated with RAGAS."
|
| 153 |
),
|
| 154 |
"search_label": "Ask a question ...",
|
|
|
|
| 194 |
"pipeline_rank": "#",
|
| 195 |
"pipeline_no_results": "No results",
|
| 196 |
"pipeline_score_change": "Score change",
|
| 197 |
+
"pipeline_plan_steps": "Execution Plan",
|
| 198 |
+
"pipeline_tool_calls": "Tool Calls",
|
| 199 |
},
|
| 200 |
}
|
| 201 |
|
|
|
|
| 715 |
else (f"Reranked to **{_rc}** results · confidence **{_cf:.0%}**")
|
| 716 |
)
|
| 717 |
|
| 718 |
+
elif _step == "plan":
|
| 719 |
+
_steps = _event.get("steps", [])
|
| 720 |
+
st.write(
|
| 721 |
+
(f"Plan oprettet med **{len(_steps)}** trin")
|
| 722 |
+
if lang == "da"
|
| 723 |
+
else (f"Plan created with **{len(_steps)}** steps")
|
| 724 |
+
)
|
| 725 |
+
for _ps in _steps:
|
| 726 |
+
st.write(f" - {_ps}")
|
| 727 |
+
|
| 728 |
+
elif _step == "execute_step":
|
| 729 |
+
_si = _event.get("step_index", 0)
|
| 730 |
+
_sd = _event.get("step_desc", "")
|
| 731 |
+
st.write(
|
| 732 |
+
(f"Trin {_si} udført: _{_sd}_")
|
| 733 |
+
if lang == "da"
|
| 734 |
+
else (f"Step {_si} executed: _{_sd}_")
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
elif _step == "synthesize":
|
| 738 |
+
st.write(
|
| 739 |
+
"Syntetiserer endeligt svar ..."
|
| 740 |
+
if lang == "da"
|
| 741 |
+
else "Synthesizing final answer ..."
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
elif _step == "tool_call":
|
| 745 |
_tool_name = _event.get("tool", "")
|
| 746 |
_tool_query = _event.get("query", "")
|
|
|
|
| 880 |
pd = data.get("pipeline_details", {})
|
| 881 |
if pd:
|
| 882 |
with st.expander(t["pipeline_heading"], expanded=False):
|
| 883 |
+
# 0) Plan steps and tool calls (Plan-and-Execute mode)
|
| 884 |
+
plan_steps = pd.get("plan_steps", [])
|
| 885 |
+
if plan_steps:
|
| 886 |
+
st.markdown(f'**{t["pipeline_plan_steps"]}**')
|
| 887 |
+
for i, step in enumerate(plan_steps, 1):
|
| 888 |
+
st.markdown(f"{i}. {step}")
|
| 889 |
+
st.markdown("---")
|
| 890 |
+
|
| 891 |
+
tool_calls = pd.get("tool_calls", [])
|
| 892 |
+
if tool_calls:
|
| 893 |
+
st.markdown(f'**{t["pipeline_tool_calls"]}**')
|
| 894 |
+
for tc in tool_calls:
|
| 895 |
+
st.markdown(f"- `{tc}`")
|
| 896 |
+
st.markdown("---")
|
| 897 |
+
|
| 898 |
# 1) Query translation (only show if translation actually happened)
|
| 899 |
if pd.get("translated"):
|
| 900 |
st.markdown(f'**{t["pipeline_translation"]}**')
|
tests/test_memory.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for conversation memory."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from src.agent.memory import ConversationMemory, Turn
|
| 6 |
+
from src.models import DocumentChunk, QueryResult
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# ---------------------------------------------------------------------------
|
| 10 |
+
# Helpers
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
def _qr(chunk_id: str = "c1", doc_id: str = "doc.pdf", score: float = 0.8) -> QueryResult:
|
| 14 |
+
chunk = DocumentChunk(
|
| 15 |
+
chunk_id=chunk_id, document_id=doc_id, text="text",
|
| 16 |
+
metadata={"page_number": 1},
|
| 17 |
+
)
|
| 18 |
+
return QueryResult(chunk=chunk, score=score, source="test")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Basic operations
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
class TestConversationMemory:
|
| 26 |
+
def test_initially_empty(self) -> None:
|
| 27 |
+
mem = ConversationMemory()
|
| 28 |
+
assert mem.is_empty
|
| 29 |
+
assert mem.turns == []
|
| 30 |
+
assert mem.last_query() == ""
|
| 31 |
+
assert mem.last_sources() == []
|
| 32 |
+
|
| 33 |
+
def test_add_turn(self) -> None:
|
| 34 |
+
mem = ConversationMemory()
|
| 35 |
+
mem.add_turn("What is X?", "X is Y.", [_qr()])
|
| 36 |
+
assert not mem.is_empty
|
| 37 |
+
assert len(mem.turns) == 1
|
| 38 |
+
assert mem.last_query() == "What is X?"
|
| 39 |
+
|
| 40 |
+
def test_multiple_turns(self) -> None:
|
| 41 |
+
mem = ConversationMemory()
|
| 42 |
+
mem.add_turn("Q1", "A1")
|
| 43 |
+
mem.add_turn("Q2", "A2")
|
| 44 |
+
assert len(mem.turns) == 2
|
| 45 |
+
assert mem.last_query() == "Q2"
|
| 46 |
+
|
| 47 |
+
def test_clear(self) -> None:
|
| 48 |
+
mem = ConversationMemory()
|
| 49 |
+
mem.add_turn("Q1", "A1")
|
| 50 |
+
mem.clear()
|
| 51 |
+
assert mem.is_empty
|
| 52 |
+
|
| 53 |
+
def test_turns_returns_copy(self) -> None:
|
| 54 |
+
mem = ConversationMemory()
|
| 55 |
+
mem.add_turn("Q1", "A1")
|
| 56 |
+
turns = mem.turns
|
| 57 |
+
turns.append(Turn(query="fake", answer="fake"))
|
| 58 |
+
assert len(mem.turns) == 1 # original unaffected
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Eviction
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
class TestEviction:
|
| 66 |
+
def test_max_turns_eviction(self) -> None:
|
| 67 |
+
mem = ConversationMemory(max_turns=3)
|
| 68 |
+
for i in range(5):
|
| 69 |
+
mem.add_turn(f"Q{i}", f"A{i}")
|
| 70 |
+
assert len(mem.turns) == 3
|
| 71 |
+
# Oldest should be Q2 (Q0 and Q1 evicted)
|
| 72 |
+
assert mem.turns[0].query == "Q2"
|
| 73 |
+
|
| 74 |
+
def test_max_turns_one(self) -> None:
|
| 75 |
+
mem = ConversationMemory(max_turns=1)
|
| 76 |
+
mem.add_turn("Q1", "A1")
|
| 77 |
+
mem.add_turn("Q2", "A2")
|
| 78 |
+
assert len(mem.turns) == 1
|
| 79 |
+
assert mem.turns[0].query == "Q2"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# format_history
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
class TestFormatHistory:
|
| 87 |
+
def test_empty_history(self) -> None:
|
| 88 |
+
mem = ConversationMemory()
|
| 89 |
+
assert mem.format_history() == ""
|
| 90 |
+
|
| 91 |
+
def test_includes_query_and_answer(self) -> None:
|
| 92 |
+
mem = ConversationMemory()
|
| 93 |
+
mem.add_turn("What is X?", "X is a policy.")
|
| 94 |
+
text = mem.format_history()
|
| 95 |
+
assert "What is X?" in text
|
| 96 |
+
assert "X is a policy." in text
|
| 97 |
+
|
| 98 |
+
def test_includes_source_doc_ids(self) -> None:
|
| 99 |
+
mem = ConversationMemory()
|
| 100 |
+
sources = [_qr(doc_id="policy.pdf"), _qr(chunk_id="c2", doc_id="rules.pdf")]
|
| 101 |
+
mem.add_turn("Q", "A", sources)
|
| 102 |
+
text = mem.format_history()
|
| 103 |
+
assert "policy.pdf" in text
|
| 104 |
+
assert "rules.pdf" in text
|
| 105 |
+
|
| 106 |
+
def test_max_recent_limits_output(self) -> None:
|
| 107 |
+
mem = ConversationMemory()
|
| 108 |
+
for i in range(10):
|
| 109 |
+
mem.add_turn(f"Q{i}", f"A{i}")
|
| 110 |
+
text = mem.format_history(max_recent=2)
|
| 111 |
+
assert "Q8" in text
|
| 112 |
+
assert "Q9" in text
|
| 113 |
+
assert "Q0" not in text
|
| 114 |
+
|
| 115 |
+
def test_long_answer_truncated(self) -> None:
|
| 116 |
+
mem = ConversationMemory()
|
| 117 |
+
mem.add_turn("Q", "x" * 1000)
|
| 118 |
+
text = mem.format_history()
|
| 119 |
+
# Answer should be truncated to 500 chars
|
| 120 |
+
assert len(text) < 1000
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
# get_prior_sources
|
| 125 |
+
# ---------------------------------------------------------------------------
|
| 126 |
+
|
| 127 |
+
class TestGetPriorSources:
|
| 128 |
+
def test_empty_returns_empty(self) -> None:
|
| 129 |
+
mem = ConversationMemory()
|
| 130 |
+
assert mem.get_prior_sources() == []
|
| 131 |
+
|
| 132 |
+
def test_collects_across_turns(self) -> None:
|
| 133 |
+
mem = ConversationMemory()
|
| 134 |
+
mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.8)])
|
| 135 |
+
mem.add_turn("Q2", "A2", [_qr(chunk_id="c2", score=0.9)])
|
| 136 |
+
sources = mem.get_prior_sources()
|
| 137 |
+
assert len(sources) == 2
|
| 138 |
+
# Sorted by score descending
|
| 139 |
+
assert sources[0].score == 0.9
|
| 140 |
+
|
| 141 |
+
def test_deduplicates_by_chunk_id(self) -> None:
|
| 142 |
+
mem = ConversationMemory()
|
| 143 |
+
mem.add_turn("Q1", "A1", [_qr(chunk_id="c1", score=0.5)])
|
| 144 |
+
mem.add_turn("Q2", "A2", [_qr(chunk_id="c1", score=0.9)])
|
| 145 |
+
sources = mem.get_prior_sources()
|
| 146 |
+
assert len(sources) == 1
|
| 147 |
+
assert sources[0].score == 0.9 # keeps higher score
|
| 148 |
+
|
| 149 |
+
def test_no_sources_turns(self) -> None:
|
| 150 |
+
mem = ConversationMemory()
|
| 151 |
+
mem.add_turn("Q1", "A1") # no sources
|
| 152 |
+
assert mem.get_prior_sources() == []
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# Integration: memory in PlanAndExecuteRouter
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
class TestMemoryIntegration:
|
| 160 |
+
def test_route_records_turn(self) -> None:
|
| 161 |
+
"""After route(), the conversation turn should be recorded in memory."""
|
| 162 |
+
from unittest.mock import MagicMock, patch
|
| 163 |
+
from langchain_core.messages import AIMessage
|
| 164 |
+
from src.agent.plan_and_execute import PlanAndExecuteRouter
|
| 165 |
+
|
| 166 |
+
llm = MagicMock()
|
| 167 |
+
retriever = MagicMock()
|
| 168 |
+
reranker = MagicMock()
|
| 169 |
+
vector_store = MagicMock()
|
| 170 |
+
memory = ConversationMemory()
|
| 171 |
+
|
| 172 |
+
plan_json = '[{"action": "search", "detail": "test"}]'
|
| 173 |
+
llm.invoke.side_effect = [plan_json, "The answer."]
|
| 174 |
+
|
| 175 |
+
mock_agent = MagicMock()
|
| 176 |
+
mock_agent.invoke.return_value = {"messages": [AIMessage(content="Found info.")]}
|
| 177 |
+
|
| 178 |
+
router = PlanAndExecuteRouter(
|
| 179 |
+
llm, retriever, reranker, vector_store, memory=memory,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 183 |
+
router.route("test question", top_k=5)
|
| 184 |
+
|
| 185 |
+
assert not memory.is_empty
|
| 186 |
+
assert memory.last_query() == "test question"
|
| 187 |
+
assert memory.turns[0].answer == "The answer."
|
| 188 |
+
|
| 189 |
+
def test_history_injected_into_planner(self) -> None:
|
| 190 |
+
"""On a follow-up query, conversation history should appear in the planner prompt."""
|
| 191 |
+
from unittest.mock import MagicMock, patch
|
| 192 |
+
from langchain_core.messages import AIMessage
|
| 193 |
+
from src.agent.plan_and_execute import PlanAndExecuteRouter
|
| 194 |
+
|
| 195 |
+
llm = MagicMock()
|
| 196 |
+
memory = ConversationMemory()
|
| 197 |
+
memory.add_turn("What is the exam policy?", "The exam policy says...")
|
| 198 |
+
|
| 199 |
+
plan_json = '[{"action": "search", "detail": "follow-up"}]'
|
| 200 |
+
llm.invoke.side_effect = [plan_json, "Follow-up answer."]
|
| 201 |
+
|
| 202 |
+
mock_agent = MagicMock()
|
| 203 |
+
mock_agent.invoke.return_value = {"messages": [AIMessage(content="More info.")]}
|
| 204 |
+
|
| 205 |
+
router = PlanAndExecuteRouter(
|
| 206 |
+
llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 210 |
+
router.route("What about the grading?", top_k=5)
|
| 211 |
+
|
| 212 |
+
# The first LLM call is the planner — check it includes history
|
| 213 |
+
planner_prompt = llm.invoke.call_args_list[0][0][0]
|
| 214 |
+
assert "exam policy" in planner_prompt
|
| 215 |
+
assert "Conversation history" in planner_prompt
|
| 216 |
+
|
| 217 |
+
def test_multi_turn_accumulates(self) -> None:
|
| 218 |
+
"""Multiple route() calls should accumulate turns in memory."""
|
| 219 |
+
from unittest.mock import MagicMock, patch
|
| 220 |
+
from langchain_core.messages import AIMessage
|
| 221 |
+
from src.agent.plan_and_execute import PlanAndExecuteRouter
|
| 222 |
+
|
| 223 |
+
llm = MagicMock()
|
| 224 |
+
memory = ConversationMemory()
|
| 225 |
+
|
| 226 |
+
mock_agent = MagicMock()
|
| 227 |
+
mock_agent.invoke.return_value = {"messages": [AIMessage(content="info")]}
|
| 228 |
+
|
| 229 |
+
router = PlanAndExecuteRouter(
|
| 230 |
+
llm, MagicMock(), MagicMock(), MagicMock(), memory=memory,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
for i in range(3):
|
| 234 |
+
plan_json = f'[{{"action": "search", "detail": "q{i}"}}]'
|
| 235 |
+
llm.invoke.side_effect = [plan_json, f"Answer {i}"]
|
| 236 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 237 |
+
router.route(f"Question {i}", top_k=5)
|
| 238 |
+
|
| 239 |
+
assert len(memory.turns) == 3
|
tests/test_plan_and_execute.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the Plan-and-Execute agent router."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.agent.plan_and_execute import (
|
| 9 |
+
PlanAndExecuteRouter,
|
| 10 |
+
PlanExecState,
|
| 11 |
+
PlanStep,
|
| 12 |
+
_extract_last_ai_text,
|
| 13 |
+
_parse_plan,
|
| 14 |
+
)
|
| 15 |
+
from src.models import (
|
| 16 |
+
DocumentChunk,
|
| 17 |
+
GenerationResponse,
|
| 18 |
+
IntentType,
|
| 19 |
+
QueryResult,
|
| 20 |
+
)
|
| 21 |
+
from src.retrieval.hybrid import HybridSearchResult
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Helpers
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def _chunk(chunk_id: str = "c1", text: str = "text") -> DocumentChunk:
|
| 29 |
+
return DocumentChunk(
|
| 30 |
+
chunk_id=chunk_id, document_id="doc.pdf", text=text,
|
| 31 |
+
metadata={"page_number": 1, "chunk_index": 0},
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _qr(chunk_id: str = "c1", score: float = 0.8, text: str = "text") -> QueryResult:
|
| 36 |
+
return QueryResult(chunk=_chunk(chunk_id=chunk_id, text=text), score=score, source="hybrid")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _hybrid_result(results: list[QueryResult]) -> HybridSearchResult:
|
| 40 |
+
return HybridSearchResult(
|
| 41 |
+
dense_results=results, sparse_results=results, fused_results=results,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# _parse_plan
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
class TestParsePlan:
|
| 50 |
+
def test_valid_json(self) -> None:
|
| 51 |
+
raw = '[{"action": "search", "detail": "exam rules"}]'
|
| 52 |
+
steps = _parse_plan(raw)
|
| 53 |
+
assert len(steps) == 1
|
| 54 |
+
assert steps[0]["action"] == "search"
|
| 55 |
+
assert steps[0]["detail"] == "exam rules"
|
| 56 |
+
|
| 57 |
+
def test_multiple_steps(self) -> None:
|
| 58 |
+
raw = json.dumps([
|
| 59 |
+
{"action": "search", "detail": "policy A"},
|
| 60 |
+
{"action": "search", "detail": "policy B"},
|
| 61 |
+
{"action": "summarize", "detail": "doc.pdf"},
|
| 62 |
+
])
|
| 63 |
+
steps = _parse_plan(raw)
|
| 64 |
+
assert len(steps) == 3
|
| 65 |
+
|
| 66 |
+
def test_markdown_fenced(self) -> None:
|
| 67 |
+
raw = '```json\n[{"action": "search", "detail": "test"}]\n```'
|
| 68 |
+
steps = _parse_plan(raw)
|
| 69 |
+
assert len(steps) == 1
|
| 70 |
+
assert steps[0]["action"] == "search"
|
| 71 |
+
|
| 72 |
+
def test_json_with_surrounding_text(self) -> None:
|
| 73 |
+
raw = 'Here is the plan:\n[{"action": "search", "detail": "x"}]\nDone.'
|
| 74 |
+
steps = _parse_plan(raw)
|
| 75 |
+
assert len(steps) == 1
|
| 76 |
+
|
| 77 |
+
def test_invalid_json_falls_back(self) -> None:
|
| 78 |
+
raw = "this is not json at all"
|
| 79 |
+
steps = _parse_plan(raw)
|
| 80 |
+
assert len(steps) == 1
|
| 81 |
+
assert steps[0]["action"] == "search"
|
| 82 |
+
|
| 83 |
+
def test_empty_array_falls_back(self) -> None:
|
| 84 |
+
raw = "[]"
|
| 85 |
+
steps = _parse_plan(raw)
|
| 86 |
+
assert len(steps) == 1 # fallback to single search
|
| 87 |
+
|
| 88 |
+
def test_malformed_items_skipped(self) -> None:
|
| 89 |
+
raw = json.dumps([
|
| 90 |
+
{"action": "search", "detail": "good"},
|
| 91 |
+
{"bad": "step"},
|
| 92 |
+
{"action": "search", "detail": "also good"},
|
| 93 |
+
])
|
| 94 |
+
steps = _parse_plan(raw)
|
| 95 |
+
assert len(steps) == 2
|
| 96 |
+
|
| 97 |
+
def test_non_list_wrapped(self) -> None:
|
| 98 |
+
raw = '{"action": "search", "detail": "test"}'
|
| 99 |
+
steps = _parse_plan(raw)
|
| 100 |
+
assert len(steps) == 1
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
# _extract_last_ai_text
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
class TestExtractLastAIText:
|
| 108 |
+
def test_returns_last_ai_message(self) -> None:
|
| 109 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 110 |
+
messages = [
|
| 111 |
+
HumanMessage(content="question"),
|
| 112 |
+
AIMessage(content="first"),
|
| 113 |
+
AIMessage(content="second"),
|
| 114 |
+
]
|
| 115 |
+
assert _extract_last_ai_text(messages) == "second"
|
| 116 |
+
|
| 117 |
+
def test_skips_tool_calls(self) -> None:
|
| 118 |
+
from langchain_core.messages import AIMessage
|
| 119 |
+
msg_with_tools = AIMessage(content="calling tool", tool_calls=[{"name": "t", "args": {}, "id": "1"}])
|
| 120 |
+
msg_final = AIMessage(content="the answer")
|
| 121 |
+
assert _extract_last_ai_text([msg_with_tools, msg_final]) == "the answer"
|
| 122 |
+
|
| 123 |
+
def test_empty_messages(self) -> None:
|
| 124 |
+
assert _extract_last_ai_text([]) == ""
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# PlanAndExecuteRouter — plan node
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
class TestPlanNode:
|
| 132 |
+
def test_plan_node_generates_steps(self) -> None:
|
| 133 |
+
llm = MagicMock()
|
| 134 |
+
llm.invoke.return_value = '[{"action": "search", "detail": "KU regler"}]'
|
| 135 |
+
|
| 136 |
+
router = PlanAndExecuteRouter(
|
| 137 |
+
llm=llm,
|
| 138 |
+
hybrid_retriever=MagicMock(),
|
| 139 |
+
reranker=MagicMock(),
|
| 140 |
+
vector_store=MagicMock(),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
state = PlanExecState(
|
| 144 |
+
query="What are the rules?",
|
| 145 |
+
top_k=5, plan=[], step_index=0, step_results=[], answer="",
|
| 146 |
+
)
|
| 147 |
+
result = router._plan_node(state)
|
| 148 |
+
assert len(result["plan"]) == 1
|
| 149 |
+
assert result["plan"][0]["action"] == "search"
|
| 150 |
+
assert result["step_index"] == 0
|
| 151 |
+
|
| 152 |
+
def test_plan_node_handles_bad_llm_output(self) -> None:
|
| 153 |
+
llm = MagicMock()
|
| 154 |
+
llm.invoke.return_value = "I cannot produce JSON"
|
| 155 |
+
|
| 156 |
+
router = PlanAndExecuteRouter(
|
| 157 |
+
llm=llm,
|
| 158 |
+
hybrid_retriever=MagicMock(),
|
| 159 |
+
reranker=MagicMock(),
|
| 160 |
+
vector_store=MagicMock(),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
state = PlanExecState(
|
| 164 |
+
query="test", top_k=5, plan=[], step_index=0, step_results=[], answer="",
|
| 165 |
+
)
|
| 166 |
+
result = router._plan_node(state)
|
| 167 |
+
assert len(result["plan"]) >= 1 # fallback plan
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# PlanAndExecuteRouter — should_execute
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
class TestShouldExecute:
|
| 175 |
+
def test_more_steps_returns_execute(self) -> None:
|
| 176 |
+
state = PlanExecState(
|
| 177 |
+
query="q", top_k=5,
|
| 178 |
+
plan=[PlanStep(action="search", detail="x")],
|
| 179 |
+
step_index=0, step_results=[], answer="",
|
| 180 |
+
)
|
| 181 |
+
assert PlanAndExecuteRouter._should_execute(state) == "execute"
|
| 182 |
+
|
| 183 |
+
def test_all_steps_done_returns_synthesize(self) -> None:
|
| 184 |
+
state = PlanExecState(
|
| 185 |
+
query="q", top_k=5,
|
| 186 |
+
plan=[PlanStep(action="search", detail="x")],
|
| 187 |
+
step_index=1, step_results=[], answer="",
|
| 188 |
+
)
|
| 189 |
+
assert PlanAndExecuteRouter._should_execute(state) == "synthesize"
|
| 190 |
+
|
| 191 |
+
def test_empty_plan_returns_synthesize(self) -> None:
|
| 192 |
+
state = PlanExecState(
|
| 193 |
+
query="q", top_k=5, plan=[], step_index=0, step_results=[], answer="",
|
| 194 |
+
)
|
| 195 |
+
assert PlanAndExecuteRouter._should_execute(state) == "synthesize"
|
| 196 |
+
|
| 197 |
+
def test_max_steps_cap(self) -> None:
|
| 198 |
+
"""Step index at _MAX_STEPS should stop execution."""
|
| 199 |
+
state = PlanExecState(
|
| 200 |
+
query="q", top_k=5,
|
| 201 |
+
plan=[PlanStep(action="search", detail=f"q{i}") for i in range(10)],
|
| 202 |
+
step_index=6, # == _MAX_STEPS
|
| 203 |
+
step_results=[], answer="",
|
| 204 |
+
)
|
| 205 |
+
assert PlanAndExecuteRouter._should_execute(state) == "synthesize"
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
# PlanAndExecuteRouter — synthesize node
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
class TestSynthesizeNode:
|
| 213 |
+
def test_synthesize_combines_results(self) -> None:
|
| 214 |
+
llm = MagicMock()
|
| 215 |
+
llm.invoke.return_value = "Combined answer about exams."
|
| 216 |
+
|
| 217 |
+
router = PlanAndExecuteRouter(
|
| 218 |
+
llm=llm,
|
| 219 |
+
hybrid_retriever=MagicMock(),
|
| 220 |
+
reranker=MagicMock(),
|
| 221 |
+
vector_store=MagicMock(),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
state = PlanExecState(
|
| 225 |
+
query="exam rules",
|
| 226 |
+
top_k=5, plan=[],
|
| 227 |
+
step_index=2,
|
| 228 |
+
step_results=[
|
| 229 |
+
("search: exam bachelor", "Found bachelor exam rules..."),
|
| 230 |
+
("search: exam master", "Found master exam rules..."),
|
| 231 |
+
],
|
| 232 |
+
answer="",
|
| 233 |
+
)
|
| 234 |
+
result = router._synthesize_node(state)
|
| 235 |
+
assert result["answer"] == "Combined answer about exams."
|
| 236 |
+
|
| 237 |
+
# Verify prompt includes both step results
|
| 238 |
+
prompt = llm.invoke.call_args[0][0]
|
| 239 |
+
assert "bachelor exam rules" in prompt
|
| 240 |
+
assert "master exam rules" in prompt
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
# PlanAndExecuteRouter — full route (integration with mocks)
|
| 245 |
+
# ---------------------------------------------------------------------------
|
| 246 |
+
|
| 247 |
+
class TestFullRoute:
|
| 248 |
+
def test_route_produces_response(self) -> None:
|
| 249 |
+
"""Full route with mocked LLM and retrieval components."""
|
| 250 |
+
llm = MagicMock()
|
| 251 |
+
retriever = MagicMock()
|
| 252 |
+
reranker = MagicMock()
|
| 253 |
+
vector_store = MagicMock()
|
| 254 |
+
|
| 255 |
+
# Plan: single search step
|
| 256 |
+
plan_json = '[{"action": "search", "detail": "test query"}]'
|
| 257 |
+
# Sub-agent answer after executing step
|
| 258 |
+
from langchain_core.messages import AIMessage
|
| 259 |
+
sub_agent_result = {"messages": [AIMessage(content="Found relevant info about test.")]}
|
| 260 |
+
# Final synthesis
|
| 261 |
+
final_answer = "The test policy states..."
|
| 262 |
+
|
| 263 |
+
# LLM calls: plan, executor system/tools, synthesis
|
| 264 |
+
# We mock the LLM and also mock the sub-agent creation
|
| 265 |
+
llm.invoke.side_effect = [plan_json, final_answer]
|
| 266 |
+
|
| 267 |
+
results = [_qr(chunk_id="c1", score=0.9, text="test policy")]
|
| 268 |
+
retriever.search_detailed.return_value = _hybrid_result(results)
|
| 269 |
+
reranker.rerank.return_value = results
|
| 270 |
+
vector_store.list_document_ids.return_value = ["doc.pdf"]
|
| 271 |
+
|
| 272 |
+
router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
|
| 273 |
+
|
| 274 |
+
# Patch create_react_agent to return a mock that returns our sub_agent_result
|
| 275 |
+
mock_agent = MagicMock()
|
| 276 |
+
mock_agent.invoke.return_value = sub_agent_result
|
| 277 |
+
|
| 278 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 279 |
+
response = router.route("test question", top_k=5)
|
| 280 |
+
|
| 281 |
+
assert isinstance(response, GenerationResponse)
|
| 282 |
+
assert response.answer == "The test policy states..."
|
| 283 |
+
|
| 284 |
+
def test_route_with_no_results(self) -> None:
|
| 285 |
+
"""Route when retrieval finds nothing."""
|
| 286 |
+
llm = MagicMock()
|
| 287 |
+
retriever = MagicMock()
|
| 288 |
+
reranker = MagicMock()
|
| 289 |
+
vector_store = MagicMock()
|
| 290 |
+
|
| 291 |
+
plan_json = '[{"action": "search", "detail": "nonexistent"}]'
|
| 292 |
+
from langchain_core.messages import AIMessage
|
| 293 |
+
sub_agent_result = {"messages": [AIMessage(content="No relevant documents found.")]}
|
| 294 |
+
|
| 295 |
+
llm.invoke.side_effect = [plan_json, "I could not find information."]
|
| 296 |
+
|
| 297 |
+
mock_agent = MagicMock()
|
| 298 |
+
mock_agent.invoke.return_value = sub_agent_result
|
| 299 |
+
|
| 300 |
+
router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
|
| 301 |
+
|
| 302 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 303 |
+
response = router.route("nonexistent topic", top_k=5)
|
| 304 |
+
|
| 305 |
+
assert response.intent == IntentType.FACTUAL
|
| 306 |
+
assert response.confidence == 0.0
|
| 307 |
+
|
| 308 |
+
def test_route_multi_step(self) -> None:
|
| 309 |
+
"""Route with a multi-step plan."""
|
| 310 |
+
llm = MagicMock()
|
| 311 |
+
retriever = MagicMock()
|
| 312 |
+
reranker = MagicMock()
|
| 313 |
+
vector_store = MagicMock()
|
| 314 |
+
|
| 315 |
+
plan_json = json.dumps([
|
| 316 |
+
{"action": "search", "detail": "policy A"},
|
| 317 |
+
{"action": "search", "detail": "policy B"},
|
| 318 |
+
])
|
| 319 |
+
from langchain_core.messages import AIMessage
|
| 320 |
+
sub_result_1 = {"messages": [AIMessage(content="Policy A info")]}
|
| 321 |
+
sub_result_2 = {"messages": [AIMessage(content="Policy B info")]}
|
| 322 |
+
|
| 323 |
+
llm.invoke.side_effect = [plan_json, "Comparison of A and B."]
|
| 324 |
+
|
| 325 |
+
mock_agent = MagicMock()
|
| 326 |
+
mock_agent.invoke.side_effect = [sub_result_1, sub_result_2]
|
| 327 |
+
|
| 328 |
+
router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
|
| 329 |
+
|
| 330 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 331 |
+
response = router.route("Compare A and B", top_k=5)
|
| 332 |
+
|
| 333 |
+
assert response.answer == "Comparison of A and B."
|
| 334 |
+
# Sub-agent should have been called twice
|
| 335 |
+
assert mock_agent.invoke.call_count == 2
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ---------------------------------------------------------------------------
|
| 339 |
+
# PlanAndExecuteRouter — route_stream
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
|
| 342 |
+
class TestRouteStream:
|
| 343 |
+
def test_stream_yields_plan_execute_synthesize_done(self) -> None:
|
| 344 |
+
"""Streaming should yield events in order: plan, execute, synthesize, done."""
|
| 345 |
+
llm = MagicMock()
|
| 346 |
+
retriever = MagicMock()
|
| 347 |
+
reranker = MagicMock()
|
| 348 |
+
vector_store = MagicMock()
|
| 349 |
+
|
| 350 |
+
plan_json = '[{"action": "search", "detail": "test"}]'
|
| 351 |
+
from langchain_core.messages import AIMessage
|
| 352 |
+
sub_agent_result = {"messages": [AIMessage(content="Found info.")]}
|
| 353 |
+
|
| 354 |
+
llm.invoke.side_effect = [plan_json, "Final answer."]
|
| 355 |
+
|
| 356 |
+
mock_agent = MagicMock()
|
| 357 |
+
mock_agent.invoke.return_value = sub_agent_result
|
| 358 |
+
|
| 359 |
+
router = PlanAndExecuteRouter(llm, retriever, reranker, vector_store)
|
| 360 |
+
|
| 361 |
+
with patch("src.agent.plan_and_execute.create_react_agent", return_value=mock_agent):
|
| 362 |
+
events = list(router.route_stream("test", top_k=5))
|
| 363 |
+
|
| 364 |
+
step_names = [e["step"] for e in events]
|
| 365 |
+
assert "plan" in step_names
|
| 366 |
+
assert "done" in step_names
|
| 367 |
+
# done event has the result
|
| 368 |
+
done_event = [e for e in events if e["step"] == "done"][0]
|
| 369 |
+
assert "result" in done_event
|
| 370 |
+
assert done_event["result"]["answer"] == "Final answer."
|
tests/test_tools.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for agent tools (hybrid_search, list_documents, fetch_document,
|
| 2 |
+
search_within_document, multi_query_search, summarize_document)."""
|
| 3 |
+
|
| 4 |
+
from unittest.mock import MagicMock
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.agent.tools import ToolResultStore, make_retrieval_tools, _merge_results, _format_results
|
| 9 |
+
from src.models import DocumentChunk, QueryResult
|
| 10 |
+
from src.retrieval.hybrid import HybridSearchResult
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Helpers
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
def _chunk(chunk_id: str = "c1", document_id: str = "doc.pdf", text: str = "text",
|
| 18 |
+
page_number: int = 1, chunk_index: int = 0) -> DocumentChunk:
|
| 19 |
+
return DocumentChunk(
|
| 20 |
+
chunk_id=chunk_id,
|
| 21 |
+
document_id=document_id,
|
| 22 |
+
text=text,
|
| 23 |
+
metadata={"page_number": page_number, "chunk_index": chunk_index},
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _qr(chunk_id: str = "c1", document_id: str = "doc.pdf", text: str = "text",
|
| 28 |
+
score: float = 0.8, source: str = "hybrid", page_number: int = 1) -> QueryResult:
|
| 29 |
+
return QueryResult(
|
| 30 |
+
chunk=_chunk(chunk_id=chunk_id, document_id=document_id, text=text, page_number=page_number),
|
| 31 |
+
score=score,
|
| 32 |
+
source=source,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _hybrid_result(results: list[QueryResult]) -> HybridSearchResult:
|
| 37 |
+
return HybridSearchResult(
|
| 38 |
+
dense_results=results,
|
| 39 |
+
sparse_results=results,
|
| 40 |
+
fused_results=results,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.fixture
|
| 45 |
+
def components():
|
| 46 |
+
"""Create mock retriever, reranker, vector_store, and store."""
|
| 47 |
+
retriever = MagicMock()
|
| 48 |
+
reranker = MagicMock()
|
| 49 |
+
vector_store = MagicMock()
|
| 50 |
+
store = ToolResultStore()
|
| 51 |
+
return retriever, reranker, vector_store, store
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Unit tests for helper functions
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
class TestMergeResults:
|
| 59 |
+
def test_merge_empty(self) -> None:
|
| 60 |
+
assert _merge_results([], []) == []
|
| 61 |
+
|
| 62 |
+
def test_merge_keeps_higher_score(self) -> None:
|
| 63 |
+
old = [_qr(chunk_id="c1", score=0.5)]
|
| 64 |
+
new = [_qr(chunk_id="c1", score=0.9)]
|
| 65 |
+
merged = _merge_results(old, new)
|
| 66 |
+
assert len(merged) == 1
|
| 67 |
+
assert merged[0].score == 0.9
|
| 68 |
+
|
| 69 |
+
def test_merge_keeps_old_if_higher(self) -> None:
|
| 70 |
+
old = [_qr(chunk_id="c1", score=0.9)]
|
| 71 |
+
new = [_qr(chunk_id="c1", score=0.5)]
|
| 72 |
+
merged = _merge_results(old, new)
|
| 73 |
+
assert merged[0].score == 0.9
|
| 74 |
+
|
| 75 |
+
def test_merge_combines_different_ids(self) -> None:
|
| 76 |
+
old = [_qr(chunk_id="c1", score=0.5)]
|
| 77 |
+
new = [_qr(chunk_id="c2", score=0.9)]
|
| 78 |
+
merged = _merge_results(old, new)
|
| 79 |
+
assert len(merged) == 2
|
| 80 |
+
assert merged[0].chunk.chunk_id == "c2" # higher score first
|
| 81 |
+
|
| 82 |
+
def test_merge_sorted_descending(self) -> None:
|
| 83 |
+
results = [_qr(chunk_id=f"c{i}", score=s) for i, s in enumerate([0.3, 0.9, 0.6])]
|
| 84 |
+
merged = _merge_results([], results)
|
| 85 |
+
scores = [r.score for r in merged]
|
| 86 |
+
assert scores == sorted(scores, reverse=True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TestFormatResults:
|
| 90 |
+
def test_empty_returns_no_results_message(self) -> None:
|
| 91 |
+
result = _format_results([])
|
| 92 |
+
assert "Ingen relevante" in result
|
| 93 |
+
|
| 94 |
+
def test_includes_document_id_and_score(self) -> None:
|
| 95 |
+
results = [_qr(document_id="policy.pdf", score=0.85)]
|
| 96 |
+
text = _format_results(results)
|
| 97 |
+
assert "policy.pdf" in text
|
| 98 |
+
assert "0.850" in text
|
| 99 |
+
|
| 100 |
+
def test_includes_page_number(self) -> None:
|
| 101 |
+
results = [_qr(page_number=5)]
|
| 102 |
+
text = _format_results(results)
|
| 103 |
+
assert "side 5" in text
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# hybrid_search
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
class TestHybridSearch:
|
| 111 |
+
def test_returns_formatted_results(self, components) -> None:
|
| 112 |
+
retriever, reranker, vector_store, store = components
|
| 113 |
+
results = [_qr(document_id="a.pdf", score=0.9, text="answer")]
|
| 114 |
+
retriever.search_detailed.return_value = _hybrid_result(results)
|
| 115 |
+
reranker.rerank.return_value = results
|
| 116 |
+
|
| 117 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 118 |
+
hybrid_search = tools[0]
|
| 119 |
+
output = hybrid_search.invoke({"query": "test", "top_k": 5})
|
| 120 |
+
|
| 121 |
+
assert "a.pdf" in output
|
| 122 |
+
assert "answer" in output
|
| 123 |
+
retriever.search_detailed.assert_called_once_with("test", top_k=5)
|
| 124 |
+
|
| 125 |
+
def test_accumulates_in_store(self, components) -> None:
|
| 126 |
+
retriever, reranker, vector_store, store = components
|
| 127 |
+
results = [_qr(chunk_id="c1", score=0.8)]
|
| 128 |
+
retriever.search_detailed.return_value = _hybrid_result(results)
|
| 129 |
+
reranker.rerank.return_value = results
|
| 130 |
+
|
| 131 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 132 |
+
tools[0].invoke({"query": "q1"})
|
| 133 |
+
|
| 134 |
+
assert len(store.retrieved) == 1
|
| 135 |
+
assert store.retrieved[0].chunk.chunk_id == "c1"
|
| 136 |
+
assert len(store.tool_calls) == 1
|
| 137 |
+
assert store.tool_calls[0] == ("hybrid_search", "q1")
|
| 138 |
+
|
| 139 |
+
def test_no_results(self, components) -> None:
|
| 140 |
+
retriever, reranker, vector_store, store = components
|
| 141 |
+
retriever.search_detailed.return_value = _hybrid_result([])
|
| 142 |
+
reranker.rerank.return_value = []
|
| 143 |
+
|
| 144 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 145 |
+
output = tools[0].invoke({"query": "nothing"})
|
| 146 |
+
assert "Ingen relevante" in output
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# list_documents
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
class TestListDocuments:
|
| 154 |
+
def test_returns_document_list(self, components) -> None:
|
| 155 |
+
retriever, reranker, vector_store, store = components
|
| 156 |
+
vector_store.list_document_ids.return_value = ["a.pdf", "b.pdf"]
|
| 157 |
+
|
| 158 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 159 |
+
list_docs = tools[1]
|
| 160 |
+
output = list_docs.invoke({})
|
| 161 |
+
|
| 162 |
+
assert "a.pdf" in output
|
| 163 |
+
assert "b.pdf" in output
|
| 164 |
+
assert "2 i alt" in output
|
| 165 |
+
|
| 166 |
+
def test_empty_knowledge_base(self, components) -> None:
|
| 167 |
+
retriever, reranker, vector_store, store = components
|
| 168 |
+
vector_store.list_document_ids.return_value = []
|
| 169 |
+
|
| 170 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 171 |
+
output = tools[1].invoke({})
|
| 172 |
+
assert "empty" in output.lower() or "Ingen" in output
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ---------------------------------------------------------------------------
|
| 176 |
+
# fetch_document
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
|
| 179 |
+
class TestFetchDocument:
|
| 180 |
+
def test_returns_full_text(self, components) -> None:
|
| 181 |
+
retriever, reranker, vector_store, store = components
|
| 182 |
+
chunks = [_chunk(chunk_id="c1", text="page1"), _chunk(chunk_id="c2", text="page2")]
|
| 183 |
+
vector_store.get_chunks_by_document_id.return_value = chunks
|
| 184 |
+
|
| 185 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 186 |
+
fetch = tools[2]
|
| 187 |
+
output = fetch.invoke({"document_id": "doc.pdf"})
|
| 188 |
+
|
| 189 |
+
assert "page1" in output
|
| 190 |
+
assert "page2" in output
|
| 191 |
+
assert len(store.retrieved) == 2
|
| 192 |
+
|
| 193 |
+
def test_document_not_found(self, components) -> None:
|
| 194 |
+
retriever, reranker, vector_store, store = components
|
| 195 |
+
vector_store.get_chunks_by_document_id.return_value = []
|
| 196 |
+
|
| 197 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 198 |
+
output = tools[2].invoke({"document_id": "missing.pdf"})
|
| 199 |
+
assert "ikke fundet" in output
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
# search_within_document
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
|
| 206 |
+
class TestSearchWithinDocument:
|
| 207 |
+
def test_reranks_document_chunks(self, components) -> None:
|
| 208 |
+
retriever, reranker, vector_store, store = components
|
| 209 |
+
chunks = [
|
| 210 |
+
_chunk(chunk_id="c1", text="irrelevant"),
|
| 211 |
+
_chunk(chunk_id="c2", text="relevant answer"),
|
| 212 |
+
]
|
| 213 |
+
vector_store.get_chunks_by_document_id.return_value = chunks
|
| 214 |
+
reranker.rerank.return_value = [_qr(chunk_id="c2", text="relevant answer", score=0.95)]
|
| 215 |
+
|
| 216 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 217 |
+
search_within = tools[3]
|
| 218 |
+
output = search_within.invoke({"document_id": "doc.pdf", "query": "answer"})
|
| 219 |
+
|
| 220 |
+
assert "relevant answer" in output
|
| 221 |
+
assert "0.950" in output
|
| 222 |
+
reranker.rerank.assert_called_once()
|
| 223 |
+
# Verify it passed all chunks to reranker
|
| 224 |
+
candidates = reranker.rerank.call_args[0][1]
|
| 225 |
+
assert len(candidates) == 2
|
| 226 |
+
|
| 227 |
+
def test_document_not_found(self, components) -> None:
|
| 228 |
+
retriever, reranker, vector_store, store = components
|
| 229 |
+
vector_store.get_chunks_by_document_id.return_value = []
|
| 230 |
+
|
| 231 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 232 |
+
output = tools[3].invoke({"document_id": "missing.pdf", "query": "test"})
|
| 233 |
+
assert "ikke fundet" in output
|
| 234 |
+
|
| 235 |
+
def test_accumulates_in_store(self, components) -> None:
|
| 236 |
+
retriever, reranker, vector_store, store = components
|
| 237 |
+
chunks = [_chunk(chunk_id="c1")]
|
| 238 |
+
vector_store.get_chunks_by_document_id.return_value = chunks
|
| 239 |
+
reranker.rerank.return_value = [_qr(chunk_id="c1", score=0.7)]
|
| 240 |
+
|
| 241 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store)
|
| 242 |
+
tools[3].invoke({"document_id": "doc.pdf", "query": "q"})
|
| 243 |
+
|
| 244 |
+
assert len(store.retrieved) == 1
|
| 245 |
+
assert store.tool_calls[-1][0] == "search_within_document"
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
# multi_query_search (requires llm_chain)
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
|
| 252 |
+
class TestMultiQuerySearch:
|
| 253 |
+
def test_decomposes_and_searches(self, components) -> None:
|
| 254 |
+
retriever, reranker, vector_store, store = components
|
| 255 |
+
llm_chain = MagicMock()
|
| 256 |
+
|
| 257 |
+
# LLM returns 2 sub-queries
|
| 258 |
+
llm_chain.invoke.return_value = "eksamenregler bachelor\neksamensregler kandidat"
|
| 259 |
+
|
| 260 |
+
results_a = [_qr(chunk_id="c1", score=0.9, text="bachelor exam")]
|
| 261 |
+
results_b = [_qr(chunk_id="c2", score=0.85, text="master exam")]
|
| 262 |
+
retriever.search_detailed.side_effect = [
|
| 263 |
+
_hybrid_result(results_a),
|
| 264 |
+
_hybrid_result(results_b),
|
| 265 |
+
]
|
| 266 |
+
reranker.rerank.side_effect = [results_a, results_b]
|
| 267 |
+
|
| 268 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 269 |
+
multi_search = tools[4]
|
| 270 |
+
output = multi_search.invoke({"question": "Compare exam rules"})
|
| 271 |
+
|
| 272 |
+
assert "delforespørgsler" in output
|
| 273 |
+
assert retriever.search_detailed.call_count == 2
|
| 274 |
+
assert reranker.rerank.call_count == 2
|
| 275 |
+
assert len(store.retrieved) == 2
|
| 276 |
+
|
| 277 |
+
def test_fallback_when_decompose_fails(self, components) -> None:
|
| 278 |
+
retriever, reranker, vector_store, store = components
|
| 279 |
+
llm_chain = MagicMock()
|
| 280 |
+
|
| 281 |
+
# LLM returns empty/garbage
|
| 282 |
+
llm_chain.invoke.return_value = ""
|
| 283 |
+
|
| 284 |
+
results = [_qr(chunk_id="c1", score=0.8)]
|
| 285 |
+
retriever.search_detailed.return_value = _hybrid_result(results)
|
| 286 |
+
reranker.rerank.return_value = results
|
| 287 |
+
|
| 288 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 289 |
+
output = tools[4].invoke({"question": "original question"})
|
| 290 |
+
|
| 291 |
+
# Should fall back to the original question as single query
|
| 292 |
+
assert retriever.search_detailed.call_count == 1
|
| 293 |
+
assert "0.800" in output
|
| 294 |
+
|
| 295 |
+
def test_not_available_without_llm(self, components) -> None:
|
| 296 |
+
retriever, reranker, vector_store, store = components
|
| 297 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=None)
|
| 298 |
+
tool_names = [t.name for t in tools]
|
| 299 |
+
assert "multi_query_search" not in tool_names
|
| 300 |
+
assert "summarize_document" not in tool_names
|
| 301 |
+
|
| 302 |
+
def test_deduplicates_across_sub_queries(self, components) -> None:
|
| 303 |
+
retriever, reranker, vector_store, store = components
|
| 304 |
+
llm_chain = MagicMock()
|
| 305 |
+
llm_chain.invoke.return_value = "query1\nquery2"
|
| 306 |
+
|
| 307 |
+
# Both sub-queries return the same chunk
|
| 308 |
+
same_result = [_qr(chunk_id="c1", score=0.8)]
|
| 309 |
+
retriever.search_detailed.return_value = _hybrid_result(same_result)
|
| 310 |
+
reranker.rerank.return_value = same_result
|
| 311 |
+
|
| 312 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 313 |
+
tools[4].invoke({"question": "test"})
|
| 314 |
+
|
| 315 |
+
# Should be deduplicated to 1
|
| 316 |
+
assert len(store.retrieved) == 1
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
# summarize_document (requires llm_chain)
|
| 321 |
+
# ---------------------------------------------------------------------------
|
| 322 |
+
|
| 323 |
+
class TestSummarizeDocument:
|
| 324 |
+
def test_generates_summary(self, components) -> None:
|
| 325 |
+
retriever, reranker, vector_store, store = components
|
| 326 |
+
llm_chain = MagicMock()
|
| 327 |
+
llm_chain.invoke.return_value = "This document covers exam policies."
|
| 328 |
+
|
| 329 |
+
chunks = [_chunk(chunk_id="c1", text="Exam rules...")]
|
| 330 |
+
vector_store.get_chunks_by_document_id.return_value = chunks
|
| 331 |
+
|
| 332 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 333 |
+
summarize = tools[5]
|
| 334 |
+
output = summarize.invoke({"document_id": "exam.pdf"})
|
| 335 |
+
|
| 336 |
+
assert "Resumé af exam.pdf" in output
|
| 337 |
+
assert "exam policies" in output
|
| 338 |
+
llm_chain.invoke.assert_called_once()
|
| 339 |
+
# Verify the prompt includes the document text
|
| 340 |
+
prompt = llm_chain.invoke.call_args[0][0]
|
| 341 |
+
assert "Exam rules" in prompt
|
| 342 |
+
|
| 343 |
+
def test_document_not_found(self, components) -> None:
|
| 344 |
+
retriever, reranker, vector_store, store = components
|
| 345 |
+
llm_chain = MagicMock()
|
| 346 |
+
vector_store.get_chunks_by_document_id.return_value = []
|
| 347 |
+
|
| 348 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 349 |
+
output = tools[5].invoke({"document_id": "missing.pdf"})
|
| 350 |
+
assert "ikke fundet" in output
|
| 351 |
+
llm_chain.invoke.assert_not_called()
|
| 352 |
+
|
| 353 |
+
def test_truncates_long_documents(self, components) -> None:
|
| 354 |
+
retriever, reranker, vector_store, store = components
|
| 355 |
+
llm_chain = MagicMock()
|
| 356 |
+
llm_chain.invoke.return_value = "summary"
|
| 357 |
+
|
| 358 |
+
# Create a document longer than 8000 chars
|
| 359 |
+
long_text = "x" * 10000
|
| 360 |
+
chunks = [_chunk(chunk_id="c1", text=long_text)]
|
| 361 |
+
vector_store.get_chunks_by_document_id.return_value = chunks
|
| 362 |
+
|
| 363 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 364 |
+
tools[5].invoke({"document_id": "long.pdf"})
|
| 365 |
+
|
| 366 |
+
prompt = llm_chain.invoke.call_args[0][0]
|
| 367 |
+
assert "forkortet" in prompt
|
| 368 |
+
|
| 369 |
+
def test_registers_chunks_as_sources(self, components) -> None:
|
| 370 |
+
retriever, reranker, vector_store, store = components
|
| 371 |
+
llm_chain = MagicMock()
|
| 372 |
+
llm_chain.invoke.return_value = "summary"
|
| 373 |
+
|
| 374 |
+
chunks = [_chunk(chunk_id="c1"), _chunk(chunk_id="c2")]
|
| 375 |
+
vector_store.get_chunks_by_document_id.return_value = chunks
|
| 376 |
+
|
| 377 |
+
tools = make_retrieval_tools(retriever, reranker, vector_store, store, llm_chain=llm_chain)
|
| 378 |
+
tools[5].invoke({"document_id": "doc.pdf"})
|
| 379 |
+
|
| 380 |
+
assert len(store.retrieved) == 2
|
| 381 |
+
assert store.tool_calls[-1] == ("summarize_document", "doc.pdf")
|