Spaces:
Runtime error
Runtime error
| import importlib | |
| import logging | |
| from typing import Optional, cast, Dict, TypeVar, Any | |
| import requests | |
| from overrides import override | |
| from pydantic import SecretStr | |
| from chromadb.auth import ( | |
| ServerAuthCredentialsProvider, | |
| AbstractCredentials, | |
| ClientAuthCredentialsProvider, | |
| AuthInfoType, | |
| ClientAuthProvider, | |
| ClientAuthProtocolAdapter, | |
| SimpleUserIdentity, | |
| ) | |
| from chromadb.auth.registry import register_provider, resolve_provider | |
| from chromadb.config import System | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| T = TypeVar("T") | |
| logger = logging.getLogger(__name__) | |
| class HtpasswdServerAuthCredentialsProvider(ServerAuthCredentialsProvider): | |
| _creds: Dict[str, SecretStr] | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| try: | |
| # Equivalent to import onnxruntime | |
| self.bc = importlib.import_module("bcrypt") | |
| except ImportError: | |
| raise ValueError( | |
| "The bcrypt python package is not installed. " | |
| "Please install it with `pip install bcrypt`" | |
| ) | |
| def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: | |
| _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
| if len(_creds) != 2: | |
| logger.error( | |
| "Returned credentials did match expected format: " | |
| "dict[username:SecretStr, password: SecretStr]" | |
| ) | |
| return False | |
| if "username" not in _creds or "password" not in _creds: | |
| logger.error( | |
| "Returned credentials do not contain username or password") | |
| return False | |
| _usr_check = bool( | |
| _creds["username"].get_secret_value() | |
| == self._creds["username"].get_secret_value() | |
| ) | |
| return _usr_check and self.bc.checkpw( | |
| _creds["password"].get_secret_value().encode("utf-8"), | |
| self._creds["password"].get_secret_value().encode("utf-8"), | |
| ) | |
| def get_user_identity( | |
| self, credentials: AbstractCredentials[T] | |
| ) -> Optional[SimpleUserIdentity]: | |
| _creds = cast(Dict[str, SecretStr], credentials.get_credentials()) | |
| return SimpleUserIdentity(_creds["username"].get_secret_value()) | |
| class HtpasswdFileServerAuthCredentialsProvider(HtpasswdServerAuthCredentialsProvider): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| system.settings.require("chroma_server_auth_credentials_file") | |
| _file = str(system.settings.chroma_server_auth_credentials_file) | |
| with open(_file, "r") as f: | |
| _raw_creds = [v for v in f.readline().strip().split(":")] | |
| self._creds = { | |
| "username": SecretStr(_raw_creds[0]), | |
| "password": SecretStr(_raw_creds[1]), | |
| } | |
| if ( | |
| len(self._creds) != 2 | |
| or "username" not in self._creds | |
| or "password" not in self._creds | |
| ): | |
| raise ValueError( | |
| "Invalid Htpasswd credentials found in " | |
| "[chroma_server_auth_credentials]. " | |
| "Must be <username>:<bcrypt passwd>." | |
| ) | |
| class HtpasswdConfigurationServerAuthCredentialsProvider( | |
| HtpasswdServerAuthCredentialsProvider | |
| ): | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| system.settings.require("chroma_server_auth_credentials") | |
| _raw_creds = ( | |
| str(system.settings.chroma_server_auth_credentials).strip().split(":") | |
| ) | |
| self._creds = { | |
| "username": SecretStr(_raw_creds[0]), | |
| "password": SecretStr(_raw_creds[1]), | |
| } | |
| if ( | |
| len(self._creds) != 2 | |
| or "username" not in self._creds | |
| or "password" not in self._creds | |
| ): | |
| raise ValueError( | |
| "Invalid Htpasswd credentials found in " | |
| "[chroma_server_auth_credentials]. " | |
| "Must be <username>:<bcrypt passwd>." | |
| ) | |
| class RequestsClientAuthProtocolAdapter( | |
| ClientAuthProtocolAdapter[requests.PreparedRequest] | |
| ): | |
| class _Session(requests.Session): | |
| _protocol_adapter: ClientAuthProtocolAdapter[requests.PreparedRequest] | |
| def __init__( | |
| self, protocol_adapter: ClientAuthProtocolAdapter[requests.PreparedRequest] | |
| ) -> None: | |
| super().__init__() | |
| self._protocol_adapter = protocol_adapter | |
| def send( | |
| self, request: requests.PreparedRequest, **kwargs: Any | |
| ) -> requests.Response: | |
| self._protocol_adapter.inject_credentials(request) | |
| return super().send(request, **kwargs) | |
| _session: _Session | |
| _auth_provider: ClientAuthProvider | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| system.settings.require("chroma_client_auth_provider") | |
| self._auth_provider = cast( | |
| ClientAuthProvider, | |
| system.require( | |
| resolve_provider( | |
| str(system.settings.chroma_client_auth_provider), ClientAuthProvider | |
| ), | |
| ), | |
| ) | |
| self._session = self._Session(self) | |
| self._auth_header = self._auth_provider.authenticate() | |
| def session(self) -> requests.Session: | |
| return self._session | |
| def inject_credentials(self, injection_context: requests.PreparedRequest) -> None: | |
| if self._auth_header.get_auth_info_type() == AuthInfoType.HEADER: | |
| _header_info = self._auth_header.get_auth_info() | |
| if isinstance(_header_info, tuple): | |
| injection_context.headers[_header_info[0]] = _header_info[ | |
| 1 | |
| ].get_secret_value() | |
| else: | |
| for header in _header_info: | |
| injection_context.headers[header[0] | |
| ] = header[1].get_secret_value() | |
| else: | |
| raise ValueError( | |
| f"Unsupported auth type: {self._auth_header.get_auth_info_type()}" | |
| ) | |
| class ConfigurationClientAuthCredentialsProvider( | |
| ClientAuthCredentialsProvider[SecretStr] | |
| ): | |
| _creds: SecretStr | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| system.settings.require("chroma_client_auth_credentials") | |
| self._creds = SecretStr( | |
| str(system.settings.chroma_client_auth_credentials)) | |
| def get_credentials(self) -> SecretStr: | |
| return self._creds | |