import os import shutil import tempfile import pytest from typing import Generator, List, Callable, Dict, Union from chromadb.db.impl.grpc.client import GrpcSysDB from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.test.conftest import find_free_port from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB from chromadb.config import ( DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings, ) from chromadb.db.system import SysDB from chromadb.db.base import NotFoundError, UniqueConstraintError from pytest import FixtureRequest import uuid from chromadb.api.configuration import CollectionConfigurationInternal TENANT = "default" NAMESPACE = "default" # These are the sample collections that are used in the tests below. Tests can override # the fields as needed. sample_collections: List[Collection] = [ Collection( id=uuid.UUID(int=1), name="test_collection_1", configuration=CollectionConfigurationInternal(), metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, dimension=128, database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ), Collection( id=uuid.UUID(int=2), name="test_collection_2", configuration=CollectionConfigurationInternal(), metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, dimension=None, database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ), Collection( id=uuid.UUID(int=3), name="test_collection_3", configuration=CollectionConfigurationInternal(), metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, dimension=None, database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ), ] def sqlite() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" db = SqliteDB( System( Settings( allow_reset=True, ) ) ) db.start() yield db db.stop() def sqlite_persistent() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" save_path = tempfile.mkdtemp() db = SqliteDB( System( Settings( allow_reset=True, is_persistent=True, persist_directory=save_path, ) ) ) db.start() yield db db.stop() if os.path.exists(save_path): shutil.rmtree(save_path) def grpc_with_mock_server() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB that creates a mock grpc sysdb server and a grpc client that connects to it.""" port = find_free_port() system = System( Settings( allow_reset=True, chroma_server_grpc_port=port, ) ) system.instance(GrpcMockSysDB) client = system.instance(GrpcSysDB) system.start() client.reset_and_wait_for_ready() yield client system.stop() def grpc_with_real_server() -> Generator[SysDB, None, None]: system = System( Settings( allow_reset=True, chroma_server_grpc_port=50051, ) ) client = system.instance(GrpcSysDB) system.start() client.reset_and_wait_for_ready() yield client def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]: if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: return [grpc_with_real_server] else: return [sqlite, sqlite_persistent, grpc_with_mock_server] @pytest.fixture(scope="module", params=db_fixtures()) def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) # region Collection tests def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], ) collection["database"] = DEFAULT_DATABASE collection["tenant"] = DEFAULT_TENANT results = sysdb.get_collections() results = sorted(results, key=lambda c: c.name) assert sorted(results, key=lambda c: c.name) == sample_collections # Duplicate create fails with pytest.raises(UniqueConstraintError): sysdb.create_collection( name=sample_collections[0].name, id=sample_collections[0].id, configuration=sample_collections[0].get_configuration(), ) # Find by name for collection in sample_collections: result = sysdb.get_collections(name=collection["name"]) assert result == [collection] # Find by id for collection in sample_collections: result = sysdb.get_collections(id=collection["id"]) assert result == [collection] # Delete c1 = sample_collections[0] sysdb.delete_collection(c1.id) results = sysdb.get_collections() assert c1 not in results assert len(results) == len(sample_collections) - 1 assert sorted(results, key=lambda c: c.name) == sample_collections[1:] by_id_result = sysdb.get_collections(id=c1["id"]) assert by_id_result == [] # Duplicate delete throws an exception with pytest.raises(NotFoundError): sysdb.delete_collection(c1.id) def test_update_collections(sysdb: SysDB) -> None: coll = Collection( name=sample_collections[0].name, id=sample_collections[0].id, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ) sysdb.reset_state() sysdb.create_collection( id=coll.id, name=coll.name, configuration=coll.get_configuration(), metadata=coll["metadata"], dimension=coll["dimension"], ) # Update name coll["name"] = "new_name" sysdb.update_collection(coll.id, name=coll.name) result = sysdb.get_collections(name=coll.name) assert result == [coll] # Update dimension coll["dimension"] = 128 sysdb.update_collection(coll.id, dimension=coll.dimension) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Reset the metadata coll["metadata"] = {"test_str2": "str2"} sysdb.update_collection(coll.id, metadata=coll["metadata"]) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Delete all metadata keys coll["metadata"] = None sysdb.update_collection(coll.id, metadata=None) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] def test_get_or_create_collection(sysdb: SysDB) -> None: sysdb.reset_state() # get_or_create = True returns existing collection collection = sample_collections[0] sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], ) result, created = sysdb.create_collection( name=collection.name, id=uuid.uuid4(), configuration=CollectionConfigurationInternal(), get_or_create=True, metadata=collection["metadata"], ) assert result == collection # Only one collection with the same name exists get_result = sysdb.get_collections(name=collection["name"]) assert get_result == [collection] # get_or_create = True creates new collection result, created = sysdb.create_collection( name=sample_collections[1].name, id=sample_collections[1].id, configuration=sample_collections[1].get_configuration(), get_or_create=True, metadata=sample_collections[1]["metadata"], ) assert result == sample_collections[1] # get_or_create = False creates new collection result, created = sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=False, metadata=sample_collections[2]["metadata"], ) assert result == sample_collections[2] # get_or_create = False fails if collection already exists with pytest.raises(UniqueConstraintError): sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=False, metadata=collection["metadata"], ) # get_or_create = True overwrites metadata overlayed_metadata: Dict[str, Union[str, int, float]] = { "test_new_str": "new_str", "test_int": 1, } result, created = sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=True, metadata=overlayed_metadata, ) assert result["metadata"] == overlayed_metadata # get_or_create = False with None metadata does not overwrite metadata result, created = sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=True, metadata=None, ) assert result["metadata"] == overlayed_metadata def test_create_get_delete_database_and_collection(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new database sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create a new collection in the new database sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", ) # Create a new collection with the same id but different name in the new database # and expect an error with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[0].id, name="new_name", configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", get_or_create=False, ) # Create a new collection in the default database sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], ) # Check that the new database and collections exist result = sysdb.get_collections( name=sample_collections[0]["name"], database="new_database" ) assert len(result) == 1 sample_collections[0]["database"] = "new_database" assert result[0] == sample_collections[0] # Check that the collection in the default database exists result = sysdb.get_collections(name=sample_collections[1]["name"]) assert len(result) == 1 assert result[0] == sample_collections[1] # Get for a database that doesn't exist with a name that exists in the new database and expect no results assert ( len( sysdb.get_collections( name=sample_collections[0]["name"], database="fake_db" ) ) == 0 ) # Delete the collection in the new database sysdb.delete_collection(id=sample_collections[0].id, database="new_database") # Check that the collection in the new database was deleted result = sysdb.get_collections(database="new_database") assert len(result) == 0 # Check that the collection in the default database still exists result = sysdb.get_collections(name=sample_collections[1].name) assert len(result) == 1 assert result[0] == sample_collections[1] # Delete the deleted collection in the default database and expect an error with pytest.raises(NotFoundError): sysdb.delete_collection(id=sample_collections[0].id) # Delete the existing collection in the new database and expect an error with pytest.raises(NotFoundError): sysdb.delete_collection(id=sample_collections[1].id, database="new_database") def test_create_update_with_database(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new database sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create a new collection in the new database sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", ) # Create a new collection in the default database sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], ) # Update the collection in the default database sysdb.update_collection( id=sample_collections[1].id, name="new_name_1", ) # Check that the collection in the default database was updated result = sysdb.get_collections(id=sample_collections[1]["id"]) assert len(result) == 1 assert result[0]["name"] == "new_name_1" # Update the collection in the new database sysdb.update_collection( id=sample_collections[0].id, name="new_name_0", ) # Check that the collection in the new database was updated result = sysdb.get_collections( id=sample_collections[0]["id"], database="new_database" ) assert len(result) == 1 assert result[0]["name"] == "new_name_0" # Try to create the collection in the default database in the new database and expect an error with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], database="new_database", ) def test_get_multiple_with_database(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new database sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create sample collections in the new database for collection in sample_collections: sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], database="new_database", ) collection["database"] = "new_database" # Get all collections in the new database result = sysdb.get_collections(database="new_database") assert len(result) == len(sample_collections) assert sorted(result, key=lambda c: c.name) == sample_collections # Get all collections in the default database result = sysdb.get_collections() assert len(result) == 0 def test_create_database_with_tenants(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new tenant sysdb.create_tenant(name="tenant1") # Create tenant that already exits and expect an error with pytest.raises(UniqueConstraintError): sysdb.create_tenant(name="tenant1") with pytest.raises(UniqueConstraintError): sysdb.create_tenant(name=DEFAULT_TENANT) # Create a new database within this tenant and also in the default tenant sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create a new collection in the new tenant sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", tenant="tenant1", ) sample_collections[0]["tenant"] = "tenant1" sample_collections[0]["database"] = "new_database" # Create a new collection in the default tenant sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], database="new_database", ) sample_collections[1]["database"] = "new_database" # Check that both tenants have the correct collections result = sysdb.get_collections(database="new_database", tenant="tenant1") assert len(result) == 1 assert result[0] == sample_collections[0] result = sysdb.get_collections(database="new_database") assert len(result) == 1 assert result[0] == sample_collections[1] # Creating a collection id that already exists in a tenant that does not have it # should error with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", ) with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], database="new_database", tenant="tenant1", ) # A new tenant DOES NOT have a default database. This does not error, instead 0 # results are returned result = sysdb.get_collections(database=DEFAULT_DATABASE, tenant="tenant1") assert len(result) == 0 def test_get_database_with_tenants(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new tenant sysdb.create_tenant(name="tenant1") # Get the tenant and check that it exists result = sysdb.get_tenant(name="tenant1") assert result["name"] == "tenant1" # Get a tenant that does not exist and expect an error with pytest.raises(NotFoundError): sysdb.get_tenant(name="tenant2") # Create a new database within this tenant sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") # Get the database and check that it exists result = sysdb.get_database(name="new_database", tenant="tenant1") assert result["name"] == "new_database" assert result["tenant"] == "tenant1" # Get a database that does not exist in a tenant that does exist and expect an error with pytest.raises(NotFoundError): sysdb.get_database(name="new_database1", tenant="tenant1") # Get a database that does not exist in a tenant that does not exist and expect an # error with pytest.raises(NotFoundError): sysdb.get_database(name="new_database1", tenant="tenant2") # endregion # region Segment tests sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), type="test_type_a", scope=SegmentScope.VECTOR, collection=sample_collections[0]["id"], metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, ), Segment( id=uuid.UUID("11111111-d7d7-413b-92e1-731098a6e492"), type="test_type_b", scope=SegmentScope.VECTOR, collection=sample_collections[1]["id"], metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, ), Segment( id=uuid.UUID("22222222-d7d7-413b-92e1-731098a6e492"), type="test_type_b", scope=SegmentScope.METADATA, collection=None, metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, ), ] def test_create_get_delete_segments(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], ) for segment in sample_segments: sysdb.create_segment(segment) results = sysdb.get_segments() results = sorted(results, key=lambda c: c["id"]) assert results == sample_segments # Duplicate create fails with pytest.raises(UniqueConstraintError): sysdb.create_segment(sample_segments[0]) # Find by id for segment in sample_segments: result = sysdb.get_segments(id=segment["id"]) assert result == [segment] # Find by type result = sysdb.get_segments(type="test_type_a") assert result == sample_segments[:1] result = sysdb.get_segments(type="test_type_b") assert sorted(result, key=lambda c: c["id"]) == sample_segments[1:] # Find by collection ID result = sysdb.get_segments(collection=sample_collections[0]["id"]) assert result == sample_segments[:1] # Find by type and collection ID (positive case) result = sysdb.get_segments( type="test_type_a", collection=sample_collections[0]["id"] ) assert result == sample_segments[:1] # Find by type and collection ID (negative case) result = sysdb.get_segments( type="test_type_b", collection=sample_collections[0]["id"] ) assert result == [] # Delete s1 = sample_segments[0] sysdb.delete_segment(s1["id"]) results = sysdb.get_segments() assert s1 not in results assert len(results) == len(sample_segments) - 1 assert sorted(results, key=lambda c: c["id"]) == sample_segments[1:] # Duplicate delete throws an exception with pytest.raises(NotFoundError): sysdb.delete_segment(s1["id"]) def test_update_segment(sysdb: SysDB) -> None: metadata: Dict[str, Union[str, int, float]] = { "test_str": "str1", "test_int": 1, "test_float": 1.3, } segment = Segment( id=uuid.uuid4(), type="test_type_a", scope=SegmentScope.VECTOR, collection=sample_collections[0]["id"], metadata=metadata, ) sysdb.reset_state() for c in sample_collections: sysdb.create_collection( id=c.id, name=c.name, configuration=c.get_configuration(), metadata=c["metadata"], dimension=c["dimension"], ) sysdb.create_segment(segment) # TODO: revisit update segment - push collection id result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Update collection to new value segment["collection"] = sample_collections[1]["id"] sysdb.update_segment(segment["id"], collection=segment["collection"]) result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Update collection to None segment["collection"] = None sysdb.update_segment(segment["id"], collection=segment["collection"]) result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Add a new metadata key metadata["test_str2"] = "str2" sysdb.update_segment(segment["id"], metadata={"test_str2": "str2"}) result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Update a metadata key metadata["test_str"] = "str3" sysdb.update_segment(segment["id"], metadata={"test_str": "str3"}) result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Delete a metadata key del metadata["test_str"] sysdb.update_segment(segment["id"], metadata={"test_str": None}) result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Delete all metadata keys segment["metadata"] = None sysdb.update_segment(segment["id"], metadata=None) result = sysdb.get_segments(id=segment["id"]) result[0]["collection"] = segment["collection"] assert result == [segment] # endregion