Spaces:
No application file
No application file
| import os | |
| import pytest | |
| import yaml | |
| from embedchain import App | |
| from embedchain.config import ChromaDbConfig | |
| from embedchain.embedder.base import BaseEmbedder | |
| from embedchain.llm.base import BaseLlm | |
| from embedchain.vectordb.base import BaseVectorDB | |
| from embedchain.vectordb.chroma import ChromaDB | |
| def app(): | |
| os.environ["OPENAI_API_KEY"] = "test_api_key" | |
| return App() | |
| def test_app(app): | |
| assert isinstance(app.llm, BaseLlm) | |
| assert isinstance(app.db, BaseVectorDB) | |
| assert isinstance(app.embedding_model, BaseEmbedder) | |
| class TestConfigForAppComponents: | |
| def test_constructor_config(self): | |
| collection_name = "my-test-collection" | |
| db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name)) | |
| app = App(db=db) | |
| assert app.db.config.collection_name == collection_name | |
| def test_component_config(self): | |
| collection_name = "my-test-collection" | |
| database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name)) | |
| app = App(db=database) | |
| assert app.db.config.collection_name == collection_name | |
| class TestAppFromConfig: | |
| def load_config_data(self, yaml_path): | |
| with open(yaml_path, "r") as file: | |
| return yaml.safe_load(file) | |
| def test_from_chroma_config(self, mocker): | |
| mocker.patch("embedchain.vectordb.chroma.chromadb.Client") | |
| yaml_path = "configs/chroma.yaml" | |
| config_data = self.load_config_data(yaml_path) | |
| app = App.from_config(config_path=yaml_path) | |
| # Check if the App instance and its components were created correctly | |
| assert isinstance(app, App) | |
| # Validate the AppConfig values | |
| assert app.config.id == config_data["app"]["config"]["id"] | |
| # Even though not present in the config, the default value is used | |
| assert app.config.collect_metrics is True | |
| # Validate the LLM config values | |
| llm_config = config_data["llm"]["config"] | |
| assert app.llm.config.temperature == llm_config["temperature"] | |
| assert app.llm.config.max_tokens == llm_config["max_tokens"] | |
| assert app.llm.config.top_p == llm_config["top_p"] | |
| assert app.llm.config.stream == llm_config["stream"] | |
| # Validate the VectorDB config values | |
| db_config = config_data["vectordb"]["config"] | |
| assert app.db.config.collection_name == db_config["collection_name"] | |
| assert app.db.config.dir == db_config["dir"] | |
| assert app.db.config.allow_reset == db_config["allow_reset"] | |
| # Validate the Embedder config values | |
| embedder_config = config_data["embedder"]["config"] | |
| assert app.embedding_model.config.model == embedder_config["model"] | |
| assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name") | |
| def test_from_opensource_config(self, mocker): | |
| mocker.patch("embedchain.vectordb.chroma.chromadb.Client") | |
| yaml_path = "configs/opensource.yaml" | |
| config_data = self.load_config_data(yaml_path) | |
| app = App.from_config(yaml_path) | |
| # Check if the App instance and its components were created correctly | |
| assert isinstance(app, App) | |
| # Validate the AppConfig values | |
| assert app.config.id == config_data["app"]["config"]["id"] | |
| assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"] | |
| # Validate the LLM config values | |
| llm_config = config_data["llm"]["config"] | |
| assert app.llm.config.model == llm_config["model"] | |
| assert app.llm.config.temperature == llm_config["temperature"] | |
| assert app.llm.config.max_tokens == llm_config["max_tokens"] | |
| assert app.llm.config.top_p == llm_config["top_p"] | |
| assert app.llm.config.stream == llm_config["stream"] | |
| # Validate the VectorDB config values | |
| db_config = config_data["vectordb"]["config"] | |
| assert app.db.config.collection_name == db_config["collection_name"] | |
| assert app.db.config.dir == db_config["dir"] | |
| assert app.db.config.allow_reset == db_config["allow_reset"] | |
| # Validate the Embedder config values | |
| embedder_config = config_data["embedder"]["config"] | |
| assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"] | |