| from langchain.chat_models.base import BaseChatModel |
| from langchain_community.chat_models.fake import FakeListChatModel |
|
|
| from gpt_engineer.core.ai import AI |
|
|
|
|
| def mock_create_chat_model(self) -> BaseChatModel: |
| return FakeListChatModel(responses=["response1", "response2", "response3"]) |
|
|
|
|
| def test_start(monkeypatch): |
| monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model) |
|
|
| ai = AI("gpt-4") |
|
|
| |
| response_messages = ai.start("system prompt", "user prompt", "step name") |
|
|
| |
| assert response_messages[-1].content == "response1" |
|
|
|
|
| def test_next(monkeypatch): |
| |
| monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model) |
|
|
| ai = AI("gpt-4") |
| response_messages = ai.start("system prompt", "user prompt", "step name") |
|
|
| |
| response_messages = ai.next( |
| response_messages, "next user prompt", step_name="step name" |
| ) |
|
|
| |
| assert response_messages[-1].content == "response2" |
|
|
|
|
| def test_token_logging(monkeypatch): |
| |
| monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model) |
|
|
| ai = AI("gpt-4") |
|
|
| |
| response_messages = ai.start("system prompt", "user prompt", "step name") |
| usageCostAfterStart = ai.token_usage_log.usage_cost() |
| ai.next(response_messages, "next user prompt", step_name="step name") |
| usageCostAfterNext = ai.token_usage_log.usage_cost() |
|
|
| |
| assert usageCostAfterStart > 0 |
| assert usageCostAfterNext > usageCostAfterStart |
|
|