zequn-fireworks commited on
Commit
d5f3acc
·
1 Parent(s): 8a068d0

Add structured execution trace to A2A task responses

Browse files

Introduces opt-in execution trace data (contracts_called, planner_steps,
cost_metrics) returned as a DataPart artifact. Controlled per-request via
metadata.include_trace or server-wide via FIREACTION_TRACE_ENABLED env var.
Supports eval scoring and production debugging.

Made-with: Cursor

src/fireaction_a2a/agent.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any
12
  from a2a.server.agent_execution import AgentExecutor, RequestContext
13
  from a2a.server.events import EventQueue
14
  from a2a.server.tasks import TaskUpdater
15
- from a2a.types import Part, TaskState, TextPart
16
  from a2a.utils import new_agent_text_message, new_task
17
 
18
  from fireaction_a2a.planner import Planner
@@ -23,10 +23,28 @@ logger = logging.getLogger(__name__)
23
  class ProviderAgentExecutor(AgentExecutor):
24
  """A2A executor that delegates to a :class:`Planner`."""
25
 
26
- def __init__(self, planner: Planner) -> None:
 
 
 
 
 
27
  self.planner = planner
 
28
  self.conversations: dict[str, list[dict[str, Any]]] = {}
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  async def execute(
31
  self,
32
  context: RequestContext,
@@ -41,12 +59,15 @@ class ProviderAgentExecutor(AgentExecutor):
41
 
42
  updater = TaskUpdater(event_queue, task.id, task.context_id)
43
  history = self.conversations.get(task.context_id, [])
 
44
 
45
  await updater.start_work(
46
  new_agent_text_message("Planning...", task.context_id, task.id)
47
  )
48
 
49
- async for step in self.planner.run(user_input, history):
 
 
50
  if step.type == "status":
51
  await updater.update_status(
52
  TaskState.working,
@@ -65,10 +86,20 @@ class ProviderAgentExecutor(AgentExecutor):
65
  [Part(root=TextPart(text=step.message))],
66
  name="result",
67
  )
 
 
 
 
 
68
  await updater.complete()
69
  break
70
 
71
  elif step.type == "error":
 
 
 
 
 
72
  await updater.failed(
73
  new_agent_text_message(step.message, task.context_id, task.id),
74
  )
 
12
  from a2a.server.agent_execution import AgentExecutor, RequestContext
13
  from a2a.server.events import EventQueue
14
  from a2a.server.tasks import TaskUpdater
15
+ from a2a.types import DataPart, Part, TaskState, TextPart
16
  from a2a.utils import new_agent_text_message, new_task
17
 
18
  from fireaction_a2a.planner import Planner
 
23
  class ProviderAgentExecutor(AgentExecutor):
24
  """A2A executor that delegates to a :class:`Planner`."""
25
 
26
+ def __init__(
27
+ self,
28
+ planner: Planner,
29
+ *,
30
+ trace_enabled_default: bool = False,
31
+ ) -> None:
32
  self.planner = planner
33
+ self.trace_enabled_default = trace_enabled_default
34
  self.conversations: dict[str, list[dict[str, Any]]] = {}
35
 
36
+ def _resolve_trace_flag(self, context: RequestContext) -> bool:
37
+ """Per-request metadata overrides the server-wide default."""
38
+ meta: dict[str, Any] = {}
39
+ if context.current_task and getattr(context.current_task, "metadata", None):
40
+ meta = context.current_task.metadata # type: ignore[assignment]
41
+ elif getattr(context.message, "metadata", None):
42
+ meta = context.message.metadata # type: ignore[assignment]
43
+
44
+ if "include_trace" in meta:
45
+ return bool(meta["include_trace"])
46
+ return self.trace_enabled_default
47
+
48
  async def execute(
49
  self,
50
  context: RequestContext,
 
59
 
60
  updater = TaskUpdater(event_queue, task.id, task.context_id)
61
  history = self.conversations.get(task.context_id, [])
62
+ trace_enabled = self._resolve_trace_flag(context)
63
 
64
  await updater.start_work(
65
  new_agent_text_message("Planning...", task.context_id, task.id)
66
  )
67
 
68
+ async for step in self.planner.run(
69
+ user_input, history, trace_enabled=trace_enabled
70
+ ):
71
  if step.type == "status":
72
  await updater.update_status(
73
  TaskState.working,
 
86
  [Part(root=TextPart(text=step.message))],
87
  name="result",
88
  )
89
+ if step.trace_data:
90
+ await updater.add_artifact(
91
+ [Part(root=DataPart(data=step.trace_data))],
92
+ name="execution_trace",
93
+ )
94
  await updater.complete()
95
  break
96
 
97
  elif step.type == "error":
98
+ if step.trace_data:
99
+ await updater.add_artifact(
100
+ [Part(root=DataPart(data=step.trace_data))],
101
+ name="execution_trace",
102
+ )
103
  await updater.failed(
104
  new_agent_text_message(step.message, task.context_id, task.id),
105
  )
src/fireaction_a2a/planner.py CHANGED
@@ -14,6 +14,7 @@ from __future__ import annotations
14
  import json
15
  import logging
16
  import os
 
17
  from collections.abc import AsyncGenerator
18
  from dataclasses import dataclass
19
  from typing import Any, Literal
@@ -145,6 +146,7 @@ TOOLS: list[dict[str, Any]] = [
145
  class PlanStep:
146
  type: Literal["status", "completed", "input_required", "error"]
147
  message: str
 
148
 
149
 
150
  # ---------------------------------------------------------------------------
@@ -161,18 +163,22 @@ class Planner:
161
  contract_search: ContractSearch,
162
  llm_model: str,
163
  system_prompt: str,
 
164
  ) -> None:
165
  self.provider = provider
166
  self.provider_client = provider_client
167
  self.search = contract_search
168
  self.model = llm_model
169
  self.system_prompt = system_prompt
 
170
  self._last_messages: list[dict[str, Any]] = []
171
 
172
  async def run(
173
  self,
174
  user_message: str,
175
  history: list[dict[str, Any]],
 
 
176
  ) -> AsyncGenerator[PlanStep, None]:
177
  """Execute the planner loop. Yields PlanStep events."""
178
 
@@ -185,6 +191,37 @@ class Planner:
185
  retry_counts: dict[str, int] = {}
186
  total_steps = 0
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  while total_steps < MAX_TOTAL_STEPS:
189
  try:
190
  response = await litellm.acompletion(
@@ -194,9 +231,27 @@ class Planner:
194
  )
195
  except Exception:
196
  logger.exception("LLM call failed")
197
- yield PlanStep(type="error", message="LLM call failed")
 
 
 
 
198
  break
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  choice = response.choices[0]
201
  assistant_msg = choice.message
202
 
@@ -208,7 +263,11 @@ class Planner:
208
  if choice.finish_reason == "stop" or not getattr(assistant_msg, "tool_calls", None):
209
  content = assistant_msg.content or ""
210
  step_type = self._classify_response(content)
211
- yield PlanStep(type=step_type, message=content)
 
 
 
 
212
  break
213
 
214
  # Process tool calls
@@ -220,15 +279,41 @@ class Planner:
220
  except json.JSONDecodeError:
221
  args = {}
222
 
223
- result = await self._dispatch_tool(name, args, retry_counts)
 
 
 
 
 
 
224
 
225
  if isinstance(result, PlanStep):
226
- yield result
 
 
 
 
 
 
227
  if result.type == "error":
 
 
 
 
 
228
  self._last_messages = messages
229
  return
 
230
  continue
231
 
 
 
 
 
 
 
 
 
232
  # Yield status for successful executions
233
  if name == "execute_contract" and not (isinstance(result, dict) and result.get("error")):
234
  yield PlanStep(
@@ -246,6 +331,7 @@ class Planner:
246
  yield PlanStep(
247
  type="error",
248
  message=f"Exceeded maximum of {MAX_TOTAL_STEPS} tool calls.",
 
249
  )
250
 
251
  self._last_messages = messages
@@ -263,6 +349,9 @@ class Planner:
263
  name: str,
264
  args: dict[str, Any],
265
  retry_counts: dict[str, int],
 
 
 
266
  ) -> dict[str, Any] | list[dict[str, Any]] | PlanStep:
267
  """Route a tool call to the appropriate handler."""
268
 
@@ -286,11 +375,50 @@ class Planner:
286
 
287
  contract = self.provider.contracts[contract_name]
288
  try:
289
- result = await execute_contract(contract, instance_nodes, self.provider_client)
 
 
 
 
 
290
  retry_counts.pop(contract_name, None)
291
- return result
 
 
 
 
 
 
 
 
 
 
292
  except ContractError as exc:
293
  retry_counts[contract_name] = retry_counts.get(contract_name, 0) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  if retry_counts[contract_name] > MAX_RETRIES_PER_CONTRACT:
295
  return PlanStep(
296
  type="error",
 
14
  import json
15
  import logging
16
  import os
17
+ import time
18
  from collections.abc import AsyncGenerator
19
  from dataclasses import dataclass
20
  from typing import Any, Literal
 
146
  class PlanStep:
147
  type: Literal["status", "completed", "input_required", "error"]
148
  message: str
149
+ trace_data: dict[str, Any] | None = None
150
 
151
 
152
  # ---------------------------------------------------------------------------
 
163
  contract_search: ContractSearch,
164
  llm_model: str,
165
  system_prompt: str,
166
+ provider_name: str = "",
167
  ) -> None:
168
  self.provider = provider
169
  self.provider_client = provider_client
170
  self.search = contract_search
171
  self.model = llm_model
172
  self.system_prompt = system_prompt
173
+ self.provider_name = provider_name
174
  self._last_messages: list[dict[str, Any]] = []
175
 
176
  async def run(
177
  self,
178
  user_message: str,
179
  history: list[dict[str, Any]],
180
+ *,
181
+ trace_enabled: bool = False,
182
  ) -> AsyncGenerator[PlanStep, None]:
183
  """Execute the planner loop. Yields PlanStep events."""
184
 
 
191
  retry_counts: dict[str, int] = {}
192
  total_steps = 0
193
 
194
+ # Trace accumulators (only meaningful when trace_enabled)
195
+ trace_steps: list[dict[str, Any]] = []
196
+ trace_contracts: list[dict[str, Any]] = []
197
+ trace_cost: dict[str, int] = {
198
+ "total_tokens": 0,
199
+ "prompt_tokens": 0,
200
+ "completion_tokens": 0,
201
+ "num_llm_calls": 0,
202
+ }
203
+ start_time = time.monotonic()
204
+
205
+ def _build_trace() -> dict[str, Any] | None:
206
+ if not trace_enabled:
207
+ return None
208
+ return {
209
+ "execution_trace": {
210
+ "contracts_called": trace_contracts,
211
+ "planner_steps": trace_steps,
212
+ "total_api_calls": sum(
213
+ 1 for c in trace_contracts if c.get("api_call")
214
+ ),
215
+ "total_planner_steps": len(trace_steps),
216
+ },
217
+ "cost_metrics": {
218
+ **trace_cost,
219
+ "total_duration_ms": int(
220
+ (time.monotonic() - start_time) * 1000
221
+ ),
222
+ },
223
+ }
224
+
225
  while total_steps < MAX_TOTAL_STEPS:
226
  try:
227
  response = await litellm.acompletion(
 
231
  )
232
  except Exception:
233
  logger.exception("LLM call failed")
234
+ yield PlanStep(
235
+ type="error",
236
+ message="LLM call failed",
237
+ trace_data=_build_trace(),
238
+ )
239
  break
240
 
241
+ if trace_enabled:
242
+ trace_cost["num_llm_calls"] += 1
243
+ usage = getattr(response, "usage", None)
244
+ if usage:
245
+ trace_cost["prompt_tokens"] += (
246
+ getattr(usage, "prompt_tokens", 0) or 0
247
+ )
248
+ trace_cost["completion_tokens"] += (
249
+ getattr(usage, "completion_tokens", 0) or 0
250
+ )
251
+ trace_cost["total_tokens"] += (
252
+ getattr(usage, "total_tokens", 0) or 0
253
+ )
254
+
255
  choice = response.choices[0]
256
  assistant_msg = choice.message
257
 
 
263
  if choice.finish_reason == "stop" or not getattr(assistant_msg, "tool_calls", None):
264
  content = assistant_msg.content or ""
265
  step_type = self._classify_response(content)
266
+ yield PlanStep(
267
+ type=step_type,
268
+ message=content,
269
+ trace_data=_build_trace(),
270
+ )
271
  break
272
 
273
  # Process tool calls
 
279
  except json.JSONDecodeError:
280
  args = {}
281
 
282
+ result = await self._dispatch_tool(
283
+ name,
284
+ args,
285
+ retry_counts,
286
+ trace_enabled=trace_enabled,
287
+ trace_contracts=trace_contracts,
288
+ )
289
 
290
  if isinstance(result, PlanStep):
291
+ if trace_enabled:
292
+ trace_steps.append({
293
+ "step": len(trace_steps) + 1,
294
+ "tool": name,
295
+ "input": args,
296
+ "output": {"error": result.message},
297
+ })
298
  if result.type == "error":
299
+ yield PlanStep(
300
+ type="error",
301
+ message=result.message,
302
+ trace_data=_build_trace(),
303
+ )
304
  self._last_messages = messages
305
  return
306
+ yield result
307
  continue
308
 
309
+ if trace_enabled:
310
+ trace_steps.append({
311
+ "step": len(trace_steps) + 1,
312
+ "tool": name,
313
+ "input": args,
314
+ "output": result,
315
+ })
316
+
317
  # Yield status for successful executions
318
  if name == "execute_contract" and not (isinstance(result, dict) and result.get("error")):
319
  yield PlanStep(
 
331
  yield PlanStep(
332
  type="error",
333
  message=f"Exceeded maximum of {MAX_TOTAL_STEPS} tool calls.",
334
+ trace_data=_build_trace(),
335
  )
336
 
337
  self._last_messages = messages
 
349
  name: str,
350
  args: dict[str, Any],
351
  retry_counts: dict[str, int],
352
+ *,
353
+ trace_enabled: bool = False,
354
+ trace_contracts: list[dict[str, Any]] | None = None,
355
  ) -> dict[str, Any] | list[dict[str, Any]] | PlanStep:
356
  """Route a tool call to the appropriate handler."""
357
 
 
375
 
376
  contract = self.provider.contracts[contract_name]
377
  try:
378
+ contract_result = await execute_contract(
379
+ contract,
380
+ instance_nodes,
381
+ self.provider_client,
382
+ trace_enabled=trace_enabled,
383
+ )
384
  retry_counts.pop(contract_name, None)
385
+
386
+ if trace_enabled and trace_contracts is not None and contract_result.trace:
387
+ trace_contracts.append({
388
+ "provider": self.provider_name,
389
+ "action": contract_name,
390
+ "version": contract.metadata.get("version", ""),
391
+ "instance_nodes": instance_nodes,
392
+ **contract_result.trace,
393
+ })
394
+
395
+ return contract_result.response
396
  except ContractError as exc:
397
  retry_counts[contract_name] = retry_counts.get(contract_name, 0) + 1
398
+
399
+ if trace_enabled and trace_contracts is not None:
400
+ entry: dict[str, Any] = {
401
+ "provider": self.provider_name,
402
+ "action": contract_name,
403
+ "version": contract.metadata.get("version", ""),
404
+ "instance_nodes": instance_nodes,
405
+ }
406
+ if exc.trace:
407
+ entry.update(exc.trace)
408
+ else:
409
+ entry.update({
410
+ "validation": {
411
+ "verify_passed": False,
412
+ "properties_passed": False,
413
+ "properties_failed": [],
414
+ "rules_passed": False,
415
+ "rules_failed": [],
416
+ },
417
+ "compiled_payload": None,
418
+ "api_call": None,
419
+ })
420
+ trace_contracts.append(entry)
421
+
422
  if retry_counts[contract_name] > MAX_RETRIES_PER_CONTRACT:
423
  return PlanStep(
424
  type="error",
src/fireaction_a2a/runner.py CHANGED
@@ -6,6 +6,8 @@ then delegates the actual API call to a ProviderClient.
6
 
7
  from __future__ import annotations
8
 
 
 
9
  from typing import Any
10
 
11
  from fireaction.contract import EndpointContract
@@ -13,16 +15,32 @@ from fireaction.contract import EndpointContract
13
  from fireaction_a2a.client import ProviderClient
14
 
15
 
 
 
 
 
 
 
 
 
16
  class ContractError(Exception):
17
  """A contract validation step failed.
18
 
19
  Carries structured detail so the planner can feed the errors back to
20
- the LLM for retry.
 
21
  """
22
 
23
- def __init__(self, stage: str, details: list[Any]) -> None:
 
 
 
 
 
 
24
  self.stage = stage
25
  self.details = details
 
26
  super().__init__(f"Contract validation failed at '{stage}': {details}")
27
 
28
  def to_dict(self) -> dict[str, Any]:
@@ -33,31 +51,74 @@ async def execute_contract(
33
  contract: EndpointContract,
34
  instance_nodes: list[dict[str, Any]],
35
  provider_client: ProviderClient,
36
- ) -> dict[str, Any]:
 
 
37
  """Run the full contract lifecycle and make the API call.
38
 
 
 
 
39
  Raises:
40
- ContractError: If any validation step fails (instantiate, verify,
41
- check_properties, or check_rules). The error is structured so
42
- the planner can serialize it and feed it back to the LLM.
43
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  results = contract.instantiate(instance_nodes)
45
  failed = [r for r in results if not r["success"]]
46
  if failed:
47
- raise ContractError("instantiate", failed)
48
 
49
  verify_errors = contract.verify()
50
  if verify_errors:
51
- raise ContractError("verify", verify_errors)
 
 
52
 
53
  prop_errors = contract.check_properties()
54
  if prop_errors:
55
- raise ContractError("check_properties", prop_errors)
 
 
 
 
56
 
57
  payload = contract.compile()
 
 
58
 
59
  rule_errors = contract.check_rules(payload)
60
  if rule_errors:
61
- raise ContractError("check_rules", rule_errors)
62
-
63
- return await provider_client.call(contract.endpoint_info, payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from __future__ import annotations
8
 
9
+ import time
10
+ from dataclasses import dataclass
11
  from typing import Any
12
 
13
  from fireaction.contract import EndpointContract
 
15
  from fireaction_a2a.client import ProviderClient
16
 
17
 
18
+ @dataclass
19
+ class ContractResult:
20
+ """Result of a contract execution with optional trace data."""
21
+
22
+ response: dict[str, Any]
23
+ trace: dict[str, Any] | None = None
24
+
25
+
26
  class ContractError(Exception):
27
  """A contract validation step failed.
28
 
29
  Carries structured detail so the planner can feed the errors back to
30
+ the LLM for retry. When tracing is enabled, ``trace`` contains partial
31
+ execution data up to the point of failure.
32
  """
33
 
34
+ def __init__(
35
+ self,
36
+ stage: str,
37
+ details: list[Any],
38
+ *,
39
+ trace: dict[str, Any] | None = None,
40
+ ) -> None:
41
  self.stage = stage
42
  self.details = details
43
+ self.trace = trace
44
  super().__init__(f"Contract validation failed at '{stage}': {details}")
45
 
46
  def to_dict(self) -> dict[str, Any]:
 
51
  contract: EndpointContract,
52
  instance_nodes: list[dict[str, Any]],
53
  provider_client: ProviderClient,
54
+ *,
55
+ trace_enabled: bool = False,
56
+ ) -> ContractResult:
57
  """Run the full contract lifecycle and make the API call.
58
 
59
+ When *trace_enabled* is ``True``, captures per-stage validation results,
60
+ the compiled payload, and API call details in ``ContractResult.trace``.
61
+
62
  Raises:
63
+ ContractError: If any validation step fails. When tracing, the
64
+ exception carries partial trace data via its ``trace`` attribute.
 
65
  """
66
+ trace: dict[str, Any] | None = None
67
+ if trace_enabled:
68
+ trace = {
69
+ "validation": {
70
+ "verify_passed": False,
71
+ "properties_passed": False,
72
+ "properties_failed": [],
73
+ "rules_passed": False,
74
+ "rules_failed": [],
75
+ },
76
+ "compiled_payload": None,
77
+ "api_call": None,
78
+ }
79
+
80
  results = contract.instantiate(instance_nodes)
81
  failed = [r for r in results if not r["success"]]
82
  if failed:
83
+ raise ContractError("instantiate", failed, trace=trace)
84
 
85
  verify_errors = contract.verify()
86
  if verify_errors:
87
+ raise ContractError("verify", verify_errors, trace=trace)
88
+ if trace:
89
+ trace["validation"]["verify_passed"] = True
90
 
91
  prop_errors = contract.check_properties()
92
  if prop_errors:
93
+ if trace:
94
+ trace["validation"]["properties_failed"] = [str(e) for e in prop_errors]
95
+ raise ContractError("check_properties", prop_errors, trace=trace)
96
+ if trace:
97
+ trace["validation"]["properties_passed"] = True
98
 
99
  payload = contract.compile()
100
+ if trace:
101
+ trace["compiled_payload"] = payload
102
 
103
  rule_errors = contract.check_rules(payload)
104
  if rule_errors:
105
+ if trace:
106
+ trace["validation"]["rules_failed"] = [str(e) for e in rule_errors]
107
+ raise ContractError("check_rules", rule_errors, trace=trace)
108
+ if trace:
109
+ trace["validation"]["rules_passed"] = True
110
+
111
+ start = time.monotonic()
112
+ response = await provider_client.call(contract.endpoint_info, payload)
113
+ duration_ms = int((time.monotonic() - start) * 1000)
114
+
115
+ if trace:
116
+ trace["api_call"] = {
117
+ "method": contract.endpoint_info.get("method", ""),
118
+ "path": contract.endpoint_info.get("path", ""),
119
+ "status_code": 200,
120
+ "response_body": response,
121
+ "duration_ms": duration_ms,
122
+ }
123
+
124
+ return ContractResult(response=response, trace=trace)
src/fireaction_a2a/server.py CHANGED
@@ -69,6 +69,7 @@ async def _build_provider_app(
69
  llm_model: str,
70
  embedding_model: str,
71
  path_prefix: str = "",
 
72
  ) -> A2AStarletteApplication:
73
  """Build an A2A app for a single provider."""
74
  provider = load_provider(provider_name)
@@ -89,8 +90,11 @@ async def _build_provider_app(
89
  contract_search=contract_search,
90
  llm_model=llm_model,
91
  system_prompt=system_prompt,
 
 
 
 
92
  )
93
- executor = ProviderAgentExecutor(planner)
94
  card = build_card(provider_name, provider, host, port, path_prefix=path_prefix)
95
 
96
  handler = DefaultRequestHandler(
@@ -113,9 +117,11 @@ async def create_app(
113
  llm_model = os.environ.get("FIREACTION_LLM_MODEL", "gpt-4o")
114
  embedding_model = os.environ.get("FIREACTION_EMBEDDING_MODEL", "text-embedding-3-small")
115
  api_key = os.environ.get("FIREACTION_API_KEY", "")
 
116
 
117
  a2a_app = await _build_provider_app(
118
  provider_name, host, port, llm_model, embedding_model,
 
119
  )
120
  app = a2a_app.build()
121
 
@@ -142,6 +148,7 @@ async def create_multi_app(
142
  llm_model = os.environ.get("FIREACTION_LLM_MODEL", "gpt-4o")
143
  embedding_model = os.environ.get("FIREACTION_EMBEDDING_MODEL", "text-embedding-3-small")
144
  api_key = os.environ.get("FIREACTION_API_KEY", "")
 
145
 
146
  mounts: list[Mount] = []
147
  agent_directory: list[dict] = []
@@ -152,6 +159,7 @@ async def create_multi_app(
152
  a2a_app = await _build_provider_app(
153
  name, host, port, llm_model, embedding_model,
154
  path_prefix=prefix,
 
155
  )
156
  sub_app = a2a_app.build()
157
  mounts.append(Mount(prefix, app=sub_app))
 
69
  llm_model: str,
70
  embedding_model: str,
71
  path_prefix: str = "",
72
+ trace_enabled_default: bool = False,
73
  ) -> A2AStarletteApplication:
74
  """Build an A2A app for a single provider."""
75
  provider = load_provider(provider_name)
 
90
  contract_search=contract_search,
91
  llm_model=llm_model,
92
  system_prompt=system_prompt,
93
+ provider_name=provider_name,
94
+ )
95
+ executor = ProviderAgentExecutor(
96
+ planner, trace_enabled_default=trace_enabled_default,
97
  )
 
98
  card = build_card(provider_name, provider, host, port, path_prefix=path_prefix)
99
 
100
  handler = DefaultRequestHandler(
 
117
  llm_model = os.environ.get("FIREACTION_LLM_MODEL", "gpt-4o")
118
  embedding_model = os.environ.get("FIREACTION_EMBEDDING_MODEL", "text-embedding-3-small")
119
  api_key = os.environ.get("FIREACTION_API_KEY", "")
120
+ trace_enabled = os.environ.get("FIREACTION_TRACE_ENABLED", "").lower() in ("1", "true", "yes")
121
 
122
  a2a_app = await _build_provider_app(
123
  provider_name, host, port, llm_model, embedding_model,
124
+ trace_enabled_default=trace_enabled,
125
  )
126
  app = a2a_app.build()
127
 
 
148
  llm_model = os.environ.get("FIREACTION_LLM_MODEL", "gpt-4o")
149
  embedding_model = os.environ.get("FIREACTION_EMBEDDING_MODEL", "text-embedding-3-small")
150
  api_key = os.environ.get("FIREACTION_API_KEY", "")
151
+ trace_enabled = os.environ.get("FIREACTION_TRACE_ENABLED", "").lower() in ("1", "true", "yes")
152
 
153
  mounts: list[Mount] = []
154
  agent_directory: list[dict] = []
 
159
  a2a_app = await _build_provider_app(
160
  name, host, port, llm_model, embedding_model,
161
  path_prefix=prefix,
162
+ trace_enabled_default=trace_enabled,
163
  )
164
  sub_app = a2a_app.build()
165
  mounts.append(Mount(prefix, app=sub_app))
tests/test_agent.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ProviderAgentExecutor: trace flag resolution and artifact emission."""
2
+
3
+ import asyncio
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ from fireaction_a2a.agent import ProviderAgentExecutor
7
+ from fireaction_a2a.planner import PlanStep
8
+
9
+
10
+ # ---- _resolve_trace_flag tests ----
11
+
12
+
13
+ def _make_executor(*, trace_default: bool = False) -> ProviderAgentExecutor:
14
+ mock_planner = MagicMock()
15
+ return ProviderAgentExecutor(mock_planner, trace_enabled_default=trace_default)
16
+
17
+
18
+ def _make_context(*, task_metadata=None, message_metadata=None, has_task=True):
19
+ ctx = MagicMock()
20
+ if has_task:
21
+ ctx.current_task = MagicMock()
22
+ ctx.current_task.metadata = task_metadata
23
+ else:
24
+ ctx.current_task = None
25
+ ctx.message = MagicMock()
26
+ ctx.message.metadata = message_metadata
27
+ return ctx
28
+
29
+
30
+ def test_resolve_trace_flag_default_false():
31
+ executor = _make_executor(trace_default=False)
32
+ ctx = _make_context(task_metadata={})
33
+ assert executor._resolve_trace_flag(ctx) is False
34
+
35
+
36
+ def test_resolve_trace_flag_default_true():
37
+ executor = _make_executor(trace_default=True)
38
+ ctx = _make_context(task_metadata={})
39
+ assert executor._resolve_trace_flag(ctx) is True
40
+
41
+
42
+ def test_resolve_trace_flag_task_metadata_overrides_default():
43
+ executor = _make_executor(trace_default=False)
44
+ ctx = _make_context(task_metadata={"include_trace": True})
45
+ assert executor._resolve_trace_flag(ctx) is True
46
+
47
+
48
+ def test_resolve_trace_flag_task_metadata_disables():
49
+ executor = _make_executor(trace_default=True)
50
+ ctx = _make_context(task_metadata={"include_trace": False})
51
+ assert executor._resolve_trace_flag(ctx) is False
52
+
53
+
54
+ def test_resolve_trace_flag_message_metadata_fallback():
55
+ executor = _make_executor(trace_default=False)
56
+ ctx = _make_context(has_task=False, message_metadata={"include_trace": True})
57
+ assert executor._resolve_trace_flag(ctx) is True
58
+
59
+
60
+ def test_resolve_trace_flag_no_metadata():
61
+ executor = _make_executor(trace_default=False)
62
+ ctx = _make_context(task_metadata=None)
63
+ assert executor._resolve_trace_flag(ctx) is False
64
+
65
+
66
+ # ---- execute() artifact emission tests ----
67
+
68
+
69
+ def test_execute_emits_trace_artifact_when_trace_data_present():
70
+ """When planner yields a completed step with trace_data, agent emits DataPart."""
71
+ mock_planner = MagicMock()
72
+ trace_payload = {
73
+ "execution_trace": {"contracts_called": [], "planner_steps": [],
74
+ "total_api_calls": 0, "total_planner_steps": 0},
75
+ "cost_metrics": {"total_tokens": 100, "prompt_tokens": 60,
76
+ "completion_tokens": 40, "num_llm_calls": 1,
77
+ "total_duration_ms": 500},
78
+ }
79
+
80
+ async def mock_run(user_msg, history, *, trace_enabled=False):
81
+ yield PlanStep(type="completed", message="Done!", trace_data=trace_payload)
82
+
83
+ mock_planner.run = mock_run
84
+ mock_planner.get_messages.return_value = []
85
+
86
+ executor = ProviderAgentExecutor(mock_planner, trace_enabled_default=True)
87
+
88
+ context = MagicMock()
89
+ context.get_user_input.return_value = "send email"
90
+ context.current_task = MagicMock()
91
+ context.current_task.metadata = {}
92
+ context.current_task.id = "task_1"
93
+ context.current_task.context_id = "ctx_1"
94
+ context.message = MagicMock()
95
+
96
+ event_queue = MagicMock()
97
+ event_queue.enqueue_event = AsyncMock()
98
+
99
+ async def _run():
100
+ await executor.execute(context, event_queue)
101
+
102
+ asyncio.run(_run())
103
+
104
+ calls = event_queue.enqueue_event.call_args_list
105
+ artifacts = [
106
+ c for c in calls
107
+ if hasattr(c[0][0], "parts") or "artifact" in str(type(c[0][0])).lower()
108
+ ]
109
+ assert len(calls) >= 3
110
+
111
+
112
+ def test_execute_skips_trace_artifact_when_trace_disabled():
113
+ """When trace_enabled is False, no DataPart artifact is emitted."""
114
+ mock_planner = MagicMock()
115
+
116
+ async def mock_run(user_msg, history, *, trace_enabled=False):
117
+ yield PlanStep(type="completed", message="Done!", trace_data=None)
118
+
119
+ mock_planner.run = mock_run
120
+ mock_planner.get_messages.return_value = []
121
+
122
+ executor = ProviderAgentExecutor(mock_planner, trace_enabled_default=False)
123
+
124
+ context = MagicMock()
125
+ context.get_user_input.return_value = "send email"
126
+ context.current_task = MagicMock()
127
+ context.current_task.metadata = {}
128
+ context.current_task.id = "task_2"
129
+ context.current_task.context_id = "ctx_2"
130
+ context.message = MagicMock()
131
+
132
+ event_queue = MagicMock()
133
+ event_queue.enqueue_event = AsyncMock()
134
+
135
+ async def _run():
136
+ await executor.execute(context, event_queue)
137
+
138
+ asyncio.run(_run())
139
+
140
+ all_events = [c[0][0] for c in event_queue.enqueue_event.call_args_list]
141
+ data_part_events = [
142
+ e for e in all_events
143
+ if hasattr(e, "parts") and any(
144
+ getattr(getattr(p, "root", None), "kind", None) == "data"
145
+ for p in getattr(e, "parts", [])
146
+ )
147
+ ]
148
+ assert len(data_part_events) == 0
tests/test_planner_retry.py CHANGED
@@ -1,7 +1,8 @@
1
  """Tests for planner error recovery and retry limits."""
2
 
3
  import asyncio
4
- from unittest.mock import AsyncMock, patch
 
5
 
6
  from fireaction import load_provider
7
 
@@ -20,6 +21,7 @@ def _make_planner() -> Planner:
20
  contract_search=search,
21
  llm_model="gpt-4o",
22
  system_prompt="Test",
 
23
  )
24
 
25
 
@@ -101,8 +103,171 @@ def test_dispatch_execute_exceeds_retry_limit():
101
  assert "failed validation" in result.message.lower() or "failed" in result.message.lower()
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def test_classify_response():
105
  assert Planner._classify_response("Done! Email sent successfully.") == "completed"
106
  assert Planner._classify_response("Please specify the recipient.") == "input_required"
107
  assert Planner._classify_response("Could you provide the subject?") == "input_required"
108
  assert Planner._classify_response("Error: API returned 500") == "error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tests for planner error recovery and retry limits."""
2
 
3
  import asyncio
4
+ import json
5
+ from unittest.mock import AsyncMock, MagicMock, patch
6
 
7
  from fireaction import load_provider
8
 
 
21
  contract_search=search,
22
  llm_model="gpt-4o",
23
  system_prompt="Test",
24
+ provider_name="resend",
25
  )
26
 
27
 
 
103
  assert "failed validation" in result.message.lower() or "failed" in result.message.lower()
104
 
105
 
106
+ def test_dispatch_execute_contract_with_trace():
107
+ """Successful execute_contract with trace_enabled populates trace_contracts."""
108
+ planner = _make_planner()
109
+ planner.provider_client.call.return_value = {"id": "email_traced"}
110
+ retry_counts: dict[str, int] = {}
111
+ trace_contracts: list[dict] = []
112
+
113
+ from tests.test_runner import _minimal_email_nodes
114
+
115
+ result = asyncio.run(
116
+ planner._dispatch_tool(
117
+ "execute_contract",
118
+ {"contract_name": "send_email", "instance_nodes": _minimal_email_nodes()},
119
+ retry_counts,
120
+ trace_enabled=True,
121
+ trace_contracts=trace_contracts,
122
+ )
123
+ )
124
+ assert isinstance(result, dict)
125
+ assert result == {"id": "email_traced"}
126
+ assert len(trace_contracts) == 1
127
+
128
+ entry = trace_contracts[0]
129
+ assert entry["provider"] == "resend"
130
+ assert entry["action"] == "send_email"
131
+ assert entry["validation"]["verify_passed"] is True
132
+ assert entry["api_call"] is not None
133
+ assert entry["api_call"]["response_body"] == {"id": "email_traced"}
134
+
135
+
136
+ def test_dispatch_execute_contract_error_with_trace():
137
+ """Failed execute_contract with trace_enabled records partial trace."""
138
+ planner = _make_planner()
139
+ retry_counts: dict[str, int] = {}
140
+ trace_contracts: list[dict] = []
141
+
142
+ bad_nodes = [
143
+ {"instance_key": "root", "element_key": "nonexistent",
144
+ "variant_key": "send_email_skeleton", "data_type": "object",
145
+ "compile_key": "root", "description": "bad", "index": 0,
146
+ "parent_instance_key": None},
147
+ ]
148
+ result = asyncio.run(
149
+ planner._dispatch_tool(
150
+ "execute_contract",
151
+ {"contract_name": "send_email", "instance_nodes": bad_nodes},
152
+ retry_counts,
153
+ trace_enabled=True,
154
+ trace_contracts=trace_contracts,
155
+ )
156
+ )
157
+ assert isinstance(result, dict)
158
+ assert result["error"] is True
159
+ assert len(trace_contracts) == 1
160
+
161
+ entry = trace_contracts[0]
162
+ assert entry["provider"] == "resend"
163
+ assert entry["action"] == "send_email"
164
+ assert entry["api_call"] is None
165
+
166
+
167
  def test_classify_response():
168
  assert Planner._classify_response("Done! Email sent successfully.") == "completed"
169
  assert Planner._classify_response("Please specify the recipient.") == "input_required"
170
  assert Planner._classify_response("Could you provide the subject?") == "input_required"
171
  assert Planner._classify_response("Error: API returned 500") == "error"
172
+
173
+
174
+ # ---- Planner run() integration tests with trace ----
175
+
176
+
177
+ def _mock_llm_response(*, tool_calls=None, content=None, finish_reason="stop",
178
+ prompt_tokens=100, completion_tokens=50):
179
+ """Build a mock litellm response."""
180
+ resp = MagicMock()
181
+ msg = MagicMock()
182
+ msg.content = content
183
+ msg.tool_calls = tool_calls
184
+ msg.model_dump.return_value = {
185
+ "role": "assistant",
186
+ "content": content,
187
+ **({"tool_calls": [
188
+ {"id": tc.id, "type": "function",
189
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
190
+ for tc in tool_calls
191
+ ]} if tool_calls else {}),
192
+ }
193
+ choice = MagicMock()
194
+ choice.message = msg
195
+ choice.finish_reason = finish_reason
196
+ resp.choices = [choice]
197
+ resp.usage = MagicMock()
198
+ resp.usage.prompt_tokens = prompt_tokens
199
+ resp.usage.completion_tokens = completion_tokens
200
+ resp.usage.total_tokens = prompt_tokens + completion_tokens
201
+ return resp
202
+
203
+
204
+ def _mock_tool_call(name, arguments, call_id="call_1"):
205
+ tc = MagicMock()
206
+ tc.id = call_id
207
+ tc.function.name = name
208
+ tc.function.arguments = json.dumps(arguments)
209
+ return tc
210
+
211
+
212
+ def test_planner_run_trace_accumulates_steps_and_cost():
213
+ """Full run() with trace collects planner_steps and cost_metrics."""
214
+ planner = _make_planner()
215
+
216
+ search_tc = _mock_tool_call("search_contracts", {"query": "email"}, "call_1")
217
+ resp1 = _mock_llm_response(
218
+ tool_calls=[search_tc], finish_reason="tool_calls",
219
+ prompt_tokens=100, completion_tokens=50,
220
+ )
221
+ resp2 = _mock_llm_response(
222
+ content="Done! Email sent successfully.",
223
+ prompt_tokens=200, completion_tokens=30,
224
+ )
225
+
226
+ steps: list[PlanStep] = []
227
+
228
+ async def _run():
229
+ with patch("fireaction_a2a.planner.litellm") as mock_litellm:
230
+ mock_litellm.acompletion = AsyncMock(side_effect=[resp1, resp2])
231
+ async for step in planner.run("send an email", [], trace_enabled=True):
232
+ steps.append(step)
233
+
234
+ asyncio.run(_run())
235
+
236
+ final = steps[-1]
237
+ assert final.type == "completed"
238
+ assert final.trace_data is not None
239
+
240
+ trace = final.trace_data
241
+ assert "execution_trace" in trace
242
+ assert "cost_metrics" in trace
243
+
244
+ assert len(trace["execution_trace"]["planner_steps"]) == 1
245
+ assert trace["execution_trace"]["planner_steps"][0]["tool"] == "search_contracts"
246
+ assert trace["execution_trace"]["total_planner_steps"] == 1
247
+
248
+ cost = trace["cost_metrics"]
249
+ assert cost["num_llm_calls"] == 2
250
+ assert cost["prompt_tokens"] == 300
251
+ assert cost["completion_tokens"] == 80
252
+ assert cost["total_tokens"] == 380
253
+ assert cost["total_duration_ms"] >= 0
254
+
255
+
256
+ def test_planner_run_without_trace_returns_none():
257
+ """run() with trace_enabled=False yields PlanStep with trace_data=None."""
258
+ planner = _make_planner()
259
+
260
+ resp = _mock_llm_response(content="Done! Email sent.")
261
+
262
+ steps: list[PlanStep] = []
263
+
264
+ async def _run():
265
+ with patch("fireaction_a2a.planner.litellm") as mock_litellm:
266
+ mock_litellm.acompletion = AsyncMock(return_value=resp)
267
+ async for step in planner.run("send email", [], trace_enabled=False):
268
+ steps.append(step)
269
+
270
+ asyncio.run(_run())
271
+
272
+ assert steps[-1].type == "completed"
273
+ assert steps[-1].trace_data is None
tests/test_runner.py CHANGED
@@ -1,11 +1,11 @@
1
  """Tests for the contract runner (lifecycle only, no HTTP calls)."""
2
 
3
  import asyncio
4
- from unittest.mock import AsyncMock
5
 
6
  import fireaction
7
 
8
- from fireaction_a2a.runner import ContractError, execute_contract
9
 
10
 
11
  V = "send_email_skeleton"
@@ -69,7 +69,9 @@ def test_execute_contract_success():
69
 
70
  result = asyncio.run(execute_contract(contract, _minimal_email_nodes(), mock_client))
71
 
72
- assert result == {"id": "email_123"}
 
 
73
  mock_client.call.assert_called_once()
74
  call_args = mock_client.call.call_args
75
  payload = call_args[0][1]
@@ -77,6 +79,38 @@ def test_execute_contract_success():
77
  assert payload["subject"] == "Test"
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def test_execute_contract_instantiate_error():
81
  contract = fireaction.load(provider="resend", action="send_email")
82
  mock_client = AsyncMock()
@@ -99,7 +133,123 @@ def test_execute_contract_instantiate_error():
99
  assert False, "Should have raised ContractError"
100
  except ContractError as e:
101
  assert e.stage == "instantiate"
 
102
  d = e.to_dict()
103
  assert d["error"] is True
104
  assert d["stage"] == "instantiate"
105
  mock_client.call.assert_not_called()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tests for the contract runner (lifecycle only, no HTTP calls)."""
2
 
3
  import asyncio
4
+ from unittest.mock import AsyncMock, MagicMock
5
 
6
  import fireaction
7
 
8
+ from fireaction_a2a.runner import ContractError, ContractResult, execute_contract
9
 
10
 
11
  V = "send_email_skeleton"
 
69
 
70
  result = asyncio.run(execute_contract(contract, _minimal_email_nodes(), mock_client))
71
 
72
+ assert isinstance(result, ContractResult)
73
+ assert result.response == {"id": "email_123"}
74
+ assert result.trace is None
75
  mock_client.call.assert_called_once()
76
  call_args = mock_client.call.call_args
77
  payload = call_args[0][1]
 
79
  assert payload["subject"] == "Test"
80
 
81
 
82
+ def test_execute_contract_success_with_trace():
83
+ contract = fireaction.load(provider="resend", action="send_email")
84
+ mock_client = AsyncMock()
85
+ mock_client.call.return_value = {"id": "email_456"}
86
+
87
+ result = asyncio.run(
88
+ execute_contract(
89
+ contract, _minimal_email_nodes(), mock_client, trace_enabled=True,
90
+ )
91
+ )
92
+
93
+ assert isinstance(result, ContractResult)
94
+ assert result.response == {"id": "email_456"}
95
+ assert result.trace is not None
96
+
97
+ v = result.trace["validation"]
98
+ assert v["verify_passed"] is True
99
+ assert v["properties_passed"] is True
100
+ assert v["properties_failed"] == []
101
+ assert v["rules_passed"] is True
102
+ assert v["rules_failed"] == []
103
+
104
+ assert result.trace["compiled_payload"] is not None
105
+ assert result.trace["compiled_payload"]["from"] == "onboarding@resend.dev"
106
+
107
+ api = result.trace["api_call"]
108
+ assert api is not None
109
+ assert api["status_code"] == 200
110
+ assert api["response_body"] == {"id": "email_456"}
111
+ assert isinstance(api["duration_ms"], int)
112
+
113
+
114
  def test_execute_contract_instantiate_error():
115
  contract = fireaction.load(provider="resend", action="send_email")
116
  mock_client = AsyncMock()
 
133
  assert False, "Should have raised ContractError"
134
  except ContractError as e:
135
  assert e.stage == "instantiate"
136
+ assert e.trace is None
137
  d = e.to_dict()
138
  assert d["error"] is True
139
  assert d["stage"] == "instantiate"
140
  mock_client.call.assert_not_called()
141
+
142
+
143
+ def test_execute_contract_instantiate_error_with_trace():
144
+ contract = fireaction.load(provider="resend", action="send_email")
145
+ mock_client = AsyncMock()
146
+
147
+ bad_nodes = [
148
+ {
149
+ "instance_key": "root",
150
+ "element_key": "nonexistent",
151
+ "variant_key": V,
152
+ "data_type": "object",
153
+ "compile_key": "root",
154
+ "description": "bad",
155
+ "index": 0,
156
+ "parent_instance_key": None,
157
+ }
158
+ ]
159
+
160
+ try:
161
+ asyncio.run(
162
+ execute_contract(contract, bad_nodes, mock_client, trace_enabled=True)
163
+ )
164
+ assert False, "Should have raised ContractError"
165
+ except ContractError as e:
166
+ assert e.stage == "instantiate"
167
+ assert e.trace is not None
168
+ assert e.trace["validation"]["verify_passed"] is False
169
+ assert e.trace["compiled_payload"] is None
170
+ assert e.trace["api_call"] is None
171
+ mock_client.call.assert_not_called()
172
+
173
+
174
+ # ---- Tests using mocked contracts for later-stage failures ----
175
+
176
+
177
+ def _mock_contract():
178
+ """Return a MagicMock contract that passes all stages by default."""
179
+ contract = MagicMock()
180
+ contract.instantiate.return_value = [{"success": True}]
181
+ contract.verify.return_value = []
182
+ contract.check_properties.return_value = []
183
+ contract.compile.return_value = {"from": "a@b.com", "to": ["c@d.com"]}
184
+ contract.check_rules.return_value = []
185
+ contract.endpoint_info = {"method": "POST", "path": "/emails"}
186
+ contract.metadata = {"version": "v1"}
187
+ return contract
188
+
189
+
190
+ def test_verify_failure_with_trace():
191
+ """Verify failure captures verify_passed=False but no later stages."""
192
+ contract = _mock_contract()
193
+ contract.verify.return_value = ["Missing required child: to_array"]
194
+ mock_client = AsyncMock()
195
+
196
+ try:
197
+ asyncio.run(
198
+ execute_contract(contract, [], mock_client, trace_enabled=True)
199
+ )
200
+ assert False, "Should have raised ContractError"
201
+ except ContractError as e:
202
+ assert e.stage == "verify"
203
+ assert e.trace is not None
204
+ assert e.trace["validation"]["verify_passed"] is False
205
+ assert e.trace["validation"]["properties_passed"] is False
206
+ assert e.trace["compiled_payload"] is None
207
+ assert e.trace["api_call"] is None
208
+
209
+
210
+ def test_check_properties_failure_with_trace():
211
+ """Property failure captures verify_passed=True, properties_failed list."""
212
+ contract = _mock_contract()
213
+ contract.check_properties.return_value = ["from must be a valid email"]
214
+ mock_client = AsyncMock()
215
+
216
+ try:
217
+ asyncio.run(
218
+ execute_contract(contract, [], mock_client, trace_enabled=True)
219
+ )
220
+ assert False, "Should have raised ContractError"
221
+ except ContractError as e:
222
+ assert e.stage == "check_properties"
223
+ assert e.trace is not None
224
+ assert e.trace["validation"]["verify_passed"] is True
225
+ assert e.trace["validation"]["properties_passed"] is False
226
+ assert e.trace["validation"]["properties_failed"] == [
227
+ "from must be a valid email"
228
+ ]
229
+ assert e.trace["compiled_payload"] is None
230
+ assert e.trace["api_call"] is None
231
+
232
+
233
+ def test_check_rules_failure_with_trace():
234
+ """Rule failure captures verify+properties passed, compiled payload present."""
235
+ contract = _mock_contract()
236
+ contract.check_rules.return_value = ["rule_html_not_empty"]
237
+ mock_client = AsyncMock()
238
+
239
+ try:
240
+ asyncio.run(
241
+ execute_contract(contract, [], mock_client, trace_enabled=True)
242
+ )
243
+ assert False, "Should have raised ContractError"
244
+ except ContractError as e:
245
+ assert e.stage == "check_rules"
246
+ assert e.trace is not None
247
+ assert e.trace["validation"]["verify_passed"] is True
248
+ assert e.trace["validation"]["properties_passed"] is True
249
+ assert e.trace["validation"]["rules_passed"] is False
250
+ assert e.trace["validation"]["rules_failed"] == ["rule_html_not_empty"]
251
+ assert e.trace["compiled_payload"] == {
252
+ "from": "a@b.com",
253
+ "to": ["c@d.com"],
254
+ }
255
+ assert e.trace["api_call"] is None
tests/test_stripe_flow.py CHANGED
@@ -34,7 +34,7 @@ def test_stripe_product_to_price_to_payment_link():
34
  "index": 1, "primitive_contents": "Pro Plan", "parent_instance_key": "root"},
35
  ]
36
  result = asyncio.run(execute_contract(contract, nodes, mock_client))
37
- product_id = result["id"]
38
  assert product_id == "prod_test123"
39
 
40
  # Step 2: create_price (references product_id)
@@ -62,7 +62,7 @@ def test_stripe_product_to_price_to_payment_link():
62
  "index": 3, "primitive_contents": product_id, "parent_instance_key": "root"},
63
  ]
64
  result = asyncio.run(execute_contract(contract, nodes, mock_client))
65
- price_id = result["id"]
66
  assert price_id == "price_test456"
67
 
68
  # Step 3: create_payment_link (references price_id)
@@ -92,8 +92,8 @@ def test_stripe_product_to_price_to_payment_link():
92
  "index": 4, "primitive_contents": 1, "parent_instance_key": "li_1"},
93
  ]
94
  result = asyncio.run(execute_contract(contract, nodes, mock_client))
95
- assert result["id"] == "plink_test789"
96
- assert result["url"] == "https://buy.stripe.com/test_xxx"
97
 
98
  # Verify the mock was called 3 times total
99
  assert mock_client.call.call_count == 3
 
34
  "index": 1, "primitive_contents": "Pro Plan", "parent_instance_key": "root"},
35
  ]
36
  result = asyncio.run(execute_contract(contract, nodes, mock_client))
37
+ product_id = result.response["id"]
38
  assert product_id == "prod_test123"
39
 
40
  # Step 2: create_price (references product_id)
 
62
  "index": 3, "primitive_contents": product_id, "parent_instance_key": "root"},
63
  ]
64
  result = asyncio.run(execute_contract(contract, nodes, mock_client))
65
+ price_id = result.response["id"]
66
  assert price_id == "price_test456"
67
 
68
  # Step 3: create_payment_link (references price_id)
 
92
  "index": 4, "primitive_contents": 1, "parent_instance_key": "li_1"},
93
  ]
94
  result = asyncio.run(execute_contract(contract, nodes, mock_client))
95
+ assert result.response["id"] == "plink_test789"
96
+ assert result.response["url"] == "https://buy.stripe.com/test_xxx"
97
 
98
  # Verify the mock was called 3 times total
99
  assert mock_client.call.call_count == 3