Spaces:
Running
Running
File size: 8,139 Bytes
3193174 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | from datetime import datetime, timedelta
import pytest
from execution.errors import (
AgentNotFoundError,
BudgetExceededError,
ErrorAction,
ErrorPolicy,
ExecutionError,
ExecutionMetrics,
RetryExhaustedError,
StepExecutionResult,
TimeoutError, # noqa: A004
ValidationError,
)
class TestExecutionError:
"""Constructor and serialization of the base error class."""
def test_full_init_and_to_dict(self):
"""All constructor parameters and to_dict with cause."""
cause = ValueError("oops")
err = ExecutionError("fail", agent_id="a1", step_index=2, cause=cause, recoverable=False)
assert str(err) == "fail"
assert err.agent_id == "a1"
assert err.step_index == 2
assert err.cause is cause
assert err.recoverable is False
assert isinstance(err.timestamp, datetime)
d = err.to_dict()
assert d["type"] == "ExecutionError"
assert d["message"] == "fail"
assert d["agent_id"] == "a1"
assert d["step_index"] == 2
assert d["recoverable"] is False
assert d["cause"] == "oops"
assert d["timestamp"]
def test_defaults_and_no_cause(self):
"""Default values and to_dict without cause."""
err = ExecutionError("msg")
assert err.agent_id is None
assert err.step_index is None
assert err.cause is None
assert err.recoverable is True
assert err.to_dict()["cause"] is None
class TestErrorSubclasses:
"""ExecutionError subclasses — constructors and inheritance."""
def test_timeout_error(self):
"""TimeoutError: timeout seconds, recoverable=True."""
err = TimeoutError("a1", 30.0, step_index=1)
assert isinstance(err, ExecutionError)
assert err.timeout_seconds == 30.0
assert err.agent_id == "a1"
assert err.step_index == 1
assert err.recoverable is True
assert "timed out" in str(err)
def test_retry_exhausted_error(self):
"""RetryExhaustedError: exhausted retry attempts, recoverable=False."""
cause = RuntimeError("boom")
err = RetryExhaustedError("a2", 3, last_error=cause, step_index=5)
assert isinstance(err, ExecutionError)
assert err.attempts == 3
assert err.last_error is cause
assert err.cause is cause
assert err.recoverable is False
assert "3 attempts" in str(err)
def test_budget_exceeded_error(self):
"""BudgetExceededError: budget limit exceeded."""
err = BudgetExceededError("token", 1000.0, 1500.0, agent_id="a3")
assert isinstance(err, ExecutionError)
assert err.budget_type == "token"
assert err.limit == 1000.0
assert err.used == 1500.0
assert err.recoverable is False
assert "1500.0/1000.0" in str(err)
def test_agent_not_found_error(self):
"""AgentNotFoundError: agent not found."""
err = AgentNotFoundError("missing")
assert isinstance(err, ExecutionError)
assert err.agent_id == "missing"
assert err.recoverable is False
def test_validation_error(self):
"""ValidationError: invalid field and value."""
err = ValidationError("bad input", field="name", value=42)
assert isinstance(err, ExecutionError)
assert err.field == "name"
assert err.value == 42
assert err.recoverable is False
class TestErrorPolicy:
"""Error routing via get_action."""
def test_default_actions(self):
"""Default policy action for each error type."""
policy = ErrorPolicy()
assert policy.get_action(TimeoutError("a", 1.0)) == ErrorAction.RETRY
assert policy.get_action(RetryExhaustedError("a", 1)) == ErrorAction.PRUNE
assert policy.get_action(BudgetExceededError("t", 1, 2)) == ErrorAction.ABORT
assert policy.get_action(AgentNotFoundError("a")) == ErrorAction.SKIP
assert policy.get_action(ValidationError("v")) == ErrorAction.ABORT
assert policy.get_action(ExecutionError("unknown")) == ErrorAction.SKIP
def test_custom_policy(self):
"""Overridden timeout action."""
policy = ErrorPolicy(on_timeout=ErrorAction.ABORT)
assert policy.get_action(TimeoutError("a", 1.0)) == ErrorAction.ABORT
def test_model_fields(self):
"""Default values for max_skipped_agents and abort_on_critical_path."""
policy = ErrorPolicy()
assert policy.max_skipped_agents == 5
assert policy.abort_on_critical_path is True
class TestExecutionMetrics:
"""Metrics accumulation, properties, and serialization."""
def test_add_request_success(self):
"""Successful request — tokens and counters."""
m = ExecutionMetrics()
m.add_request(100, 50, success=True, latency_ms=200.0)
assert m.prompt_tokens == 100
assert m.completion_tokens == 50
assert m.total_tokens == 150
assert m.total_requests == 1
assert m.successful_requests == 1
assert m.failed_requests == 0
assert m.retried_requests == 0
def test_add_request_failed_and_retried(self):
"""Failed request with retry."""
m = ExecutionMetrics()
m.add_request(10, 5, success=False, latency_ms=50.0, retried=True)
assert m.failed_requests == 1
assert m.retried_requests == 1
assert m.successful_requests == 0
def test_success_rate(self):
"""success_rate: 0.0 when empty, 0.5 for one success out of two."""
m = ExecutionMetrics()
assert m.success_rate == 0.0
m.add_request(1, 1, success=True, latency_ms=1.0)
m.add_request(1, 1, success=False, latency_ms=1.0)
assert m.success_rate == 0.5
def test_duration_seconds(self):
"""duration_seconds derived from start_time and end_time."""
now = datetime.now()
m = ExecutionMetrics(start_time=now, end_time=now + timedelta(seconds=3))
assert m.duration_seconds == pytest.approx(3.0)
def test_duration_seconds_no_times(self):
"""duration_seconds is 0 when no timestamps are set."""
assert ExecutionMetrics().duration_seconds == 0.0
def test_to_dict_structure(self):
"""Dict structure: tokens, requests, agents, timing."""
m = ExecutionMetrics(total_agents=10, executed_agents=8, skipped_agents=1, failed_agents=1)
m.add_request(100, 50, success=True, latency_ms=200.0)
d = m.to_dict()
assert d["tokens"]["total"] == 150
assert d["requests"]["success_rate"] == 1.0
assert d["agents"]["total"] == 10
assert d["timing"]["avg_latency_ms"] == 200.0
class TestStepExecutionResult:
"""Step execution result — properties and serialization."""
def test_tokens_used(self):
"""tokens_used equals prompt + completion tokens."""
r = StepExecutionResult(agent_id="a1", success=True, prompt_tokens=100, completion_tokens=50)
assert r.tokens_used == 150
def test_to_dict_success(self):
"""to_dict of a successful step with response."""
r = StepExecutionResult(agent_id="a1", success=True, response="hello", latency_ms=10.0)
d = r.to_dict()
assert d["agent_id"] == "a1"
assert d["success"] is True
assert d["response_length"] == 5
assert d["error"] is None
assert d["status"]["skipped"] is False
def test_to_dict_with_error(self):
"""to_dict with a nested error object."""
err = ExecutionError("oops", agent_id="a1")
r = StepExecutionResult(agent_id="a1", success=False, error=err)
d = r.to_dict()
assert d["error"]["type"] == "ExecutionError"
assert d["response_length"] == 0
def test_defaults(self):
"""Default values for all optional fields."""
r = StepExecutionResult(agent_id="x", success=True)
assert r.response is None
assert r.error is None
assert r.attempts == 1
assert r.quality_score == 1.0
assert r.skipped is False
assert r.fallback_used is False
|