Spaces:
Paused
Paused
| from starlette.datastructures import URL | |
| from starlette.responses import RedirectResponse | |
| from ..base_client import OAuthError | |
| from ..base_client import BaseApp | |
| from ..base_client.async_app import AsyncOAuth1Mixin, AsyncOAuth2Mixin | |
| from ..base_client.async_openid import AsyncOpenIDMixin | |
| from ..httpx_client import AsyncOAuth1Client, AsyncOAuth2Client | |
| class StarletteAppMixin: | |
| async def save_authorize_data(self, request, **kwargs): | |
| state = kwargs.pop('state', None) | |
| if state: | |
| if self.framework.cache: | |
| session = None | |
| else: | |
| session = request.session | |
| await self.framework.set_state_data(session, state, kwargs) | |
| else: | |
| raise RuntimeError('Missing state value') | |
| async def authorize_redirect(self, request, redirect_uri=None, **kwargs): | |
| """Create a HTTP Redirect for Authorization Endpoint. | |
| :param request: HTTP request instance from Starlette view. | |
| :param redirect_uri: Callback or redirect URI for authorization. | |
| :param kwargs: Extra parameters to include. | |
| :return: A HTTP redirect response. | |
| """ | |
| # Handle Starlette >= 0.26.0 where redirect_uri may now be a URL and not a string | |
| if redirect_uri and isinstance(redirect_uri, URL): | |
| redirect_uri = str(redirect_uri) | |
| rv = await self.create_authorization_url(redirect_uri, **kwargs) | |
| await self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) | |
| return RedirectResponse(rv['url'], status_code=302) | |
| class StarletteOAuth1App(StarletteAppMixin, AsyncOAuth1Mixin, BaseApp): | |
| client_cls = AsyncOAuth1Client | |
| async def authorize_access_token(self, request, **kwargs): | |
| params = dict(request.query_params) | |
| state = params.get('oauth_token') | |
| if not state: | |
| raise OAuthError(description='Missing "oauth_token" parameter') | |
| data = await self.framework.get_state_data(request.session, state) | |
| if not data: | |
| raise OAuthError(description='Missing "request_token" in temporary data') | |
| params['request_token'] = data['request_token'] | |
| params.update(kwargs) | |
| await self.framework.clear_state_data(request.session, state) | |
| return await self.fetch_access_token(**params) | |
| class StarletteOAuth2App(StarletteAppMixin, AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): | |
| client_cls = AsyncOAuth2Client | |
| async def authorize_access_token(self, request, **kwargs): | |
| error = request.query_params.get('error') | |
| if error: | |
| description = request.query_params.get('error_description') | |
| raise OAuthError(error=error, description=description) | |
| params = { | |
| 'code': request.query_params.get('code'), | |
| 'state': request.query_params.get('state'), | |
| } | |
| if self.framework.cache: | |
| session = None | |
| else: | |
| session = request.session | |
| claims_options = kwargs.pop('claims_options', None) | |
| state_data = await self.framework.get_state_data(session, params.get('state')) | |
| await self.framework.clear_state_data(session, params.get('state')) | |
| params = self._format_state_params(state_data, params) | |
| token = await self.fetch_access_token(**params, **kwargs) | |
| if 'id_token' in token and 'nonce' in state_data: | |
| userinfo = await self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options) | |
| token['userinfo'] = userinfo | |
| return token | |