AgentGraph / tests /unit /test_concurrent_executor.py
wu981526092's picture
Add comprehensive perturbation testing system with E2E tests
795b72e
"""
Unit tests for agentgraph/testing/concurrent_executor.py
Tests concurrent execution, retry logic, and rate limiting.
"""
import pytest
import asyncio
import time
from unittest.mock import MagicMock, patch
from agentgraph.testing.concurrent_executor import (
ConcurrentTestExecutor,
AsyncConcurrentExecutor,
create_executor,
)
class TestConcurrentTestExecutorInit:
"""Tests for ConcurrentTestExecutor initialization."""
def test_default_values(self):
"""Test default initialization values."""
executor = ConcurrentTestExecutor()
assert executor.max_workers == 5
assert executor.max_retries == 3
assert executor.base_delay == 1.0
assert executor.max_delay == 60.0
assert executor.rate_limit_per_minute == 60
def test_custom_values(self):
"""Test custom initialization values."""
executor = ConcurrentTestExecutor(
max_workers=10,
max_retries=5,
base_delay=2.0,
max_delay=120.0,
rate_limit_per_minute=100
)
assert executor.max_workers == 10
assert executor.max_retries == 5
assert executor.base_delay == 2.0
assert executor.max_delay == 120.0
assert executor.rate_limit_per_minute == 100
class TestShouldRetry:
"""Tests for retry decision logic."""
def test_rate_limit_errors_should_retry(self):
"""Test that rate limit errors trigger retry."""
executor = ConcurrentTestExecutor()
retryable_errors = [
Exception("rate limit exceeded"),
Exception("429 Too Many Requests"),
Exception("Rate_limit error"),
Exception("timeout waiting for response"),
Exception("connection refused"),
Exception("server error 500"),
]
for error in retryable_errors:
assert executor._should_retry(error) is True, f"Should retry: {error}"
def test_non_retryable_errors(self):
"""Test that certain errors don't trigger retry."""
executor = ConcurrentTestExecutor()
non_retryable_errors = [
Exception("Invalid API key"),
Exception("File not found"),
Exception("Permission denied"),
ValueError("Invalid input"),
]
for error in non_retryable_errors:
assert executor._should_retry(error) is False, f"Should not retry: {error}"
class TestExecuteWithRetry:
"""Tests for execute_with_retry method."""
def test_successful_execution(self):
"""Test successful execution without retry."""
executor = ConcurrentTestExecutor(max_retries=3)
def success_func():
return "success"
result = executor.execute_with_retry(success_func)
assert result == "success"
def test_successful_execution_with_args(self):
"""Test execution with arguments."""
executor = ConcurrentTestExecutor()
def add(a, b):
return a + b
result = executor.execute_with_retry(add, 2, 3)
assert result == 5
def test_retry_on_rate_limit(self):
"""Test retry behavior on rate limit error."""
executor = ConcurrentTestExecutor(
max_retries=3,
base_delay=0.01, # Fast retry for testing
max_delay=0.1
)
call_count = [0]
def flaky_func():
call_count[0] += 1
if call_count[0] < 3:
raise Exception("rate limit exceeded")
return "success after retry"
result = executor.execute_with_retry(flaky_func)
assert result == "success after retry"
assert call_count[0] == 3
def test_max_retries_exceeded(self):
"""Test that exception is raised after max retries."""
executor = ConcurrentTestExecutor(
max_retries=3,
base_delay=0.01,
max_delay=0.1
)
def always_fails():
raise Exception("rate limit exceeded")
with pytest.raises(Exception) as exc_info:
executor.execute_with_retry(always_fails)
assert "rate limit" in str(exc_info.value).lower()
def test_non_retryable_error_raises_immediately(self):
"""Test that non-retryable errors raise immediately."""
executor = ConcurrentTestExecutor(max_retries=5)
call_count = [0]
def auth_error():
call_count[0] += 1
raise Exception("Invalid API key")
with pytest.raises(Exception) as exc_info:
executor.execute_with_retry(auth_error)
assert call_count[0] == 1 # Only called once
assert "Invalid API key" in str(exc_info.value)
class TestExecuteBatch:
"""Tests for execute_batch method."""
def test_empty_batch(self):
"""Test processing empty batch."""
executor = ConcurrentTestExecutor()
results = executor.execute_batch([], lambda x: x)
assert results == []
def test_successful_batch(self):
"""Test successful batch processing."""
executor = ConcurrentTestExecutor(max_workers=3)
items = [1, 2, 3, 4, 5]
results = executor.execute_batch(items, lambda x: x * 2)
assert results == [2, 4, 6, 8, 10]
def test_batch_preserves_order(self):
"""Test that batch results maintain original order."""
executor = ConcurrentTestExecutor(max_workers=5)
items = list(range(10))
def slow_process(x):
time.sleep(0.01 * (10 - x)) # Higher numbers finish first
return x
results = executor.execute_batch(items, slow_process)
assert results == items
def test_batch_with_failures(self):
"""Test batch with some failing items."""
executor = ConcurrentTestExecutor(
max_workers=3,
max_retries=1,
base_delay=0.01
)
def process(x):
if x == 3:
raise Exception("Invalid API key for item 3") # Non-retryable
return x * 2
items = [1, 2, 3, 4, 5]
results = executor.execute_batch(items, process)
assert results[0] == 2
assert results[1] == 4
assert "error" in results[2] # Item 3 failed
assert results[3] == 8
assert results[4] == 10
def test_batch_with_progress_callback(self):
"""Test progress callback during batch processing."""
executor = ConcurrentTestExecutor(max_workers=2)
progress_calls = []
def progress_callback(current, total, message):
progress_calls.append((current, total, message))
items = [1, 2, 3]
executor.execute_batch(items, lambda x: x, progress_callback)
assert len(progress_calls) == 3
# All items should be completed
currents = [p[0] for p in progress_calls]
assert set(currents) == {1, 2, 3}
class TestRateLimiting:
"""Tests for rate limiting functionality."""
def test_rate_limit_tracking(self):
"""Test that rate limit tracks requests."""
executor = ConcurrentTestExecutor(
max_workers=1,
rate_limit_per_minute=5
)
# Clear any existing tracking
executor._request_times = []
# Make a few requests
for _ in range(3):
executor._wait_for_rate_limit()
assert len(executor._request_times) == 3
def test_rate_limit_clears_old_records(self):
"""Test that old request records are cleared."""
executor = ConcurrentTestExecutor(rate_limit_per_minute=100)
# Add old request times (more than 60s ago)
old_time = time.time() - 120 # 2 minutes ago
executor._request_times = [old_time, old_time, old_time]
# This should clear old records
executor._wait_for_rate_limit()
# Should only have the new request
assert len(executor._request_times) == 1
class TestAsyncConcurrentExecutor:
"""Tests for AsyncConcurrentExecutor."""
def test_init(self):
"""Test async executor initialization."""
executor = AsyncConcurrentExecutor(
max_concurrent=10,
max_retries=5,
base_delay=2.0,
max_delay=120.0
)
assert executor.max_concurrent == 10
assert executor.max_retries == 5
assert executor.base_delay == 2.0
assert executor.max_delay == 120.0
@pytest.mark.asyncio
async def test_async_execute_with_retry_success(self):
"""Test successful async execution."""
executor = AsyncConcurrentExecutor()
async def async_func():
return "async success"
result = await executor.execute_with_retry(async_func)
assert result == "async success"
@pytest.mark.asyncio
async def test_async_retry_on_error(self):
"""Test async retry on retryable error."""
executor = AsyncConcurrentExecutor(
max_retries=3,
base_delay=0.01,
max_delay=0.1
)
call_count = [0]
async def flaky_async():
call_count[0] += 1
if call_count[0] < 3:
raise Exception("rate limit exceeded")
return "success"
result = await executor.execute_with_retry(flaky_async)
assert result == "success"
assert call_count[0] == 3
@pytest.mark.asyncio
async def test_async_execute_batch(self):
"""Test async batch execution."""
executor = AsyncConcurrentExecutor(max_concurrent=3)
async def process(x):
await asyncio.sleep(0.01)
return x * 2
items = [1, 2, 3, 4, 5]
results = await executor.execute_batch(items, process)
assert results == [2, 4, 6, 8, 10]
@pytest.mark.asyncio
async def test_async_batch_with_failures(self):
"""Test async batch with failures."""
executor = AsyncConcurrentExecutor(
max_concurrent=2,
max_retries=1,
base_delay=0.01
)
async def process(x):
if x == 3:
raise ValueError("Test error") # Non-retryable
return x
items = [1, 2, 3, 4]
results = await executor.execute_batch(items, process)
assert results[0] == 1
assert results[1] == 2
assert "error" in results[2]
assert results[3] == 4
class TestCreateExecutor:
"""Tests for create_executor factory function."""
def test_create_with_defaults(self):
"""Test creating executor with default values."""
executor = create_executor()
assert isinstance(executor, ConcurrentTestExecutor)
assert executor.max_workers == 5
assert executor.max_retries == 3
def test_create_with_custom_values(self):
"""Test creating executor with custom values."""
executor = create_executor(
max_workers=10,
max_retries=5,
base_delay=2.0,
max_delay=100.0,
rate_limit_per_minute=120
)
assert executor.max_workers == 10
assert executor.max_retries == 5
assert executor.base_delay == 2.0
assert executor.max_delay == 100.0
assert executor.rate_limit_per_minute == 120
class TestConcurrencyBehavior:
"""Tests for actual concurrent execution behavior."""
def test_concurrent_execution_faster_than_serial(self):
"""Test that concurrent execution is faster than serial."""
executor = ConcurrentTestExecutor(max_workers=5)
items = list(range(5))
def slow_func(x):
time.sleep(0.1)
return x
start = time.time()
executor.execute_batch(items, slow_func)
concurrent_time = time.time() - start
# Serial would take ~0.5s, concurrent should be ~0.1s
# Allow some overhead
assert concurrent_time < 0.3, f"Concurrent took {concurrent_time}s, expected < 0.3s"
def test_worker_limit_respected(self):
"""Test that max_workers limit is respected."""
executor = ConcurrentTestExecutor(max_workers=2)
active_count = [0]
max_active = [0]
def track_concurrency(x):
active_count[0] += 1
max_active[0] = max(max_active[0], active_count[0])
time.sleep(0.05)
active_count[0] -= 1
return x
items = list(range(5))
executor.execute_batch(items, track_concurrency)
assert max_active[0] <= 2, f"Max active was {max_active[0]}, expected <= 2"