gMAS / tests /test_core_events.py
Артём Боярских
chore: initial commit
3193174
"""Tests for src/core/events.py"""
import pytest
from core.events import (
BudgetEvent,
BudgetExceededEvent,
BudgetWarningEvent,
CallableHandler,
EdgeAddedEvent,
EdgeRemovedEvent,
EdgeUpdatedEvent,
Event,
EventBus,
EventHandler,
EventPriority,
EventType,
GlobalEventBus,
GraphEvent,
LoggingEventHandler,
MemoryExpiredEvent,
MemoryReadEvent,
MemoryWriteEvent,
MetricsEventHandler,
NodeAddedEvent,
NodeRemovedEvent,
NodeReplacedEvent,
RunCompletedEvent,
RunStartedEvent,
StepCompletedEvent,
StepFailedEvent,
StepRetriedEvent,
StepStartedEvent,
emit_event,
global_event_bus,
on_event,
)
# ─────────────────────────── Event Base ──────────────────────────────────────
class TestEventBase:
def test_create_event(self):
event = NodeAddedEvent(
event_type=EventType.NODE_ADDED,
node_id="solver",
)
assert event.event_type == EventType.NODE_ADDED
assert event.priority == EventPriority.NORMAL
assert event.timestamp is not None
def test_to_dict(self):
event = NodeAddedEvent(node_id="solver")
d = event.to_dict()
assert d["event_type"] == EventType.NODE_ADDED.value
assert "timestamp" in d
assert "source" in d
def test_event_with_source(self):
event = RunStartedEvent(source="runner", query="test")
assert event.source == "runner"
def test_event_with_metadata(self):
event = NodeAddedEvent(node_id="n1", metadata={"tag": "test"})
assert event.metadata["tag"] == "test"
# ─────────────────────────── Graph Events ────────────────────────────────────
class TestGraphEvents:
def test_node_added_event(self):
e = NodeAddedEvent(node_id="solver", connected_to=["a", "b"])
assert e.node_id == "solver"
assert e.connected_to == ["a", "b"]
def test_node_removed_event(self):
e = NodeRemovedEvent(node_id="old", migration_policy="merge", state_archived=True)
assert e.migration_policy == "merge"
assert e.state_archived is True
def test_node_replaced_event(self):
e = NodeReplacedEvent(old_node_id="old", new_node_id="new", state_migrated=True)
assert e.old_node_id == "old"
assert e.new_node_id == "new"
def test_edge_added_event(self):
e = EdgeAddedEvent(source_id="a", target_id="b", weight=0.5)
assert e.source_id == "a"
assert e.weight == 0.5
def test_edge_removed_event(self):
e = EdgeRemovedEvent(source_id="a", target_id="b")
assert e.source_id == "a"
assert e.target_id == "b"
def test_edge_updated_event(self):
e = EdgeUpdatedEvent(
source_id="a",
target_id="b",
old_weight=1.0,
new_weight=2.0,
changes={"weight": 2.0},
)
assert e.new_weight == 2.0
# ─────────────────────────── Execution Events ────────────────────────────────
class TestExecutionEvents:
def test_run_started_event(self):
e = RunStartedEvent(run_id="run1", query="Hello", num_agents=3)
assert e.query == "Hello"
assert e.num_agents == 3
def test_run_completed_event_success(self):
e = RunCompletedEvent(
run_id="run1",
success=True,
answer="42",
total_tokens=500,
total_steps=3,
duration_ms=1500.0,
)
assert e.success is True
assert e.total_tokens == 500
def test_run_completed_event_failure(self):
e = RunCompletedEvent(success=False, errors=["timeout"])
assert e.success is False
assert "timeout" in e.errors
def test_step_started_event(self):
e = StepStartedEvent(agent_id="solver", step_index=0, predecessors=["planner"])
assert e.agent_id == "solver"
assert e.predecessors == ["planner"]
def test_step_completed_event(self):
e = StepCompletedEvent(
agent_id="solver",
step_index=1,
success=True,
tokens_used=100,
duration_ms=200.0,
)
assert e.tokens_used == 100
def test_step_failed_event(self):
e = StepFailedEvent(
agent_id="solver",
error_type="TimeoutError",
error_message="timeout",
will_retry=True,
)
assert e.priority == EventPriority.HIGH
assert e.will_retry is True
def test_step_retried_event(self):
e = StepRetriedEvent(agent_id="solver", attempt=2, max_attempts=3, delay_ms=500.0)
assert e.attempt == 2
# ─────────────────────────── Memory Events ───────────────────────────────────
class TestMemoryEvents:
def test_memory_write_event(self):
e = MemoryWriteEvent(agent_id="solver", key="context", value_size=256)
assert e.key == "context"
assert e.value_size == 256
def test_memory_read_event(self):
e = MemoryReadEvent(agent_id="solver", key="context", found=True)
assert e.found is True
def test_memory_expired_event(self):
e = MemoryExpiredEvent(key="old_data", ttl_seconds=60.0)
assert e.ttl_seconds == 60.0
# ─────────────────────────── Budget Events ───────────────────────────────────
class TestBudgetEvents:
def test_budget_warning_event(self):
e = BudgetWarningEvent(budget_type="tokens", current_value=800.0, limit=1000.0, ratio=0.8)
assert e.budget_type == "tokens"
assert e.ratio == 0.8
def test_budget_exceeded_event(self):
e = BudgetExceededEvent(budget_type="requests", current_value=10.0, limit=10.0)
assert e.priority == EventPriority.CRITICAL
assert e.budget_type == "requests"
# ─────────────────────────── EventBus ────────────────────────────────────────
class ConcreteHandler(EventHandler):
def __init__(self):
self.received: list[Event] = []
def handle(self, event: Event) -> None:
self.received.append(event)
class FailingHandler(EventHandler):
raise_error = True
def handle(self, event: Event) -> None:
raise RuntimeError("handler error")
class TestEventBus:
def setup_method(self):
self.bus = EventBus()
def test_subscribe_and_publish(self):
handler = ConcreteHandler()
self.bus.subscribe(EventType.NODE_ADDED, handler)
event = NodeAddedEvent(node_id="n1")
self.bus.publish(event)
assert len(handler.received) == 1
def test_subscribe_global_handler(self):
handler = ConcreteHandler()
self.bus.subscribe(None, handler)
self.bus.publish(NodeAddedEvent(node_id="n1"))
self.bus.publish(StepCompletedEvent(agent_id="a"))
assert len(handler.received) == 2
def test_publish_to_wrong_event_type(self):
handler = ConcreteHandler()
self.bus.subscribe(EventType.NODE_REMOVED, handler)
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(handler.received) == 0
def test_unsubscribe_type_handler(self):
handler = ConcreteHandler()
self.bus.subscribe(EventType.NODE_ADDED, handler)
self.bus.unsubscribe(EventType.NODE_ADDED, handler)
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(handler.received) == 0
def test_unsubscribe_global_handler(self):
handler = ConcreteHandler()
self.bus.subscribe(None, handler)
self.bus.unsubscribe(None, handler)
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(handler.received) == 0
def test_disable_and_enable(self):
handler = ConcreteHandler()
self.bus.subscribe(EventType.NODE_ADDED, handler)
self.bus.disable()
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(handler.received) == 0
self.bus.enable()
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(handler.received) == 1
def test_clear_handlers(self):
handler = ConcreteHandler()
self.bus.subscribe(EventType.NODE_ADDED, handler)
self.bus.clear()
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(handler.received) == 0
def test_handler_error_with_raise(self):
handler = FailingHandler()
self.bus.subscribe(EventType.NODE_ADDED, handler)
with pytest.raises(RuntimeError, match="handler error"):
self.bus.publish(NodeAddedEvent(node_id="n1"))
def test_handler_error_without_raise(self):
class SilentFailingHandler(EventHandler):
raise_error = False
def handle(self, event: Event) -> None:
raise ValueError("silent error")
handler = SilentFailingHandler()
self.bus.subscribe(EventType.NODE_ADDED, handler)
self.bus.publish(NodeAddedEvent(node_id="n1")) # should not raise
def test_callable_handler(self):
received = []
self.bus.subscribe(EventType.NODE_ADDED, lambda e: received.append(e))
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(received) == 1
def test_multiple_handlers_for_same_type(self):
h1 = ConcreteHandler()
h2 = ConcreteHandler()
self.bus.subscribe(EventType.NODE_ADDED, h1)
self.bus.subscribe(EventType.NODE_ADDED, h2)
self.bus.publish(NodeAddedEvent(node_id="n1"))
assert len(h1.received) == 1
assert len(h2.received) == 1
def test_can_handle_override(self):
class FilteredHandler(EventHandler):
def can_handle(self, event: Event) -> bool:
return isinstance(event, NodeAddedEvent)
def handle(self, event: Event) -> None:
pass
handler = FilteredHandler()
self.bus.subscribe(None, handler) # global
# Only NodeAddedEvent passes can_handle, but publish doesn't bypass it
# (can_handle is checked per event)
# ─────────────────────────── CallableHandler ─────────────────────────────────
class TestCallableHandler:
def test_wraps_function(self):
received = []
handler = CallableHandler(lambda e: received.append(e))
event = NodeAddedEvent(node_id="n1")
handler.handle(event)
assert len(received) == 1
def test_can_handle_default_true(self):
handler = CallableHandler(lambda e: None)
assert handler.can_handle(NodeAddedEvent(node_id="n1")) is True
# ─────────────────────────── LoggingEventHandler ─────────────────────────────
class TestLoggingEventHandler:
def test_handle_node_added(self):
handler = LoggingEventHandler()
handler.handle(NodeAddedEvent(node_id="solver"))
def test_handle_node_removed(self):
handler = LoggingEventHandler()
handler.handle(NodeRemovedEvent(node_id="old", migration_policy="discard"))
def test_handle_edge_added(self):
handler = LoggingEventHandler()
handler.handle(EdgeAddedEvent(source_id="a", target_id="b", weight=1.0))
def test_handle_step_completed(self):
handler = LoggingEventHandler()
handler.handle(StepCompletedEvent(agent_id="solver", tokens_used=100, success=True))
def test_handle_step_failed(self):
handler = LoggingEventHandler()
handler.handle(StepFailedEvent(agent_id="solver", error_message="timeout"))
def test_handle_budget_warning(self):
handler = LoggingEventHandler()
handler.handle(BudgetWarningEvent(budget_type="tokens", current_value=800.0, limit=1000.0, ratio=0.8))
def test_handle_run_completed(self):
handler = LoggingEventHandler()
handler.handle(RunCompletedEvent(success=True, total_steps=3, total_tokens=500))
def test_handle_critical_priority(self):
handler = LoggingEventHandler()
event = BudgetExceededEvent(budget_type="tokens", current_value=1000.0, limit=1000.0)
assert event.priority == EventPriority.CRITICAL
handler.handle(event) # should not raise
def test_handle_high_priority(self):
handler = LoggingEventHandler()
event = StepFailedEvent(agent_id="s", error_message="err")
assert event.priority == EventPriority.HIGH
handler.handle(event)
def test_handle_with_metadata(self):
handler = LoggingEventHandler(include_metadata=True)
event = NodeAddedEvent(node_id="n1", metadata={"tag": "test"})
handler.handle(event)
def test_custom_format_func(self):
called = []
handler = LoggingEventHandler(format_func=lambda e: called.append(e) or "formatted")
handler.handle(NodeAddedEvent(node_id="n1"))
assert len(called) == 1
# ─────────────────────────── MetricsEventHandler ─────────────────────────────
class TestMetricsEventHandler:
def setup_method(self):
self.handler = MetricsEventHandler()
def test_initial_state(self):
metrics = self.handler.get_metrics()
assert metrics["total_tokens"] == 0
assert metrics["errors_count"] == 0
def test_step_completed_accumulates_tokens(self):
self.handler.handle(StepCompletedEvent(agent_id="a", tokens_used=100, duration_ms=200.0))
self.handler.handle(StepCompletedEvent(agent_id="b", tokens_used=50, duration_ms=100.0))
metrics = self.handler.get_metrics()
assert metrics["total_tokens"] == 150
assert metrics["total_duration_ms"] == 300.0
def test_step_failed_records_error(self):
self.handler.handle(StepFailedEvent(agent_id="a", error_type="TimeoutError", error_message="timeout"))
metrics = self.handler.get_metrics()
assert metrics["errors_count"] == 1
assert "TimeoutError" in metrics["errors"][0]["error_type"]
def test_budget_warning(self):
self.handler.handle(BudgetWarningEvent(budget_type="tokens", current_value=800.0, limit=1000.0))
assert self.handler.get_metrics()["budget_warnings"] == 1
def test_run_completed_success(self):
self.handler.handle(RunCompletedEvent(success=True))
metrics = self.handler.get_metrics()
assert metrics["runs_completed"] == 1
assert metrics["runs_failed"] == 0
def test_run_completed_failure(self):
self.handler.handle(RunCompletedEvent(success=False))
metrics = self.handler.get_metrics()
assert metrics["runs_failed"] == 1
def test_event_count_tracking(self):
self.handler.handle(NodeAddedEvent(node_id="n1"))
self.handler.handle(NodeAddedEvent(node_id="n2"))
metrics = self.handler.get_metrics()
assert metrics["event_counts"]["node_added"] == 2
def test_avg_step_duration(self):
self.handler.handle(StepCompletedEvent(agent_id="a", duration_ms=200.0))
self.handler.handle(StepCompletedEvent(agent_id="b", duration_ms=400.0))
metrics = self.handler.get_metrics()
assert metrics["avg_step_duration_ms"] == 300.0
def test_reset(self):
self.handler.handle(StepCompletedEvent(agent_id="a", tokens_used=100))
self.handler.reset()
metrics = self.handler.get_metrics()
assert metrics["total_tokens"] == 0
assert metrics["event_counts"] == {}
# ─────────────────────────── Global Bus / emit_event / on_event ──────────────
class TestGlobalBus:
def test_global_event_bus_singleton(self):
bus1 = global_event_bus()
bus2 = global_event_bus()
assert bus1 is bus2
def test_GlobalEventBus_alias(self):
assert GlobalEventBus is global_event_bus
def test_emit_event(self):
bus = global_event_bus()
handler = ConcreteHandler()
bus.subscribe(EventType.NODE_ADDED, handler)
emit_event(NodeAddedEvent(node_id="test_emit"))
assert any(e.node_id == "test_emit" for e in handler.received if isinstance(e, NodeAddedEvent))
bus.unsubscribe(EventType.NODE_ADDED, handler)
def test_on_event_decorator(self):
received = []
@on_event(EventType.EDGE_ADDED)
def my_handler(event):
received.append(event)
bus = global_event_bus()
bus.publish(EdgeAddedEvent(source_id="x", target_id="y"))
assert len(received) >= 1
# Cleanup: unsubscribe
bus.unsubscribe(EventType.EDGE_ADDED, my_handler) # won't work since it was wrapped, but that's OK
# ─────────────────────────── Missing branch coverage ─────────────────────────
class TestEventBusSubscribeTypeError:
def test_subscribe_non_callable_raises_type_error(self):
"""Line 318-319: TypeError when handler is not EventHandler or callable."""
bus = EventBus()
with pytest.raises(TypeError, match="Handler must be EventHandler or callable"):
bus.subscribe(EventType.NODE_ADDED, "not_a_handler") # type: ignore
class TestLoggingEventHandlerNonStandardLogger:
def test_logger_without_log_method_uses_getattr(self):
"""Line 430: else branch when logger doesn't have .log() method."""
class SimpleLogger:
def __init__(self):
self.messages = []
def info(self, msg):
self.messages.append(msg)
def debug(self, msg):
self.messages.append(msg)
def warning(self, msg):
self.messages.append(msg)
def error(self, msg):
self.messages.append(msg)
simple_logger = SimpleLogger()
handler = LoggingEventHandler(log_level="info")
# Override _logger with one that has no .log() method
handler._logger = simple_logger
handler.handle(NodeAddedEvent(node_id="test_node"))
assert len(simple_logger.messages) > 0