Spaces:
Runtime error
Runtime error
| import chromadb | |
| from contextvars import ContextVar | |
| from functools import wraps | |
| import logging | |
| from typing import Callable, Optional, Dict, List, Union, cast, Any | |
| from overrides import override | |
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
| from starlette.requests import Request | |
| from starlette.responses import Response | |
| from starlette.types import ASGIApp | |
| from chromadb.config import DEFAULT_TENANT, System | |
| from chromadb.auth import ( | |
| AuthorizationContext, | |
| AuthorizationRequestContext, | |
| AuthzAction, | |
| AuthzResource, | |
| AuthzResourceActions, | |
| AuthzUser, | |
| DynamicAuthzResource, | |
| ServerAuthenticationRequest, | |
| AuthInfoType, | |
| ServerAuthenticationResponse, | |
| ServerAuthProvider, | |
| ChromaAuthMiddleware, | |
| ChromaAuthzMiddleware, | |
| ServerAuthorizationProvider, | |
| ) | |
| from chromadb.auth.registry import resolve_provider | |
| from chromadb.errors import AuthorizationError | |
| from chromadb.server.fastapi.utils import fastapi_json_response | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class FastAPIServerAuthenticationRequest(ServerAuthenticationRequest[Optional[str]]): | |
| def __init__(self, request: Request) -> None: | |
| self._request = request | |
| def get_auth_info( | |
| self, auth_info_type: AuthInfoType, auth_info_id: str | |
| ) -> Optional[str]: | |
| if auth_info_type == AuthInfoType.HEADER: | |
| return str(self._request.headers[auth_info_id]) | |
| elif auth_info_type == AuthInfoType.COOKIE: | |
| return str(self._request.cookies[auth_info_id]) | |
| elif auth_info_type == AuthInfoType.URL: | |
| return str(self._request.query_params[auth_info_id]) | |
| elif auth_info_type == AuthInfoType.METADATA: | |
| raise ValueError("Metadata not supported for FastAPI") | |
| else: | |
| raise ValueError(f"Unknown auth info type: {auth_info_type}") | |
| class FastAPIServerAuthenticationResponse(ServerAuthenticationResponse): | |
| _auth_success: bool | |
| def __init__(self, auth_success: bool) -> None: | |
| self._auth_success = auth_success | |
| def success(self) -> bool: | |
| return self._auth_success | |
| class FastAPIChromaAuthMiddleware(ChromaAuthMiddleware): | |
| _auth_provider: ServerAuthProvider | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| self._system = system | |
| self._settings = system.settings | |
| self._settings.require("chroma_server_auth_provider") | |
| self._ignore_auth_paths: Dict[ | |
| str, List[str] | |
| ] = self._settings.chroma_server_auth_ignore_paths | |
| if self._settings.chroma_server_auth_provider: | |
| logger.debug( | |
| f"Server Auth Provider: {self._settings.chroma_server_auth_provider}" | |
| ) | |
| _cls = resolve_provider( | |
| self._settings.chroma_server_auth_provider, ServerAuthProvider | |
| ) | |
| self._auth_provider = cast(ServerAuthProvider, self.require(_cls)) | |
| def authenticate( | |
| self, request: ServerAuthenticationRequest[Any] | |
| ) -> ServerAuthenticationResponse: | |
| return self._auth_provider.authenticate(request) | |
| def ignore_operation(self, verb: str, path: str) -> bool: | |
| if ( | |
| path in self._ignore_auth_paths.keys() | |
| and verb.upper() in self._ignore_auth_paths[path] | |
| ): | |
| logger.debug(f"Skipping auth for path {path} and method {verb}") | |
| return True | |
| return False | |
| def instrument_server(self, app: ASGIApp) -> None: | |
| # We can potentially add an `/auth` endpoint to the server to allow for more | |
| # complex auth flows | |
| raise NotImplementedError("Not implemented yet") | |
| class FastAPIChromaAuthMiddlewareWrapper(BaseHTTPMiddleware): # type: ignore | |
| def __init__( | |
| self, app: ASGIApp, auth_middleware: FastAPIChromaAuthMiddleware | |
| ) -> None: | |
| super().__init__(app) | |
| self._middleware = auth_middleware | |
| try: | |
| self._middleware.instrument_server(app) | |
| except NotImplementedError: | |
| pass | |
| async def dispatch( | |
| self, request: Request, call_next: RequestResponseEndpoint | |
| ) -> Response: | |
| if self._middleware.ignore_operation(request.method, request.url.path): | |
| logger.debug( | |
| f"Skipping auth for path {request.url.path} and method {request.method}" | |
| ) | |
| return await call_next(request) | |
| response = self._middleware.authenticate( | |
| FastAPIServerAuthenticationRequest(request) | |
| ) | |
| if not response or not response.success(): | |
| return fastapi_json_response(AuthorizationError("Unauthorized")) | |
| request.state.user_identity = response.get_user_identity() | |
| return await call_next(request) | |
| request_var: ContextVar[Optional[Request]] = ContextVar("request_var", default=None) | |
| authz_provider: ContextVar[Optional[ServerAuthorizationProvider]] = ContextVar( | |
| "authz_provider", default=None | |
| ) | |
| # This needs to be module-level config, since it's used in authz_context() where we | |
| # don't have a system (so don't have easy access to the settings). | |
| overwrite_singleton_tenant_database_access_from_auth: bool = False | |
| def set_overwrite_singleton_tenant_database_access_from_auth( | |
| overwrite: bool = False, | |
| ) -> None: | |
| global overwrite_singleton_tenant_database_access_from_auth | |
| overwrite_singleton_tenant_database_access_from_auth = overwrite | |
| def authz_context( | |
| action: Union[str, AuthzResourceActions, List[str], List[AuthzResourceActions]], | |
| resource: Union[AuthzResource, DynamicAuthzResource], | |
| ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | |
| def decorator(f: Callable[..., Any]) -> Callable[..., Any]: | |
| def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any: | |
| _dynamic_kwargs = { | |
| "api": args[0]._api, | |
| "function": f, | |
| "function_args": args, | |
| "function_kwargs": kwargs, | |
| } | |
| request = request_var.get() | |
| if request: | |
| _provider = authz_provider.get() | |
| a_list: List[Union[str, AuthzAction]] = [] | |
| if not isinstance(action, list): | |
| a_list = [action] | |
| else: | |
| a_list = cast(List[Union[str, AuthzAction]], action) | |
| a_authz_responses = [] | |
| for a in a_list: | |
| _action = a if isinstance(a, AuthzAction) else AuthzAction(id=a) | |
| _resource = ( | |
| resource | |
| if isinstance(resource, AuthzResource) | |
| else resource.to_authz_resource(**_dynamic_kwargs) | |
| ) | |
| _context = AuthorizationContext( | |
| user=AuthzUser( | |
| id=request.state.user_identity.get_user_id() | |
| if hasattr(request.state, "user_identity") | |
| else "Anonymous", | |
| tenant=request.state.user_identity.get_user_tenant() | |
| if hasattr(request.state, "user_identity") | |
| else DEFAULT_TENANT, | |
| attributes=request.state.user_identity.get_user_attributes() | |
| if hasattr(request.state, "user_identity") | |
| else {}, | |
| ), | |
| resource=_resource, | |
| action=_action, | |
| ) | |
| if _provider: | |
| a_authz_responses.append(_provider.authorize(_context)) | |
| if not any(a_authz_responses): | |
| raise AuthorizationError("Unauthorized") | |
| # In a multi-tenant environment, we may want to allow users to send | |
| # requests without configuring a tenant and DB. If so, they can set | |
| # the request tenant and DB however they like and we simply overwrite it. | |
| if overwrite_singleton_tenant_database_access_from_auth: | |
| desired_tenant = request.state.user_identity.get_user_tenant() | |
| if desired_tenant and "tenant" in kwargs: | |
| if isinstance(kwargs["tenant"], str): | |
| kwargs["tenant"] = desired_tenant | |
| elif isinstance( | |
| kwargs["tenant"], chromadb.server.fastapi.types.CreateTenant | |
| ): | |
| kwargs["tenant"].name = desired_tenant | |
| databases = request.state.user_identity.get_user_databases() | |
| if databases and len(databases) == 1 and "database" in kwargs: | |
| desired_database = databases[0] | |
| if isinstance(kwargs["database"], str): | |
| kwargs["database"] = desired_database | |
| elif isinstance( | |
| kwargs["database"], | |
| chromadb.server.fastapi.types.CreateDatabase, | |
| ): | |
| kwargs["database"].name = desired_database | |
| return f(*args, **kwargs) | |
| return wrapped | |
| return decorator | |
| class FastAPIAuthorizationRequestContext(AuthorizationRequestContext[Request]): | |
| _request: Request | |
| def __init__(self, request: Request) -> None: | |
| self._request = request | |
| pass | |
| def get_request(self) -> Request: | |
| return self._request | |
| class FastAPIChromaAuthzMiddleware(ChromaAuthzMiddleware[ASGIApp, Request]): | |
| _authz_provider: ServerAuthorizationProvider | |
| def __init__(self, system: System) -> None: | |
| super().__init__(system) | |
| self._system = system | |
| self._settings = system.settings | |
| self._settings.require("chroma_server_authz_provider") | |
| self._ignore_auth_paths: Dict[ | |
| str, List[str] | |
| ] = self._settings.chroma_server_authz_ignore_paths | |
| if self._settings.chroma_server_authz_provider: | |
| logger.debug( | |
| "Server Authorization Provider: " | |
| f"{self._settings.chroma_server_authz_provider}" | |
| ) | |
| _cls = resolve_provider( | |
| self._settings.chroma_server_authz_provider, ServerAuthorizationProvider | |
| ) | |
| self._authz_provider = cast(ServerAuthorizationProvider, self.require(_cls)) | |
| def pre_process(self, request: AuthorizationRequestContext[Request]) -> None: | |
| rest_request = request.get_request() | |
| request_var.set(rest_request) | |
| authz_provider.set(self._authz_provider) | |
| def ignore_operation(self, verb: str, path: str) -> bool: | |
| if ( | |
| path in self._ignore_auth_paths.keys() | |
| and verb.upper() in self._ignore_auth_paths[path] | |
| ): | |
| logger.debug(f"Skipping authz for path {path} and method {verb}") | |
| return True | |
| return False | |
| def instrument_server(self, app: ASGIApp) -> None: | |
| # We can potentially add an `/auth` endpoint to the server to allow | |
| # for more complex auth flows | |
| raise NotImplementedError("Not implemented yet") | |
| class FastAPIChromaAuthzMiddlewareWrapper(BaseHTTPMiddleware): # type: ignore | |
| def __init__( | |
| self, app: ASGIApp, authz_middleware: FastAPIChromaAuthzMiddleware | |
| ) -> None: | |
| super().__init__(app) | |
| self._middleware = authz_middleware | |
| try: | |
| self._middleware.instrument_server(app) | |
| except NotImplementedError: | |
| pass | |
| async def dispatch( | |
| self, request: Request, call_next: RequestResponseEndpoint | |
| ) -> Response: | |
| if self._middleware.ignore_operation(request.method, request.url.path): | |
| logger.debug( | |
| f"Skipping authz for path {request.url.path} " | |
| "and method {request.method}" | |
| ) | |
| return await call_next(request) | |
| self._middleware.pre_process(FastAPIAuthorizationRequestContext(request)) | |
| return await call_next(request) | |