tskwvr / tests /unit_tests /test_planner.py
TRaw's picture
Upload 297 files
3d3d712
import os
from injector import Injector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory.plugin import PluginModule
def test_compose_prompt():
from taskweaver.memory import Attachment, Memory, Post, Round
from taskweaver.planner import Planner
app_injector = Injector(
[LoggingModule, PluginModule],
)
app_config = AppConfigSource(
config={
"llm.api_key": "test_key",
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"),
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
planner = app_injector.create_object(Planner)
post1 = Post.create(
message="count the rows of /home/data.csv",
send_from="User",
send_to="Planner",
attachment_list=[],
)
post2 = Post.create(
message="Please load the data file /home/data.csv and count the rows of the loaded data",
send_from="Planner",
send_to="CodeInterpreter",
attachment_list=[],
)
post2.add_attachment(
Attachment.create(
"init_plan",
"1. load the data file\n2. count the rows of the loaded data <narrow depend on 1>\n3. report the result to the user <wide depend on 2>",
),
)
post2.add_attachment(
Attachment.create(
"plan",
"1. instruct CodeInterpreter to load the data file and count the rows of the loaded data\n2. report the result to the user",
),
)
post2.add_attachment(
Attachment.create(
"current_plan_step",
"1. instruct CodeInterpreter to load the data file and count the rows of the loaded data",
),
)
post3 = Post.create(
message="Load the data file /home/data.csv successfully and there are 100 rows in the data file",
send_from="CodeInterpreter",
send_to="Planner",
attachment_list=[],
)
post4 = Post.create(
message="The data file /home/data.csv is loaded and there are 100 rows in the data file",
send_from="Planner",
send_to="User",
attachment_list=[],
)
post4.add_attachment(
Attachment.create(
"init_plan",
"1. load the data file\n2. count the rows of the loaded data <narrow depend on 1>\n3. report the result to the user <wide depend on 2>",
),
)
post4.add_attachment(
Attachment.create(
"plan",
"1. instruct CodeInterpreter to load the data file and count the rows of the loaded data\n2. report the result to the user",
),
)
post4.add_attachment(Attachment.create("current_plan_step", "2. report the result to the user"))
round1 = Round.create(user_query="count the rows of ./data.csv", id="round-1")
round1.add_post(post1)
round1.add_post(post2)
round1.add_post(post3)
round1.add_post(post4)
round2 = Round.create(user_query="hello", id="round-2")
post5 = Post.create(
message="hello",
send_from="User",
send_to="Planner",
attachment_list=[],
)
round2.add_post(post5)
memory = Memory(session_id="session-1")
memory.conversation.add_round(round1)
memory.conversation.add_round(round2)
messages = planner.compose_prompt(rounds=memory.conversation.rounds)
assert messages[0]["role"] == "system"
assert messages[0]["content"].startswith(
"You are the Planner who can coordinate CodeInterpreter to finish the user task.",
)
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "User: Let's start the new conversation!\ncount the rows of /home/data.csv"
assert messages[2]["role"] == "assistant"
assert messages[2]["content"] == (
'{"response": [{"type": "init_plan", "content": "1. load the data file\\n2. count the rows of the loaded data <narrow depend on 1>\\n3. report the result to the user <wide depend on 2>"}, {"type": "plan", "content": "1. instruct CodeInterpreter to load the data file and count the rows of the loaded data\\n2. report the result to the user"}, {"type": "current_plan_step", "content": "1. instruct CodeInterpreter to load the data file and count the rows of the loaded data"}, {"type": "send_to", "content": "CodeInterpreter"}, {"type": "message", "content": "Please load the data file /home/data.csv and count the rows of the loaded data"}]}'
)
assert messages[3]["role"] == "user"
assert (
messages[3]["content"]
== "CodeInterpreter: Load the data file /home/data.csv successfully and there are 100 rows in the data file"
)
assert messages[4]["role"] == "assistant"
assert (
messages[4]["content"]
== '{"response": [{"type": "init_plan", "content": "1. load the data file\\n2. count the rows of the loaded data <narrow depend on 1>\\n3. report the result to the user <wide depend on 2>"}, {"type": "plan", "content": "1. instruct CodeInterpreter to load the data file and count the rows of the loaded data\\n2. report the result to the user"}, {"type": "current_plan_step", "content": "2. report the result to the user"}, {"type": "send_to", "content": "User"}, {"type": "message", "content": "The data file /home/data.csv is loaded and there are 100 rows in the data file"}]}'
)
assert messages[5]["role"] == "user"
assert messages[5]["content"] == "User: hello"
def test_compose_example_for_prompt():
from taskweaver.memory import Memory, Post, Round
from taskweaver.planner import Planner
app_injector = Injector(
[LoggingModule, PluginModule],
)
app_config = AppConfigSource(
config={
"llm.api_key": "test_key",
"planner.use_example": True,
"planner.example_base_path": os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data/examples/planner_examples",
),
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"),
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
planner = app_injector.create_object(Planner)
round1 = Round.create(user_query="hello", id="round-1")
post1 = Post.create(
message="hello",
send_from="User",
send_to="Planner",
attachment_list=[],
)
round1.add_post(post1)
memory = Memory(session_id="session-1")
memory.conversation.add_round(round1)
messages = planner.compose_prompt(rounds=memory.conversation.rounds)
assert messages[0]["role"] == "system"
assert messages[0]["content"].startswith(
"You are the Planner who can coordinate CodeInterpreter to finish the user task.",
)
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "User: Let's start the new conversation!\ncount the rows of /home/data.csv"
assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "User: Let's start the new conversation!\nhello"
def test_skip_planning():
from taskweaver.memory import Memory, Post, Round
from taskweaver.planner import Planner
app_injector = Injector(
[LoggingModule, PluginModule],
)
app_config = AppConfigSource(
config={
"llm.api_key": "test_key",
"plugin.base_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/plugins"),
"planner.skip_planning": True,
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
planner = app_injector.create_object(Planner)
post1 = Post.create(
message="count the rows of /home/data.csv",
send_from="User",
send_to="Planner",
attachment_list=[],
)
round1 = Round.create(user_query="count the rows of ./data.csv", id="round-1")
round1.add_post(post1)
memory = Memory(session_id="session-1")
memory.conversation.add_round(round1)
response_post = planner.reply(
memory,
prompt_log_path=None,
event_handler=lambda *args: None,
use_back_up_engine=False,
)
assert response_post.message == "Please process this request: count the rows of /home/data.csv"
assert response_post.send_from == "Planner"
assert response_post.send_to == "CodeInterpreter"