gMAS / tests /test_callback_handlers.py
Артём Боярских
chore: initial commit
3193174
"""Tests for callback handlers: FileCallbackHandler, MetricsCallbackHandler, StdoutCallbackHandler"""
import json
import tempfile
from pathlib import Path
from uuid import uuid4
import pytest
from callbacks.handlers.file import FileCallbackHandler
from callbacks.handlers.metrics import MetricsCallbackHandler
from callbacks.handlers.stdout import StdoutCallbackHandler
# ─────────────────────────── FileCallbackHandler ─────────────────────────────
@pytest.fixture
def tmp_log_file(tmp_path):
return tmp_path / "events.jsonl"
@pytest.fixture
def file_handler(tmp_log_file):
handler = FileCallbackHandler(tmp_log_file)
yield handler
handler.close()
def _read_events(path: Path) -> list[dict]:
if not path.exists():
return []
events = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
events.append(json.loads(line))
return events
class TestFileCallbackHandlerInit:
def test_creates_file(self, tmp_log_file):
handler = FileCallbackHandler(tmp_log_file)
handler.close()
assert tmp_log_file.exists()
def test_creates_parent_dirs(self, tmp_path):
deep_path = tmp_path / "logs" / "sub" / "events.jsonl"
handler = FileCallbackHandler(deep_path)
handler.close()
assert deep_path.exists()
def test_append_mode(self, tmp_log_file):
run_id = uuid4()
h = FileCallbackHandler(tmp_log_file, append=True)
h.on_run_start(run_id=run_id, query="q1")
h.close()
h2 = FileCallbackHandler(tmp_log_file, append=True)
h2.on_run_start(run_id=run_id, query="q2")
h2.close()
events = _read_events(tmp_log_file)
assert len(events) == 2
def test_overwrite_mode(self, tmp_log_file):
run_id = uuid4()
h = FileCallbackHandler(tmp_log_file, append=False)
h.on_run_start(run_id=run_id, query="q1")
h.close()
h2 = FileCallbackHandler(tmp_log_file, append=False)
h2.on_run_start(run_id=run_id, query="q2")
h2.close()
events = _read_events(tmp_log_file)
assert len(events) == 1
class TestFileCallbackHandlerRunLifecycle:
def test_on_run_start(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_run_start(
run_id=run_id,
query="test query",
num_agents=3,
execution_order=["a", "b", "c"],
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert len(events) == 1
e = events[0]
assert e["event_type"] == "run_start"
assert e["query"] == "test query"
assert e["num_agents"] == 3
assert e["run_id"] == str(run_id)
def test_on_run_end_success(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_run_end(
run_id=run_id,
output="result",
success=True,
total_tokens=500,
total_time_ms=1234.5,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "run_end"
assert e["success"] is True
assert e["total_tokens"] == 500
def test_on_run_end_failure(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_run_end(
run_id=run_id,
output="",
success=False,
error=ValueError("something failed"),
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "run_end"
assert e["success"] is False
assert "something failed" in e["error"]
class TestFileCallbackHandlerAgentLifecycle:
def test_on_agent_start(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_agent_start(
run_id=run_id,
agent_id="solver",
agent_name="Solver",
step_index=0,
prompt="Solve this",
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "agent_start"
assert e["agent_id"] == "solver"
def test_on_agent_end(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_agent_end(
run_id=run_id,
agent_id="solver",
output="answer",
tokens_used=100,
duration_ms=250.0,
is_final=True,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "agent_end"
assert e["tokens_used"] == 100
assert e["is_final"] is True
def test_on_agent_error(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_agent_error(
ValueError("oops"),
run_id=run_id,
agent_id="solver",
will_retry=True,
attempt=1,
max_attempts=3,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "agent_error"
assert e["will_retry"] is True
assert "ValueError" in e["error_type"] or "oops" in e["error_message"]
def test_on_retry(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_retry(
run_id=run_id,
agent_id="solver",
attempt=2,
max_attempts=3,
delay_ms=500.0,
error="Timeout",
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "retry"
assert e["attempt"] == 2
class TestFileCallbackHandlerTokenStreaming:
def test_on_llm_new_token_first(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_llm_new_token(
"Hello",
run_id=run_id,
agent_id="solver",
is_first=True,
token_index=0,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert len(events) == 1
assert events[0]["is_first"] is True
def test_on_llm_new_token_middle_not_logged(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_llm_new_token(
"word",
run_id=run_id,
agent_id="solver",
is_first=False,
is_last=False,
token_index=5,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert len(events) == 0 # middle tokens not logged
def test_on_llm_new_token_last(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_llm_new_token(
"end",
run_id=run_id,
agent_id="solver",
is_last=True,
token_index=99,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert len(events) == 1
assert events[0]["is_last"] is True
class TestFileCallbackHandlerPlanningAndOther:
def test_on_plan_created(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_plan_created(
run_id=run_id,
num_steps=3,
execution_order=["a", "b", "c"],
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "plan_created"
assert e["num_steps"] == 3
def test_on_topology_changed(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_topology_changed(
run_id=run_id,
reason="agent pruned",
old_remaining=["a", "b"],
new_remaining=["b"],
change_count=1,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "topology_changed"
assert e["change_count"] == 1
def test_on_prune(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_prune(run_id=run_id, agent_id="slow_agent", reason="too slow")
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert events[0]["event_type"] == "prune"
assert events[0]["reason"] == "too slow"
def test_on_fallback(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_fallback(
run_id=run_id,
failed_agent_id="agent_a",
fallback_agent_id="agent_b",
reason="failure",
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "fallback"
assert e["failed_agent_id"] == "agent_a"
def test_on_parallel_start(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_parallel_start(
run_id=run_id,
agent_ids=["a", "b"],
group_index=0,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert events[0]["event_type"] == "parallel_start"
def test_on_parallel_end(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_parallel_end(
run_id=run_id,
agent_ids=["a", "b"],
group_index=0,
successful=["a", "b"],
failed=[],
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert events[0]["event_type"] == "parallel_end"
def test_on_memory_read(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_memory_read(
run_id=run_id,
agent_id="agent1",
entries_count=5,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert events[0]["event_type"] == "memory_read"
def test_on_memory_write(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_memory_write(
run_id=run_id,
agent_id="agent1",
key="context",
value_size=1024,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
assert events[0]["event_type"] == "memory_write"
def test_on_budget_warning(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_budget_warning(
run_id=run_id,
budget_type="tokens",
current=800.0,
limit=1000.0,
ratio=0.8,
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "budget_warning"
assert e["budget_type"] == "tokens"
def test_on_budget_exceeded(self, file_handler, tmp_log_file):
run_id = uuid4()
file_handler.on_budget_exceeded(
run_id=run_id,
budget_type="requests",
current=10.0,
limit=10.0,
action_taken="stop",
)
file_handler._file.flush()
events = _read_events(tmp_log_file)
e = events[0]
assert e["event_type"] == "budget_exceeded"
assert e["action_taken"] == "stop"
def test_close_idempotent(self, file_handler):
file_handler.close()
file_handler.close() # should not raise
def test_flush_every(self, tmp_log_file):
"""Test flush_every parameter (event count flush trigger)."""
handler = FileCallbackHandler(tmp_log_file, flush_every=3)
run_id = uuid4()
for _ in range(3):
handler.on_run_start(run_id=run_id, query="q")
handler.close()
events = _read_events(tmp_log_file)
assert len(events) == 3
# ─────────────────────────── MetricsCallbackHandler ──────────────────────────
class TestMetricsCallbackHandler:
def setup_method(self):
self.handler = MetricsCallbackHandler()
self.run_id = uuid4()
def test_initial_state(self):
metrics = self.handler.get_metrics()
assert metrics["total_tokens"] == 0
assert metrics["runs_completed"] == 0
assert metrics["runs_failed"] == 0
assert metrics["retries"] == 0
def test_on_run_start_records_time(self):
self.handler.on_run_start(run_id=self.run_id, query="test")
assert self.handler._run_start_time is not None
def test_on_run_end_success(self):
self.handler.on_run_start(run_id=self.run_id, query="test")
self.handler.on_run_end(
run_id=self.run_id,
output="result",
success=True,
total_tokens=500,
total_time_ms=1000.0,
)
assert self.handler.runs_completed == 1
assert self.handler.runs_failed == 0
assert self.handler.total_tokens == 500
def test_on_run_end_failure(self):
self.handler.on_run_start(run_id=self.run_id, query="test")
self.handler.on_run_end(
run_id=self.run_id,
output="",
success=False,
)
assert self.handler.runs_failed == 1
assert self.handler.runs_completed == 0
def test_on_agent_end_accumulates(self):
self.handler.on_agent_end(
run_id=self.run_id,
agent_id="solver",
output="answer",
tokens_used=150,
duration_ms=200.0,
)
self.handler.on_agent_end(
run_id=self.run_id,
agent_id="solver",
output="another",
tokens_used=100,
duration_ms=100.0,
)
metrics = self.handler.get_metrics()
assert metrics["total_tokens"] == 250
assert metrics["agent_tokens"]["solver"] == 250
assert metrics["agent_calls"]["solver"] == 2
def test_on_agent_error(self):
self.handler.on_agent_error(
ValueError("timeout"),
run_id=self.run_id,
agent_id="solver",
)
metrics = self.handler.get_metrics()
assert metrics["errors_count"] == 1
assert "ValueError" in metrics["errors"][-1]["error_type"]
def test_on_retry(self):
self.handler.on_retry(
run_id=self.run_id,
agent_id="solver",
attempt=1,
max_attempts=3,
)
metrics = self.handler.get_metrics()
assert metrics["retries"] == 1
def test_on_budget_warning(self):
self.handler.on_budget_warning(
run_id=self.run_id,
budget_type="tokens",
current=800.0,
limit=1000.0,
)
metrics = self.handler.get_metrics()
assert metrics["budget_warnings"] == 1
def test_on_tool_end(self):
self.handler.on_tool_end(
run_id=self.run_id,
tool_name="code_interpreter",
action="execute",
success=True,
duration_ms=100.0,
)
metrics = self.handler.get_metrics()
assert "code_interpreter.execute" in metrics["tool_calls"]
assert metrics["tool_calls"]["code_interpreter.execute"] == 1
def test_on_tool_end_no_action(self):
self.handler.on_tool_end(
run_id=self.run_id,
tool_name="file_search",
success=True,
duration_ms=50.0,
)
metrics = self.handler.get_metrics()
assert "file_search" in metrics["tool_calls"]
def test_on_tool_error(self):
self.handler.on_tool_error(
run_id=self.run_id,
tool_name="web_search",
error_type="TimeoutError",
error_message="Connection timeout",
)
metrics = self.handler.get_metrics()
assert metrics["tool_errors"] == 1
assert metrics["errors_count"] >= 1
def test_reset(self):
self.handler.on_agent_end(
run_id=self.run_id,
agent_id="solver",
output="x",
tokens_used=100,
)
self.handler.reset()
metrics = self.handler.get_metrics()
assert metrics["total_tokens"] == 0
assert metrics["runs_completed"] == 0
def test_avg_tokens_per_agent(self):
self.handler.on_agent_end(
run_id=self.run_id,
agent_id="a1",
output="x",
tokens_used=100,
)
self.handler.on_agent_end(
run_id=self.run_id,
agent_id="a2",
output="y",
tokens_used=200,
)
metrics = self.handler.get_metrics()
avg = metrics["avg_tokens_per_agent"]
assert avg == 150.0
def test_multiple_runs(self):
for i in range(3):
rid = uuid4()
self.handler.on_run_start(run_id=rid, query=f"q{i}")
self.handler.on_run_end(run_id=rid, output="ok", success=True)
assert self.handler.runs_completed == 3
def test_total_duration_accumulates(self):
self.handler.on_agent_end(
run_id=self.run_id, agent_id="a1", output="x", duration_ms=100.0
)
self.handler.on_agent_end(
run_id=self.run_id, agent_id="a1", output="y", duration_ms=200.0
)
assert self.handler.total_duration_ms == 300.0
# ─────────────────────────── StdoutCallbackHandler ───────────────────────────
class TestStdoutCallbackHandler:
"""StdoutCallbackHandler methods should not raise exceptions."""
def setup_method(self):
self.handler = StdoutCallbackHandler()
self.run_id = uuid4()
def test_on_run_start(self):
self.handler.on_run_start(
run_id=self.run_id,
query="test",
num_agents=3,
execution_order=["a", "b", "c"],
)
assert self.handler._indent == 1
def test_on_run_end_success(self):
self.handler._indent = 1
self.handler.on_run_end(
run_id=self.run_id,
output="result",
success=True,
total_tokens=100,
total_time_ms=500.0,
)
assert self.handler._indent == 0
def test_on_run_end_failure(self):
self.handler._indent = 1
self.handler.on_run_end(
run_id=self.run_id,
output="",
success=False,
error=RuntimeError("failed"),
)
def test_on_agent_start(self):
initial = self.handler._indent
self.handler.on_agent_start(
run_id=self.run_id,
agent_id="solver",
agent_name="Solver",
step_index=0,
prompt="hello",
)
assert self.handler._indent == initial + 1
def test_on_agent_start_with_prompt(self):
handler = StdoutCallbackHandler(show_prompts=True)
handler.on_agent_start(
run_id=self.run_id,
agent_id="solver",
agent_name="Solver",
step_index=0,
prompt="A very long prompt that should be shown",
)
def test_on_agent_end(self):
self.handler._indent = 1
self.handler.on_agent_end(
run_id=self.run_id,
agent_id="solver",
output="result",
tokens_used=50,
duration_ms=100.0,
is_final=True,
)
assert self.handler._indent == 0
def test_on_agent_end_with_output(self):
handler = StdoutCallbackHandler(show_outputs=True)
handler._indent = 1
handler.on_agent_end(
run_id=self.run_id,
agent_id="solver",
output="The answer is 42",
tokens_used=50,
)
def test_on_agent_error_no_retry(self):
self.handler.on_agent_error(
ValueError("test"),
run_id=self.run_id,
agent_id="solver",
)
def test_on_agent_error_with_retry(self):
self.handler.on_agent_error(
ValueError("test"),
run_id=self.run_id,
agent_id="solver",
will_retry=True,
attempt=1,
max_attempts=3,
)
def test_on_retry(self):
self.handler.on_retry(
run_id=self.run_id,
agent_id="solver",
attempt=2,
max_attempts=3,
delay_ms=500.0,
)
def test_on_llm_new_token(self):
self.handler.on_llm_new_token(
"token",
run_id=self.run_id,
agent_id="solver",
is_first=True,
)
self.handler.on_llm_new_token(
"token",
run_id=self.run_id,
agent_id="solver",
is_last=True,
)
def test_on_plan_created(self):
self.handler.on_plan_created(
run_id=self.run_id,
num_steps=3,
execution_order=["a", "b", "c"],
)
def test_on_topology_changed(self):
self.handler.on_topology_changed(
run_id=self.run_id,
reason="pruned",
old_remaining=["a", "b"],
new_remaining=["b"],
change_count=1,
)
def test_on_prune(self):
self.handler.on_prune(
run_id=self.run_id,
agent_id="slow_agent",
reason="too slow",
)
def test_on_fallback(self):
self.handler.on_fallback(
run_id=self.run_id,
failed_agent_id="agent_a",
fallback_agent_id="agent_b",
)
def test_on_parallel_start(self):
self.handler.on_parallel_start(
run_id=self.run_id,
agent_ids=["a", "b"],
group_index=0,
)
def test_on_parallel_end(self):
self.handler._indent = 1
self.handler.on_parallel_end(
run_id=self.run_id,
agent_ids=["a", "b"],
group_index=0,
successful=["a", "b"],
)
def test_on_budget_warning(self):
self.handler.on_budget_warning(
run_id=self.run_id,
budget_type="tokens",
current=800.0,
limit=1000.0,
ratio=0.8,
)
def test_on_budget_exceeded(self):
self.handler.on_budget_exceeded(
run_id=self.run_id,
budget_type="tokens",
current=1000.0,
limit=1000.0,
action_taken="stop",
)
def test_on_tool_start(self):
self.handler.on_tool_start(
run_id=self.run_id,
tool_name="code_interpreter",
action="execute",
arguments={"code": "print(1)"},
)
def test_on_tool_end_success(self):
self.handler._indent = 1
self.handler.on_tool_end(
run_id=self.run_id,
tool_name="code_interpreter",
action="execute",
success=True,
duration_ms=100.0,
output_size=50,
)
def test_on_tool_end_failure(self):
self.handler._indent = 1
self.handler.on_tool_end(
run_id=self.run_id,
tool_name="code_interpreter",
action="execute",
success=False,
duration_ms=100.0,
)
def test_on_tool_error(self):
self.handler.on_tool_error(
run_id=self.run_id,
tool_name="web_search",
action="search",
error_type="TimeoutError",
error_message="Connection timed out",
)
def test_truncate(self):
handler = StdoutCallbackHandler(truncate_length=10)
short = "hello"
long_text = "x" * 100
assert handler._truncate(short) == short
truncated = handler._truncate(long_text)
assert truncated.endswith("...")
assert len(truncated) == 13 # 10 + "..."
def test_indent_not_go_below_zero(self):
self.handler._indent = 0
self.handler.on_run_end(
run_id=self.run_id,
output="x",
success=True,
)
assert self.handler._indent == 0
def test_run_without_execution_order(self):
self.handler.on_run_start(
run_id=self.run_id,
query="test",
num_agents=0,
)
# ─────────────────────────── FileCallbackHandler - _file is None ─────────────
class TestFileCallbackHandlerFileNone:
def test_write_event_returns_early_when_file_is_none(self, tmp_path):
"""Line 42: _write_event returns early when _file is None."""
from uuid import uuid4
tmp_log_file = tmp_path / "test_events.jsonl"
handler = FileCallbackHandler(tmp_log_file)
# Manually set _file to None to simulate closed state
handler._file = None
# Should not raise and should not write anything
handler._write_event("test_event", {"key": "value"})
handler.close()
# File should be empty or not exist since we bypassed normal writes
import json
events = []
if tmp_log_file.exists():
with tmp_log_file.open() as f:
for line in f:
line = line.strip()
if line:
events.append(json.loads(line))
assert len(events) == 0