| """Tests for Pi session token usage aggregation.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import sys |
| from pathlib import Path |
| from unittest.mock import MagicMock |
|
|
| _PI = Path(__file__).resolve().parents[1] / "agent-redact" / "pi" |
| if str(_PI) not in sys.path: |
| sys.path.insert(0, str(_PI)) |
|
|
| from pi_session_usage import ( |
| TokenUsageTotals, |
| sum_usage_from_jsonl, |
| sum_usage_from_messages, |
| totals_from_stats_payload, |
| totals_from_usage_dict, |
| usage_for_completed_turn, |
| ) |
|
|
|
|
| def test_totals_from_usage_dict_sums_cache_into_input_column(): |
| usage = totals_from_usage_dict( |
| { |
| "input": 100, |
| "output": 40, |
| "cacheRead": 500, |
| "cacheWrite": 10, |
| } |
| ) |
| assert usage.llm_input_tokens == 610 |
| assert usage.llm_output_tokens == 40 |
|
|
|
|
| def test_sum_usage_from_messages_since_last_user(): |
| messages = [ |
| {"role": "user", "content": "first"}, |
| { |
| "role": "assistant", |
| "usage": {"input": 10, "output": 5, "cacheRead": 0, "cacheWrite": 0}, |
| }, |
| {"role": "user", "content": "second"}, |
| { |
| "role": "assistant", |
| "usage": {"input": 20, "output": 8, "cacheRead": 1, "cacheWrite": 0}, |
| }, |
| ] |
| turn = sum_usage_from_messages(messages, since_last_user=True) |
| assert turn.llm_input_tokens == 21 |
| assert turn.llm_output_tokens == 8 |
|
|
|
|
| def test_sum_usage_from_jsonl(tmp_path): |
| log = tmp_path / "session.jsonl" |
| log.write_text( |
| "\n".join( |
| [ |
| json.dumps({"type": "session", "id": "s1"}), |
| json.dumps( |
| { |
| "type": "message", |
| "message": { |
| "role": "assistant", |
| "usage": {"input": 3, "output": 2}, |
| }, |
| } |
| ), |
| ] |
| ) |
| + "\n", |
| encoding="utf-8", |
| ) |
| totals = sum_usage_from_jsonl(log) |
| assert totals.llm_input_tokens == 3 |
| assert totals.llm_output_tokens == 2 |
|
|
|
|
| def test_usage_for_completed_turn_prefers_stats_delta(): |
| client = MagicMock() |
| client.running = True |
| client.get_session_stats.side_effect = [ |
| {"tokens": {"input": 100, "output": 10, "cacheRead": 0, "cacheWrite": 0}}, |
| {"tokens": {"input": 250, "output": 55, "cacheRead": 0, "cacheWrite": 0}}, |
| ] |
| client.get_messages.return_value = [] |
|
|
| baseline = totals_from_stats_payload(client.get_session_stats()) |
| usage = usage_for_completed_turn(client, baseline) |
| assert usage.llm_input_tokens == 150 |
| assert usage.llm_output_tokens == 45 |
|
|
|
|
| def test_totals_from_stats_payload_empty(): |
| assert totals_from_stats_payload(None) == TokenUsageTotals() |
|
|