Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, List, Sequence, Set | |
| from uuid import UUID | |
| from chromadb.config import Settings, System | |
| from chromadb.ingest import CollectionAssignmentPolicy, Consumer | |
| from chromadb.proto.chroma_pb2_grpc import ( | |
| # SegmentServerServicer, | |
| # add_SegmentServerServicer_to_server, | |
| VectorReaderServicer, | |
| add_VectorReaderServicer_to_server, | |
| ) | |
| import chromadb.proto.chroma_pb2 as proto | |
| import grpc | |
| from concurrent import futures | |
| from chromadb.proto.convert import ( | |
| to_proto_vector_embedding_record | |
| ) | |
| from chromadb.segment import SegmentImplementation, SegmentType | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryClient | |
| ) | |
| from chromadb.types import EmbeddingRecord | |
| from chromadb.segment.distributed import MemberlistProvider, Memberlist | |
| from chromadb.utils.rendezvous_hash import assign, murmur3hasher | |
| from chromadb.ingest.impl.pulsar_admin import PulsarAdmin | |
| import logging | |
| import os | |
| # This file is a prototype. It will be replaced with a real distributed segment server | |
| # written in a different language. This is just a proof of concept to get the distributed | |
| # segment type working end to end. | |
| # Run this with python -m chromadb.segment.impl.distributed.server | |
| SEGMENT_TYPE_IMPLS = { | |
| SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", | |
| } | |
| class SegmentServer(VectorReaderServicer): | |
| _segment_cache: Dict[UUID, SegmentImplementation] = {} | |
| _system: System | |
| _opentelemetry_client: OpenTelemetryClient | |
| _memberlist_provider: MemberlistProvider | |
| _curr_memberlist: Memberlist | |
| _assigned_topics: Set[str] | |
| _topic_to_subscription: Dict[str, UUID] | |
| _consumer: Consumer | |
| def __init__(self, system: System) -> None: | |
| super().__init__() | |
| self._system = system | |
| # Init dependency services | |
| self._opentelemetry_client = system.require(OpenTelemetryClient) | |
| # TODO: add term and epoch to segment server | |
| self._memberlist_provider = system.require(MemberlistProvider) | |
| self._memberlist_provider.set_memberlist_name("worker-memberlist") | |
| self._assignment_policy = system.require(CollectionAssignmentPolicy) | |
| self._create_pulsar_topics() | |
| self._consumer = system.require(Consumer) | |
| # Init data | |
| self._topic_to_subscription = {} | |
| self._assigned_topics = set() | |
| self._curr_memberlist = self._memberlist_provider.get_memberlist() | |
| self._compute_assigned_topics() | |
| self._memberlist_provider.register_updated_memberlist_callback( | |
| self._on_memberlist_update | |
| ) | |
| def _compute_assigned_topics(self) -> None: | |
| """Uses rendezvous hashing to compute the topics that this node is responsible for""" | |
| if not self._curr_memberlist: | |
| self._assigned_topics = set() | |
| return | |
| topics = self._assignment_policy.get_topics() | |
| my_ip = os.environ["MY_POD_IP"] | |
| new_assignments: List[str] = [] | |
| for topic in topics: | |
| assigned = assign(topic, self._curr_memberlist, murmur3hasher) | |
| if assigned == my_ip: | |
| new_assignments.append(topic) | |
| new_assignments_set = set(new_assignments) | |
| # TODO: We need to lock around this assignment | |
| net_new_assignments = new_assignments_set - self._assigned_topics | |
| removed_assignments = self._assigned_topics - new_assignments_set | |
| for topic in removed_assignments: | |
| subscription = self._topic_to_subscription[topic] | |
| self._consumer.unsubscribe(subscription) | |
| del self._topic_to_subscription[topic] | |
| for topic in net_new_assignments: | |
| subscription = self._consumer.subscribe(topic, self._on_message) | |
| self._topic_to_subscription[topic] = subscription | |
| self._assigned_topics = new_assignments_set | |
| print( | |
| f"Topic assigment updated and now assigned to {len(self._assigned_topics)} topics" | |
| ) | |
| def _on_memberlist_update(self, memberlist: Memberlist) -> None: | |
| """Called when the memberlist is updated""" | |
| self._curr_memberlist = memberlist | |
| if len(self._curr_memberlist) > 0: | |
| self._compute_assigned_topics() | |
| else: | |
| # In this case we'd want to warn that there are no members but | |
| # this is not an error, as it could be that the cluster is just starting up | |
| print("Memberlist is empty") | |
| def _on_message(self, embedding_records: Sequence[EmbeddingRecord]) -> None: | |
| """Called when a message is received from the consumer""" | |
| print(f"Received {len(embedding_records)} records") | |
| print( | |
| f"First record: {embedding_records[0]} is for collection {embedding_records[0]['collection_id']}" | |
| ) | |
| return None | |
| def _create_pulsar_topics(self) -> None: | |
| """This creates the pulsar topics used by the system. | |
| HACK: THIS IS COMPLETELY A HACK AND WILL BE REPLACED | |
| BY A PROPER TOPIC MANAGEMENT SYSTEM IN THE COORDINATOR""" | |
| topics = self._assignment_policy.get_topics() | |
| admin = PulsarAdmin(self._system) | |
| for topic in topics: | |
| admin.create_topic(topic) | |
| def QueryVectors( | |
| self, request: proto.QueryVectorsRequest, context: Any | |
| ) -> proto.QueryVectorsResponse: | |
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |
| context.set_details("Query segment not implemented yet") | |
| return proto.QueryVectorsResponse() | |
| # @trace_method( | |
| # "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT | |
| # ) | |
| # def GetVectors( | |
| # self, request: proto.GetVectorsRequest, context: Any | |
| # ) -> proto.GetVectorsResponse: | |
| # segment_id = UUID(hex=request.segment_id) | |
| # if segment_id not in self._segment_cache: | |
| # context.set_code(grpc.StatusCode.NOT_FOUND) | |
| # context.set_details("Segment not found") | |
| # return proto.GetVectorsResponse() | |
| # else: | |
| # segment = self._segment_cache[segment_id] | |
| # segment = cast(VectorReader, segment) | |
| # segment_results = segment.get_vectors(request.ids) | |
| # return_records = [] | |
| # for record in segment_results: | |
| # # TODO: encoding should be based on stored encoding for segment | |
| # # For now we just assume float32 | |
| # return_record = to_proto_vector_embedding_record( | |
| # record, ScalarEncoding.FLOAT32 | |
| # ) | |
| # return_records.append(return_record) | |
| # return proto.GetVectorsResponse(records=return_records) | |
| # def _cls(self, segment: Segment) -> Type[SegmentImplementation]: | |
| # classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] | |
| # cls = get_class(classname, SegmentImplementation) | |
| # return cls | |
| # def _create_instance(self, segment: Segment) -> None: | |
| # if segment["id"] not in self._segment_cache: | |
| # cls = self._cls(segment) | |
| # instance = cls(self._system, segment) | |
| # instance.start() | |
| # self._segment_cache[segment["id"]] = instance | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| system = System(Settings()) | |
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
| segment_server = SegmentServer(system) | |
| # add_SegmentServerServicer_to_server(segment_server, server) # type: ignore | |
| add_VectorReaderServicer_to_server(segment_server, server) # type: ignore | |
| server.add_insecure_port( | |
| f"[::]:{system.settings.require('chroma_server_grpc_port')}" | |
| ) | |
| system.start() | |
| server.start() | |
| server.wait_for_termination() | |