Spaces:
Paused
Paused
| import logging | |
| from typing import List, Optional, Sequence, Tuple, Union, cast | |
| from uuid import UUID | |
| from overrides import overrides | |
| from chromadb.api.configuration import CollectionConfigurationInternal | |
| from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger | |
| from chromadb.db.base import NotFoundError, UniqueConstraintError | |
| from chromadb.db.system import SysDB | |
| from chromadb.proto.convert import ( | |
| from_proto_collection, | |
| from_proto_segment, | |
| to_proto_update_metadata, | |
| to_proto_segment, | |
| to_proto_segment_scope, | |
| ) | |
| from chromadb.proto.coordinator_pb2 import ( | |
| CreateCollectionRequest, | |
| CreateDatabaseRequest, | |
| CreateSegmentRequest, | |
| CreateTenantRequest, | |
| DeleteCollectionRequest, | |
| DeleteSegmentRequest, | |
| GetCollectionsRequest, | |
| GetCollectionsResponse, | |
| GetDatabaseRequest, | |
| GetSegmentsRequest, | |
| GetTenantRequest, | |
| UpdateCollectionRequest, | |
| UpdateSegmentRequest, | |
| ) | |
| from chromadb.proto.coordinator_pb2_grpc import SysDBStub | |
| from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor | |
| from chromadb.types import ( | |
| Collection, | |
| Database, | |
| Metadata, | |
| OptionalArgument, | |
| Segment, | |
| SegmentScope, | |
| Tenant, | |
| Unspecified, | |
| UpdateMetadata, | |
| ) | |
| from google.protobuf.empty_pb2 import Empty | |
| import grpc | |
| class GrpcSysDB(SysDB): | |
| """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also | |
| called the 'Coordinator'. This implementation is used by Chroma frontend servers | |
| to call a remote SysDB (Coordinator) service.""" | |
| _sys_db_stub: SysDBStub | |
| _channel: grpc.Channel | |
| _coordinator_url: str | |
| _coordinator_port: int | |
| def __init__(self, system: System): | |
| self._coordinator_url = system.settings.require("chroma_coordinator_host") | |
| # TODO: break out coordinator_port into a separate setting? | |
| self._coordinator_port = system.settings.require("chroma_server_grpc_port") | |
| return super().__init__(system) | |
| def start(self) -> None: | |
| # TODO: add retry policy here | |
| self._channel = grpc.insecure_channel( | |
| f"{self._coordinator_url}:{self._coordinator_port}" | |
| ) | |
| interceptors = [OtelInterceptor()] | |
| self._channel = grpc.intercept_channel(self._channel, *interceptors) | |
| self._sys_db_stub = SysDBStub(self._channel) # type: ignore | |
| return super().start() | |
| def stop(self) -> None: | |
| self._channel.close() | |
| return super().stop() | |
| def reset_state(self) -> None: | |
| self._sys_db_stub.ResetState(Empty()) | |
| return super().reset_state() | |
| def create_database( | |
| self, id: UUID, name: str, tenant: str = DEFAULT_TENANT | |
| ) -> None: | |
| request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant) | |
| response = self._sys_db_stub.CreateDatabase(request) | |
| if response.status.code == 409: | |
| raise UniqueConstraintError() | |
| def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: | |
| request = GetDatabaseRequest(name=name, tenant=tenant) | |
| response = self._sys_db_stub.GetDatabase(request) | |
| if response.status.code == 404: | |
| raise NotFoundError() | |
| return Database( | |
| id=UUID(hex=response.database.id), | |
| name=response.database.name, | |
| tenant=response.database.tenant, | |
| ) | |
| def create_tenant(self, name: str) -> None: | |
| request = CreateTenantRequest(name=name) | |
| response = self._sys_db_stub.CreateTenant(request) | |
| if response.status.code == 409: | |
| raise UniqueConstraintError() | |
| def get_tenant(self, name: str) -> Tenant: | |
| request = GetTenantRequest(name=name) | |
| response = self._sys_db_stub.GetTenant(request) | |
| if response.status.code == 404: | |
| raise NotFoundError() | |
| return Tenant( | |
| name=response.tenant.name, | |
| ) | |
| def create_segment(self, segment: Segment) -> None: | |
| proto_segment = to_proto_segment(segment) | |
| request = CreateSegmentRequest( | |
| segment=proto_segment, | |
| ) | |
| response = self._sys_db_stub.CreateSegment(request) | |
| if response.status.code == 409: | |
| raise UniqueConstraintError() | |
| def delete_segment(self, id: UUID) -> None: | |
| request = DeleteSegmentRequest( | |
| id=id.hex, | |
| ) | |
| response = self._sys_db_stub.DeleteSegment(request) | |
| if response.status.code == 404: | |
| raise NotFoundError() | |
| def get_segments( | |
| self, | |
| id: Optional[UUID] = None, | |
| type: Optional[str] = None, | |
| scope: Optional[SegmentScope] = None, | |
| collection: Optional[UUID] = None, | |
| ) -> Sequence[Segment]: | |
| request = GetSegmentsRequest( | |
| id=id.hex if id else None, | |
| type=type, | |
| scope=to_proto_segment_scope(scope) if scope else None, | |
| collection=collection.hex if collection else None, | |
| ) | |
| response = self._sys_db_stub.GetSegments(request) | |
| results: List[Segment] = [] | |
| for proto_segment in response.segments: | |
| segment = from_proto_segment(proto_segment) | |
| results.append(segment) | |
| return results | |
| def update_segment( | |
| self, | |
| id: UUID, | |
| collection: OptionalArgument[Optional[UUID]] = Unspecified(), | |
| metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | |
| ) -> None: | |
| write_collection = None | |
| if collection != Unspecified(): | |
| write_collection = cast(Union[UUID, None], collection) | |
| write_metadata = None | |
| if metadata != Unspecified(): | |
| write_metadata = cast(Union[UpdateMetadata, None], metadata) | |
| request = UpdateSegmentRequest( | |
| id=id.hex, | |
| collection=write_collection.hex if write_collection else None, | |
| metadata=to_proto_update_metadata(write_metadata) | |
| if write_metadata | |
| else None, | |
| ) | |
| if collection is None: | |
| request.ClearField("collection") | |
| request.reset_collection = True | |
| if metadata is None: | |
| request.ClearField("metadata") | |
| request.reset_metadata = True | |
| self._sys_db_stub.UpdateSegment(request) | |
| def create_collection( | |
| self, | |
| id: UUID, | |
| name: str, | |
| configuration: CollectionConfigurationInternal, | |
| metadata: Optional[Metadata] = None, | |
| dimension: Optional[int] = None, | |
| get_or_create: bool = False, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> Tuple[Collection, bool]: | |
| request = CreateCollectionRequest( | |
| id=id.hex, | |
| name=name, | |
| configuration_json_str=configuration.to_json_str(), | |
| metadata=to_proto_update_metadata(metadata) if metadata else None, | |
| dimension=dimension, | |
| get_or_create=get_or_create, | |
| tenant=tenant, | |
| database=database, | |
| ) | |
| response = self._sys_db_stub.CreateCollection(request) | |
| # TODO: this needs to be changed to try, catch instead of checking the status code | |
| if response.status.code != 200: | |
| logger.info(f"failed to create collection, response: {response}") | |
| if response.status.code == 409: | |
| raise UniqueConstraintError() | |
| collection = from_proto_collection(response.collection) | |
| return collection, response.created | |
| def delete_collection( | |
| self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE | |
| ) -> None: | |
| request = DeleteCollectionRequest( | |
| id=id.hex, | |
| tenant=tenant, | |
| database=database, | |
| ) | |
| response = self._sys_db_stub.DeleteCollection(request) | |
| logging.debug(f"delete_collection response: {response}") | |
| if response.status.code == 404: | |
| raise NotFoundError() | |
| def get_collections( | |
| self, | |
| id: Optional[UUID] = None, | |
| name: Optional[str] = None, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| limit: Optional[int] = None, | |
| offset: Optional[int] = None, | |
| ) -> Sequence[Collection]: | |
| # TODO: implement limit and offset in the gRPC service | |
| request = None | |
| if id is not None: | |
| request = GetCollectionsRequest( | |
| id=id.hex, | |
| limit=limit, | |
| offset=offset, | |
| ) | |
| if name is not None: | |
| if tenant is None and database is None: | |
| raise ValueError( | |
| "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" | |
| ) | |
| request = GetCollectionsRequest( | |
| name=name, | |
| tenant=tenant, | |
| database=database, | |
| limit=limit, | |
| offset=offset, | |
| ) | |
| if id is None and name is None: | |
| request = GetCollectionsRequest( | |
| tenant=tenant, | |
| database=database, | |
| limit=limit, | |
| offset=offset, | |
| ) | |
| response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request) | |
| results: List[Collection] = [] | |
| for collection in response.collections: | |
| results.append(from_proto_collection(collection)) | |
| return results | |
| def update_collection( | |
| self, | |
| id: UUID, | |
| name: OptionalArgument[str] = Unspecified(), | |
| dimension: OptionalArgument[Optional[int]] = Unspecified(), | |
| metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | |
| ) -> None: | |
| write_name = None | |
| if name != Unspecified(): | |
| write_name = cast(str, name) | |
| write_dimension = None | |
| if dimension != Unspecified(): | |
| write_dimension = cast(Union[int, None], dimension) | |
| write_metadata = None | |
| if metadata != Unspecified(): | |
| write_metadata = cast(Union[UpdateMetadata, None], metadata) | |
| request = UpdateCollectionRequest( | |
| id=id.hex, | |
| name=write_name, | |
| dimension=write_dimension, | |
| metadata=to_proto_update_metadata(write_metadata) | |
| if write_metadata | |
| else None, | |
| ) | |
| if metadata is None: | |
| request.ClearField("metadata") | |
| request.reset_metadata = True | |
| response = self._sys_db_stub.UpdateCollection(request) | |
| if response.status.code == 404: | |
| raise NotFoundError() | |
| if response.status.code == 409: | |
| raise UniqueConstraintError() | |
| def reset_and_wait_for_ready(self) -> None: | |
| self._sys_db_stub.ResetState(Empty(), wait_for_ready=True) | |