File size: 2,277 Bytes
3d3d712 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | from injector import Injector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory import RoundCompressor
def test_round_compressor():
from taskweaver.memory import Post, Round
app_injector = Injector(
[LoggingModule],
)
app_config = AppConfigSource(
config={
"llm.api_key": "test_key",
"round_compressor.rounds_to_compress": 2,
"round_compressor.rounds_to_retain": 2,
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
compressor = app_injector.get(RoundCompressor)
assert compressor.rounds_to_compress == 2
assert compressor.rounds_to_retain == 2
round1 = Round.create(user_query="hello", id="round-1")
post1 = Post.create(
message="hello",
send_from="User",
send_to="Planner",
attachment_list=[],
)
post2 = Post.create(
message="hello",
send_from="Planner",
send_to="User",
attachment_list=[],
)
round1.add_post(post1)
round1.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1],
lambda x: x,
use_back_up_engine=False,
)
assert summary == "None"
assert len(retained) == 1
round2 = Round.create(user_query="hello", id="round-2")
round2.add_post(post1)
round2.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1, round2],
lambda x: x,
use_back_up_engine=False,
)
assert summary == "None"
assert len(retained) == 2
round3 = Round.create(user_query="hello", id="round-3")
round3.add_post(post1)
round3.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1, round2, round3],
lambda x: x,
use_back_up_engine=False,
)
assert summary == "None"
assert len(retained) == 3
round4 = Round.create(user_query="hello", id="round-4")
round4.add_post(post1)
round4.add_post(post2)
summary, retained = compressor.compress_rounds(
[round1, round2, round3, round4],
lambda x: x,
use_back_up_engine=False,
)
assert summary == "None"
assert len(retained) == 4
|