Spaces:
Running
Running
File size: 2,731 Bytes
e9aa04f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | from __future__ import annotations
from langchain_core.messages import AIMessageChunk
from livekit.plugins.langchain import langgraph as lk_langgraph
from src.agent._langchain_usage_patch import (
_PATCH_FLAG,
_ORIGINAL_TO_CHAT_CHUNK,
_completion_usage_from_message_chunk,
apply_langchain_usage_patch,
)
def test_completion_usage_is_extracted_from_usage_metadata() -> None:
chunk = AIMessageChunk(
content="hello",
usage_metadata={
"input_tokens": 12,
"output_tokens": 24,
"total_tokens": 36,
"input_token_details": {"cache_read": 7},
},
)
usage = _completion_usage_from_message_chunk(chunk)
assert usage is not None
assert usage.prompt_tokens == 12
assert usage.completion_tokens == 24
assert usage.total_tokens == 36
assert usage.prompt_cached_tokens == 7
def test_completion_usage_falls_back_to_response_metadata_token_usage() -> None:
chunk = AIMessageChunk(
content="hello",
response_metadata={
"token_usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
"prompt_tokens_details": {"cached_tokens": 2},
}
},
)
usage = _completion_usage_from_message_chunk(chunk)
assert usage is not None
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 5
assert usage.total_tokens == 15
assert usage.prompt_cached_tokens == 2
def test_patch_preserves_usage_only_chunks(monkeypatch) -> None: # type: ignore[no-untyped-def]
original = lk_langgraph._to_chat_chunk
monkeypatch.setattr(lk_langgraph, _PATCH_FLAG, False, raising=False)
monkeypatch.delattr(lk_langgraph, _ORIGINAL_TO_CHAT_CHUNK, raising=False)
monkeypatch.setattr(lk_langgraph, "_to_chat_chunk", original)
assert apply_langchain_usage_patch() is True
assert apply_langchain_usage_patch() is False
usage_only_chunk = AIMessageChunk(
content="",
usage_metadata={
"input_tokens": 8,
"output_tokens": 3,
"total_tokens": 11,
},
)
chat_chunk = lk_langgraph._to_chat_chunk(usage_only_chunk)
assert chat_chunk is not None
assert chat_chunk.delta is None
assert chat_chunk.usage is not None
assert chat_chunk.usage.prompt_tokens == 8
assert chat_chunk.usage.completion_tokens == 3
# Restore to avoid bleeding monkeypatches between tests.
monkeypatch.setattr(lk_langgraph, "_to_chat_chunk", original)
monkeypatch.setattr(lk_langgraph, _PATCH_FLAG, False, raising=False)
monkeypatch.delattr(lk_langgraph, _ORIGINAL_TO_CHAT_CHUNK, raising=False)
|