| import asyncio |
|
|
| import pytest |
|
|
| from server.progress import ProgressEvent, get_bus |
|
|
|
|
| pytestmark = pytest.mark.asyncio |
|
|
|
|
| async def test_subscribe_receives_published_events(reset_progress_bus): |
| bus = get_bus() |
| async with bus.subscribe() as q: |
| await bus.publish(ProgressEvent(type="tick", elapsed_s=0.1, payload={"foo": 1})) |
| evt = await asyncio.wait_for(q.get(), 0.5) |
| assert evt.type == "tick" |
| assert evt.payload == {"foo": 1} |
|
|
|
|
| async def test_two_subscribers_both_receive_events(reset_progress_bus): |
| bus = get_bus() |
| async with bus.subscribe() as q1, bus.subscribe() as q2: |
| await bus.publish(ProgressEvent(type="tick", elapsed_s=0.0)) |
| a = await asyncio.wait_for(q1.get(), 0.5) |
| b = await asyncio.wait_for(q2.get(), 0.5) |
| assert a.type == "tick" |
| assert b.type == "tick" |
|
|
|
|
| async def test_session_emits_start_and_done(reset_progress_bus): |
| bus = get_bus() |
| received: list[ProgressEvent] = [] |
|
|
| async def collect(): |
| async with bus.subscribe() as q: |
| while True: |
| received.append(await q.get()) |
| if received[-1].type == "done": |
| return |
|
|
| consumer = asyncio.create_task(collect()) |
| await asyncio.sleep(0) |
|
|
| async with bus.session("single", total_turns=1) as sess: |
| sess.set_seed(42) |
|
|
| await asyncio.wait_for(consumer, 1.0) |
| types = [e.type for e in received] |
| assert types[0] == "start" |
| assert types[-1] == "done" |
| done_payload = received[-1].payload |
| assert done_payload["seed_used"] == 42 |
|
|
|
|
| async def test_session_emits_error_on_exception_and_reraises(reset_progress_bus): |
| bus = get_bus() |
| received: list[ProgressEvent] = [] |
|
|
| async def collect(): |
| async with bus.subscribe() as q: |
| while True: |
| received.append(await q.get()) |
| if received[-1].type in ("done", "error"): |
| return |
|
|
| consumer = asyncio.create_task(collect()) |
| await asyncio.sleep(0) |
|
|
| with pytest.raises(RuntimeError): |
| async with bus.session("single", total_turns=1): |
| raise RuntimeError("boom") |
|
|
| await asyncio.wait_for(consumer, 1.0) |
| types = [e.type for e in received] |
| assert "error" in types |
| assert any(e.payload.get("message") == "boom" for e in received) |
|
|
|
|
| async def test_turn_complete_event_carries_turn_payload(reset_progress_bus): |
| bus = get_bus() |
| received: list[ProgressEvent] = [] |
|
|
| async def collect(): |
| async with bus.subscribe() as q: |
| while True: |
| received.append(await q.get()) |
| if received[-1].type == "done": |
| return |
|
|
| consumer = asyncio.create_task(collect()) |
| await asyncio.sleep(0) |
|
|
| async with bus.session("dialog", total_turns=3) as sess: |
| await sess.turn_complete(1) |
| await sess.turn_complete(2) |
| await sess.turn_complete(3) |
|
|
| await asyncio.wait_for(consumer, 1.0) |
| turn_events = [e for e in received if e.type == "turn_complete"] |
| assert [e.payload["turn"] for e in turn_events] == [1, 2, 3] |
| assert all(e.payload["total_turns"] == 3 for e in turn_events) |
|
|
|
|
| async def test_late_subscriber_gets_snapshot(reset_progress_bus): |
| bus = get_bus() |
| received: list[ProgressEvent] = [] |
|
|
| async def collect(): |
| async with bus.subscribe() as q: |
| received.append(await asyncio.wait_for(q.get(), 1.0)) |
|
|
| async with bus.session("dialog", total_turns=4) as sess: |
| await sess.turn_complete(2) |
| |
| consumer = asyncio.create_task(collect()) |
| await asyncio.wait_for(consumer, 1.0) |
|
|
| assert received[0].type == "tick" |
| assert received[0].payload["kind"] == "dialog" |
| assert received[0].payload["turn"] == 2 |
| assert received[0].payload["total_turns"] == 4 |
|
|