Spaces:
Running
Running
| """ | |
| Comprehensive tests for src/execution/streaming.py. | |
| Covers all StreamEvent types, StreamBuffer, format_event, print/aprint helpers, | |
| stream_to_string / astream_to_string utilities. | |
| """ | |
| import asyncio | |
| from datetime import datetime | |
| from uuid import uuid4 | |
| import pytest | |
| from execution.streaming import ( | |
| AgentErrorEvent, | |
| AgentOutputEvent, | |
| AgentStartEvent, | |
| BudgetExceededEvent, | |
| BudgetWarningEvent, | |
| FallbackEvent, | |
| MemoryReadEvent, | |
| MemoryWriteEvent, | |
| ParallelEndEvent, | |
| ParallelStartEvent, | |
| PruneEvent, | |
| RunEndEvent, | |
| RunStartEvent, | |
| StreamBuffer, | |
| StreamEvent, | |
| StreamEventType, | |
| TokenEvent, | |
| TopologyChangedEvent, | |
| aprint_stream, | |
| astream_to_string, | |
| format_event, | |
| print_stream, | |
| stream_to_string, | |
| ) | |
| def rid(): | |
| return str(uuid4()) | |
| # βββββββββββββββββββββββββββ StreamEvent base βββββββββββββββββββββββββββββββββ | |
| class TestStreamEventBase: | |
| def test_base_to_dict(self): | |
| e = StreamEvent(event_type="test_event", run_id=rid()) | |
| d = e.to_dict() | |
| assert d["event_type"] == "test_event" | |
| assert "timestamp" in d | |
| def test_timestamp_defaults_to_now(self): | |
| e = StreamEvent(event_type="x") | |
| assert isinstance(e.timestamp, datetime) | |
| def test_metadata_default_empty(self): | |
| e = StreamEvent(event_type="x") | |
| assert e.metadata == {} | |
| # βββββββββββββββββββββββββββ Event constructors βββββββββββββββββββββββββββββββ | |
| class TestRunStartEvent: | |
| def test_basic(self): | |
| e = RunStartEvent( | |
| run_id=rid(), | |
| query="What is 2+2?", | |
| num_agents=3, | |
| execution_order=["a", "b", "c"], | |
| ) | |
| assert e.event_type == "run_start" | |
| assert e.query == "What is 2+2?" | |
| assert e.num_agents == 3 | |
| def test_defaults(self): | |
| e = RunStartEvent(run_id=rid(), query="test") | |
| assert e.num_agents == 0 | |
| assert e.execution_order == [] | |
| def test_no_run_id(self): | |
| e = RunStartEvent(query="test") | |
| assert e.run_id is None | |
| class TestRunEndEvent: | |
| def test_basic(self): | |
| e = RunEndEvent( | |
| run_id=rid(), | |
| final_answer="Final answer", | |
| success=True, | |
| total_time=1.5, | |
| total_tokens=100, | |
| ) | |
| assert e.event_type == "run_end" | |
| assert e.success is True | |
| assert e.total_tokens == 100 | |
| def test_failure(self): | |
| e = RunEndEvent(run_id=rid(), final_answer="", success=False) | |
| assert e.success is False | |
| def test_with_errors(self): | |
| e = RunEndEvent(run_id=rid(), errors=["agent failed"], success=False) | |
| assert len(e.errors) == 1 | |
| class TestAgentStartEvent: | |
| def test_basic(self): | |
| e = AgentStartEvent( | |
| run_id=rid(), | |
| agent_id="solver", | |
| agent_name="Math Solver", | |
| step_index=1, | |
| ) | |
| assert e.event_type == "agent_start" | |
| assert e.agent_id == "solver" | |
| assert e.step_index == 1 | |
| def test_defaults(self): | |
| e = AgentStartEvent() | |
| assert e.agent_id == "" | |
| assert e.step_index == 0 | |
| assert e.predecessors == [] | |
| class TestAgentOutputEvent: | |
| def test_basic(self): | |
| e = AgentOutputEvent( | |
| run_id=rid(), | |
| agent_id="solver", | |
| content="The answer is 42", | |
| tokens_used=20, | |
| duration_ms=150.0, | |
| is_final=True, | |
| ) | |
| assert e.event_type == "agent_output" | |
| assert e.content == "The answer is 42" | |
| assert e.is_final is True | |
| def test_defaults(self): | |
| e = AgentOutputEvent(run_id=rid(), agent_id="a", content="output") | |
| assert e.tokens_used == 0 | |
| assert e.is_final is False | |
| class TestAgentErrorEvent: | |
| def test_basic(self): | |
| e = AgentErrorEvent( | |
| run_id=rid(), | |
| agent_id="faulty", | |
| error_message="Something went wrong", | |
| error_type="RuntimeError", | |
| ) | |
| assert e.event_type == "agent_error" | |
| assert e.error_message == "Something went wrong" | |
| def test_will_retry_default(self): | |
| e = AgentErrorEvent(run_id=rid(), agent_id="a", error_message="err") | |
| assert e.will_retry is False | |
| class TestTokenEvent: | |
| def test_basic(self): | |
| e = TokenEvent( | |
| run_id=rid(), | |
| agent_id="writer", | |
| token="Hello", | |
| token_index=0, | |
| is_first=True, | |
| ) | |
| assert e.event_type == "token" | |
| assert e.token == "Hello" | |
| assert e.is_first is True | |
| def test_is_last(self): | |
| e = TokenEvent(token=".", is_last=True) | |
| assert e.is_last is True | |
| class TestPruneEvent: | |
| def test_basic(self): | |
| e = PruneEvent(run_id=rid(), agent_id="pruned_agent", reason="low trust score") | |
| assert e.event_type == "prune" | |
| assert e.reason == "low trust score" | |
| def test_defaults(self): | |
| e = PruneEvent() | |
| assert e.agent_id == "" | |
| assert e.reason == "" | |
| class TestFallbackEvent: | |
| def test_basic(self): | |
| e = FallbackEvent( | |
| run_id=rid(), | |
| failed_agent_id="broken_agent", | |
| fallback_agent_id="backup_agent", | |
| ) | |
| assert e.event_type == "fallback" | |
| assert e.failed_agent_id == "broken_agent" | |
| def test_defaults(self): | |
| e = FallbackEvent() | |
| assert e.failed_agent_id == "" | |
| class TestParallelEvents: | |
| def test_parallel_start(self): | |
| e = ParallelStartEvent(run_id=rid(), agent_ids=["a", "b", "c"], group_index=0) | |
| assert e.event_type == "parallel_start" | |
| assert len(e.agent_ids) == 3 | |
| def test_parallel_end(self): | |
| e = ParallelEndEvent( | |
| run_id=rid(), | |
| agent_ids=["a", "b"], | |
| group_index=0, | |
| successful=["a"], | |
| failed=["b"], | |
| ) | |
| assert e.event_type == "parallel_end" | |
| assert "a" in e.successful | |
| assert "b" in e.failed | |
| def test_parallel_end_defaults(self): | |
| e = ParallelEndEvent() | |
| assert e.successful == [] | |
| assert e.failed == [] | |
| class TestMemoryEvents: | |
| def test_memory_read(self): | |
| e = MemoryReadEvent(run_id=rid(), agent_id="agent1", entries_count=5) | |
| assert e.event_type == "memory_read" | |
| assert e.entries_count == 5 | |
| def test_memory_write(self): | |
| e = MemoryWriteEvent(run_id=rid(), agent_id="agent1", key="answer", value_size=42) | |
| assert e.event_type == "memory_write" | |
| assert e.key == "answer" | |
| def test_memory_read_defaults(self): | |
| e = MemoryReadEvent() | |
| assert e.entries_count == 0 | |
| class TestBudgetEvents: | |
| def test_budget_warning(self): | |
| e = BudgetWarningEvent( | |
| run_id=rid(), | |
| budget_type="tokens", | |
| current=800.0, | |
| limit=1000.0, | |
| ratio=0.8, | |
| ) | |
| assert e.event_type == "budget_warning" | |
| assert e.budget_type == "tokens" | |
| def test_budget_exceeded(self): | |
| e = BudgetExceededEvent( | |
| run_id=rid(), | |
| budget_type="requests", | |
| current=100.0, | |
| limit=100.0, | |
| ) | |
| assert e.event_type == "budget_exceeded" | |
| assert e.budget_type == "requests" | |
| class TestTopologyChangedEvent: | |
| def test_basic(self): | |
| e = TopologyChangedEvent( | |
| run_id=rid(), | |
| reason="agent pruned", | |
| old_remaining=["a", "b", "c"], | |
| new_remaining=["b", "c"], | |
| ) | |
| assert e.event_type == "topology_changed" | |
| assert "a" in e.old_remaining | |
| def test_defaults(self): | |
| e = TopologyChangedEvent() | |
| assert e.reason == "" | |
| assert e.old_remaining == [] | |
| # βββββββββββββββββββββββββββ format_event βββββββββββββββββββββββββββββββββββββ | |
| class TestFormatEvent: | |
| def test_format_run_start(self): | |
| e = RunStartEvent(run_id=rid(), query="test", num_agents=2) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| assert len(text) > 0 | |
| def test_format_agent_output(self): | |
| e = AgentOutputEvent(run_id=rid(), agent_id="agent1", content="output text") | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_run_end(self): | |
| e = RunEndEvent(run_id=rid(), final_answer="final", success=True) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_token_event(self): | |
| e = TokenEvent(run_id=rid(), agent_id="a", token="hello") | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_agent_error(self): | |
| e = AgentErrorEvent(run_id=rid(), agent_id="bad", error_message="oops") | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_prune(self): | |
| e = PruneEvent(run_id=rid(), agent_id="a", reason="trust too low") | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_fallback(self): | |
| e = FallbackEvent(run_id=rid(), failed_agent_id="a", fallback_agent_id="b") | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_budget_warning(self): | |
| e = BudgetWarningEvent(budget_type="tokens", current=80.0, limit=100.0) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_budget_exceeded(self): | |
| e = BudgetExceededEvent(budget_type="time", current=60.0, limit=60.0) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_agent_output_verbose(self): | |
| e = AgentOutputEvent(agent_id="a", content="x" * 200) | |
| text = format_event(e, verbose=True) | |
| assert isinstance(text, str) | |
| def test_format_topology_changed(self): | |
| e = TopologyChangedEvent(reason="pruned", old_remaining=["a"], new_remaining=[]) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_parallel_start(self): | |
| e = ParallelStartEvent(agent_ids=["a", "b"]) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_parallel_end(self): | |
| e = ParallelEndEvent(agent_ids=["a"], successful=["a"], failed=[]) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_memory_read(self): | |
| e = MemoryReadEvent(agent_id="a", entries_count=3) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_memory_write(self): | |
| e = MemoryWriteEvent(agent_id="a", key="k", value_size=10) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| def test_format_agent_start(self): | |
| e = AgentStartEvent(agent_id="a", step_index=0) | |
| text = format_event(e) | |
| assert isinstance(text, str) | |
| # βββββββββββββββββββββββββββ StreamBuffer βββββββββββββββββββββββββββββββββββββ | |
| class TestStreamBuffer: | |
| def test_add_and_events(self): | |
| buf = StreamBuffer() | |
| e = RunStartEvent(run_id=rid(), query="test") | |
| buf.add(e) | |
| assert len(buf.events) == 1 | |
| assert buf.events[0] is e | |
| def test_add_agent_output_updates_final_answer(self): | |
| buf = StreamBuffer() | |
| e = AgentOutputEvent(agent_id="a", content="The answer", is_final=True) | |
| buf.add(e) | |
| assert buf.final_answer == "The answer" | |
| assert buf.final_agent_id == "a" | |
| def test_add_run_end_updates_final_answer(self): | |
| buf = StreamBuffer() | |
| e = RunEndEvent(final_answer="Run answer", final_agent_id="agent1", success=True) | |
| buf.add(e) | |
| assert buf.final_answer == "Run answer" | |
| def test_add_token_events_accumulate(self): | |
| buf = StreamBuffer() | |
| buf.add(TokenEvent(agent_id="w", token="Hel", is_first=True)) | |
| buf.add(TokenEvent(agent_id="w", token="lo")) | |
| buf.add(TokenEvent(agent_id="w", token="!", is_last=True)) | |
| assert "Hello!" in buf.agent_outputs.get("w", "") | |
| def test_get_output_for(self): | |
| buf = StreamBuffer() | |
| buf.add(AgentOutputEvent(agent_id="writer", content="Written text")) | |
| assert buf.get_output_for("writer") == "Written text" | |
| assert buf.get_output_for("unknown") == "" | |
| def test_get_output_for_in_progress_tokens(self): | |
| """Line 391: get_output_for returns joined tokens when agent is in _current_tokens.""" | |
| buf = StreamBuffer() | |
| # Add token events without a final AgentOutputEvent | |
| buf.add(TokenEvent(agent_id="streamer", token="tok1", is_first=True)) | |
| buf.add(TokenEvent(agent_id="streamer", token="tok2")) | |
| # Agent not in _agent_outputs yet, but is in _current_tokens | |
| output = buf.get_output_for("streamer") | |
| assert "tok1" in output or "tok2" in output | |
| def test_agent_outputs(self): | |
| buf = StreamBuffer() | |
| buf.add(AgentOutputEvent(agent_id="a", content="output_a")) | |
| buf.add(AgentOutputEvent(agent_id="b", content="output_b")) | |
| assert "a" in buf.agent_outputs | |
| assert "b" in buf.agent_outputs | |
| def test_clear(self): | |
| buf = StreamBuffer() | |
| buf.add(AgentOutputEvent(agent_id="a", content="data", is_final=True)) | |
| buf.clear() | |
| assert len(buf.events) == 0 | |
| assert buf.final_answer == "" | |
| def test_buffer_init(self): | |
| buf = StreamBuffer() | |
| assert buf.final_answer == "" | |
| assert buf.events == [] | |
| # βββββββββββββββββββββββββββ stream_to_string βββββββββββββββββββββββββββββββββ | |
| class TestStreamToString: | |
| def test_collects_agent_output(self): | |
| def gen(): | |
| yield AgentOutputEvent(run_id=rid(), agent_id="a", content="Hello World", is_final=True) | |
| yield RunEndEvent(run_id=rid(), final_answer="Hello World", success=True) | |
| result = stream_to_string(gen()) | |
| assert isinstance(result, str) | |
| def test_with_token_events_updates_current_agent(self): | |
| """Line 511: _handle_stream_event updates current_agent_ref for TokenEvent.""" | |
| def gen(): | |
| yield RunStartEvent(run_id=rid(), query="test") | |
| yield TokenEvent(agent_id="a", token="Hello", is_first=True) | |
| yield TokenEvent(agent_id="a", token=" World", is_last=True) | |
| yield RunEndEvent(run_id=rid(), final_answer="Hello World", success=True) | |
| result = stream_to_string(gen()) | |
| assert isinstance(result, str) | |
| def test_empty_stream_returns_string(self): | |
| def gen(): | |
| yield RunEndEvent(run_id=rid(), final_answer="", success=True) | |
| result = stream_to_string(gen()) | |
| assert isinstance(result, str) | |
| def test_multiple_agents(self): | |
| def gen(): | |
| yield AgentOutputEvent(agent_id="a", content="Part 1 ") | |
| yield AgentOutputEvent(agent_id="b", content="Part 2") | |
| yield RunEndEvent(final_answer="done", success=True) | |
| result = stream_to_string(gen()) | |
| assert isinstance(result, str) | |
| class TestAStreamToString: | |
| async def test_collects_agent_output(self): | |
| async def agen(): | |
| yield AgentOutputEvent(run_id=rid(), agent_id="a", content="Hello", is_final=True) | |
| yield RunEndEvent(run_id=rid(), final_answer="Hello", success=True) | |
| result = await astream_to_string(agen()) | |
| assert isinstance(result, str) | |
| async def test_empty_async_stream(self): | |
| async def agen(): | |
| yield RunEndEvent(run_id=rid(), final_answer="", success=True) | |
| result = await astream_to_string(agen()) | |
| assert isinstance(result, str) | |
| # βββββββββββββββββββββββββββ print_stream / aprint_stream βββββββββββββββββββββ | |
| class TestPrintStream: | |
| def test_print_stream_runs(self, capsys): | |
| def gen(): | |
| yield RunStartEvent(run_id=rid(), query="test", num_agents=1) | |
| yield AgentOutputEvent(run_id=rid(), agent_id="a", content="output") | |
| yield RunEndEvent(run_id=rid(), final_answer="output", success=True) | |
| print_stream(gen()) | |
| # Should not raise | |
| async def test_aprint_stream_runs(self, capsys): | |
| async def agen(): | |
| yield RunStartEvent(run_id=rid(), query="test") | |
| yield RunEndEvent(run_id=rid(), final_answer="done", success=True) | |
| await aprint_stream(agen()) | |
| # Should not raise | |
| # βββββββββββββββββββββββββββ StreamEventType enum βββββββββββββββββββββββββββββ | |
| class TestStreamEventType: | |
| def test_all_values_are_strings(self): | |
| for et in StreamEventType: | |
| assert isinstance(et.value, str) | |
| def test_run_start_value(self): | |
| assert StreamEventType.RUN_START.value == "run_start" | |
| def test_agent_output_value(self): | |
| assert StreamEventType.AGENT_OUTPUT.value == "agent_output" | |
| def test_token_value(self): | |
| assert StreamEventType.TOKEN.value == "token" | |
| def test_all_expected_types_present(self): | |
| types = {et.value for et in StreamEventType} | |
| assert "run_start" in types | |
| assert "run_end" in types | |
| assert "agent_start" in types | |
| assert "agent_output" in types | |
| assert "agent_error" in types | |
| assert "token" in types | |
| assert "prune" in types | |
| assert "fallback" in types | |