document_redaction / test /test_pi_session_usage.py
seanpedrickcase's picture
Sync: fix on agent task download links with root path
b5355b0
Raw
History Blame Contribute Delete
2.81 kB
"""Tests for Pi session token usage aggregation."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from unittest.mock import MagicMock
_PI = Path(__file__).resolve().parents[1] / "agent-redact" / "pi"
if str(_PI) not in sys.path:
sys.path.insert(0, str(_PI))
from pi_session_usage import ( # noqa: E402
TokenUsageTotals,
sum_usage_from_jsonl,
sum_usage_from_messages,
totals_from_stats_payload,
totals_from_usage_dict,
usage_for_completed_turn,
)
def test_totals_from_usage_dict_sums_cache_into_input_column():
usage = totals_from_usage_dict(
{
"input": 100,
"output": 40,
"cacheRead": 500,
"cacheWrite": 10,
}
)
assert usage.llm_input_tokens == 610
assert usage.llm_output_tokens == 40
def test_sum_usage_from_messages_since_last_user():
messages = [
{"role": "user", "content": "first"},
{
"role": "assistant",
"usage": {"input": 10, "output": 5, "cacheRead": 0, "cacheWrite": 0},
},
{"role": "user", "content": "second"},
{
"role": "assistant",
"usage": {"input": 20, "output": 8, "cacheRead": 1, "cacheWrite": 0},
},
]
turn = sum_usage_from_messages(messages, since_last_user=True)
assert turn.llm_input_tokens == 21
assert turn.llm_output_tokens == 8
def test_sum_usage_from_jsonl(tmp_path):
log = tmp_path / "session.jsonl"
log.write_text(
"\n".join(
[
json.dumps({"type": "session", "id": "s1"}),
json.dumps(
{
"type": "message",
"message": {
"role": "assistant",
"usage": {"input": 3, "output": 2},
},
}
),
]
)
+ "\n",
encoding="utf-8",
)
totals = sum_usage_from_jsonl(log)
assert totals.llm_input_tokens == 3
assert totals.llm_output_tokens == 2
def test_usage_for_completed_turn_prefers_stats_delta():
client = MagicMock()
client.running = True
client.get_session_stats.side_effect = [
{"tokens": {"input": 100, "output": 10, "cacheRead": 0, "cacheWrite": 0}},
{"tokens": {"input": 250, "output": 55, "cacheRead": 0, "cacheWrite": 0}},
]
client.get_messages.return_value = []
baseline = totals_from_stats_payload(client.get_session_stats())
usage = usage_for_completed_turn(client, baseline)
assert usage.llm_input_tokens == 150
assert usage.llm_output_tokens == 45
def test_totals_from_stats_payload_empty():
assert totals_from_stats_payload(None) == TokenUsageTotals()