Spaces:
Paused
Paused
| from concurrent import futures | |
| from typing import Any, Dict, cast | |
| from uuid import UUID | |
| from overrides import overrides | |
| from chromadb.api.configuration import CollectionConfigurationInternal | |
| from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System | |
| from chromadb.proto.convert import ( | |
| from_proto_metadata, | |
| from_proto_update_metadata, | |
| from_proto_segment, | |
| from_proto_segment_scope, | |
| to_proto_collection, | |
| to_proto_segment, | |
| ) | |
| import chromadb.proto.chroma_pb2 as proto | |
| from chromadb.proto.coordinator_pb2 import ( | |
| CreateCollectionRequest, | |
| CreateCollectionResponse, | |
| CreateDatabaseRequest, | |
| CreateDatabaseResponse, | |
| CreateSegmentRequest, | |
| CreateSegmentResponse, | |
| CreateTenantRequest, | |
| CreateTenantResponse, | |
| DeleteCollectionRequest, | |
| DeleteCollectionResponse, | |
| DeleteSegmentRequest, | |
| DeleteSegmentResponse, | |
| GetCollectionsRequest, | |
| GetCollectionsResponse, | |
| GetDatabaseRequest, | |
| GetDatabaseResponse, | |
| GetSegmentsRequest, | |
| GetSegmentsResponse, | |
| GetTenantRequest, | |
| GetTenantResponse, | |
| ResetStateResponse, | |
| UpdateCollectionRequest, | |
| UpdateCollectionResponse, | |
| UpdateSegmentRequest, | |
| UpdateSegmentResponse, | |
| ) | |
| from chromadb.proto.coordinator_pb2_grpc import ( | |
| SysDBServicer, | |
| add_SysDBServicer_to_server, | |
| ) | |
| import grpc | |
| from google.protobuf.empty_pb2 import Empty | |
| from chromadb.types import Collection, Metadata, Segment | |
| class GrpcMockSysDB(SysDBServicer, Component): | |
| """A mock sysdb implementation that can be used for testing the grpc client. It stores | |
| state in simple python data structures instead of a database.""" | |
| _server: grpc.Server | |
| _server_port: int | |
| _segments: Dict[str, Segment] = {} | |
| _tenants_to_databases_to_collections: Dict[ | |
| str, Dict[str, Dict[str, Collection]] | |
| ] = {} | |
| _tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {} | |
| def __init__(self, system: System): | |
| self._server_port = system.settings.require("chroma_server_grpc_port") | |
| return super().__init__(system) | |
| def start(self) -> None: | |
| self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
| add_SysDBServicer_to_server(self, self._server) # type: ignore | |
| self._server.add_insecure_port(f"[::]:{self._server_port}") | |
| self._server.start() | |
| return super().start() | |
| def stop(self) -> None: | |
| self._server.stop(None) | |
| return super().stop() | |
| def reset_state(self) -> None: | |
| self._segments = {} | |
| self._tenants_to_databases_to_collections = {} | |
| # Create defaults | |
| self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {} | |
| self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {} | |
| self._tenants_to_database_to_id[DEFAULT_TENANT] = {} | |
| self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0) | |
| return super().reset_state() | |
| def CreateDatabase( | |
| self, request: CreateDatabaseRequest, context: grpc.ServicerContext | |
| ) -> CreateDatabaseResponse: | |
| tenant = request.tenant | |
| database = request.name | |
| if tenant not in self._tenants_to_databases_to_collections: | |
| return CreateDatabaseResponse( | |
| status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
| ) | |
| if database in self._tenants_to_databases_to_collections[tenant]: | |
| return CreateDatabaseResponse( | |
| status=proto.Status( | |
| code=409, reason=f"Database {database} already exists" | |
| ) | |
| ) | |
| self._tenants_to_databases_to_collections[tenant][database] = {} | |
| self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id) | |
| return CreateDatabaseResponse(status=proto.Status(code=200)) | |
| def GetDatabase( | |
| self, request: GetDatabaseRequest, context: grpc.ServicerContext | |
| ) -> GetDatabaseResponse: | |
| tenant = request.tenant | |
| database = request.name | |
| if tenant not in self._tenants_to_databases_to_collections: | |
| return GetDatabaseResponse( | |
| status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
| ) | |
| if database not in self._tenants_to_databases_to_collections[tenant]: | |
| return GetDatabaseResponse( | |
| status=proto.Status(code=404, reason=f"Database {database} not found") | |
| ) | |
| id = self._tenants_to_database_to_id[tenant][database] | |
| return GetDatabaseResponse( | |
| status=proto.Status(code=200), | |
| database=proto.Database(id=id.hex, name=database, tenant=tenant), | |
| ) | |
| def CreateTenant( | |
| self, request: CreateTenantRequest, context: grpc.ServicerContext | |
| ) -> CreateTenantResponse: | |
| tenant = request.name | |
| if tenant in self._tenants_to_databases_to_collections: | |
| return CreateTenantResponse( | |
| status=proto.Status(code=409, reason=f"Tenant {tenant} already exists") | |
| ) | |
| self._tenants_to_databases_to_collections[tenant] = {} | |
| self._tenants_to_database_to_id[tenant] = {} | |
| return CreateTenantResponse(status=proto.Status(code=200)) | |
| def GetTenant( | |
| self, request: GetTenantRequest, context: grpc.ServicerContext | |
| ) -> GetTenantResponse: | |
| tenant = request.name | |
| if tenant not in self._tenants_to_databases_to_collections: | |
| return GetTenantResponse( | |
| status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
| ) | |
| return GetTenantResponse( | |
| status=proto.Status(code=200), | |
| tenant=proto.Tenant(name=tenant), | |
| ) | |
| # We are forced to use check_signature=False because the generated proto code | |
| # does not have type annotations for the request and response objects. | |
| # TODO: investigate generating types for the request and response objects | |
| def CreateSegment( | |
| self, request: CreateSegmentRequest, context: grpc.ServicerContext | |
| ) -> CreateSegmentResponse: | |
| segment = from_proto_segment(request.segment) | |
| if segment["id"].hex in self._segments: | |
| return CreateSegmentResponse( | |
| status=proto.Status( | |
| code=409, reason=f"Segment {segment['id']} already exists" | |
| ) | |
| ) | |
| self._segments[segment["id"].hex] = segment | |
| return CreateSegmentResponse( | |
| status=proto.Status(code=200) | |
| ) # TODO: how are these codes used? Need to determine the standards for the code and reason. | |
| def DeleteSegment( | |
| self, request: DeleteSegmentRequest, context: grpc.ServicerContext | |
| ) -> DeleteSegmentResponse: | |
| id_to_delete = request.id | |
| if id_to_delete in self._segments: | |
| del self._segments[id_to_delete] | |
| return DeleteSegmentResponse(status=proto.Status(code=200)) | |
| else: | |
| return DeleteSegmentResponse( | |
| status=proto.Status( | |
| code=404, reason=f"Segment {id_to_delete} not found" | |
| ) | |
| ) | |
| def GetSegments( | |
| self, request: GetSegmentsRequest, context: grpc.ServicerContext | |
| ) -> GetSegmentsResponse: | |
| target_id = UUID(hex=request.id) if request.HasField("id") else None | |
| target_type = request.type if request.HasField("type") else None | |
| target_scope = ( | |
| from_proto_segment_scope(request.scope) | |
| if request.HasField("scope") | |
| else None | |
| ) | |
| target_collection = ( | |
| UUID(hex=request.collection) if request.HasField("collection") else None | |
| ) | |
| found_segments = [] | |
| for segment in self._segments.values(): | |
| if target_id and segment["id"] != target_id: | |
| continue | |
| if target_type and segment["type"] != target_type: | |
| continue | |
| if target_scope and segment["scope"] != target_scope: | |
| continue | |
| if target_collection and segment["collection"] != target_collection: | |
| continue | |
| found_segments.append(segment) | |
| return GetSegmentsResponse( | |
| segments=[to_proto_segment(segment) for segment in found_segments] | |
| ) | |
| def UpdateSegment( | |
| self, request: UpdateSegmentRequest, context: grpc.ServicerContext | |
| ) -> UpdateSegmentResponse: | |
| id_to_update = UUID(request.id) | |
| if id_to_update.hex not in self._segments: | |
| return UpdateSegmentResponse( | |
| status=proto.Status( | |
| code=404, reason=f"Segment {id_to_update} not found" | |
| ) | |
| ) | |
| else: | |
| segment = self._segments[id_to_update.hex] | |
| if request.HasField("collection"): | |
| segment["collection"] = UUID(hex=request.collection) | |
| if request.HasField("reset_collection") and request.reset_collection: | |
| segment["collection"] = None | |
| if request.HasField("metadata"): | |
| target = cast(Dict[str, Any], segment["metadata"]) | |
| if segment["metadata"] is None: | |
| segment["metadata"] = {} | |
| self._merge_metadata(target, request.metadata) | |
| if request.HasField("reset_metadata") and request.reset_metadata: | |
| segment["metadata"] = {} | |
| return UpdateSegmentResponse(status=proto.Status(code=200)) | |
| def CreateCollection( | |
| self, request: CreateCollectionRequest, context: grpc.ServicerContext | |
| ) -> CreateCollectionResponse: | |
| collection_name = request.name | |
| tenant = request.tenant | |
| database = request.database | |
| if tenant not in self._tenants_to_databases_to_collections: | |
| return CreateCollectionResponse( | |
| status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
| ) | |
| if database not in self._tenants_to_databases_to_collections[tenant]: | |
| return CreateCollectionResponse( | |
| status=proto.Status(code=404, reason=f"Database {database} not found") | |
| ) | |
| # Check if the collection already exists globally by id | |
| for ( | |
| search_tenant, | |
| databases, | |
| ) in self._tenants_to_databases_to_collections.items(): | |
| for search_database, search_collections in databases.items(): | |
| if request.id in search_collections: | |
| if ( | |
| search_tenant != request.tenant | |
| or search_database != request.database | |
| ): | |
| return CreateCollectionResponse( | |
| status=proto.Status( | |
| code=409, | |
| reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", | |
| ) | |
| ) | |
| elif not request.get_or_create: | |
| # If the id exists for this tenant and database, and we are not doing a get_or_create, then | |
| # we should return a 409 | |
| return CreateCollectionResponse( | |
| status=proto.Status( | |
| code=409, | |
| reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", | |
| ) | |
| ) | |
| # Check if the collection already exists in this database by name | |
| collections = self._tenants_to_databases_to_collections[tenant][database] | |
| matches = [c for c in collections.values() if c["name"] == collection_name] | |
| assert len(matches) <= 1 | |
| if len(matches) > 0: | |
| if request.get_or_create: | |
| existing_collection = matches[0] | |
| if request.HasField("metadata"): | |
| existing_collection["metadata"] = from_proto_metadata( | |
| request.metadata | |
| ) | |
| return CreateCollectionResponse( | |
| status=proto.Status(code=200), | |
| collection=to_proto_collection(existing_collection), | |
| created=False, | |
| ) | |
| return CreateCollectionResponse( | |
| status=proto.Status( | |
| code=409, reason=f"Collection {request.name} already exists" | |
| ) | |
| ) | |
| configuration = CollectionConfigurationInternal.from_json_str( | |
| request.configuration_json_str | |
| ) | |
| id = UUID(hex=request.id) | |
| new_collection = Collection( | |
| id=id, | |
| name=request.name, | |
| configuration=configuration, | |
| metadata=from_proto_metadata(request.metadata), | |
| dimension=request.dimension, | |
| database=database, | |
| tenant=tenant, | |
| version=0, | |
| ) | |
| collections[request.id] = new_collection | |
| return CreateCollectionResponse( | |
| status=proto.Status(code=200), | |
| collection=to_proto_collection(new_collection), | |
| created=True, | |
| ) | |
| def DeleteCollection( | |
| self, request: DeleteCollectionRequest, context: grpc.ServicerContext | |
| ) -> DeleteCollectionResponse: | |
| collection_id = request.id | |
| tenant = request.tenant | |
| database = request.database | |
| if tenant not in self._tenants_to_databases_to_collections: | |
| return DeleteCollectionResponse( | |
| status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
| ) | |
| if database not in self._tenants_to_databases_to_collections[tenant]: | |
| return DeleteCollectionResponse( | |
| status=proto.Status(code=404, reason=f"Database {database} not found") | |
| ) | |
| collections = self._tenants_to_databases_to_collections[tenant][database] | |
| if collection_id in collections: | |
| del collections[collection_id] | |
| return DeleteCollectionResponse(status=proto.Status(code=200)) | |
| else: | |
| return DeleteCollectionResponse( | |
| status=proto.Status( | |
| code=404, reason=f"Collection {collection_id} not found" | |
| ) | |
| ) | |
| def GetCollections( | |
| self, request: GetCollectionsRequest, context: grpc.ServicerContext | |
| ) -> GetCollectionsResponse: | |
| target_id = UUID(hex=request.id) if request.HasField("id") else None | |
| target_name = request.name if request.HasField("name") else None | |
| allCollections = {} | |
| for tenant, databases in self._tenants_to_databases_to_collections.items(): | |
| for database, collections in databases.items(): | |
| if request.tenant != "" and tenant != request.tenant: | |
| continue | |
| if request.database != "" and database != request.database: | |
| continue | |
| allCollections.update(collections) | |
| print( | |
| f"Tenant: {tenant}, Database: {database}, Collections: {collections}" | |
| ) | |
| found_collections = [] | |
| for collection in allCollections.values(): | |
| if target_id and collection["id"] != target_id: | |
| continue | |
| if target_name and collection["name"] != target_name: | |
| continue | |
| found_collections.append(collection) | |
| return GetCollectionsResponse( | |
| collections=[ | |
| to_proto_collection(collection) for collection in found_collections | |
| ] | |
| ) | |
| def UpdateCollection( | |
| self, request: UpdateCollectionRequest, context: grpc.ServicerContext | |
| ) -> UpdateCollectionResponse: | |
| id_to_update = UUID(request.id) | |
| # Find the collection with this id | |
| collections = {} | |
| for tenant, databases in self._tenants_to_databases_to_collections.items(): | |
| for database, maybe_collections in databases.items(): | |
| if id_to_update.hex in maybe_collections: | |
| collections = maybe_collections | |
| if id_to_update.hex not in collections: | |
| return UpdateCollectionResponse( | |
| status=proto.Status( | |
| code=404, reason=f"Collection {id_to_update} not found" | |
| ) | |
| ) | |
| else: | |
| collection = collections[id_to_update.hex] | |
| if request.HasField("name"): | |
| collection["name"] = request.name | |
| if request.HasField("dimension"): | |
| collection["dimension"] = request.dimension | |
| if request.HasField("metadata"): | |
| # TODO: IN SysDB SQlite we have technical debt where we | |
| # replace the entire metadata dict with the new one. We should | |
| # fix that by merging it. For now we just do the same thing here | |
| update_metadata = from_proto_update_metadata(request.metadata) | |
| cleaned_metadata = None | |
| if update_metadata is not None: | |
| cleaned_metadata = {} | |
| for key, value in update_metadata.items(): | |
| if value is not None: | |
| cleaned_metadata[key] = value | |
| collection["metadata"] = cleaned_metadata | |
| elif request.HasField("reset_metadata"): | |
| if request.reset_metadata: | |
| collection["metadata"] = {} | |
| return UpdateCollectionResponse(status=proto.Status(code=200)) | |
| def ResetState( | |
| self, request: Empty, context: grpc.ServicerContext | |
| ) -> ResetStateResponse: | |
| self.reset_state() | |
| return ResetStateResponse(status=proto.Status(code=200)) | |
| def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None: | |
| target_metadata = cast(Dict[str, Any], target) | |
| source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source)) | |
| target_metadata.update(source_metadata) | |
| # If a key has a None value, remove it from the metadata | |
| for key, value in source_metadata.items(): | |
| if value is None and key in target: | |
| del target_metadata[key] | |