gMAS / tests /test_scheduler.py
Артём Боярских
chore: initial commit
3193174
"""Tests for src/execution/scheduler.py"""
import pytest
import torch
from execution.scheduler import (
AdaptiveScheduler,
ConditionContext,
ConditionEvaluator,
ExecutionPlan,
ExecutionStep,
PruningConfig,
RoutingPolicy,
StepResult,
build_execution_order,
extract_agent_adjacency,
filter_reachable_agents,
get_incoming_agents,
get_outgoing_agents,
get_parallel_groups,
)
# ─────────────────────────── ConditionContext ───────────────────────────────
class TestConditionContext:
def test_get_last_response_present(self):
ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
messages={"solver": "great answer"},
)
assert ctx.get_last_response() == "great answer"
def test_get_last_response_absent(self):
ctx = ConditionContext(source_agent="solver", target_agent="reviewer")
assert ctx.get_last_response() is None
def test_source_succeeded_no_result(self):
ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
messages={"solver": "ok"},
)
assert ctx.source_succeeded() is True
def test_source_succeeded_not_in_messages(self):
ctx = ConditionContext(source_agent="solver", target_agent="reviewer")
assert ctx.source_succeeded() is False
def test_source_succeeded_with_step_result(self):
result = StepResult(agent_id="solver", success=True)
ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
step_results={"solver": result},
)
assert ctx.source_succeeded() is True
def test_source_failed_with_step_result(self):
result = StepResult(agent_id="solver", success=False)
ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
step_results={"solver": result},
)
assert ctx.source_succeeded() is False
def test_has_keyword_in_source(self):
ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
messages={"solver": "Error: something failed"},
)
assert ctx.has_keyword("error") is True
assert ctx.has_keyword("success") is False
def test_has_keyword_in_target(self):
ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
messages={"reviewer": "looks good"},
)
assert ctx.has_keyword("looks", in_source=False) is True
def test_get_state_value(self):
ctx = ConditionContext(
source_agent="a",
target_agent="b",
state={"status": "ok", "count": 5},
)
assert ctx.get_state_value("status") == "ok"
assert ctx.get_state_value("count") == 5
assert ctx.get_state_value("missing") is None
assert ctx.get_state_value("missing", "default") == "default"
# ─────────────────────────── ConditionEvaluator ───────────────────────────────
class TestConditionEvaluator:
def setup_method(self):
self.eval = ConditionEvaluator()
self.ctx = ConditionContext(
source_agent="solver",
target_agent="reviewer",
messages={"solver": "The answer is 42"},
state={"quality": "high"},
)
def test_none_condition_always_true(self):
assert self.eval.evaluate(None, self.ctx) is True
def test_callable_condition_true(self):
def cond(ctx):
return True
assert self.eval.evaluate(cond, self.ctx) is True
def test_callable_condition_false(self):
def cond(ctx):
return False
assert self.eval.evaluate(cond, self.ctx) is False
def test_callable_condition_exception_returns_false(self):
# Only ValueError, TypeError, KeyError, AttributeError, RuntimeError are caught
def cond(ctx):
return (_ for _ in ()).throw(ValueError("bad"))
assert self.eval.evaluate(cond, self.ctx) is False
def test_builtin_always(self):
assert self.eval.evaluate("always", self.ctx) is True
def test_builtin_never(self):
assert self.eval.evaluate("never", self.ctx) is False
def test_builtin_source_success(self):
# solver is in messages, so source_success should be True
assert self.eval.evaluate("source_success", self.ctx) is True
def test_builtin_source_failed(self):
empty_ctx = ConditionContext(source_agent="solver", target_agent="reviewer")
assert self.eval.evaluate("source_failed", empty_ctx) is True
def test_builtin_has_response(self):
assert self.eval.evaluate("has_response", self.ctx) is True
empty_ctx = ConditionContext(source_agent="solver", target_agent="reviewer")
assert self.eval.evaluate("has_response", empty_ctx) is False
def test_register_custom_condition(self):
self.eval.register("is_42", lambda ctx: "42" in (ctx.get_last_response() or ""))
assert self.eval.evaluate("is_42", self.ctx) is True
def test_unregister_condition(self):
self.eval.register("temp_cond", lambda _: True)
result = self.eval.unregister("temp_cond")
assert result is True
# After unregistering, it falls through to default (True)
# unless it becomes an unknown expr
result2 = self.eval.unregister("nonexistent")
assert result2 is False
def test_get_condition(self):
def cond(ctx):
return True
self.eval.register("my_cond", cond)
retrieved = self.eval.get("my_cond")
assert retrieved is cond
def test_string_contains(self):
result = self.eval.evaluate("contains:42", self.ctx)
assert result is True
result2 = self.eval.evaluate("contains:error", self.ctx)
assert result2 is False
def test_string_not(self):
result = self.eval.evaluate("not:always", self.ctx)
assert result is False
result2 = self.eval.evaluate("not:never", self.ctx)
assert result2 is True
def test_string_state(self):
result = self.eval.evaluate("state:quality=high", self.ctx)
assert result is True
result2 = self.eval.evaluate("state:quality=low", self.ctx)
assert result2 is False
def test_string_state_key_exists(self):
result = self.eval.evaluate("state:quality", self.ctx)
assert result is True
result2 = self.eval.evaluate("state:missing_key", self.ctx)
assert result2 is False
def test_compose_and(self):
composed = self.eval.compose_and("always", "has_response")
assert composed(self.ctx) is True
empty_ctx = ConditionContext(source_agent="solver", target_agent="reviewer")
assert composed(empty_ctx) is False # has_response fails
def test_compose_or(self):
composed = self.eval.compose_or("never", "always")
assert composed(self.ctx) is True
never_composed = self.eval.compose_or("never", "never")
assert never_composed(self.ctx) is False
def test_unknown_string_returns_true(self):
# Unknown conditions fall back to True
result = self.eval.evaluate("unknown_condition_xyz", self.ctx)
assert result is True
# ─────────────────────────── ExecutionPlan ────────────────────────────────────
class TestExecutionPlan:
def _make_plan(self, agents=("a", "b", "c")):
steps = [ExecutionStep(agent_id=aid, predecessors=[]) for aid in agents]
return ExecutionPlan(steps=steps)
def test_initial_state(self):
plan = self._make_plan()
assert not plan.is_complete
assert plan.current_index == 0
def test_get_current_step(self):
plan = self._make_plan(["a", "b"])
step = plan.get_current_step()
assert step.agent_id == "a"
def test_mark_completed_advances(self):
plan = self._make_plan(["a", "b"])
plan.mark_completed("a", tokens=100)
assert "a" in plan.completed
assert plan.tokens_used == 100
assert plan.current_index == 1
def test_mark_failed_advances(self):
plan = self._make_plan(["a", "b"])
plan.mark_failed("a")
assert "a" in plan.failed
assert plan.current_index == 1
def test_mark_skipped_advances(self):
plan = self._make_plan(["a", "b"])
plan.mark_skipped("a")
assert "a" in plan.skipped
assert plan.current_index == 1
def test_is_complete_after_all(self):
plan = self._make_plan(["a"])
plan.mark_completed("a")
assert plan.is_complete
def test_remaining_steps(self):
plan = self._make_plan(["a", "b", "c"])
plan.mark_completed("a")
remaining = plan.remaining_steps
assert [s.agent_id for s in remaining] == ["b", "c"]
def test_remaining_steps_excludes_skipped(self):
plan = self._make_plan(["a", "b", "c"])
plan.skipped.add("b")
remaining = plan.remaining_steps
assert all(s.agent_id != "b" for s in remaining)
def test_execution_order(self):
plan = self._make_plan(["a", "b", "c"])
assert plan.execution_order == ["a", "b", "c"]
def test_insert_fallback(self):
plan = self._make_plan(["a", "b"])
plan.insert_fallback("fallback", after_index=0)
assert plan.steps[1].agent_id == "fallback"
assert plan.steps[1].is_optional
def test_insert_fallback_skipped(self):
plan = self._make_plan(["a", "b"])
plan.skipped.add("fallback")
plan.insert_fallback("fallback", after_index=0)
# Should not insert since fallback is skipped
assert len(plan.steps) == 2
def test_insert_conditional_step(self):
plan = self._make_plan(["a", "b"])
result = plan.insert_conditional_step("a", predecessors=["b"])
assert result is True
assert plan.steps[-1].agent_id == "a"
def test_insert_conditional_step_exceeds_max(self):
plan = self._make_plan(["a"])
plan.max_iterations = 2
plan.iteration_count["a"] = 2
result = plan.insert_conditional_step("a")
assert result is False
def test_can_iterate(self):
plan = ExecutionPlan(max_iterations=3)
plan.iteration_count["a"] = 2
assert plan.can_iterate("a") is True
plan.iteration_count["a"] = 3
assert plan.can_iterate("a") is False
def test_get_current_step_complete(self):
plan = ExecutionPlan()
assert plan.get_current_step() is None
# ─────────────────────────── Graph utility functions ──────────────────────────
def make_adj(n, edges, weight=1.0):
"""Build an n×n adjacency matrix with specified edges."""
a = torch.zeros(n, n)
for i, j in edges:
a[i, j] = weight
return a
class TestExtractAgentAdjacency:
def test_removes_task_row_col(self):
# 3×3 matrix, task_idx=0
a = make_adj(3, [(0, 1), (1, 2), (0, 2)])
result = extract_agent_adjacency(a, task_idx=0)
assert result.shape == (2, 2)
def test_removes_middle_task(self):
a = make_adj(3, [(0, 1), (1, 2)])
result = extract_agent_adjacency(a, task_idx=1)
assert result.shape == (2, 2)
class TestGetIncomingOutgoingAgents:
def test_get_incoming(self):
# a→b, a→c, b→c
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (0, 2), (1, 2)])
incoming_c = get_incoming_agents("c", a, ids, threshold=0.5)
assert "a" in incoming_c
assert "b" in incoming_c
def test_get_incoming_empty(self):
ids = ["a", "b"]
a = make_adj(2, [(0, 1)])
incoming_a = get_incoming_agents("a", a, ids, threshold=0.5)
assert incoming_a == []
def test_get_incoming_not_in_ids(self):
ids = ["a", "b"]
a = make_adj(2, [(0, 1)])
result = get_incoming_agents("unknown", a, ids)
assert result == []
def test_get_outgoing(self):
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (0, 2)])
outgoing_a = get_outgoing_agents("a", a, ids, threshold=0.5)
assert "b" in outgoing_a
assert "c" in outgoing_a
def test_get_outgoing_not_in_ids(self):
ids = ["a", "b"]
a = make_adj(2, [(0, 1)])
result = get_outgoing_agents("unknown", a, ids)
assert result == []
class TestFilterReachableAgents:
def test_linear_chain(self):
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
relevant, excluded = filter_reachable_agents(a, ids, "a", "c")
assert set(relevant) == {"a", "b", "c"}
assert excluded == []
def test_isolated_node(self):
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1)]) # c is isolated
relevant, excluded = filter_reachable_agents(a, ids, "a", "b")
assert "a" in relevant
assert "b" in relevant
assert "c" in excluded
def test_no_start_end(self):
ids = ["a", "b"]
a = make_adj(2, [(0, 1)])
relevant, _excluded = filter_reachable_agents(a, ids)
# Without start/end, should still work
assert isinstance(relevant, list)
def test_empty_ids(self):
a = torch.zeros(0, 0)
relevant, excluded = filter_reachable_agents(a, [])
assert relevant == []
assert excluded == []
def test_start_not_in_graph(self):
ids = ["a", "b"]
a = make_adj(2, [(0, 1)])
relevant, _excluded = filter_reachable_agents(a, ids, "unknown", "b")
assert isinstance(relevant, list)
class TestBuildExecutionOrder:
def test_simple_chain(self):
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
order = build_execution_order(a, ids)
assert order.index("a") < order.index("b") < order.index("c")
def test_empty(self):
a = torch.zeros(0, 0)
result = build_execution_order(a, [])
assert result == []
def test_size_mismatch_raises(self):
a = make_adj(3, [])
with pytest.raises(ValueError, match="a_agents size"):
build_execution_order(a, ["a", "b"])
def test_with_start_agent(self):
ids = ["b", "a", "c"]
a = make_adj(3, [(0, 2), (1, 0)]) # b→c, a→b
order = build_execution_order(a, ids, start_agent="a")
assert order[0] == "a"
def test_cyclic_graph(self):
ids = ["a", "b"]
# a→b, b→a (cycle)
a = make_adj(2, [(0, 1), (1, 0)])
order = build_execution_order(a, ids)
assert set(order) == {"a", "b"}
class TestGetParallelGroups:
def test_linear_chain(self):
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
groups = get_parallel_groups(a, ids)
# Should be executed one by one
assert len(groups) == 3
def test_two_parallel_then_merge(self):
# a→c, b→c: a and b can run in parallel
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 2), (1, 2)])
groups = get_parallel_groups(a, ids)
assert len(groups) == 2
assert set(groups[0]) == {"a", "b"}
assert groups[1] == ["c"]
def test_empty(self):
a = torch.zeros(0, 0)
groups = get_parallel_groups(a, [])
assert groups == []
# ─────────────────────────── AdaptiveScheduler ────────────────────────────────
def make_chain_matrix(ids):
"""Make a simple linear chain adjacency matrix."""
n = len(ids)
a = torch.zeros(n, n)
for i in range(n - 1):
a[i, i + 1] = 1.0
return a
class TestAdaptiveScheduler:
def test_default_policy(self):
sched = AdaptiveScheduler()
assert sched.policy == RoutingPolicy.TOPOLOGICAL
def test_build_plan_empty(self):
sched = AdaptiveScheduler()
plan = sched.build_plan(torch.zeros(0, 0), [])
assert plan.is_complete
def test_build_plan_topological(self):
ids = ["a", "b", "c"]
a = make_chain_matrix(ids)
sched = AdaptiveScheduler(policy=RoutingPolicy.TOPOLOGICAL)
plan = sched.build_plan(a, ids)
assert len(plan.steps) == 3
assert plan.execution_order.index("a") < plan.execution_order.index("c")
def test_build_plan_greedy(self):
ids = ["a", "b", "c"]
a = make_chain_matrix(ids)
sched = AdaptiveScheduler(policy=RoutingPolicy.GREEDY)
plan = sched.build_plan(a, ids)
assert len(plan.steps) >= 1
def test_build_plan_weighted_topo(self):
ids = ["a", "b", "c"]
a = make_chain_matrix(ids)
sched = AdaptiveScheduler(policy=RoutingPolicy.WEIGHTED_TOPO)
plan = sched.build_plan(a, ids)
assert len(plan.steps) >= 1
def test_build_plan_beam_search(self):
ids = ["a", "b", "c"]
a = make_chain_matrix(ids)
sched = AdaptiveScheduler(policy=RoutingPolicy.BEAM_SEARCH, beam_width=2)
plan = sched.build_plan(a, ids)
assert len(plan.steps) >= 1
def test_build_plan_k_shortest(self):
ids = ["a", "b", "c"]
a = make_chain_matrix(ids)
sched = AdaptiveScheduler(policy=RoutingPolicy.K_SHORTEST, k_paths=2)
plan = sched.build_plan(a, ids)
assert len(plan.steps) >= 1
def test_build_plan_with_start_end(self):
ids = ["a", "b", "c", "d"]
a = make_chain_matrix(ids)
sched = AdaptiveScheduler()
plan = sched.build_plan(a, ids, start_agent="a", end_agent="c")
# Only a, b, c should be included
assert "d" not in plan.execution_order or "d" in plan.skipped
def test_build_plan_with_p_matrix(self):
ids = ["a", "b", "c"]
a = make_chain_matrix(ids)
p = torch.eye(3) * 0.9
p[0, 1] = 0.8
p[1, 2] = 0.9
sched = AdaptiveScheduler()
plan = sched.build_plan(a, ids, p_matrix=p)
assert len(plan.steps) >= 1
def test_build_plan_with_edge_conditions(self):
ids = ["a", "b"]
a = make_adj(2, [(0, 1)])
conditions = {("a", "b"): lambda _: True}
sched = AdaptiveScheduler()
plan = sched.build_plan(a, ids, edge_conditions=conditions)
assert len(plan.steps) >= 1
def test_evaluate_edge_condition_callable(self):
sched = AdaptiveScheduler()
ctx = ConditionContext(source_agent="a", target_agent="b")
def cond(c):
return True
assert sched.evaluate_edge_condition("a", "b", cond, ctx) is True
def test_evaluate_edge_condition_none(self):
sched = AdaptiveScheduler()
ctx = ConditionContext(source_agent="a", target_agent="b")
assert sched.evaluate_edge_condition("a", "b", None, ctx) is True
def test_evaluate_edge_condition_string(self):
sched = AdaptiveScheduler()
ctx = ConditionContext(source_agent="a", target_agent="b")
assert sched.evaluate_edge_condition("a", "b", "always", ctx) is True
def test_pruning_config(self):
config = PruningConfig(min_weight_threshold=0.3, max_consecutive_errors=5)
sched = AdaptiveScheduler(pruning_config=config)
assert sched.pruning.min_weight_threshold == 0.3
def test_filter_unreachable_false(self):
ids = ["a", "b", "c"]
# c is not reachable from a
a = make_adj(3, [(0, 1)])
sched = AdaptiveScheduler()
plan = sched.build_plan(a, ids, start_agent="a", end_agent="b", filter_unreachable=False)
# All agents should be in plan since filter is off
assert len(plan.steps) == 3
class TestSchedulerMissingCoverage:
"""Tests for missing lines in execution/scheduler.py."""
def test_execution_plan_skipped_agent(self):
"""ExecutionPlan.insert_conditional_step returns False for skipped agents (line 420)."""
plan = ExecutionPlan()
plan.skipped.add("agent_a")
result = plan.insert_conditional_step("agent_a")
assert result is False # Skipped agent should not be added
def test_evaluate_string_condition_method(self):
"""_evaluate_string_condition is called via evaluate (line 204)."""
evaluator = ConditionEvaluator()
ctx = ConditionContext(
source_agent="a",
target_agent="b",
messages={"a": "good result"},
)
# A string condition should be evaluated
result = evaluator.evaluate("always", ctx)
assert isinstance(result, bool)
def test_filter_reachable_no_zero_in_degree(self):
"""filter_reachable_agents when all nodes have in-degree > 0 (line 536)."""
ids = ["a", "b"]
# Cycle: a→b, b→a (all have in-degree > 0)
a = make_adj(2, [(0, 1), (1, 0)])
# No start_agent specified, so it tries to find 0 in-degree node, fails, falls back to [0]
relevant, _excluded = filter_reachable_agents(a, ids)
assert len(relevant) >= 0 # Should not raise
def test_filter_reachable_no_zero_out_degree(self):
"""filter_reachable_agents when all nodes have out-degree > 0 (line 548 fallback)."""
ids = ["a", "b", "c"]
# All have outgoing edges: a→b, b→c, c→a
a = make_adj(3, [(0, 1), (1, 2), (2, 0)])
relevant, _excluded = filter_reachable_agents(a, ids)
assert len(relevant) >= 0 # Should not raise
def test_build_execution_order_with_cycle(self):
"""build_execution_order with cyclic graph uses SCC (lines 644-645)."""
ids = ["a", "b", "c"]
# Cycle: a→b, b→c, c→a
a = make_adj(3, [(0, 1), (1, 2), (2, 0)])
order = build_execution_order(a, ids, start_agent="a")
assert set(order) == {"a", "b", "c"}
def test_adaptive_scheduler_topological_with_cycle(self):
"""AdaptiveScheduler build_plan with cycle graph (lines 1000-1007 DAGHasCycle fallback)."""
ids = ["a", "b", "c"]
# Cycle: a→b, b→c, c→a
a = make_adj(3, [(0, 1), (1, 2), (2, 0)])
sched = AdaptiveScheduler(policy=RoutingPolicy.TOPOLOGICAL)
plan = sched.build_plan(a, ids)
# Should handle cycle without error
assert {s.agent_id for s in plan.steps} == {"a", "b", "c"}
def test_should_prune_weight_threshold(self):
"""should_prune returns True when step weight below threshold (line 922)."""
pruning = PruningConfig(min_weight_threshold=0.5)
sched = AdaptiveScheduler(pruning_config=pruning)
plan = ExecutionPlan()
step = ExecutionStep(agent_id="a", predecessors=[], weight=0.1, probability=1.0)
prune, reason = sched.should_prune(step, plan)
assert prune is True
assert "weight" in reason
def test_should_prune_probability_threshold(self):
"""should_prune returns True when step probability below threshold (line 924-928)."""
pruning = PruningConfig(min_probability_threshold=0.5)
sched = AdaptiveScheduler(pruning_config=pruning)
plan = ExecutionPlan()
step = ExecutionStep(agent_id="a", predecessors=[], weight=1.0, probability=0.1)
prune, reason = sched.should_prune(step, plan)
assert prune is True
assert "probability" in reason
def test_should_prune_token_budget_exhausted(self):
"""should_prune returns True when token budget exhausted (line 930-934)."""
pruning = PruningConfig(token_budget=100)
sched = AdaptiveScheduler(pruning_config=pruning)
plan = ExecutionPlan()
plan.tokens_used = 200 # Over budget
step = ExecutionStep(agent_id="a", predecessors=[], weight=1.0, probability=1.0)
prune, reason = sched.should_prune(step, plan)
assert prune is True
assert "token" in reason.lower()
def test_should_prune_consecutive_errors(self):
"""should_prune returns True when too many consecutive errors (line 937-938)."""
pruning = PruningConfig(max_consecutive_errors=2)
sched = AdaptiveScheduler(pruning_config=pruning)
plan = ExecutionPlan()
# Add 2 failed steps
plan.steps = [
ExecutionStep(agent_id="x", predecessors=[]),
ExecutionStep(agent_id="y", predecessors=[]),
]
plan.failed.add("x")
plan.failed.add("y")
plan.current_index = 2
step = ExecutionStep(agent_id="a", predecessors=[], weight=1.0, probability=1.0)
prune, reason = sched.should_prune(step, plan)
assert prune is True
assert "error" in reason.lower()
def test_should_prune_predecessor_failure(self):
"""should_prune returns True when predecessor failed (line 940-944)."""
pruning = PruningConfig(skip_on_predecessor_failure=True, enable_fallback=False)
sched = AdaptiveScheduler(pruning_config=pruning)
plan = ExecutionPlan()
plan.failed.add("b")
step = ExecutionStep(
agent_id="a",
predecessors=["b"],
weight=1.0,
probability=1.0,
is_optional=False,
fallback_agents=[],
)
prune, reason = sched.should_prune(step, plan)
assert prune is True
assert "predecessors failed" in reason
def test_should_use_fallback_disabled(self):
"""should_use_fallback returns False when disabled (line 955-956)."""
pruning = PruningConfig(enable_fallback=False)
sched = AdaptiveScheduler(pruning_config=pruning)
step = ExecutionStep(agent_id="a", predecessors=[], fallback_agents=["b"])
result_obj = StepResult(agent_id="a", success=False, quality_score=0.0)
assert sched.should_use_fallback(step, result_obj, 0) is False
def test_should_use_fallback_max_attempts(self):
"""should_use_fallback returns False when max fallback attempts reached (line 957-958)."""
pruning = PruningConfig(enable_fallback=True, max_fallback_attempts=2)
sched = AdaptiveScheduler(pruning_config=pruning)
step = ExecutionStep(agent_id="a", predecessors=[], fallback_agents=["b"])
result_obj = StepResult(agent_id="a", success=False, quality_score=0.0)
assert sched.should_use_fallback(step, result_obj, 3) is False # Over limit
def test_should_use_fallback_no_fallback_agents(self):
"""should_use_fallback returns False when no fallback agents (line 959-960)."""
pruning = PruningConfig(enable_fallback=True)
sched = AdaptiveScheduler(pruning_config=pruning)
step = ExecutionStep(agent_id="a", predecessors=[], fallback_agents=[])
result_obj = StepResult(agent_id="a", success=False)
assert sched.should_use_fallback(step, result_obj, 0) is False
def test_should_use_fallback_on_failure(self):
"""should_use_fallback returns True when step fails and has fallback (line 961-962)."""
pruning = PruningConfig(enable_fallback=True)
sched = AdaptiveScheduler(pruning_config=pruning)
step = ExecutionStep(agent_id="a", predecessors=[], fallback_agents=["b"])
result_obj = StepResult(agent_id="a", success=False)
assert sched.should_use_fallback(step, result_obj, 0) is True
def test_greedy_order_basic(self):
"""_greedy_order basic execution (various lines)."""
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
sched = AdaptiveScheduler(policy=RoutingPolicy.GREEDY)
plan = sched.build_plan(a, ids)
agent_ids = [s.agent_id for s in plan.steps]
assert "a" in agent_ids
assert "b" in agent_ids
assert "c" in agent_ids
def test_greedy_order_with_end_agent(self):
"""_greedy_order with end_agent stops early (line 1062-1063)."""
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
sched = AdaptiveScheduler(policy=RoutingPolicy.GREEDY)
plan = sched.build_plan(a, ids, start_agent="a", end_agent="b")
# Should stop at b or process all in order
assert len(plan.steps) >= 1
def test_beam_search_order(self):
"""_beam_search_order basic execution (various lines)."""
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
sched = AdaptiveScheduler(policy=RoutingPolicy.BEAM_SEARCH, beam_width=2)
plan = sched.build_plan(a, ids)
assert len(plan.steps) == 3
def test_beam_search_with_end_agent(self):
"""_beam_search_order with end_agent (lines 1109-1115)."""
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
sched = AdaptiveScheduler(policy=RoutingPolicy.BEAM_SEARCH)
plan = sched.build_plan(a, ids, start_agent="a", end_agent="b")
assert len(plan.steps) >= 1
def test_k_shortest_order(self):
"""_k_shortest_order basic execution."""
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
sched = AdaptiveScheduler(policy=RoutingPolicy.K_SHORTEST)
plan = sched.build_plan(a, ids)
assert len(plan.steps) == 3
def test_get_parallel_groups_with_deadlock(self):
"""get_parallel_groups when nothing is ready (lines 693-697)."""
ids = ["a", "b"]
# Both have mutual incoming edges - deadlock scenario
a = make_adj(2, [(0, 1), (1, 0)])
groups = get_parallel_groups(a, ids)
# Should handle by forcing at least one group with one agent
assert len(groups) >= 1
def test_build_plan_with_p_matrix(self):
"""build_plan with p_matrix provided (line 812)."""
ids = ["a", "b", "c"]
a = make_adj(3, [(0, 1), (1, 2)])
p = torch.ones(3, 3) * 0.5
sched = AdaptiveScheduler(policy=RoutingPolicy.TOPOLOGICAL)
# Should use p_matrix for filtering
plan = sched.build_plan(a, ids, p_matrix=p, start_agent="a", end_agent="c")
assert len(plan.steps) > 0