Spaces:
Runtime error
Runtime error
File size: 8,855 Bytes
44655c5 2770657 44655c5 2770657 44655c5 2770657 44655c5 2770657 44655c5 2770657 44655c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
"""
Tests for TaskExecutor - backend-agnostic task execution.
These tests define the expected behavior of TaskExecutor BEFORE implementation.
They should FAIL initially (TaskExecutor doesn't exist yet).
Phase 2 of #73 adapter refactor.
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock
class TestTaskExecutor:
"""Tests for TaskExecutor - backend-agnostic execution engine."""
@pytest.fixture
def mock_adapter(self):
"""Create a mock HostAdapter for testing."""
adapter = MagicMock()
adapter.get_concurrency_limit.return_value = 2
# Mock stream_completion to yield chunks
async def mock_stream(*args, **kwargs):
yield "Hello "
yield "world!"
adapter.stream_completion = mock_stream
return adapter
@pytest.mark.asyncio
async def test_execute_single_task(self, mock_adapter):
"""Execute a single task through the adapter."""
from prompt_prix.executor import TaskExecutor, Task
executor = TaskExecutor(mock_adapter)
task = Task(
id="test-1",
model_id="model-a",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
timeout_seconds=60
)
results = [r async for r in executor.execute([task])]
assert len(results) == 1
assert results[0].task_id == "test-1"
assert results[0].model_id == "model-a"
assert results[0].status == "success"
assert results[0].response == "Hello world!"
@pytest.mark.asyncio
async def test_execute_multiple_tasks(self, mock_adapter):
"""Execute multiple tasks and get results for each."""
from prompt_prix.executor import TaskExecutor, Task
executor = TaskExecutor(mock_adapter)
tasks = [
Task(id="t1", model_id="model-a", messages=[{"role": "user", "content": "Hi"}],
temperature=0.7, max_tokens=100, timeout_seconds=60),
Task(id="t2", model_id="model-b", messages=[{"role": "user", "content": "Hey"}],
temperature=0.7, max_tokens=100, timeout_seconds=60),
]
results = [r async for r in executor.execute(tasks)]
assert len(results) == 2
task_ids = {r.task_id for r in results}
assert task_ids == {"t1", "t2"}
@pytest.mark.asyncio
async def test_error_handling(self, mock_adapter):
"""Errors in stream_completion are captured in TaskResult."""
from prompt_prix.executor import TaskExecutor, Task
# Make stream_completion raise an error
async def failing_stream(*args, **kwargs):
raise RuntimeError("Model error")
yield # unreachable, but makes it a generator
mock_adapter.stream_completion = failing_stream
executor = TaskExecutor(mock_adapter)
task = Task(
id="test-1",
model_id="model-a",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
timeout_seconds=60
)
results = [r async for r in executor.execute([task])]
# Task should complete with error status
assert len(results) == 1
assert results[0].status == "error"
assert "Model error" in results[0].error
@pytest.mark.asyncio
async def test_respects_concurrency_limit(self, mock_adapter):
"""Executor respects adapter.get_concurrency_limit()."""
from prompt_prix.executor import TaskExecutor, Task
# Track concurrent stream_completion calls
concurrent_count = 0
max_concurrent = 0
async def tracking_stream(*args, **kwargs):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
try:
await asyncio.sleep(0.05) # Delay to ensure concurrency is measurable
yield "response"
finally:
concurrent_count -= 1
mock_adapter.stream_completion = tracking_stream
mock_adapter.get_concurrency_limit.return_value = 2
executor = TaskExecutor(mock_adapter)
tasks = [
Task(id=f"t{i}", model_id="model-a", messages=[{"role": "user", "content": "Hi"}],
temperature=0.7, max_tokens=100, timeout_seconds=60)
for i in range(5)
]
results = [r async for r in executor.execute(tasks)]
assert len(results) == 5
# Max concurrent should not exceed limit
assert max_concurrent <= 2
@pytest.mark.asyncio
async def test_result_includes_duration(self, mock_adapter):
"""TaskResult includes execution duration in milliseconds."""
from prompt_prix.executor import TaskExecutor, Task
# Add small delay to measure
async def slow_stream(*args, **kwargs):
await asyncio.sleep(0.01)
yield "response"
mock_adapter.stream_completion = slow_stream
executor = TaskExecutor(mock_adapter)
task = Task(
id="test-1",
model_id="model-a",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
timeout_seconds=60
)
results = [r async for r in executor.execute([task])]
assert results[0].duration_ms >= 10 # At least 10ms from sleep
@pytest.mark.asyncio
async def test_passes_tools_to_adapter(self, mock_adapter):
"""Tools parameter is passed through to adapter."""
from prompt_prix.executor import TaskExecutor, Task
captured_kwargs = {}
async def capturing_stream(*args, **kwargs):
captured_kwargs.update(kwargs)
yield "response"
mock_adapter.stream_completion = capturing_stream
tools = [{"type": "function", "function": {"name": "get_weather"}}]
executor = TaskExecutor(mock_adapter)
task = Task(
id="test-1",
model_id="model-a",
messages=[{"role": "user", "content": "Weather?"}],
temperature=0.7,
max_tokens=100,
timeout_seconds=60,
tools=tools
)
async for _ in executor.execute([task]):
pass
assert captured_kwargs.get("tools") == tools
class TestTaskDataclass:
"""Tests for Task dataclass structure."""
def test_task_required_fields(self):
"""Task has required fields."""
from prompt_prix.executor import Task
task = Task(
id="test-1",
model_id="model-a",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
timeout_seconds=60
)
assert task.id == "test-1"
assert task.model_id == "model-a"
assert task.messages == [{"role": "user", "content": "Hello"}]
assert task.temperature == 0.7
assert task.max_tokens == 100
assert task.timeout_seconds == 60
def test_task_optional_fields(self):
"""Task has optional tools field."""
from prompt_prix.executor import Task
task = Task(
id="test-1",
model_id="model-a",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
timeout_seconds=60,
tools=[{"type": "function", "function": {"name": "test"}}]
)
assert task.tools == [{"type": "function", "function": {"name": "test"}}]
class TestTaskResultDataclass:
"""Tests for TaskResult dataclass structure."""
def test_taskresult_success(self):
"""TaskResult captures successful execution."""
from prompt_prix.executor import TaskResult
result = TaskResult(
task_id="test-1",
model_id="model-a",
response="Hello world!",
status="success",
duration_ms=150
)
assert result.task_id == "test-1"
assert result.model_id == "model-a"
assert result.response == "Hello world!"
assert result.status == "success"
assert result.duration_ms == 150
assert result.error is None
def test_taskresult_error(self):
"""TaskResult captures error with message."""
from prompt_prix.executor import TaskResult
result = TaskResult(
task_id="test-1",
model_id="model-a",
response="",
status="error",
duration_ms=50,
error="Connection timeout"
)
assert result.status == "error"
assert result.error == "Connection timeout"
|