Spaces:
Paused
Paused
| import httpx | |
| from typing import Optional, Sequence | |
| from uuid import UUID | |
| from overrides import override | |
| from chromadb.api import AsyncAdminAPI, AsyncClientAPI, AsyncServerAPI | |
| from chromadb.api.configuration import CollectionConfiguration | |
| from chromadb.api.models.AsyncCollection import AsyncCollection | |
| from chromadb.api.shared_system_client import SharedSystemClient | |
| from chromadb.api.types import ( | |
| CollectionMetadata, | |
| DataLoader, | |
| Documents, | |
| Embeddable, | |
| EmbeddingFunction, | |
| Embeddings, | |
| GetResult, | |
| IDs, | |
| Include, | |
| Loadable, | |
| Metadatas, | |
| QueryResult, | |
| URIs, | |
| ) | |
| from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
| from chromadb.errors import ChromaError | |
| from chromadb.types import Database, Tenant, Where, WhereDocument | |
| import chromadb.utils.embedding_functions as ef | |
| class AsyncClient(SharedSystemClient, AsyncClientAPI): | |
| """A client for Chroma. This is the main entrypoint for interacting with Chroma. | |
| A client internally stores its tenant and database and proxies calls to a | |
| Server API instance of Chroma. It treats the Server API and corresponding System | |
| as a singleton, so multiple clients connecting to the same resource will share the | |
| same API instance. | |
| Client implementations should be implement their own API-caching strategies. | |
| """ | |
| # An internal admin client for verifying that databases and tenants exist | |
| _admin_client: AsyncAdminAPI | |
| tenant: str = DEFAULT_TENANT | |
| database: str = DEFAULT_DATABASE | |
| _server: AsyncServerAPI | |
| async def create( | |
| cls, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| settings: Settings = Settings(), | |
| ) -> "AsyncClient": | |
| # Create an admin client for verifying that databases and tenants exist | |
| self = cls(settings=settings) | |
| SharedSystemClient._populate_data_from_system(self._system) | |
| self._admin_client = AsyncAdminClient.from_system(self._system) | |
| await self._validate_tenant_database(tenant=tenant, database=database) | |
| self.tenant = tenant | |
| self.database = database | |
| # Get the root system component we want to interact with | |
| self._server = self._system.instance(AsyncServerAPI) | |
| self._submit_client_start_event() | |
| return self | |
| # (we can't override and use from_system() because it's synchronous) | |
| async def from_system_async( | |
| cls, | |
| system: System, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> "AsyncClient": | |
| """Create a client from an existing system. This is useful for testing and debugging.""" | |
| return await AsyncClient.create(tenant, database, system.settings) | |
| def from_system( | |
| cls, | |
| system: System, | |
| ) -> "SharedSystemClient": | |
| """AsyncClient cannot be created synchronously. Use .from_system_async() instead.""" | |
| raise NotImplementedError( | |
| "AsyncClient cannot be created synchronously. Use .from_system_async() instead." | |
| ) | |
| async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: | |
| await self._validate_tenant_database(tenant=tenant, database=database) | |
| self.tenant = tenant | |
| self.database = database | |
| async def set_database(self, database: str) -> None: | |
| await self._validate_tenant_database(tenant=self.tenant, database=database) | |
| self.database = database | |
| async def _validate_tenant_database(self, tenant: str, database: str) -> None: | |
| try: | |
| await self._admin_client.get_tenant(name=tenant) | |
| except httpx.ConnectError: | |
| raise ValueError( | |
| "Could not connect to a Chroma server. Are you sure it is running?" | |
| ) | |
| # Propagate ChromaErrors | |
| except ChromaError as e: | |
| raise e | |
| except Exception: | |
| raise ValueError( | |
| f"Could not connect to tenant {tenant}. Are you sure it exists?" | |
| ) | |
| try: | |
| await self._admin_client.get_database(name=database, tenant=tenant) | |
| except httpx.ConnectError: | |
| raise ValueError( | |
| "Could not connect to a Chroma server. Are you sure it is running?" | |
| ) | |
| except Exception: | |
| raise ValueError( | |
| f"Could not connect to database {database} for tenant {tenant}. Are you sure it exists?" | |
| ) | |
| # region BaseAPI Methods | |
| # Note - we could do this in less verbose ways, but they break type checking | |
| async def heartbeat(self) -> int: | |
| return await self._server.heartbeat() | |
| async def list_collections( | |
| self, limit: Optional[int] = None, offset: Optional[int] = None | |
| ) -> Sequence[AsyncCollection]: | |
| models = await self._server.list_collections( | |
| limit, offset, tenant=self.tenant, database=self.database | |
| ) | |
| return [ | |
| AsyncCollection( | |
| client=self._server, | |
| model=model, | |
| ) | |
| for model in models | |
| ] | |
| async def count_collections(self) -> int: | |
| return await self._server.count_collections( | |
| tenant=self.tenant, database=self.database | |
| ) | |
| async def create_collection( | |
| self, | |
| name: str, | |
| configuration: Optional[CollectionConfiguration] = None, | |
| metadata: Optional[CollectionMetadata] = None, | |
| embedding_function: Optional[ | |
| EmbeddingFunction[Embeddable] | |
| ] = ef.DefaultEmbeddingFunction(), # type: ignore | |
| data_loader: Optional[DataLoader[Loadable]] = None, | |
| get_or_create: bool = False, | |
| ) -> AsyncCollection: | |
| model = await self._server.create_collection( | |
| name=name, | |
| configuration=configuration, | |
| metadata=metadata, | |
| tenant=self.tenant, | |
| database=self.database, | |
| get_or_create=get_or_create, | |
| ) | |
| return AsyncCollection( | |
| client=self._server, | |
| model=model, | |
| embedding_function=embedding_function, | |
| data_loader=data_loader, | |
| ) | |
| async def get_collection( | |
| self, | |
| name: str, | |
| id: Optional[UUID] = None, | |
| embedding_function: Optional[ | |
| EmbeddingFunction[Embeddable] | |
| ] = ef.DefaultEmbeddingFunction(), # type: ignore | |
| data_loader: Optional[DataLoader[Loadable]] = None, | |
| ) -> AsyncCollection: | |
| model = await self._server.get_collection( | |
| id=id, | |
| name=name, | |
| tenant=self.tenant, | |
| database=self.database, | |
| ) | |
| return AsyncCollection( | |
| client=self._server, | |
| model=model, | |
| embedding_function=embedding_function, | |
| data_loader=data_loader, | |
| ) | |
| async def get_or_create_collection( | |
| self, | |
| name: str, | |
| configuration: Optional[CollectionConfiguration] = None, | |
| metadata: Optional[CollectionMetadata] = None, | |
| embedding_function: Optional[ | |
| EmbeddingFunction[Embeddable] | |
| ] = ef.DefaultEmbeddingFunction(), # type: ignore | |
| data_loader: Optional[DataLoader[Loadable]] = None, | |
| ) -> AsyncCollection: | |
| model = await self._server.get_or_create_collection( | |
| name=name, | |
| configuration=configuration, | |
| metadata=metadata, | |
| tenant=self.tenant, | |
| database=self.database, | |
| ) | |
| return AsyncCollection( | |
| client=self._server, | |
| model=model, | |
| embedding_function=embedding_function, | |
| data_loader=data_loader, | |
| ) | |
| async def _modify( | |
| self, | |
| id: UUID, | |
| new_name: Optional[str] = None, | |
| new_metadata: Optional[CollectionMetadata] = None, | |
| ) -> None: | |
| return await self._server._modify( | |
| id=id, | |
| new_name=new_name, | |
| new_metadata=new_metadata, | |
| ) | |
| async def delete_collection( | |
| self, | |
| name: str, | |
| ) -> None: | |
| return await self._server.delete_collection( | |
| name=name, | |
| tenant=self.tenant, | |
| database=self.database, | |
| ) | |
| # | |
| # ITEM METHODS | |
| # | |
| async def _add( | |
| self, | |
| ids: IDs, | |
| collection_id: UUID, | |
| embeddings: Embeddings, | |
| metadatas: Optional[Metadatas] = None, | |
| documents: Optional[Documents] = None, | |
| uris: Optional[URIs] = None, | |
| ) -> bool: | |
| return await self._server._add( | |
| ids=ids, | |
| collection_id=collection_id, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| documents=documents, | |
| uris=uris, | |
| ) | |
| async def _update( | |
| self, | |
| collection_id: UUID, | |
| ids: IDs, | |
| embeddings: Optional[Embeddings] = None, | |
| metadatas: Optional[Metadatas] = None, | |
| documents: Optional[Documents] = None, | |
| uris: Optional[URIs] = None, | |
| ) -> bool: | |
| return await self._server._update( | |
| collection_id=collection_id, | |
| ids=ids, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| documents=documents, | |
| uris=uris, | |
| ) | |
| async def _upsert( | |
| self, | |
| collection_id: UUID, | |
| ids: IDs, | |
| embeddings: Embeddings, | |
| metadatas: Optional[Metadatas] = None, | |
| documents: Optional[Documents] = None, | |
| uris: Optional[URIs] = None, | |
| ) -> bool: | |
| return await self._server._upsert( | |
| collection_id=collection_id, | |
| ids=ids, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| documents=documents, | |
| uris=uris, | |
| ) | |
| async def _count(self, collection_id: UUID) -> int: | |
| return await self._server._count( | |
| collection_id=collection_id, | |
| ) | |
| async def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: | |
| return await self._server._peek( | |
| collection_id=collection_id, | |
| n=n, | |
| ) | |
| async def _get( | |
| self, | |
| collection_id: UUID, | |
| ids: Optional[IDs] = None, | |
| where: Optional[Where] = {}, | |
| sort: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| offset: Optional[int] = None, | |
| page: Optional[int] = None, | |
| page_size: Optional[int] = None, | |
| where_document: Optional[WhereDocument] = {}, | |
| include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item] | |
| ) -> GetResult: | |
| return await self._server._get( | |
| collection_id=collection_id, | |
| ids=ids, | |
| where=where, | |
| sort=sort, | |
| limit=limit, | |
| offset=offset, | |
| page=page, | |
| page_size=page_size, | |
| where_document=where_document, | |
| include=include, | |
| ) | |
| async def _delete( | |
| self, | |
| collection_id: UUID, | |
| ids: Optional[IDs], | |
| where: Optional[Where] = {}, | |
| where_document: Optional[WhereDocument] = {}, | |
| ) -> IDs: | |
| return await self._server._delete( | |
| collection_id=collection_id, | |
| ids=ids, | |
| where=where, | |
| where_document=where_document, | |
| ) | |
| async def _query( | |
| self, | |
| collection_id: UUID, | |
| query_embeddings: Embeddings, | |
| n_results: int = 10, | |
| where: Where = {}, | |
| where_document: WhereDocument = {}, | |
| include: Include = ["embeddings", "metadatas", "documents", "distances"], # type: ignore[list-item] | |
| ) -> QueryResult: | |
| return await self._server._query( | |
| collection_id=collection_id, | |
| query_embeddings=query_embeddings, | |
| n_results=n_results, | |
| where=where, | |
| where_document=where_document, | |
| include=include, | |
| ) | |
| async def reset(self) -> bool: | |
| return await self._server.reset() | |
| async def get_version(self) -> str: | |
| return await self._server.get_version() | |
| def get_settings(self) -> Settings: | |
| return self._server.get_settings() | |
| async def get_max_batch_size(self) -> int: | |
| return await self._server.get_max_batch_size() | |
| # endregion | |
| class AsyncAdminClient(SharedSystemClient, AsyncAdminAPI): | |
| _server: AsyncServerAPI | |
| def __init__(self, settings: Settings = Settings()) -> None: | |
| super().__init__(settings) | |
| self._server = self._system.instance(AsyncServerAPI) | |
| async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: | |
| return await self._server.create_database(name=name, tenant=tenant) | |
| async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: | |
| return await self._server.get_database(name=name, tenant=tenant) | |
| async def create_tenant(self, name: str) -> None: | |
| return await self._server.create_tenant(name=name) | |
| async def get_tenant(self, name: str) -> Tenant: | |
| return await self._server.get_tenant(name=name) | |
| def from_system( | |
| cls, | |
| system: System, | |
| ) -> "AsyncAdminClient": | |
| SharedSystemClient._populate_data_from_system(system) | |
| instance = cls(settings=system.settings) | |
| return instance | |