Spaces:
Paused
Paused
Term-4-Project
/
myenv
/lib
/python3.12
/site-packages
/chromadb
/segment
/impl
/manager
/distributed.py
| from threading import Lock | |
| from chromadb.segment import ( | |
| SegmentImplementation, | |
| SegmentManager, | |
| MetadataReader, | |
| SegmentType, | |
| VectorReader, | |
| S, | |
| ) | |
| from chromadb.config import System, get_class | |
| from chromadb.db.system import SysDB | |
| from overrides import override | |
| from chromadb.segment.distributed import SegmentDirectory | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryClient, | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata | |
| from typing import Dict, Type, Sequence, Optional, cast | |
| from uuid import UUID, uuid4 | |
| from collections import defaultdict | |
| # TODO: it is odd that the segment manager is different for distributed vs local | |
| # implementations. This should be refactored to be more consistent and shared. | |
| # needed in this is the ability to specify the desired segment types for a collection | |
| # It is odd that segment manager is coupled to the segment implementation. We need to rethink | |
| # this abstraction. | |
| SEGMENT_TYPE_IMPLS = { | |
| SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", | |
| SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.grpc_segment.GrpcVectorSegment", | |
| SegmentType.BLOCKFILE_METADATA: "chromadb.segment.impl.metadata.grpc_segment.GrpcMetadataSegment", | |
| } | |
| class DistributedSegmentManager(SegmentManager): | |
| _sysdb: SysDB | |
| _system: System | |
| _opentelemetry_client: OpenTelemetryClient | |
| _instances: Dict[UUID, SegmentImplementation] | |
| _segment_cache: Dict[ | |
| UUID, Dict[SegmentScope, Segment] | |
| ] # collection_id -> scope -> segment | |
| _segment_directory: SegmentDirectory | |
| _lock: Lock | |
| # _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub | |
| def __init__(self, system: System): | |
| super().__init__(system) | |
| self._sysdb = self.require(SysDB) | |
| self._segment_directory = self.require(SegmentDirectory) | |
| self._system = system | |
| self._opentelemetry_client = system.require(OpenTelemetryClient) | |
| self._instances = {} | |
| self._segment_cache = defaultdict(dict) | |
| self._lock = Lock() | |
| def create_segments(self, collection: Collection) -> Sequence[Segment]: | |
| vector_segment = _segment( | |
| SegmentType.HNSW_DISTRIBUTED, SegmentScope.VECTOR, collection | |
| ) | |
| metadata_segment = _segment( | |
| SegmentType.BLOCKFILE_METADATA, SegmentScope.METADATA, collection | |
| ) | |
| record_segment = _segment( | |
| SegmentType.BLOCKFILE_RECORD, SegmentScope.RECORD, collection | |
| ) | |
| return [vector_segment, record_segment, metadata_segment] | |
| def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: | |
| segments = self._sysdb.get_segments(collection=collection_id) | |
| return [s["id"] for s in segments] | |
| def get_segment(self, collection_id: UUID, type: Type[S]) -> S: | |
| if type == MetadataReader: | |
| scope = SegmentScope.METADATA | |
| elif type == VectorReader: | |
| scope = SegmentScope.VECTOR | |
| else: | |
| raise ValueError(f"Invalid segment type: {type}") | |
| if scope not in self._segment_cache[collection_id]: | |
| segments = self._sysdb.get_segments(collection=collection_id, scope=scope) | |
| known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) | |
| # Get the first segment of a known type | |
| segment = next(filter(lambda s: s["type"] in known_types, segments)) | |
| grpc_url = self._segment_directory.get_segment_endpoint(segment) | |
| if segment["metadata"] is not None: | |
| segment["metadata"]["grpc_url"] = grpc_url # type: ignore | |
| else: | |
| segment["metadata"] = {"grpc_url": grpc_url} | |
| # TODO: Register a callback to update the segment when it gets moved | |
| # self._segment_directory.register_updated_segment_callback() | |
| self._segment_cache[collection_id][scope] = segment | |
| # Instances must be atomically created, so we use a lock to ensure that only one thread | |
| # creates the instance. | |
| with self._lock: | |
| instance = self._instance(self._segment_cache[collection_id][scope]) | |
| return cast(S, instance) | |
| def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: | |
| pass | |
| # TODO: rethink duplication from local segment manager | |
| def _cls(self, segment: Segment) -> Type[SegmentImplementation]: | |
| classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] | |
| cls = get_class(classname, SegmentImplementation) | |
| return cls | |
| def _instance(self, segment: Segment) -> SegmentImplementation: | |
| if segment["id"] not in self._instances: | |
| cls = self._cls(segment) | |
| instance = cls(self._system, segment) | |
| instance.start() | |
| self._instances[segment["id"]] = instance | |
| return self._instances[segment["id"]] | |
| # TODO: rethink duplication from local segment manager | |
| def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment: | |
| """Create a metadata dict, propagating metadata correctly for the given segment type.""" | |
| metadata: Optional[Metadata] = None | |
| # For the segment types with python implementations, we can propagate metadata | |
| if type in SEGMENT_TYPE_IMPLS: | |
| cls = get_class(SEGMENT_TYPE_IMPLS[type], SegmentImplementation) | |
| collection_metadata = collection.metadata | |
| if collection_metadata: | |
| metadata = cls.propagate_collection_metadata(collection_metadata) | |
| return Segment( | |
| id=uuid4(), | |
| type=type.value, | |
| scope=scope, | |
| collection=collection.id, | |
| metadata=metadata, | |
| ) | |