Spaces:
Runtime error
Runtime error
| import logging | |
| import multiprocessing | |
| import os | |
| import shutil | |
| import socket | |
| import subprocess | |
| import tempfile | |
| import time | |
| from typing import ( | |
| Any, | |
| Dict, | |
| Generator, | |
| Iterator, | |
| List, | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| Callable, | |
| ) | |
| import hypothesis | |
| import pytest | |
| import uvicorn | |
| from requests.exceptions import ConnectionError | |
| from typing_extensions import Protocol | |
| import chromadb.server.fastapi | |
| from chromadb.api import ClientAPI, ServerAPI | |
| from chromadb.config import Settings, System | |
| from chromadb.db.mixins import embeddings_queue | |
| from chromadb.ingest import Producer | |
| from chromadb.types import SeqId, SubmitEmbeddingRecord | |
| from chromadb.api.client import Client as ClientCreator | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(logging.DEBUG) # This will only run when testing | |
| logger = logging.getLogger(__name__) | |
| hypothesis.settings.register_profile( | |
| "dev", | |
| deadline=45000, | |
| suppress_health_check=[ | |
| hypothesis.HealthCheck.data_too_large, | |
| hypothesis.HealthCheck.large_base_example, | |
| hypothesis.HealthCheck.function_scoped_fixture, | |
| ], | |
| ) | |
| hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev")) | |
| NOT_CLUSTER_ONLY = os.getenv("CHROMA_CLUSTER_TEST_ONLY") != "1" | |
| def skip_if_not_cluster() -> pytest.MarkDecorator: | |
| return pytest.mark.skipif( | |
| NOT_CLUSTER_ONLY, | |
| reason="Requires Kubernetes to be running with a valid config", | |
| ) | |
| def generate_self_signed_certificate() -> None: | |
| config_path = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "openssl.cnf" | |
| ) | |
| print(f"Config path: {config_path}") # Debug print to verify path | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"Config file not found at {config_path}") | |
| subprocess.run( | |
| [ | |
| "openssl", | |
| "req", | |
| "-x509", | |
| "-newkey", | |
| "rsa:4096", | |
| "-keyout", | |
| "serverkey.pem", | |
| "-out", | |
| "servercert.pem", | |
| "-days", | |
| "365", | |
| "-nodes", | |
| "-subj", | |
| "/CN=localhost", | |
| "-config", | |
| config_path, | |
| ] | |
| ) | |
| def find_free_port() -> int: | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| s.bind(("", 0)) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| return s.getsockname()[1] # type: ignore | |
| def _run_server( | |
| port: int, | |
| is_persistent: bool = False, | |
| persist_directory: Optional[str] = None, | |
| chroma_server_auth_provider: Optional[str] = None, | |
| chroma_server_auth_credentials_provider: Optional[str] = None, | |
| chroma_server_auth_credentials_file: Optional[str] = None, | |
| chroma_server_auth_credentials: Optional[str] = None, | |
| chroma_server_auth_token_transport_header: Optional[str] = None, | |
| chroma_server_authz_provider: Optional[str] = None, | |
| chroma_server_authz_config_file: Optional[str] = None, | |
| chroma_server_authz_config: Optional[Dict[str, Any]] = None, | |
| chroma_server_ssl_certfile: Optional[str] = None, | |
| chroma_server_ssl_keyfile: Optional[str] = None, | |
| ) -> None: | |
| """Run a Chroma server locally""" | |
| if is_persistent and persist_directory: | |
| settings = Settings( | |
| chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
| chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
| is_persistent=is_persistent, | |
| persist_directory=persist_directory, | |
| allow_reset=True, | |
| chroma_server_auth_provider=chroma_server_auth_provider, | |
| chroma_server_auth_credentials_provider=chroma_server_auth_credentials_provider, | |
| chroma_server_auth_credentials_file=chroma_server_auth_credentials_file, | |
| chroma_server_auth_credentials=chroma_server_auth_credentials, | |
| chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header, | |
| chroma_server_authz_provider=chroma_server_authz_provider, | |
| chroma_server_authz_config_file=chroma_server_authz_config_file, | |
| chroma_server_authz_config=chroma_server_authz_config, | |
| ) | |
| else: | |
| settings = Settings( | |
| chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
| chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
| is_persistent=False, | |
| allow_reset=True, | |
| chroma_server_auth_provider=chroma_server_auth_provider, | |
| chroma_server_auth_credentials_provider=chroma_server_auth_credentials_provider, | |
| chroma_server_auth_credentials_file=chroma_server_auth_credentials_file, | |
| chroma_server_auth_credentials=chroma_server_auth_credentials, | |
| chroma_server_auth_token_transport_header=chroma_server_auth_token_transport_header, | |
| chroma_server_authz_provider=chroma_server_authz_provider, | |
| chroma_server_authz_config_file=chroma_server_authz_config_file, | |
| chroma_server_authz_config=chroma_server_authz_config, | |
| ) | |
| server = chromadb.server.fastapi.FastAPI(settings) | |
| uvicorn.run( | |
| server.app(), | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="error", | |
| timeout_keep_alive=30, | |
| ssl_keyfile=chroma_server_ssl_keyfile, | |
| ssl_certfile=chroma_server_ssl_certfile, | |
| ) | |
| def _await_server(api: ServerAPI, attempts: int = 0) -> None: | |
| try: | |
| api.heartbeat() | |
| except ConnectionError as e: | |
| if attempts > 15: | |
| logger.error("Test server failed to start after 15 attempts") | |
| raise e | |
| else: | |
| logger.info("Waiting for server to start...") | |
| time.sleep(4) | |
| _await_server(api, attempts + 1) | |
| def _fastapi_fixture( | |
| is_persistent: bool = False, | |
| chroma_server_auth_provider: Optional[str] = None, | |
| chroma_server_auth_credentials_provider: Optional[str] = None, | |
| chroma_client_auth_provider: Optional[str] = None, | |
| chroma_server_auth_credentials_file: Optional[str] = None, | |
| chroma_client_auth_credentials: Optional[str] = None, | |
| chroma_server_auth_credentials: Optional[str] = None, | |
| chroma_client_auth_token_transport_header: Optional[str] = None, | |
| chroma_server_auth_token_transport_header: Optional[str] = None, | |
| chroma_server_authz_provider: Optional[str] = None, | |
| chroma_server_authz_config_file: Optional[str] = None, | |
| chroma_server_authz_config: Optional[Dict[str, Any]] = None, | |
| chroma_server_ssl_certfile: Optional[str] = None, | |
| chroma_server_ssl_keyfile: Optional[str] = None, | |
| ) -> Generator[System, None, None]: | |
| """Fixture generator that launches a server in a separate process, and yields a | |
| fastapi client connect to it""" | |
| port = find_free_port() | |
| logger.info(f"Running test FastAPI server on port {port}") | |
| ctx = multiprocessing.get_context("spawn") | |
| args: Tuple[ | |
| int, | |
| bool, | |
| Optional[str], | |
| Optional[str], | |
| Optional[str], | |
| Optional[str], | |
| Optional[str], | |
| Optional[str], | |
| Optional[str], | |
| Optional[str], | |
| Optional[Dict[str, Any]], | |
| Optional[str], | |
| Optional[str], | |
| ] = ( | |
| port, | |
| False, | |
| None, | |
| chroma_server_auth_provider, | |
| chroma_server_auth_credentials_provider, | |
| chroma_server_auth_credentials_file, | |
| chroma_server_auth_credentials, | |
| chroma_server_auth_token_transport_header, | |
| chroma_server_authz_provider, | |
| chroma_server_authz_config_file, | |
| chroma_server_authz_config, | |
| chroma_server_ssl_certfile, | |
| chroma_server_ssl_keyfile, | |
| ) | |
| persist_directory = None | |
| if is_persistent: | |
| persist_directory = tempfile.mkdtemp() | |
| args = ( | |
| port, | |
| is_persistent, | |
| persist_directory, | |
| chroma_server_auth_provider, | |
| chroma_server_auth_credentials_provider, | |
| chroma_server_auth_credentials_file, | |
| chroma_server_auth_credentials, | |
| chroma_server_auth_token_transport_header, | |
| chroma_server_authz_provider, | |
| chroma_server_authz_config_file, | |
| chroma_server_authz_config, | |
| chroma_server_ssl_certfile, | |
| chroma_server_ssl_keyfile, | |
| ) | |
| proc = ctx.Process(target=_run_server, args=args, daemon=True) | |
| proc.start() | |
| settings = Settings( | |
| chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
| chroma_server_host="localhost", | |
| chroma_server_http_port=str(port), | |
| allow_reset=True, | |
| chroma_client_auth_provider=chroma_client_auth_provider, | |
| chroma_client_auth_credentials=chroma_client_auth_credentials, | |
| chroma_client_auth_token_transport_header=chroma_client_auth_token_transport_header, | |
| chroma_server_ssl_verify=chroma_server_ssl_certfile, | |
| chroma_server_ssl_enabled=True if chroma_server_ssl_certfile else False, | |
| ) | |
| system = System(settings) | |
| api = system.instance(ServerAPI) | |
| system.start() | |
| _await_server(api) | |
| yield system | |
| system.stop() | |
| proc.kill() | |
| if is_persistent and persist_directory is not None: | |
| if os.path.exists(persist_directory): | |
| shutil.rmtree(persist_directory) | |
| def fastapi() -> Generator[System, None, None]: | |
| return _fastapi_fixture(is_persistent=False) | |
| def fastapi_persistent() -> Generator[System, None, None]: | |
| return _fastapi_fixture(is_persistent=True) | |
| def fastapi_ssl() -> Generator[System, None, None]: | |
| generate_self_signed_certificate() | |
| return _fastapi_fixture( | |
| is_persistent=False, | |
| chroma_server_ssl_certfile="./servercert.pem", | |
| chroma_server_ssl_keyfile="./serverkey.pem", | |
| ) | |
| def basic_http_client() -> Generator[System, None, None]: | |
| settings = Settings( | |
| chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
| chroma_server_http_port="8000", | |
| allow_reset=True, | |
| ) | |
| system = System(settings) | |
| api = system.instance(ServerAPI) | |
| _await_server(api) | |
| system.start() | |
| yield system | |
| system.stop() | |
| def fastapi_server_basic_auth() -> Generator[System, None, None]: | |
| server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
| with open(server_auth_file, "w") as f: | |
| f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
| for item in _fastapi_fixture( | |
| is_persistent=False, | |
| chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
| chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
| chroma_server_auth_credentials_file="./server.htpasswd", | |
| chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
| chroma_client_auth_credentials="admin:admin", | |
| ): | |
| yield item | |
| os.remove(server_auth_file) | |
| def fastapi_server_basic_auth_param() -> Generator[System, None, None]: | |
| server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
| with open(server_auth_file, "w") as f: | |
| f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
| for item in _fastapi_fixture( | |
| is_persistent=False, | |
| chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
| chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
| chroma_server_auth_credentials_file="./server.htpasswd", | |
| chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
| chroma_client_auth_credentials="admin:admin", | |
| ): | |
| yield item | |
| os.remove(server_auth_file) | |
| # TODO we need a generator for auth providers | |
| def fastapi_server_basic_auth_file() -> Generator[System, None, None]: | |
| server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
| with open(server_auth_file, "w") as f: | |
| f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
| for item in _fastapi_fixture( | |
| is_persistent=False, | |
| chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
| chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
| chroma_server_auth_credentials_file="./server.htpasswd", | |
| chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
| chroma_client_auth_credentials="admin:admin", | |
| ): | |
| yield item | |
| os.remove(server_auth_file) | |
| def fastapi_server_basic_auth_shorthand() -> Generator[System, None, None]: | |
| server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
| with open(server_auth_file, "w") as f: | |
| f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
| for item in _fastapi_fixture( | |
| is_persistent=False, | |
| chroma_server_auth_provider="basic", | |
| chroma_server_auth_credentials_provider="htpasswd_file", | |
| chroma_server_auth_credentials_file="./server.htpasswd", | |
| chroma_client_auth_provider="basic", | |
| chroma_client_auth_credentials="admin:admin", | |
| ): | |
| yield item | |
| os.remove(server_auth_file) | |
| def fastapi_server_basic_auth_invalid_cred() -> Generator[System, None, None]: | |
| server_auth_file = os.path.abspath(os.path.join(".", "server.htpasswd")) | |
| with open(server_auth_file, "w") as f: | |
| f.write("admin:$2y$05$e5sRb6NCcSH3YfbIxe1AGu2h5K7OOd982OXKmd8WyQ3DRQ4MvpnZS\n") | |
| for item in _fastapi_fixture( | |
| is_persistent=False, | |
| chroma_server_auth_provider="chromadb.auth.basic.BasicAuthServerProvider", | |
| chroma_server_auth_credentials_provider="chromadb.auth.providers.HtpasswdFileServerAuthCredentialsProvider", | |
| chroma_server_auth_credentials_file="./server.htpasswd", | |
| chroma_client_auth_provider="chromadb.auth.basic.BasicAuthClientProvider", | |
| chroma_client_auth_credentials="admin:admin1", | |
| ): | |
| yield item | |
| os.remove(server_auth_file) | |
| def integration() -> Generator[System, None, None]: | |
| """Fixture generator for returning a client configured via environmenet | |
| variables, intended for externally configured integration tests | |
| """ | |
| settings = Settings(allow_reset=True) | |
| system = System(settings) | |
| system.start() | |
| yield system | |
| system.stop() | |
| def sqlite() -> Generator[System, None, None]: | |
| """Fixture generator for segment-based API using in-memory Sqlite""" | |
| settings = Settings( | |
| chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
| chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
| is_persistent=False, | |
| allow_reset=True, | |
| ) | |
| system = System(settings) | |
| system.start() | |
| yield system | |
| system.stop() | |
| def sqlite_persistent() -> Generator[System, None, None]: | |
| """Fixture generator for segment-based API using persistent Sqlite""" | |
| save_path = tempfile.mkdtemp() | |
| settings = Settings( | |
| chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
| chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
| chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
| allow_reset=True, | |
| is_persistent=True, | |
| persist_directory=save_path, | |
| ) | |
| system = System(settings) | |
| system.start() | |
| yield system | |
| system.stop() | |
| if os.path.exists(save_path): | |
| shutil.rmtree(save_path) | |
| def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: | |
| fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent] | |
| if "CHROMA_INTEGRATION_TEST" in os.environ: | |
| fixtures.append(integration) | |
| if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: | |
| fixtures = [integration] | |
| if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: | |
| fixtures = [basic_http_client] | |
| return fixtures | |
| def system_fixtures_auth() -> List[Callable[[], Generator[System, None, None]]]: | |
| fixtures = [ | |
| fastapi_server_basic_auth_param, | |
| fastapi_server_basic_auth_file, | |
| fastapi_server_basic_auth_shorthand, | |
| ] | |
| return fixtures | |
| def system_fixtures_wrong_auth() -> List[Callable[[], Generator[System, None, None]]]: | |
| fixtures = [fastapi_server_basic_auth_invalid_cred] | |
| return fixtures | |
| def system_fixtures_ssl() -> List[Callable[[], Generator[System, None, None]]]: | |
| fixtures = [fastapi_ssl] | |
| return fixtures | |
| def system_wrong_auth( | |
| request: pytest.FixtureRequest, | |
| ) -> Generator[ServerAPI, None, None]: | |
| yield next(request.param()) | |
| def system(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: | |
| yield next(request.param()) | |
| def system_ssl(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: | |
| yield next(request.param()) | |
| def system_auth(request: pytest.FixtureRequest) -> Generator[ServerAPI, None, None]: | |
| yield next(request.param()) | |
| def api(system: System) -> Generator[ServerAPI, None, None]: | |
| system.reset_state() | |
| api = system.instance(ServerAPI) | |
| yield api | |
| def client(system: System) -> Generator[ClientAPI, None, None]: | |
| system.reset_state() | |
| client = ClientCreator.from_system(system) | |
| yield client | |
| client.clear_system_cache() | |
| def client_ssl(system_ssl: System) -> Generator[ClientAPI, None, None]: | |
| system_ssl.reset_state() | |
| client = ClientCreator.from_system(system_ssl) | |
| yield client | |
| client.clear_system_cache() | |
| def api_wrong_cred( | |
| system_wrong_auth: System, | |
| ) -> Generator[ServerAPI, None, None]: | |
| system_wrong_auth.reset_state() | |
| api = system_wrong_auth.instance(ServerAPI) | |
| yield api | |
| def api_with_server_auth(system_auth: System) -> Generator[ServerAPI, None, None]: | |
| _sys = system_auth | |
| _sys.reset_state() | |
| api = _sys.instance(ServerAPI) | |
| yield api | |
| # Producer / Consumer fixtures # | |
| class ProducerFn(Protocol): | |
| def __call__( | |
| self, | |
| producer: Producer, | |
| topic: str, | |
| embeddings: Iterator[SubmitEmbeddingRecord], | |
| n: int, | |
| ) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: | |
| ... | |
| def produce_n_single( | |
| producer: Producer, | |
| topic: str, | |
| embeddings: Iterator[SubmitEmbeddingRecord], | |
| n: int, | |
| ) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: | |
| submitted_embeddings = [] | |
| seq_ids = [] | |
| for _ in range(n): | |
| e = next(embeddings) | |
| seq_id = producer.submit_embedding(topic, e) | |
| submitted_embeddings.append(e) | |
| seq_ids.append(seq_id) | |
| return submitted_embeddings, seq_ids | |
| def produce_n_batch( | |
| producer: Producer, | |
| topic: str, | |
| embeddings: Iterator[SubmitEmbeddingRecord], | |
| n: int, | |
| ) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: | |
| submitted_embeddings = [] | |
| seq_ids: Sequence[SeqId] = [] | |
| for _ in range(n): | |
| e = next(embeddings) | |
| submitted_embeddings.append(e) | |
| seq_ids = producer.submit_embeddings(topic, submitted_embeddings) | |
| return submitted_embeddings, seq_ids | |
| def produce_fn_fixtures() -> List[ProducerFn]: | |
| return [produce_n_single, produce_n_batch] | |
| def produce_fns( | |
| request: pytest.FixtureRequest, | |
| ) -> Generator[ProducerFn, None, None]: | |
| yield request.param | |
| def pytest_configure(config): # type: ignore | |
| embeddings_queue._called_from_test = True | |