| from injector import Injector | |
| from taskweaver.config.config_mgt import AppConfigSource | |
| from taskweaver.logging import LoggingModule | |
| from taskweaver.memory import RoundCompressor | |
| def test_round_compressor(): | |
| from taskweaver.memory import Post, Round | |
| app_injector = Injector( | |
| [LoggingModule], | |
| ) | |
| app_config = AppConfigSource( | |
| config={ | |
| "llm.api_key": "test_key", | |
| "round_compressor.rounds_to_compress": 2, | |
| "round_compressor.rounds_to_retain": 2, | |
| }, | |
| ) | |
| app_injector.binder.bind(AppConfigSource, to=app_config) | |
| compressor = app_injector.get(RoundCompressor) | |
| assert compressor.rounds_to_compress == 2 | |
| assert compressor.rounds_to_retain == 2 | |
| round1 = Round.create(user_query="hello", id="round-1") | |
| post1 = Post.create( | |
| message="hello", | |
| send_from="User", | |
| send_to="Planner", | |
| attachment_list=[], | |
| ) | |
| post2 = Post.create( | |
| message="hello", | |
| send_from="Planner", | |
| send_to="User", | |
| attachment_list=[], | |
| ) | |
| round1.add_post(post1) | |
| round1.add_post(post2) | |
| summary, retained = compressor.compress_rounds( | |
| [round1], | |
| lambda x: x, | |
| use_back_up_engine=False, | |
| ) | |
| assert summary == "None" | |
| assert len(retained) == 1 | |
| round2 = Round.create(user_query="hello", id="round-2") | |
| round2.add_post(post1) | |
| round2.add_post(post2) | |
| summary, retained = compressor.compress_rounds( | |
| [round1, round2], | |
| lambda x: x, | |
| use_back_up_engine=False, | |
| ) | |
| assert summary == "None" | |
| assert len(retained) == 2 | |
| round3 = Round.create(user_query="hello", id="round-3") | |
| round3.add_post(post1) | |
| round3.add_post(post2) | |
| summary, retained = compressor.compress_rounds( | |
| [round1, round2, round3], | |
| lambda x: x, | |
| use_back_up_engine=False, | |
| ) | |
| assert summary == "None" | |
| assert len(retained) == 3 | |
| round4 = Round.create(user_query="hello", id="round-4") | |
| round4.add_post(post1) | |
| round4.add_post(post2) | |
| summary, retained = compressor.compress_rounds( | |
| [round1, round2, round3, round4], | |
| lambda x: x, | |
| use_back_up_engine=False, | |
| ) | |
| assert summary == "None" | |
| assert len(retained) == 4 | |