Spaces:
No application file
No application file
| import os | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from embedchain import App | |
| from embedchain.config import AppConfig, BaseLlmConfig | |
| from embedchain.llm.openai import OpenAILlm | |
| def app(): | |
| os.environ["OPENAI_API_KEY"] = "test_api_key" | |
| app = App(config=AppConfig(collect_metrics=False)) | |
| return app | |
| def test_query(app): | |
| with patch.object(app, "_retrieve_from_database") as mock_retrieve: | |
| mock_retrieve.return_value = ["Test context"] | |
| with patch.object(app.llm, "get_llm_model_answer") as mock_answer: | |
| mock_answer.return_value = "Test answer" | |
| answer = app.query(input_query="Test query") | |
| assert answer == "Test answer" | |
| mock_retrieve.assert_called_once() | |
| _, kwargs = mock_retrieve.call_args | |
| input_query_arg = kwargs.get("input_query") | |
| assert input_query_arg == "Test query" | |
| mock_answer.assert_called_once() | |
| def test_query_config_app_passing(mock_get_answer): | |
| mock_get_answer.return_value = MagicMock() | |
| mock_get_answer.return_value = "Test answer" | |
| config = AppConfig(collect_metrics=False) | |
| chat_config = BaseLlmConfig(system_prompt="Test system prompt") | |
| llm = OpenAILlm(config=chat_config) | |
| app = App(config=config, llm=llm) | |
| answer = app.llm.get_llm_model_answer("Test query") | |
| assert app.llm.config.system_prompt == "Test system prompt" | |
| assert answer == "Test answer" | |
| def test_query_with_where_in_params(app): | |
| with patch.object(app, "_retrieve_from_database") as mock_retrieve: | |
| mock_retrieve.return_value = ["Test context"] | |
| with patch.object(app.llm, "get_llm_model_answer") as mock_answer: | |
| mock_answer.return_value = "Test answer" | |
| answer = app.query("Test query", where={"attribute": "value"}) | |
| assert answer == "Test answer" | |
| _, kwargs = mock_retrieve.call_args | |
| assert kwargs.get("input_query") == "Test query" | |
| assert kwargs.get("where") == {"attribute": "value"} | |
| mock_answer.assert_called_once() | |
| def test_query_with_where_in_query_config(app): | |
| with patch.object(app.llm, "get_llm_model_answer") as mock_answer: | |
| mock_answer.return_value = "Test answer" | |
| with patch.object(app.db, "query") as mock_database_query: | |
| mock_database_query.return_value = ["Test context"] | |
| llm_config = BaseLlmConfig(where={"attribute": "value"}) | |
| answer = app.query("Test query", llm_config) | |
| assert answer == "Test answer" | |
| _, kwargs = mock_database_query.call_args | |
| assert kwargs.get("input_query") == "Test query" | |
| where = kwargs.get("where") | |
| assert "app_id" in where | |
| assert "attribute" in where | |
| mock_answer.assert_called_once() | |