""" Unit tests for agentgraph/testing/concurrent_executor.py Tests concurrent execution, retry logic, and rate limiting. """ import pytest import asyncio import time from unittest.mock import MagicMock, patch from agentgraph.testing.concurrent_executor import ( ConcurrentTestExecutor, AsyncConcurrentExecutor, create_executor, ) class TestConcurrentTestExecutorInit: """Tests for ConcurrentTestExecutor initialization.""" def test_default_values(self): """Test default initialization values.""" executor = ConcurrentTestExecutor() assert executor.max_workers == 5 assert executor.max_retries == 3 assert executor.base_delay == 1.0 assert executor.max_delay == 60.0 assert executor.rate_limit_per_minute == 60 def test_custom_values(self): """Test custom initialization values.""" executor = ConcurrentTestExecutor( max_workers=10, max_retries=5, base_delay=2.0, max_delay=120.0, rate_limit_per_minute=100 ) assert executor.max_workers == 10 assert executor.max_retries == 5 assert executor.base_delay == 2.0 assert executor.max_delay == 120.0 assert executor.rate_limit_per_minute == 100 class TestShouldRetry: """Tests for retry decision logic.""" def test_rate_limit_errors_should_retry(self): """Test that rate limit errors trigger retry.""" executor = ConcurrentTestExecutor() retryable_errors = [ Exception("rate limit exceeded"), Exception("429 Too Many Requests"), Exception("Rate_limit error"), Exception("timeout waiting for response"), Exception("connection refused"), Exception("server error 500"), ] for error in retryable_errors: assert executor._should_retry(error) is True, f"Should retry: {error}" def test_non_retryable_errors(self): """Test that certain errors don't trigger retry.""" executor = ConcurrentTestExecutor() non_retryable_errors = [ Exception("Invalid API key"), Exception("File not found"), Exception("Permission denied"), ValueError("Invalid input"), ] for error in non_retryable_errors: assert executor._should_retry(error) is False, f"Should not retry: {error}" class TestExecuteWithRetry: """Tests for execute_with_retry method.""" def test_successful_execution(self): """Test successful execution without retry.""" executor = ConcurrentTestExecutor(max_retries=3) def success_func(): return "success" result = executor.execute_with_retry(success_func) assert result == "success" def test_successful_execution_with_args(self): """Test execution with arguments.""" executor = ConcurrentTestExecutor() def add(a, b): return a + b result = executor.execute_with_retry(add, 2, 3) assert result == 5 def test_retry_on_rate_limit(self): """Test retry behavior on rate limit error.""" executor = ConcurrentTestExecutor( max_retries=3, base_delay=0.01, # Fast retry for testing max_delay=0.1 ) call_count = [0] def flaky_func(): call_count[0] += 1 if call_count[0] < 3: raise Exception("rate limit exceeded") return "success after retry" result = executor.execute_with_retry(flaky_func) assert result == "success after retry" assert call_count[0] == 3 def test_max_retries_exceeded(self): """Test that exception is raised after max retries.""" executor = ConcurrentTestExecutor( max_retries=3, base_delay=0.01, max_delay=0.1 ) def always_fails(): raise Exception("rate limit exceeded") with pytest.raises(Exception) as exc_info: executor.execute_with_retry(always_fails) assert "rate limit" in str(exc_info.value).lower() def test_non_retryable_error_raises_immediately(self): """Test that non-retryable errors raise immediately.""" executor = ConcurrentTestExecutor(max_retries=5) call_count = [0] def auth_error(): call_count[0] += 1 raise Exception("Invalid API key") with pytest.raises(Exception) as exc_info: executor.execute_with_retry(auth_error) assert call_count[0] == 1 # Only called once assert "Invalid API key" in str(exc_info.value) class TestExecuteBatch: """Tests for execute_batch method.""" def test_empty_batch(self): """Test processing empty batch.""" executor = ConcurrentTestExecutor() results = executor.execute_batch([], lambda x: x) assert results == [] def test_successful_batch(self): """Test successful batch processing.""" executor = ConcurrentTestExecutor(max_workers=3) items = [1, 2, 3, 4, 5] results = executor.execute_batch(items, lambda x: x * 2) assert results == [2, 4, 6, 8, 10] def test_batch_preserves_order(self): """Test that batch results maintain original order.""" executor = ConcurrentTestExecutor(max_workers=5) items = list(range(10)) def slow_process(x): time.sleep(0.01 * (10 - x)) # Higher numbers finish first return x results = executor.execute_batch(items, slow_process) assert results == items def test_batch_with_failures(self): """Test batch with some failing items.""" executor = ConcurrentTestExecutor( max_workers=3, max_retries=1, base_delay=0.01 ) def process(x): if x == 3: raise Exception("Invalid API key for item 3") # Non-retryable return x * 2 items = [1, 2, 3, 4, 5] results = executor.execute_batch(items, process) assert results[0] == 2 assert results[1] == 4 assert "error" in results[2] # Item 3 failed assert results[3] == 8 assert results[4] == 10 def test_batch_with_progress_callback(self): """Test progress callback during batch processing.""" executor = ConcurrentTestExecutor(max_workers=2) progress_calls = [] def progress_callback(current, total, message): progress_calls.append((current, total, message)) items = [1, 2, 3] executor.execute_batch(items, lambda x: x, progress_callback) assert len(progress_calls) == 3 # All items should be completed currents = [p[0] for p in progress_calls] assert set(currents) == {1, 2, 3} class TestRateLimiting: """Tests for rate limiting functionality.""" def test_rate_limit_tracking(self): """Test that rate limit tracks requests.""" executor = ConcurrentTestExecutor( max_workers=1, rate_limit_per_minute=5 ) # Clear any existing tracking executor._request_times = [] # Make a few requests for _ in range(3): executor._wait_for_rate_limit() assert len(executor._request_times) == 3 def test_rate_limit_clears_old_records(self): """Test that old request records are cleared.""" executor = ConcurrentTestExecutor(rate_limit_per_minute=100) # Add old request times (more than 60s ago) old_time = time.time() - 120 # 2 minutes ago executor._request_times = [old_time, old_time, old_time] # This should clear old records executor._wait_for_rate_limit() # Should only have the new request assert len(executor._request_times) == 1 class TestAsyncConcurrentExecutor: """Tests for AsyncConcurrentExecutor.""" def test_init(self): """Test async executor initialization.""" executor = AsyncConcurrentExecutor( max_concurrent=10, max_retries=5, base_delay=2.0, max_delay=120.0 ) assert executor.max_concurrent == 10 assert executor.max_retries == 5 assert executor.base_delay == 2.0 assert executor.max_delay == 120.0 @pytest.mark.asyncio async def test_async_execute_with_retry_success(self): """Test successful async execution.""" executor = AsyncConcurrentExecutor() async def async_func(): return "async success" result = await executor.execute_with_retry(async_func) assert result == "async success" @pytest.mark.asyncio async def test_async_retry_on_error(self): """Test async retry on retryable error.""" executor = AsyncConcurrentExecutor( max_retries=3, base_delay=0.01, max_delay=0.1 ) call_count = [0] async def flaky_async(): call_count[0] += 1 if call_count[0] < 3: raise Exception("rate limit exceeded") return "success" result = await executor.execute_with_retry(flaky_async) assert result == "success" assert call_count[0] == 3 @pytest.mark.asyncio async def test_async_execute_batch(self): """Test async batch execution.""" executor = AsyncConcurrentExecutor(max_concurrent=3) async def process(x): await asyncio.sleep(0.01) return x * 2 items = [1, 2, 3, 4, 5] results = await executor.execute_batch(items, process) assert results == [2, 4, 6, 8, 10] @pytest.mark.asyncio async def test_async_batch_with_failures(self): """Test async batch with failures.""" executor = AsyncConcurrentExecutor( max_concurrent=2, max_retries=1, base_delay=0.01 ) async def process(x): if x == 3: raise ValueError("Test error") # Non-retryable return x items = [1, 2, 3, 4] results = await executor.execute_batch(items, process) assert results[0] == 1 assert results[1] == 2 assert "error" in results[2] assert results[3] == 4 class TestCreateExecutor: """Tests for create_executor factory function.""" def test_create_with_defaults(self): """Test creating executor with default values.""" executor = create_executor() assert isinstance(executor, ConcurrentTestExecutor) assert executor.max_workers == 5 assert executor.max_retries == 3 def test_create_with_custom_values(self): """Test creating executor with custom values.""" executor = create_executor( max_workers=10, max_retries=5, base_delay=2.0, max_delay=100.0, rate_limit_per_minute=120 ) assert executor.max_workers == 10 assert executor.max_retries == 5 assert executor.base_delay == 2.0 assert executor.max_delay == 100.0 assert executor.rate_limit_per_minute == 120 class TestConcurrencyBehavior: """Tests for actual concurrent execution behavior.""" def test_concurrent_execution_faster_than_serial(self): """Test that concurrent execution is faster than serial.""" executor = ConcurrentTestExecutor(max_workers=5) items = list(range(5)) def slow_func(x): time.sleep(0.1) return x start = time.time() executor.execute_batch(items, slow_func) concurrent_time = time.time() - start # Serial would take ~0.5s, concurrent should be ~0.1s # Allow some overhead assert concurrent_time < 0.3, f"Concurrent took {concurrent_time}s, expected < 0.3s" def test_worker_limit_respected(self): """Test that max_workers limit is respected.""" executor = ConcurrentTestExecutor(max_workers=2) active_count = [0] max_active = [0] def track_concurrency(x): active_count[0] += 1 max_active[0] = max(max_active[0], active_count[0]) time.sleep(0.05) active_count[0] -= 1 return x items = list(range(5)) executor.execute_batch(items, track_concurrency) assert max_active[0] <= 2, f"Max active was {max_active[0]}, expected <= 2"