File size: 2,277 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from injector import Injector

from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory import RoundCompressor


def test_round_compressor():
    from taskweaver.memory import Post, Round

    app_injector = Injector(
        [LoggingModule],
    )
    app_config = AppConfigSource(
        config={
            "llm.api_key": "test_key",
            "round_compressor.rounds_to_compress": 2,
            "round_compressor.rounds_to_retain": 2,
        },
    )
    app_injector.binder.bind(AppConfigSource, to=app_config)
    compressor = app_injector.get(RoundCompressor)

    assert compressor.rounds_to_compress == 2
    assert compressor.rounds_to_retain == 2

    round1 = Round.create(user_query="hello", id="round-1")
    post1 = Post.create(
        message="hello",
        send_from="User",
        send_to="Planner",
        attachment_list=[],
    )
    post2 = Post.create(
        message="hello",
        send_from="Planner",
        send_to="User",
        attachment_list=[],
    )
    round1.add_post(post1)
    round1.add_post(post2)

    summary, retained = compressor.compress_rounds(
        [round1],
        lambda x: x,
        use_back_up_engine=False,
    )
    assert summary == "None"
    assert len(retained) == 1

    round2 = Round.create(user_query="hello", id="round-2")
    round2.add_post(post1)
    round2.add_post(post2)

    summary, retained = compressor.compress_rounds(
        [round1, round2],
        lambda x: x,
        use_back_up_engine=False,
    )
    assert summary == "None"
    assert len(retained) == 2

    round3 = Round.create(user_query="hello", id="round-3")
    round3.add_post(post1)
    round3.add_post(post2)
    summary, retained = compressor.compress_rounds(
        [round1, round2, round3],
        lambda x: x,
        use_back_up_engine=False,
    )
    assert summary == "None"
    assert len(retained) == 3

    round4 = Round.create(user_query="hello", id="round-4")
    round4.add_post(post1)
    round4.add_post(post2)
    summary, retained = compressor.compress_rounds(
        [round1, round2, round3, round4],
        lambda x: x,
        use_back_up_engine=False,
    )
    assert summary == "None"
    assert len(retained) == 4