File size: 2,731 Bytes
e9aa04f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import annotations

from langchain_core.messages import AIMessageChunk
from livekit.plugins.langchain import langgraph as lk_langgraph

from src.agent._langchain_usage_patch import (
    _PATCH_FLAG,
    _ORIGINAL_TO_CHAT_CHUNK,
    _completion_usage_from_message_chunk,
    apply_langchain_usage_patch,
)


def test_completion_usage_is_extracted_from_usage_metadata() -> None:
    chunk = AIMessageChunk(
        content="hello",
        usage_metadata={
            "input_tokens": 12,
            "output_tokens": 24,
            "total_tokens": 36,
            "input_token_details": {"cache_read": 7},
        },
    )

    usage = _completion_usage_from_message_chunk(chunk)
    assert usage is not None
    assert usage.prompt_tokens == 12
    assert usage.completion_tokens == 24
    assert usage.total_tokens == 36
    assert usage.prompt_cached_tokens == 7


def test_completion_usage_falls_back_to_response_metadata_token_usage() -> None:
    chunk = AIMessageChunk(
        content="hello",
        response_metadata={
            "token_usage": {
                "prompt_tokens": 10,
                "completion_tokens": 5,
                "total_tokens": 15,
                "prompt_tokens_details": {"cached_tokens": 2},
            }
        },
    )

    usage = _completion_usage_from_message_chunk(chunk)
    assert usage is not None
    assert usage.prompt_tokens == 10
    assert usage.completion_tokens == 5
    assert usage.total_tokens == 15
    assert usage.prompt_cached_tokens == 2


def test_patch_preserves_usage_only_chunks(monkeypatch) -> None:  # type: ignore[no-untyped-def]
    original = lk_langgraph._to_chat_chunk
    monkeypatch.setattr(lk_langgraph, _PATCH_FLAG, False, raising=False)
    monkeypatch.delattr(lk_langgraph, _ORIGINAL_TO_CHAT_CHUNK, raising=False)
    monkeypatch.setattr(lk_langgraph, "_to_chat_chunk", original)

    assert apply_langchain_usage_patch() is True
    assert apply_langchain_usage_patch() is False

    usage_only_chunk = AIMessageChunk(
        content="",
        usage_metadata={
            "input_tokens": 8,
            "output_tokens": 3,
            "total_tokens": 11,
        },
    )
    chat_chunk = lk_langgraph._to_chat_chunk(usage_only_chunk)
    assert chat_chunk is not None
    assert chat_chunk.delta is None
    assert chat_chunk.usage is not None
    assert chat_chunk.usage.prompt_tokens == 8
    assert chat_chunk.usage.completion_tokens == 3

    # Restore to avoid bleeding monkeypatches between tests.
    monkeypatch.setattr(lk_langgraph, "_to_chat_chunk", original)
    monkeypatch.setattr(lk_langgraph, _PATCH_FLAG, False, raising=False)
    monkeypatch.delattr(lk_langgraph, _ORIGINAL_TO_CHAT_CHUNK, raising=False)