chatterbox-voice-studio / tests /test_progress.py
techfreakworm's picture
feat(progress): ProgressBus with sessions, ticks, and turn-complete events
422829d unverified
import asyncio
import pytest
from server.progress import ProgressEvent, get_bus
pytestmark = pytest.mark.asyncio
async def test_subscribe_receives_published_events(reset_progress_bus):
bus = get_bus()
async with bus.subscribe() as q:
await bus.publish(ProgressEvent(type="tick", elapsed_s=0.1, payload={"foo": 1}))
evt = await asyncio.wait_for(q.get(), 0.5)
assert evt.type == "tick"
assert evt.payload == {"foo": 1}
async def test_two_subscribers_both_receive_events(reset_progress_bus):
bus = get_bus()
async with bus.subscribe() as q1, bus.subscribe() as q2:
await bus.publish(ProgressEvent(type="tick", elapsed_s=0.0))
a = await asyncio.wait_for(q1.get(), 0.5)
b = await asyncio.wait_for(q2.get(), 0.5)
assert a.type == "tick"
assert b.type == "tick"
async def test_session_emits_start_and_done(reset_progress_bus):
bus = get_bus()
received: list[ProgressEvent] = []
async def collect():
async with bus.subscribe() as q:
while True:
received.append(await q.get())
if received[-1].type == "done":
return
consumer = asyncio.create_task(collect())
await asyncio.sleep(0) # let subscriber register
async with bus.session("single", total_turns=1) as sess:
sess.set_seed(42)
await asyncio.wait_for(consumer, 1.0)
types = [e.type for e in received]
assert types[0] == "start"
assert types[-1] == "done"
done_payload = received[-1].payload
assert done_payload["seed_used"] == 42
async def test_session_emits_error_on_exception_and_reraises(reset_progress_bus):
bus = get_bus()
received: list[ProgressEvent] = []
async def collect():
async with bus.subscribe() as q:
while True:
received.append(await q.get())
if received[-1].type in ("done", "error"):
return
consumer = asyncio.create_task(collect())
await asyncio.sleep(0)
with pytest.raises(RuntimeError):
async with bus.session("single", total_turns=1):
raise RuntimeError("boom")
await asyncio.wait_for(consumer, 1.0)
types = [e.type for e in received]
assert "error" in types
assert any(e.payload.get("message") == "boom" for e in received)
async def test_turn_complete_event_carries_turn_payload(reset_progress_bus):
bus = get_bus()
received: list[ProgressEvent] = []
async def collect():
async with bus.subscribe() as q:
while True:
received.append(await q.get())
if received[-1].type == "done":
return
consumer = asyncio.create_task(collect())
await asyncio.sleep(0)
async with bus.session("dialog", total_turns=3) as sess:
await sess.turn_complete(1)
await sess.turn_complete(2)
await sess.turn_complete(3)
await asyncio.wait_for(consumer, 1.0)
turn_events = [e for e in received if e.type == "turn_complete"]
assert [e.payload["turn"] for e in turn_events] == [1, 2, 3]
assert all(e.payload["total_turns"] == 3 for e in turn_events)
async def test_late_subscriber_gets_snapshot(reset_progress_bus):
bus = get_bus()
received: list[ProgressEvent] = []
async def collect():
async with bus.subscribe() as q:
received.append(await asyncio.wait_for(q.get(), 1.0))
async with bus.session("dialog", total_turns=4) as sess:
await sess.turn_complete(2)
# join AFTER the session started
consumer = asyncio.create_task(collect())
await asyncio.wait_for(consumer, 1.0)
assert received[0].type == "tick"
assert received[0].payload["kind"] == "dialog"
assert received[0].payload["turn"] == 2
assert received[0].payload["total_turns"] == 4