Spaces:
Running
Running
| """ | |
| Unit tests for agentgraph/testing/concurrent_executor.py | |
| Tests concurrent execution, retry logic, and rate limiting. | |
| """ | |
| import pytest | |
| import asyncio | |
| import time | |
| from unittest.mock import MagicMock, patch | |
| from agentgraph.testing.concurrent_executor import ( | |
| ConcurrentTestExecutor, | |
| AsyncConcurrentExecutor, | |
| create_executor, | |
| ) | |
| class TestConcurrentTestExecutorInit: | |
| """Tests for ConcurrentTestExecutor initialization.""" | |
| def test_default_values(self): | |
| """Test default initialization values.""" | |
| executor = ConcurrentTestExecutor() | |
| assert executor.max_workers == 5 | |
| assert executor.max_retries == 3 | |
| assert executor.base_delay == 1.0 | |
| assert executor.max_delay == 60.0 | |
| assert executor.rate_limit_per_minute == 60 | |
| def test_custom_values(self): | |
| """Test custom initialization values.""" | |
| executor = ConcurrentTestExecutor( | |
| max_workers=10, | |
| max_retries=5, | |
| base_delay=2.0, | |
| max_delay=120.0, | |
| rate_limit_per_minute=100 | |
| ) | |
| assert executor.max_workers == 10 | |
| assert executor.max_retries == 5 | |
| assert executor.base_delay == 2.0 | |
| assert executor.max_delay == 120.0 | |
| assert executor.rate_limit_per_minute == 100 | |
| class TestShouldRetry: | |
| """Tests for retry decision logic.""" | |
| def test_rate_limit_errors_should_retry(self): | |
| """Test that rate limit errors trigger retry.""" | |
| executor = ConcurrentTestExecutor() | |
| retryable_errors = [ | |
| Exception("rate limit exceeded"), | |
| Exception("429 Too Many Requests"), | |
| Exception("Rate_limit error"), | |
| Exception("timeout waiting for response"), | |
| Exception("connection refused"), | |
| Exception("server error 500"), | |
| ] | |
| for error in retryable_errors: | |
| assert executor._should_retry(error) is True, f"Should retry: {error}" | |
| def test_non_retryable_errors(self): | |
| """Test that certain errors don't trigger retry.""" | |
| executor = ConcurrentTestExecutor() | |
| non_retryable_errors = [ | |
| Exception("Invalid API key"), | |
| Exception("File not found"), | |
| Exception("Permission denied"), | |
| ValueError("Invalid input"), | |
| ] | |
| for error in non_retryable_errors: | |
| assert executor._should_retry(error) is False, f"Should not retry: {error}" | |
| class TestExecuteWithRetry: | |
| """Tests for execute_with_retry method.""" | |
| def test_successful_execution(self): | |
| """Test successful execution without retry.""" | |
| executor = ConcurrentTestExecutor(max_retries=3) | |
| def success_func(): | |
| return "success" | |
| result = executor.execute_with_retry(success_func) | |
| assert result == "success" | |
| def test_successful_execution_with_args(self): | |
| """Test execution with arguments.""" | |
| executor = ConcurrentTestExecutor() | |
| def add(a, b): | |
| return a + b | |
| result = executor.execute_with_retry(add, 2, 3) | |
| assert result == 5 | |
| def test_retry_on_rate_limit(self): | |
| """Test retry behavior on rate limit error.""" | |
| executor = ConcurrentTestExecutor( | |
| max_retries=3, | |
| base_delay=0.01, # Fast retry for testing | |
| max_delay=0.1 | |
| ) | |
| call_count = [0] | |
| def flaky_func(): | |
| call_count[0] += 1 | |
| if call_count[0] < 3: | |
| raise Exception("rate limit exceeded") | |
| return "success after retry" | |
| result = executor.execute_with_retry(flaky_func) | |
| assert result == "success after retry" | |
| assert call_count[0] == 3 | |
| def test_max_retries_exceeded(self): | |
| """Test that exception is raised after max retries.""" | |
| executor = ConcurrentTestExecutor( | |
| max_retries=3, | |
| base_delay=0.01, | |
| max_delay=0.1 | |
| ) | |
| def always_fails(): | |
| raise Exception("rate limit exceeded") | |
| with pytest.raises(Exception) as exc_info: | |
| executor.execute_with_retry(always_fails) | |
| assert "rate limit" in str(exc_info.value).lower() | |
| def test_non_retryable_error_raises_immediately(self): | |
| """Test that non-retryable errors raise immediately.""" | |
| executor = ConcurrentTestExecutor(max_retries=5) | |
| call_count = [0] | |
| def auth_error(): | |
| call_count[0] += 1 | |
| raise Exception("Invalid API key") | |
| with pytest.raises(Exception) as exc_info: | |
| executor.execute_with_retry(auth_error) | |
| assert call_count[0] == 1 # Only called once | |
| assert "Invalid API key" in str(exc_info.value) | |
| class TestExecuteBatch: | |
| """Tests for execute_batch method.""" | |
| def test_empty_batch(self): | |
| """Test processing empty batch.""" | |
| executor = ConcurrentTestExecutor() | |
| results = executor.execute_batch([], lambda x: x) | |
| assert results == [] | |
| def test_successful_batch(self): | |
| """Test successful batch processing.""" | |
| executor = ConcurrentTestExecutor(max_workers=3) | |
| items = [1, 2, 3, 4, 5] | |
| results = executor.execute_batch(items, lambda x: x * 2) | |
| assert results == [2, 4, 6, 8, 10] | |
| def test_batch_preserves_order(self): | |
| """Test that batch results maintain original order.""" | |
| executor = ConcurrentTestExecutor(max_workers=5) | |
| items = list(range(10)) | |
| def slow_process(x): | |
| time.sleep(0.01 * (10 - x)) # Higher numbers finish first | |
| return x | |
| results = executor.execute_batch(items, slow_process) | |
| assert results == items | |
| def test_batch_with_failures(self): | |
| """Test batch with some failing items.""" | |
| executor = ConcurrentTestExecutor( | |
| max_workers=3, | |
| max_retries=1, | |
| base_delay=0.01 | |
| ) | |
| def process(x): | |
| if x == 3: | |
| raise Exception("Invalid API key for item 3") # Non-retryable | |
| return x * 2 | |
| items = [1, 2, 3, 4, 5] | |
| results = executor.execute_batch(items, process) | |
| assert results[0] == 2 | |
| assert results[1] == 4 | |
| assert "error" in results[2] # Item 3 failed | |
| assert results[3] == 8 | |
| assert results[4] == 10 | |
| def test_batch_with_progress_callback(self): | |
| """Test progress callback during batch processing.""" | |
| executor = ConcurrentTestExecutor(max_workers=2) | |
| progress_calls = [] | |
| def progress_callback(current, total, message): | |
| progress_calls.append((current, total, message)) | |
| items = [1, 2, 3] | |
| executor.execute_batch(items, lambda x: x, progress_callback) | |
| assert len(progress_calls) == 3 | |
| # All items should be completed | |
| currents = [p[0] for p in progress_calls] | |
| assert set(currents) == {1, 2, 3} | |
| class TestRateLimiting: | |
| """Tests for rate limiting functionality.""" | |
| def test_rate_limit_tracking(self): | |
| """Test that rate limit tracks requests.""" | |
| executor = ConcurrentTestExecutor( | |
| max_workers=1, | |
| rate_limit_per_minute=5 | |
| ) | |
| # Clear any existing tracking | |
| executor._request_times = [] | |
| # Make a few requests | |
| for _ in range(3): | |
| executor._wait_for_rate_limit() | |
| assert len(executor._request_times) == 3 | |
| def test_rate_limit_clears_old_records(self): | |
| """Test that old request records are cleared.""" | |
| executor = ConcurrentTestExecutor(rate_limit_per_minute=100) | |
| # Add old request times (more than 60s ago) | |
| old_time = time.time() - 120 # 2 minutes ago | |
| executor._request_times = [old_time, old_time, old_time] | |
| # This should clear old records | |
| executor._wait_for_rate_limit() | |
| # Should only have the new request | |
| assert len(executor._request_times) == 1 | |
| class TestAsyncConcurrentExecutor: | |
| """Tests for AsyncConcurrentExecutor.""" | |
| def test_init(self): | |
| """Test async executor initialization.""" | |
| executor = AsyncConcurrentExecutor( | |
| max_concurrent=10, | |
| max_retries=5, | |
| base_delay=2.0, | |
| max_delay=120.0 | |
| ) | |
| assert executor.max_concurrent == 10 | |
| assert executor.max_retries == 5 | |
| assert executor.base_delay == 2.0 | |
| assert executor.max_delay == 120.0 | |
| async def test_async_execute_with_retry_success(self): | |
| """Test successful async execution.""" | |
| executor = AsyncConcurrentExecutor() | |
| async def async_func(): | |
| return "async success" | |
| result = await executor.execute_with_retry(async_func) | |
| assert result == "async success" | |
| async def test_async_retry_on_error(self): | |
| """Test async retry on retryable error.""" | |
| executor = AsyncConcurrentExecutor( | |
| max_retries=3, | |
| base_delay=0.01, | |
| max_delay=0.1 | |
| ) | |
| call_count = [0] | |
| async def flaky_async(): | |
| call_count[0] += 1 | |
| if call_count[0] < 3: | |
| raise Exception("rate limit exceeded") | |
| return "success" | |
| result = await executor.execute_with_retry(flaky_async) | |
| assert result == "success" | |
| assert call_count[0] == 3 | |
| async def test_async_execute_batch(self): | |
| """Test async batch execution.""" | |
| executor = AsyncConcurrentExecutor(max_concurrent=3) | |
| async def process(x): | |
| await asyncio.sleep(0.01) | |
| return x * 2 | |
| items = [1, 2, 3, 4, 5] | |
| results = await executor.execute_batch(items, process) | |
| assert results == [2, 4, 6, 8, 10] | |
| async def test_async_batch_with_failures(self): | |
| """Test async batch with failures.""" | |
| executor = AsyncConcurrentExecutor( | |
| max_concurrent=2, | |
| max_retries=1, | |
| base_delay=0.01 | |
| ) | |
| async def process(x): | |
| if x == 3: | |
| raise ValueError("Test error") # Non-retryable | |
| return x | |
| items = [1, 2, 3, 4] | |
| results = await executor.execute_batch(items, process) | |
| assert results[0] == 1 | |
| assert results[1] == 2 | |
| assert "error" in results[2] | |
| assert results[3] == 4 | |
| class TestCreateExecutor: | |
| """Tests for create_executor factory function.""" | |
| def test_create_with_defaults(self): | |
| """Test creating executor with default values.""" | |
| executor = create_executor() | |
| assert isinstance(executor, ConcurrentTestExecutor) | |
| assert executor.max_workers == 5 | |
| assert executor.max_retries == 3 | |
| def test_create_with_custom_values(self): | |
| """Test creating executor with custom values.""" | |
| executor = create_executor( | |
| max_workers=10, | |
| max_retries=5, | |
| base_delay=2.0, | |
| max_delay=100.0, | |
| rate_limit_per_minute=120 | |
| ) | |
| assert executor.max_workers == 10 | |
| assert executor.max_retries == 5 | |
| assert executor.base_delay == 2.0 | |
| assert executor.max_delay == 100.0 | |
| assert executor.rate_limit_per_minute == 120 | |
| class TestConcurrencyBehavior: | |
| """Tests for actual concurrent execution behavior.""" | |
| def test_concurrent_execution_faster_than_serial(self): | |
| """Test that concurrent execution is faster than serial.""" | |
| executor = ConcurrentTestExecutor(max_workers=5) | |
| items = list(range(5)) | |
| def slow_func(x): | |
| time.sleep(0.1) | |
| return x | |
| start = time.time() | |
| executor.execute_batch(items, slow_func) | |
| concurrent_time = time.time() - start | |
| # Serial would take ~0.5s, concurrent should be ~0.1s | |
| # Allow some overhead | |
| assert concurrent_time < 0.3, f"Concurrent took {concurrent_time}s, expected < 0.3s" | |
| def test_worker_limit_respected(self): | |
| """Test that max_workers limit is respected.""" | |
| executor = ConcurrentTestExecutor(max_workers=2) | |
| active_count = [0] | |
| max_active = [0] | |
| def track_concurrency(x): | |
| active_count[0] += 1 | |
| max_active[0] = max(max_active[0], active_count[0]) | |
| time.sleep(0.05) | |
| active_count[0] -= 1 | |
| return x | |
| items = list(range(5)) | |
| executor.execute_batch(items, track_concurrency) | |
| assert max_active[0] <= 2, f"Max active was {max_active[0]}, expected <= 2" | |