Rifqi Hafizuddin Claude Opus 4.8 commited on
Commit
72306d0
·
1 Parent(s): 83ba6b1

[KM-567][AI] Planner agent: validator + service

Browse files

- validator.py: PlannerValidator runs the 8 checks from §7.3 (tools in registry,
catalog refs exist, DAG valid / no cycles, parallelism consistent, within task
cap, checkable success_criteria, args valid, inline query_structured IR via the
existing IRValidator). Raises PlannerValidationError with self-correctable
messages.
- service.py: PlannerService + plan_analysis(). LLM chain mirrors
query/planner/service.py; validate-and-retry loop (max 3) mirrors QueryService.
Takes the full Catalog, derives the PII-safe CatalogSummary for the prompt.
Static plan only — no replanning.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

src/agents/planner/service.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PlannerService — single LLM call: context + catalog + tools + question -> TaskList.
2
+
3
+ Mirrors `query/planner/service.py` (chain construction) and `query/service.py`
4
+ (validate-and-retry loop). The planner LLM emits a `TaskList` via structured
5
+ output; the `PlannerValidator` runs the 8 checks; on failure the planner is
6
+ re-prompted with the error context, up to `max_retries` (default 3). No
7
+ replanning happens at execution time — this loop only hardens the *initial*
8
+ static plan.
9
+
10
+ The service takes the full `Catalog` (not just a `CatalogSummary`): it derives
11
+ the PII-safe `CatalogSummary` for the prompt, but validation needs the full
12
+ catalog so the existing `IRValidator` can check inline `query_structured` IRs.
13
+
14
+ See AGENT_ARCHITECTURE_CONTEXT_new.md §7.3.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from pathlib import Path
20
+
21
+ from langchain_core.messages import SystemMessage
22
+ from langchain_core.prompts import ChatPromptTemplate
23
+ from langchain_core.runnables import Runnable
24
+ from langchain_openai import AzureChatOpenAI
25
+
26
+ from src.middlewares.logging import get_logger
27
+
28
+ from ...catalog.models import Catalog
29
+ from .contracts import BusinessContext, ToolRegistry
30
+ from .errors import PlannerError, PlannerValidationError
31
+ from .inputs import CatalogSummary, Constraints
32
+ from .prompt import build_planner_prompt
33
+ from .schemas import TaskList
34
+ from .validator import PlannerValidator
35
+
36
+ logger = get_logger("planner_agent")
37
+
38
+ _PROMPT_PATH = (
39
+ Path(__file__).resolve().parent.parent.parent / "config" / "prompts" / "planner.md"
40
+ )
41
+
42
+
43
+ def _load_prompt_text() -> str:
44
+ return _PROMPT_PATH.read_text(encoding="utf-8")
45
+
46
+
47
+ def _build_default_chain() -> Runnable:
48
+ from src.config.settings import settings
49
+
50
+ llm = AzureChatOpenAI(
51
+ azure_deployment=settings.azureai_deployment_name_4o,
52
+ openai_api_version=settings.azureai_api_version_4o,
53
+ azure_endpoint=settings.azureai_endpoint_url_4o,
54
+ api_key=settings.azureai_api_key_4o,
55
+ temperature=0,
56
+ )
57
+ prompt = ChatPromptTemplate.from_messages(
58
+ [
59
+ SystemMessage(content=_load_prompt_text()),
60
+ ("human", "{human_content}"),
61
+ ]
62
+ )
63
+ return prompt | llm.with_structured_output(TaskList)
64
+
65
+
66
+ _default_chain: Runnable | None = None
67
+
68
+
69
+ def _get_default_chain() -> Runnable:
70
+ global _default_chain
71
+ if _default_chain is None:
72
+ _default_chain = _build_default_chain()
73
+ return _default_chain
74
+
75
+
76
+ class PlannerService:
77
+ """Wraps the planner LLM call + the validate-and-retry loop.
78
+
79
+ Inject `structured_chain` and/or `validator` for tests.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ structured_chain: Runnable | None = None,
85
+ validator: PlannerValidator | None = None,
86
+ max_retries: int = 3,
87
+ ) -> None:
88
+ self._chain = structured_chain
89
+ self._validator = validator or PlannerValidator()
90
+ self._max_retries = max_retries
91
+
92
+ def _ensure_chain(self) -> Runnable:
93
+ if self._chain is None:
94
+ self._chain = _get_default_chain()
95
+ return self._chain
96
+
97
+ async def plan(
98
+ self,
99
+ context: BusinessContext,
100
+ catalog: Catalog,
101
+ tools: ToolRegistry,
102
+ query: str,
103
+ constraints: Constraints,
104
+ ) -> TaskList:
105
+ summary = CatalogSummary.from_catalog(catalog)
106
+ chain = self._ensure_chain()
107
+ previous_error: str | None = None
108
+
109
+ for attempt in range(1, self._max_retries + 1):
110
+ human_content = build_planner_prompt(
111
+ context, summary, tools, query, constraints, previous_error
112
+ )
113
+ task_list: TaskList = await chain.ainvoke({"human_content": human_content})
114
+ try:
115
+ self._validator.validate(task_list, tools, catalog, constraints)
116
+ except PlannerValidationError as e:
117
+ previous_error = str(e)
118
+ logger.warning(
119
+ "planner validation failed",
120
+ project_id=context.project_id,
121
+ plan_id=task_list.plan_id,
122
+ attempt=attempt,
123
+ error=previous_error,
124
+ )
125
+ continue
126
+
127
+ logger.info(
128
+ "analysis planned",
129
+ project_id=context.project_id,
130
+ plan_id=task_list.plan_id,
131
+ n_tasks=len(task_list.tasks),
132
+ retry=attempt > 1,
133
+ )
134
+ return task_list
135
+
136
+ raise PlannerError(
137
+ f"planner failed validation after {self._max_retries} attempts; "
138
+ f"last error: {previous_error}"
139
+ )
140
+
141
+
142
+ async def plan_analysis(
143
+ context: BusinessContext,
144
+ catalog: Catalog,
145
+ tools: ToolRegistry,
146
+ query: str,
147
+ constraints: Constraints,
148
+ ) -> TaskList:
149
+ """Convenience entry point using the default chain + validator."""
150
+ return await PlannerService().plan(context, catalog, tools, query, constraints)
src/agents/planner/validator.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PlannerValidator — checks a TaskList before it reaches the TaskRunner.
2
+
3
+ Runs the 8 checks from AGENT_ARCHITECTURE_CONTEXT_new.md §7.3. On failure it
4
+ raises `PlannerValidationError` with a message specific enough that the planner
5
+ can be re-prompted to self-correct (the retry loop lives in service.py).
6
+
7
+ Check #1 (Pydantic parse) is enforced at the structured-output boundary — by the
8
+ time a `TaskList` reaches here it has already parsed; this validator additionally
9
+ rejects structurally-invalid plans (duplicate ids, dangling edges, cycles).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import re
15
+
16
+ from pydantic import ValidationError
17
+
18
+ from ...catalog.models import Catalog
19
+ from ...query.ir.models import QueryIR
20
+ from ...query.ir.validator import IRValidationError, IRValidator
21
+ from .contracts import ToolRegistry
22
+ from .errors import PlannerValidationError
23
+ from .inputs import Constraints
24
+ from .schemas import TaskList
25
+
26
+ # Heuristic: a checkable success_criteria mentions a measurable signal.
27
+ _CHECKABLE_TOKENS = ("rate", "count", "match", "produced", "above", "below", "equal")
28
+ _PLACEHOLDER_RE = re.compile(r"\$\{(t[^}]+)\}")
29
+
30
+ # DFS colors for cycle detection.
31
+ _WHITE, _GREY, _BLACK = 0, 1, 2
32
+
33
+
34
+ class PlannerValidator:
35
+ def __init__(self, ir_validator: IRValidator | None = None) -> None:
36
+ self._ir_validator = ir_validator or IRValidator()
37
+
38
+ def validate(
39
+ self,
40
+ task_list: TaskList,
41
+ registry: ToolRegistry,
42
+ catalog: Catalog,
43
+ constraints: Constraints,
44
+ ) -> None:
45
+ tasks = task_list.tasks
46
+
47
+ # Check 6 — plan non-empty and within the task cap.
48
+ if not tasks:
49
+ raise PlannerValidationError("plan is empty: at least one task is required")
50
+ if len(tasks) > constraints.max_tasks:
51
+ raise PlannerValidationError(
52
+ f"plan has {len(tasks)} tasks, exceeds max_tasks={constraints.max_tasks}"
53
+ )
54
+
55
+ ids = [t.id for t in tasks]
56
+ if len(set(ids)) != len(ids):
57
+ dupes = sorted({i for i in ids if ids.count(i) > 1})
58
+ raise PlannerValidationError(f"duplicate task id(s): {dupes}")
59
+ id_set = set(ids)
60
+ tasks_by_id = {t.id: t for t in tasks}
61
+
62
+ known_tools = registry.names()
63
+ known_sources = {s.source_id for s in catalog.sources}
64
+
65
+ for task in tasks:
66
+ for call in task.tool_calls:
67
+ # Check 2 — every tool exists in the registry.
68
+ if call.tool not in known_tools:
69
+ raise PlannerValidationError(
70
+ f"task {task.id}: tool {call.tool!r} not in registry "
71
+ f"(known: {sorted(known_tools)})"
72
+ )
73
+ spec = registry.get(call.tool)
74
+ assert spec is not None # guaranteed by the membership check above
75
+
76
+ # Check 8a — args carry the required keys and no unknown keys.
77
+ required = set(spec.input_schema.get("required", []))
78
+ allowed = set(spec.input_schema.get("properties", {}).keys()) | required
79
+ missing = required - set(call.args.keys())
80
+ if missing:
81
+ raise PlannerValidationError(
82
+ f"task {task.id}: tool {call.tool!r} missing required arg(s): "
83
+ f"{sorted(missing)}"
84
+ )
85
+ unknown = set(call.args.keys()) - allowed
86
+ if unknown:
87
+ raise PlannerValidationError(
88
+ f"task {task.id}: tool {call.tool!r} has unknown arg(s): "
89
+ f"{sorted(unknown)} (allowed: {sorted(allowed)})"
90
+ )
91
+
92
+ # Check 3 — concrete source_id args must exist in the catalog.
93
+ src = call.args.get("source_id")
94
+ if isinstance(src, str) and not _is_placeholder(src):
95
+ if src not in known_sources:
96
+ raise PlannerValidationError(
97
+ f"task {task.id}: tool {call.tool!r} references unknown "
98
+ f"source_id {src!r} (known: {sorted(known_sources)})"
99
+ )
100
+
101
+ # Check 8b — inline query_structured IR validates against the catalog.
102
+ if call.tool == "query_structured":
103
+ self._validate_inline_ir(task.id, call.args, catalog)
104
+
105
+ # Check 7 — success_criteria is checkable.
106
+ if not _is_checkable(task.success_criteria):
107
+ raise PlannerValidationError(
108
+ f"task {task.id}: success_criteria is not checkable — include a "
109
+ f"measurable signal (one of {list(_CHECKABLE_TOKENS)}); "
110
+ f"got {task.success_criteria!r}"
111
+ )
112
+
113
+ # Check 4 — DAG: edges resolve, placeholders resolve, no cycles.
114
+ self._validate_dag(tasks_by_id, id_set)
115
+
116
+ # Check 5 — parallelizable_with is consistent with the dependency graph.
117
+ self._validate_parallelism(tasks_by_id, id_set)
118
+
119
+ def _validate_inline_ir(self, task_id: str, args: dict, catalog: Catalog) -> None:
120
+ raw_ir = args.get("ir")
121
+ if not isinstance(raw_ir, dict):
122
+ raise PlannerValidationError(
123
+ f"task {task_id}: query_structured.args.ir must be an inline QueryIR "
124
+ f"object, got {type(raw_ir).__name__}"
125
+ )
126
+ try:
127
+ ir = QueryIR.model_validate(raw_ir)
128
+ except ValidationError as e:
129
+ raise PlannerValidationError(
130
+ f"task {task_id}: query_structured.args.ir is not a valid QueryIR: {e}"
131
+ ) from e
132
+ try:
133
+ self._ir_validator.validate(ir, catalog)
134
+ except IRValidationError as e:
135
+ raise PlannerValidationError(
136
+ f"task {task_id}: query_structured IR failed catalog validation: {e}"
137
+ ) from e
138
+
139
+ @staticmethod
140
+ def _validate_dag(tasks_by_id: dict, id_set: set[str]) -> None:
141
+ for task in tasks_by_id.values():
142
+ for dep in task.depends_on:
143
+ if dep not in id_set:
144
+ raise PlannerValidationError(
145
+ f"task {task.id}: depends_on references unknown task {dep!r}"
146
+ )
147
+ if dep == task.id:
148
+ raise PlannerValidationError(
149
+ f"task {task.id}: depends_on includes itself"
150
+ )
151
+ # Placeholders must reference an existing, declared dependency.
152
+ for ref in _placeholder_refs(task):
153
+ if ref not in id_set:
154
+ raise PlannerValidationError(
155
+ f"task {task.id}: placeholder '${{{ref}}}' references unknown task"
156
+ )
157
+ if ref not in task.depends_on:
158
+ raise PlannerValidationError(
159
+ f"task {task.id}: placeholder '${{{ref}}}' used but {ref!r} is "
160
+ f"not in depends_on"
161
+ )
162
+
163
+ cycle = _find_cycle(tasks_by_id)
164
+ if cycle:
165
+ raise PlannerValidationError(f"cycle detected in depends_on: {' -> '.join(cycle)}")
166
+
167
+ @staticmethod
168
+ def _validate_parallelism(tasks_by_id: dict, id_set: set[str]) -> None:
169
+ ancestors = _all_ancestors(tasks_by_id)
170
+ for task in tasks_by_id.values():
171
+ for other in task.parallelizable_with:
172
+ if other not in id_set:
173
+ raise PlannerValidationError(
174
+ f"task {task.id}: parallelizable_with references unknown task "
175
+ f"{other!r}"
176
+ )
177
+ if other == task.id:
178
+ raise PlannerValidationError(
179
+ f"task {task.id}: parallelizable_with includes itself"
180
+ )
181
+ if other in ancestors[task.id] or task.id in ancestors[other]:
182
+ raise PlannerValidationError(
183
+ f"task {task.id}: parallelizable_with {other!r} conflicts with a "
184
+ f"(transitive) depends_on relationship between them"
185
+ )
186
+
187
+
188
+ def _is_placeholder(value: str) -> bool:
189
+ return bool(_PLACEHOLDER_RE.fullmatch(value.strip()))
190
+
191
+
192
+ def _placeholder_refs(task) -> set[str]:
193
+ refs: set[str] = set()
194
+ for call in task.tool_calls:
195
+ for value in call.args.values():
196
+ if isinstance(value, str):
197
+ refs.update(_PLACEHOLDER_RE.findall(value))
198
+ return refs
199
+
200
+
201
+ def _is_checkable(text: str) -> bool:
202
+ low = text.lower()
203
+ return any(tok in low for tok in _CHECKABLE_TOKENS)
204
+
205
+
206
+ def _find_cycle(tasks_by_id: dict) -> list[str] | None:
207
+ color = {tid: _WHITE for tid in tasks_by_id}
208
+ stack: list[str] = []
209
+
210
+ def dfs(node: str) -> list[str] | None:
211
+ color[node] = _GREY
212
+ stack.append(node)
213
+ for dep in tasks_by_id[node].depends_on:
214
+ if color.get(dep) == _GREY:
215
+ idx = stack.index(dep)
216
+ return stack[idx:] + [dep]
217
+ if color.get(dep) == _WHITE:
218
+ found = dfs(dep)
219
+ if found:
220
+ return found
221
+ stack.pop()
222
+ color[node] = _BLACK
223
+ return None
224
+
225
+ for tid in tasks_by_id:
226
+ if color[tid] == _WHITE:
227
+ found = dfs(tid)
228
+ if found:
229
+ return found
230
+ return None
231
+
232
+
233
+ def _all_ancestors(tasks_by_id: dict) -> dict[str, set[str]]:
234
+ """ancestors[id] = all tasks reachable by following depends_on edges."""
235
+ cache: dict[str, set[str]] = {}
236
+
237
+ def visit(node: str, seen: set[str]) -> set[str]:
238
+ if node in cache:
239
+ return cache[node]
240
+ acc: set[str] = set()
241
+ for dep in tasks_by_id[node].depends_on:
242
+ if dep in seen or dep not in tasks_by_id:
243
+ continue
244
+ acc.add(dep)
245
+ acc |= visit(dep, seen | {dep})
246
+ cache[node] = acc
247
+ return acc
248
+
249
+ return {tid: visit(tid, {tid}) for tid in tasks_by_id}