Spaces:
Runtime error
Runtime error
| from typing import Any, Callable, Dict, Optional, cast | |
| from overrides import EnforceOverrides, override | |
| from chromadb.config import System | |
| from chromadb.segment.distributed import ( | |
| Memberlist, | |
| MemberlistProvider, | |
| SegmentDirectory, | |
| ) | |
| from chromadb.types import Segment | |
| from kubernetes import client, config, watch | |
| from kubernetes.client.rest import ApiException | |
| import threading | |
| from chromadb.utils.rendezvous_hash import assign, murmur3hasher | |
| # These could go in config but given that they will rarely change, they are here for now to avoid | |
| # polluting the config file further. | |
| WATCH_TIMEOUT_SECONDS = 60 | |
| KUBERNETES_NAMESPACE = "chroma" | |
| KUBERNETES_GROUP = "chroma.cluster" | |
| class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): | |
| """A mock memberlist provider for testing""" | |
| _memberlist: Memberlist | |
| def __init__(self, system: System): | |
| super().__init__(system) | |
| self._memberlist = ["a", "b", "c"] | |
| def get_memberlist(self) -> Memberlist: | |
| return self._memberlist | |
| def set_memberlist_name(self, memberlist: str) -> None: | |
| pass # The mock provider does not need to set the memberlist name | |
| def update_memberlist(self, memberlist: Memberlist) -> None: | |
| """Updates the memberlist and calls all registered callbacks. This mocks an update from a k8s CR""" | |
| self._memberlist = memberlist | |
| for callback in self.callbacks: | |
| callback(memberlist) | |
| class CustomResourceMemberlistProvider(MemberlistProvider, EnforceOverrides): | |
| """A memberlist provider that uses a k8s custom resource to store the memberlist""" | |
| _kubernetes_api: client.CustomObjectsApi | |
| _memberlist_name: Optional[str] | |
| _curr_memberlist: Optional[Memberlist] | |
| _curr_memberlist_mutex: threading.Lock | |
| _watch_thread: Optional[threading.Thread] | |
| _kill_watch_thread: threading.Event | |
| def __init__(self, system: System): | |
| super().__init__(system) | |
| config.load_config() | |
| self._kubernetes_api = client.CustomObjectsApi() | |
| self._watch_thread = None | |
| self._memberlist_name = None | |
| self._curr_memberlist = None | |
| self._curr_memberlist_mutex = threading.Lock() | |
| self._kill_watch_thread = threading.Event() | |
| def start(self) -> None: | |
| if self._memberlist_name is None: | |
| raise ValueError("Memberlist name must be set before starting") | |
| self.get_memberlist() | |
| self._watch_worker_memberlist() | |
| return super().start() | |
| def stop(self) -> None: | |
| self._curr_memberlist = None | |
| self._memberlist_name = None | |
| # Stop the watch thread | |
| self._kill_watch_thread.set() | |
| if self._watch_thread is not None: | |
| self._watch_thread.join() | |
| self._watch_thread = None | |
| self._kill_watch_thread.clear() | |
| return super().stop() | |
| def reset_state(self) -> None: | |
| if not self._system.settings.require("allow_reset"): | |
| raise ValueError( | |
| "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." | |
| ) | |
| if self._memberlist_name: | |
| self._kubernetes_api.patch_namespaced_custom_object( | |
| group=KUBERNETES_GROUP, | |
| version="v1", | |
| namespace=KUBERNETES_NAMESPACE, | |
| plural="memberlists", | |
| name=self._memberlist_name, | |
| body={ | |
| "kind": "MemberList", | |
| "spec": {"members": []}, | |
| }, | |
| ) | |
| def get_memberlist(self) -> Memberlist: | |
| if self._curr_memberlist is None: | |
| self._curr_memberlist = self._fetch_memberlist() | |
| return self._curr_memberlist | |
| def set_memberlist_name(self, memberlist: str) -> None: | |
| self._memberlist_name = memberlist | |
| def _fetch_memberlist(self) -> Memberlist: | |
| api_response = self._kubernetes_api.get_namespaced_custom_object( | |
| group=KUBERNETES_GROUP, | |
| version="v1", | |
| namespace=KUBERNETES_NAMESPACE, | |
| plural="memberlists", | |
| name=f"{self._memberlist_name}", | |
| ) | |
| api_response = cast(Dict[str, Any], api_response) | |
| if "spec" not in api_response: | |
| return [] | |
| response_spec = cast(Dict[str, Any], api_response["spec"]) | |
| return self._parse_response_memberlist(response_spec) | |
| def _watch_worker_memberlist(self) -> None: | |
| # TODO: We may want to make this watch function a library function that can be used by other | |
| # components that need to watch k8s custom resources. | |
| def run_watch() -> None: | |
| w = watch.Watch() | |
| def do_watch() -> None: | |
| for event in w.stream( | |
| self._kubernetes_api.list_namespaced_custom_object, | |
| group=KUBERNETES_GROUP, | |
| version="v1", | |
| namespace=KUBERNETES_NAMESPACE, | |
| plural="memberlists", | |
| field_selector=f"metadata.name={self._memberlist_name}", | |
| timeout_seconds=WATCH_TIMEOUT_SECONDS, | |
| ): | |
| event = cast(Dict[str, Any], event) | |
| response_spec = event["object"]["spec"] | |
| response_spec = cast(Dict[str, Any], response_spec) | |
| with self._curr_memberlist_mutex: | |
| self._curr_memberlist = self._parse_response_memberlist( | |
| response_spec | |
| ) | |
| self._notify(self._curr_memberlist) | |
| # Watch the custom resource for changes | |
| # Watch with a timeout and retry so we can gracefully stop this if needed | |
| while not self._kill_watch_thread.is_set(): | |
| try: | |
| do_watch() | |
| except ApiException as e: | |
| # If status code is 410, the watch has expired and we need to start a new one. | |
| if e.status == 410: | |
| pass | |
| return | |
| if self._watch_thread is None: | |
| thread = threading.Thread(target=run_watch, daemon=True) | |
| thread.start() | |
| self._watch_thread = thread | |
| else: | |
| raise Exception("A watch thread is already running.") | |
| def _parse_response_memberlist( | |
| self, api_response_spec: Dict[str, Any] | |
| ) -> Memberlist: | |
| if "members" not in api_response_spec: | |
| return [] | |
| return [m["url"] for m in api_response_spec["members"]] | |
| def _notify(self, memberlist: Memberlist) -> None: | |
| for callback in self.callbacks: | |
| callback(memberlist) | |
| class RendezvousHashSegmentDirectory(SegmentDirectory, EnforceOverrides): | |
| _memberlist_provider: MemberlistProvider | |
| _curr_memberlist_mutex: threading.Lock | |
| _curr_memberlist: Optional[Memberlist] | |
| def __init__(self, system: System): | |
| super().__init__(system) | |
| self._memberlist_provider = self.require(MemberlistProvider) | |
| memberlist_name = system.settings.require("worker_memberlist_name") | |
| self._memberlist_provider.set_memberlist_name(memberlist_name) | |
| self._curr_memberlist = None | |
| self._curr_memberlist_mutex = threading.Lock() | |
| def start(self) -> None: | |
| self._curr_memberlist = self._memberlist_provider.get_memberlist() | |
| self._memberlist_provider.register_updated_memberlist_callback( | |
| self._update_memberlist | |
| ) | |
| return super().start() | |
| def stop(self) -> None: | |
| self._memberlist_provider.unregister_updated_memberlist_callback( | |
| self._update_memberlist | |
| ) | |
| return super().stop() | |
| def get_segment_endpoint(self, segment: Segment) -> str: | |
| if self._curr_memberlist is None or len(self._curr_memberlist) == 0: | |
| raise ValueError("Memberlist is not initialized") | |
| assignment = assign(segment["id"].hex, self._curr_memberlist, murmur3hasher) | |
| assignment = f"{assignment}:50051" # TODO: make port configurable | |
| return assignment | |
| def register_updated_segment_callback( | |
| self, callback: Callable[[Segment], None] | |
| ) -> None: | |
| raise NotImplementedError() | |
| def _update_memberlist(self, memberlist: Memberlist) -> None: | |
| with self._curr_memberlist_mutex: | |
| self._curr_memberlist = memberlist | |