Spaces:
No application file
No application file
File size: 2,760 Bytes
a85c9b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import os
import pytest
import embedchain
import embedchain.embedder.gpt4all
import embedchain.embedder.huggingface
import embedchain.embedder.openai
import embedchain.embedder.vertexai
import embedchain.llm.anthropic
import embedchain.llm.openai
import embedchain.vectordb.chroma
import embedchain.vectordb.elasticsearch
import embedchain.vectordb.opensearch
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
class TestFactories:
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("openai", {}, embedchain.llm.openai.OpenAILlm),
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
],
)
def test_llm_factory_create(self, provider_name, config_data, expected_class):
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_KEY"] = "test_api_key"
llm_instance = LlmFactory.create(provider_name, config_data)
assert isinstance(llm_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
(
"huggingface",
{"model": "sentence-transformers/all-mpnet-base-v2", "vector_dimension": 768},
embedchain.embedder.huggingface.HuggingFaceEmbedder,
),
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAIEmbedder),
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
],
)
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.embedder.vertexai.VertexAIEmbedder", autospec=True)
embedder_instance = EmbedderFactory.create(provider_name, config_data)
assert isinstance(embedder_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
(
"opensearch",
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
embedchain.vectordb.opensearch.OpenSearchDB,
),
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
],
)
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
assert isinstance(vectordb_instance, expected_class)
|