Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import string | |
| from enum import Enum | |
| from typing import List, Optional, Tuple, Any, TypedDict, cast, Dict, TypeVar | |
| from overrides import override | |
| from pydantic import SecretStr | |
| import yaml | |
| from chromadb.auth import ( | |
| ServerAuthProvider, | |
| ClientAuthProvider, | |
| ServerAuthenticationRequest, | |
| ServerAuthCredentialsProvider, | |
| AuthInfoType, | |
| ClientAuthCredentialsProvider, | |
| ClientAuthResponse, | |
| SecretStrAbstractCredentials, | |
| AbstractCredentials, | |
| SimpleServerAuthenticationResponse, | |
| SimpleUserIdentity, | |
| ) | |
| from chromadb.auth.registry import register_provider, resolve_provider | |
| from chromadb.config import System | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| from chromadb.utils import get_class | |
| T = TypeVar("T") | |
| logger = logging.getLogger(__name__) | |
| __all__ = ["TokenAuthServerProvider", "TokenAuthClientProvider"] | |
| _token_transport_headers = ["Authorization", "X-Chroma-Token"] | |
| class TokenTransportHeader(Enum): | |
| AUTHORIZATION = "Authorization" | |
| X_CHROMA_TOKEN = "X-Chroma-Token" | |
| class TokenAuthClientAuthResponse(ClientAuthResponse): | |
| _token_transport_header: TokenTransportHeader | |
| def __init__( | |
| self, | |
| credentials: SecretStr, | |
| token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION, | |
| ) -> None: | |
| self._credentials = credentials | |
| self._token_transport_header = token_transport_header | |
| def get_auth_info_type(self) -> AuthInfoType: | |
| return AuthInfoType.HEADER | |
| def get_auth_info(self) -> Tuple[str, SecretStr]: | |
| if self._token_transport_header == TokenTransportHeader.AUTHORIZATION: | |
| return "Authorization", SecretStr( | |
| f"Bearer {self._credentials.get_secret_value()}" | |
| ) | |
| elif self._token_transport_header == TokenTransportHeader.X_CHROMA_TOKEN: | |
| return "X-Chroma-Token", SecretStr( | |
| f"{self._credentials.get_secret_value()}" | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Invalid token transport header: {self._token_transport_header}" | |
| ) | |
| def check_token(token: str) -> None: | |
| token_str = str(token) | |
| if not all( | |
| c in string.digits + string.ascii_letters + string.punctuation | |
| for c in token_str | |
| ): | |
| raise ValueError("Invalid token. Must contain only ASCII letters and digits.") | |
| class TokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider): | |
| _token: SecretStr | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| system.settings.require("chroma_server_auth_credentials") | |
| token_str = str(system.settings.chroma_server_auth_credentials) | |
| check_token(token_str) | |
| self._token = SecretStr(token_str) | |
| def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | |
| _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
| if "token" not in _creds: | |
| logger.error("Returned credentials do not contain token") | |
| return False | |
| return _creds["token"].get_secret_value() == self._token.get_secret_value() | |
| def get_user_identity( | |
| self, credentials: AbstractCredentials[T] | |
| ) -> Optional[SimpleUserIdentity]: | |
| return None | |
| class Token(TypedDict): | |
| token: str | |
| secret: str | |
| class User(TypedDict): | |
| id: str | |
| role: str | |
| tenant: Optional[str] | |
| databases: Optional[List[str]] | |
| tokens: List[Token] | |
| class UserTokenConfigServerAuthCredentialsProvider(ServerAuthCredentialsProvider): | |
| _users: List[User] | |
| _token_user_mapping: Dict[str, str] # reverse mapping of token to user | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| if system.settings.chroma_server_auth_credentials_file: | |
| system.settings.require("chroma_server_auth_credentials_file") | |
| user_file = str(system.settings.chroma_server_auth_credentials_file) | |
| with open(user_file) as f: | |
| self._users = cast(List[User], yaml.safe_load(f)["users"]) | |
| elif system.settings.chroma_server_auth_credentials: | |
| self._users = cast( | |
| List[User], json.loads(system.settings.chroma_server_auth_credentials) | |
| ) | |
| self._token_user_mapping = {} | |
| for user in self._users: | |
| for t in user["tokens"]: | |
| token_str = t["token"] | |
| check_token(token_str) | |
| if token_str in self._token_user_mapping: | |
| raise ValueError("Token already exists for another user") | |
| self._token_user_mapping[token_str] = user["id"] | |
| def find_user_by_id(self, _user_id: str) -> Optional[User]: | |
| for user in self._users: | |
| if user["id"] == _user_id: | |
| return user | |
| return None | |
| def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | |
| _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
| if "token" not in _creds: | |
| logger.error("Returned credentials do not contain token") | |
| return False | |
| return _creds["token"].get_secret_value() in self._token_user_mapping.keys() | |
| def get_user_identity( | |
| self, credentials: AbstractCredentials[T] | |
| ) -> Optional[SimpleUserIdentity]: | |
| _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
| if "token" not in _creds: | |
| logger.error("Returned credentials do not contain token") | |
| return None | |
| # below is just simple identity mapping and may need future work for more | |
| # complex use cases | |
| _user_id = self._token_user_mapping[_creds["token"].get_secret_value()] | |
| _user = self.find_user_by_id(_user_id) | |
| return SimpleUserIdentity( | |
| user_id=_user_id, | |
| tenant=_user["tenant"] if _user and "tenant" in _user else "*", | |
| databases=_user["databases"] if _user and "databases" in _user else ["*"], | |
| ) | |
| class TokenAuthCredentials(SecretStrAbstractCredentials): | |
| _token: SecretStr | |
| def __init__(self, token: SecretStr) -> None: | |
| self._token = token | |
| def get_credentials(self) -> Dict[str, SecretStr]: | |
| return {"token": self._token} | |
| def from_header( | |
| header: str, | |
| token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION, | |
| ) -> "TokenAuthCredentials": | |
| """ | |
| Extracts token from header and returns a TokenAuthCredentials object. | |
| """ | |
| if token_transport_header == TokenTransportHeader.AUTHORIZATION: | |
| header = header.replace("Bearer ", "") | |
| header = header.strip() | |
| token = header | |
| elif token_transport_header == TokenTransportHeader.X_CHROMA_TOKEN: | |
| header = header.strip() | |
| token = header | |
| else: | |
| raise ValueError( | |
| f"Invalid token transport header: {token_transport_header}" | |
| ) | |
| return TokenAuthCredentials(SecretStr(token)) | |
| class TokenAuthServerProvider(ServerAuthProvider): | |
| _credentials_provider: ServerAuthCredentialsProvider | |
| _token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| self._settings = system.settings | |
| system.settings.require("chroma_server_auth_credentials_provider") | |
| self._credentials_provider = cast( | |
| ServerAuthCredentialsProvider, | |
| system.require( | |
| resolve_provider( | |
| str(system.settings.chroma_server_auth_credentials_provider), | |
| ServerAuthCredentialsProvider, | |
| ) | |
| ), | |
| ) | |
| if system.settings.chroma_server_auth_token_transport_header: | |
| self._token_transport_header = TokenTransportHeader[ | |
| str(system.settings.chroma_server_auth_token_transport_header) | |
| ] | |
| def authenticate( | |
| self, request: ServerAuthenticationRequest[Any] | |
| ) -> SimpleServerAuthenticationResponse: | |
| try: | |
| _auth_header = request.get_auth_info( | |
| AuthInfoType.HEADER, self._token_transport_header.value | |
| ) | |
| _token_creds = TokenAuthCredentials.from_header( | |
| _auth_header, self._token_transport_header | |
| ) | |
| return SimpleServerAuthenticationResponse( | |
| self._credentials_provider.validate_credentials(_token_creds), | |
| self._credentials_provider.get_user_identity(_token_creds), | |
| ) | |
| except Exception as e: | |
| logger.error(f"TokenAuthServerProvider.authenticate failed: {repr(e)}") | |
| return SimpleServerAuthenticationResponse(False, None) | |
| class TokenAuthClientProvider(ClientAuthProvider): | |
| _credentials_provider: ClientAuthCredentialsProvider[Any] | |
| _token_transport_header: TokenTransportHeader = TokenTransportHeader.AUTHORIZATION | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| self._settings = system.settings | |
| system.settings.require("chroma_client_auth_credentials_provider") | |
| self._credentials_provider = system.require( | |
| get_class( | |
| str(system.settings.chroma_client_auth_credentials_provider), | |
| ClientAuthCredentialsProvider, | |
| ) | |
| ) | |
| _token = self._credentials_provider.get_credentials() | |
| check_token(_token.get_secret_value()) | |
| if system.settings.chroma_client_auth_token_transport_header: | |
| self._token_transport_header = TokenTransportHeader[ | |
| str(system.settings.chroma_client_auth_token_transport_header) | |
| ] | |
| def authenticate(self) -> ClientAuthResponse: | |
| _token = self._credentials_provider.get_credentials() | |
| return TokenAuthClientAuthResponse(_token, self._token_transport_header) | |