Spaces:
Runtime error
Runtime error
| """ | |
| Contains only Auth abstractions, no implementations. | |
| """ | |
| import base64 | |
| from functools import partial | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| from typing import ( | |
| Any, | |
| Callable, | |
| List, | |
| Optional, | |
| Dict, | |
| TypeVar, | |
| Tuple, | |
| Generic, | |
| Union, | |
| ) | |
| from dataclasses import dataclass | |
| from overrides import EnforceOverrides, override | |
| from pydantic import SecretStr | |
| from chromadb.config import ( | |
| DEFAULT_DATABASE, | |
| DEFAULT_TENANT, | |
| Component, | |
| System, | |
| ) | |
| from chromadb.errors import ChromaError | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T") | |
| S = TypeVar("S") | |
| class AuthInfoType(Enum): | |
| COOKIE = "cookie" | |
| HEADER = "header" | |
| URL = "url" | |
| METADATA = "metadata" # gRPC | |
| class UserIdentity(EnforceOverrides, ABC): | |
| def get_user_id(self) -> str: | |
| ... | |
| def get_user_tenant(self) -> Optional[str]: | |
| ... | |
| def get_user_databases(self) -> Optional[List[str]]: | |
| ... | |
| def get_user_attributes(self) -> Optional[Dict[str, Any]]: | |
| ... | |
| class SimpleUserIdentity(UserIdentity): | |
| def __init__( | |
| self, | |
| user_id: str, | |
| tenant: Optional[str] = None, | |
| databases: Optional[List[str]] = None, | |
| attributes: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| self._user_id = user_id | |
| self._tenant = tenant | |
| self._attributes = attributes | |
| self._databases = databases | |
| def get_user_id(self) -> str: | |
| return self._user_id | |
| def get_user_tenant(self) -> Optional[str]: | |
| return self._tenant if self._tenant else DEFAULT_TENANT | |
| def get_user_databases(self) -> Optional[List[str]]: | |
| return self._databases | |
| def get_user_attributes(self) -> Optional[Dict[str, Any]]: | |
| return self._attributes | |
| class ClientAuthResponse(EnforceOverrides, ABC): | |
| def get_auth_info_type(self) -> AuthInfoType: | |
| ... | |
| def get_auth_info( | |
| self, | |
| ) -> Union[Tuple[str, SecretStr], List[Tuple[str, SecretStr]]]: | |
| ... | |
| class ClientAuthProvider(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def authenticate(self) -> ClientAuthResponse: | |
| pass | |
| class ClientAuthConfigurationProvider(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def get_configuration(self) -> Optional[T]: | |
| pass | |
| class ClientAuthCredentialsProvider(Component, Generic[T]): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def get_credentials(self) -> T: | |
| pass | |
| class ClientAuthProtocolAdapter(Component, Generic[T]): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def inject_credentials(self, injection_context: T) -> None: | |
| pass | |
| # SERVER-SIDE Abstractions | |
| class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]): | |
| def get_auth_info(self, auth_info_type: AuthInfoType, auth_info_id: str) -> T: | |
| """ | |
| This method should return the necessary auth info based on the type of | |
| authentication (e.g. header, cookie, url) and a given id for the respective | |
| auth type (e.g. name of the header, cookie, url param). | |
| :param auth_info_type: The type of auth info to return | |
| :param auth_info_id: The id of the auth info to return | |
| :return: The auth info which can be specific to the implementation | |
| """ | |
| pass | |
| class ServerAuthenticationResponse(EnforceOverrides, ABC): | |
| def success(self) -> bool: | |
| ... | |
| def get_user_identity(self) -> Optional[UserIdentity]: | |
| ... | |
| class SimpleServerAuthenticationResponse(ServerAuthenticationResponse): | |
| """Simple implementation of ServerAuthenticationResponse""" | |
| _auth_success: bool | |
| _user_identity: Optional[UserIdentity] | |
| def __init__( | |
| self, auth_success: bool, user_identity: Optional[UserIdentity] | |
| ) -> None: | |
| self._auth_success = auth_success | |
| self._user_identity = user_identity | |
| def success(self) -> bool: | |
| return self._auth_success | |
| def get_user_identity(self) -> Optional[UserIdentity]: | |
| return self._user_identity | |
| class ServerAuthProvider(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def authenticate( | |
| self, request: ServerAuthenticationRequest[T] | |
| ) -> ServerAuthenticationResponse: | |
| pass | |
| class ChromaAuthMiddleware(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def authenticate( | |
| self, request: ServerAuthenticationRequest[T] | |
| ) -> ServerAuthenticationResponse: | |
| ... | |
| def ignore_operation(self, verb: str, path: str) -> bool: | |
| ... | |
| def instrument_server(self, app: T) -> None: | |
| ... | |
| class ServerAuthConfigurationProvider(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def get_configuration(self) -> Optional[T]: | |
| pass | |
| class AuthenticationError(ChromaError): | |
| def code(self) -> int: | |
| return 401 | |
| def name(cls) -> str: | |
| return "AuthenticationError" | |
| class AbstractCredentials(EnforceOverrides, ABC, Generic[T]): | |
| """ | |
| The class is used by Auth Providers to encapsulate credentials received | |
| from the server and pass them to a ServerAuthCredentialsProvider. | |
| """ | |
| def get_credentials(self) -> Dict[str, T]: | |
| """ | |
| Returns the data encapsulated by the credentials object. | |
| """ | |
| pass | |
| class SecretStrAbstractCredentials(AbstractCredentials[SecretStr]): | |
| def get_credentials(self) -> Dict[str, SecretStr]: | |
| """ | |
| Returns the data encapsulated by the credentials object. | |
| """ | |
| pass | |
| class BasicAuthCredentials(SecretStrAbstractCredentials): | |
| def __init__(self, username: SecretStr, password: SecretStr) -> None: | |
| self.username = username | |
| self.password = password | |
| def get_credentials(self) -> Dict[str, SecretStr]: | |
| return {"username": self.username, "password": self.password} | |
| def from_header(header: str) -> "BasicAuthCredentials": | |
| """ | |
| Parses a basic auth header and returns a BasicAuthCredentials object. | |
| """ | |
| header = header.replace("Basic ", "") | |
| header = header.strip() | |
| base64_decoded = base64.b64decode(header).decode("utf-8") | |
| username, password = base64_decoded.split(":") | |
| return BasicAuthCredentials(SecretStr(username), SecretStr(password)) | |
| class ServerAuthCredentialsProvider(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | |
| ... | |
| def get_user_identity( | |
| self, credentials: AbstractCredentials[T] | |
| ) -> Optional[UserIdentity]: | |
| ... | |
| class AuthzResourceTypes(str, Enum): | |
| DB = "db" | |
| COLLECTION = "collection" | |
| TENANT = "tenant" | |
| class AuthzResourceActions(str, Enum): | |
| CREATE_DATABASE = "create_database" | |
| GET_DATABASE = "get_database" | |
| CREATE_TENANT = "create_tenant" | |
| GET_TENANT = "get_tenant" | |
| LIST_COLLECTIONS = "list_collections" | |
| COUNT_COLLECTIONS = "count_collections" | |
| GET_COLLECTION = "get_collection" | |
| CREATE_COLLECTION = "create_collection" | |
| GET_OR_CREATE_COLLECTION = "get_or_create_collection" | |
| DELETE_COLLECTION = "delete_collection" | |
| UPDATE_COLLECTION = "update_collection" | |
| ADD = "add" | |
| DELETE = "delete" | |
| GET = "get" | |
| QUERY = "query" | |
| COUNT = "count" | |
| UPDATE = "update" | |
| UPSERT = "upsert" | |
| RESET = "reset" | |
| class AuthzUser: | |
| id: Optional[str] | |
| tenant: Optional[str] = DEFAULT_TENANT | |
| attributes: Optional[Dict[str, Any]] = None | |
| claims: Optional[Dict[str, Any]] = None | |
| class AuthzResource: | |
| id: Optional[str] | |
| type: Optional[str] | |
| attributes: Optional[Dict[str, Any]] = None | |
| class DynamicAuthzResource: | |
| id: Optional[Union[str, Callable[..., str]]] | |
| type: Optional[Union[str, Callable[..., str]]] | |
| attributes: Optional[Union[Dict[str, Any], Callable[..., Dict[str, Any]]]] | |
| def __init__( | |
| self, | |
| id: Optional[Union[str, Callable[..., str]]] = None, | |
| attributes: Optional[ | |
| Union[Dict[str, Any], Callable[..., Dict[str, Any]]] | |
| ] = lambda **kwargs: {}, | |
| type: Optional[Union[str, Callable[..., str]]] = DEFAULT_DATABASE, | |
| ) -> None: | |
| self.id = id | |
| self.attributes = attributes | |
| self.type = type | |
| def to_authz_resource(self, **kwargs: Any) -> AuthzResource: | |
| return AuthzResource( | |
| id=self.id(**kwargs) if callable(self.id) else self.id, | |
| type=self.type(**kwargs) if callable(self.type) else self.type, | |
| attributes=self.attributes(**kwargs) | |
| if callable(self.attributes) | |
| else self.attributes, | |
| ) | |
| class AuthzDynamicParams: | |
| def from_function_name(**kwargs: Any) -> Callable[..., str]: | |
| return partial(lambda **kwargs: kwargs["function"].__name__, **kwargs) | |
| def from_function_args(**kwargs: Any) -> Callable[..., str]: | |
| return partial( | |
| lambda **kwargs: kwargs["function_args"][kwargs["arg_num"]], **kwargs | |
| ) | |
| def from_function_kwargs(**kwargs: Any) -> Callable[..., str]: | |
| return partial( | |
| lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs | |
| ) | |
| def dict_from_function_kwargs(**kwargs: Any) -> Callable[..., Dict[str, Any]]: | |
| return partial( | |
| lambda **kwargs: { | |
| k: kwargs["function_kwargs"][k] for k in kwargs["arg_names"] | |
| }, | |
| **kwargs, | |
| ) | |
| class AuthzAction: | |
| id: str | |
| attributes: Optional[Dict[str, Any]] = None | |
| class AuthorizationContext: | |
| user: AuthzUser | |
| resource: AuthzResource | |
| action: AuthzAction | |
| class ServerAuthorizationProvider(Component): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def authorize(self, context: AuthorizationContext) -> bool: | |
| pass | |
| class AuthorizationRequestContext(EnforceOverrides, ABC, Generic[T]): | |
| def get_request(self) -> T: | |
| ... | |
| class ChromaAuthzMiddleware(Component, Generic[T, S]): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def pre_process(self, request: AuthorizationRequestContext[S]) -> None: | |
| ... | |
| def ignore_operation(self, verb: str, path: str) -> bool: | |
| ... | |
| def instrument_server(self, app: T) -> None: | |
| ... | |
| class ServerAuthorizationConfigurationProvider(Component, Generic[T]): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| def get_configuration(self) -> T: | |
| pass | |