|
|
import asyncio |
|
|
import logging |
|
|
from datetime import timedelta |
|
|
|
|
|
import dateutil.parser |
|
|
from botocore.compat import total_seconds |
|
|
from botocore.exceptions import ClientError, TokenRetrievalError |
|
|
from botocore.tokens import ( |
|
|
DeferredRefreshableToken, |
|
|
FrozenAuthToken, |
|
|
SSOTokenProvider, |
|
|
TokenProviderChain, |
|
|
_utc_now, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def create_token_resolver(session): |
|
|
providers = [ |
|
|
AioSSOTokenProvider(session), |
|
|
] |
|
|
return TokenProviderChain(providers=providers) |
|
|
|
|
|
|
|
|
class AioDeferredRefreshableToken(DeferredRefreshableToken): |
|
|
def __init__( |
|
|
self, method, refresh_using, time_fetcher=_utc_now |
|
|
): |
|
|
self._time_fetcher = time_fetcher |
|
|
self._refresh_using = refresh_using |
|
|
self.method = method |
|
|
|
|
|
|
|
|
self._refresh_lock = asyncio.Lock() |
|
|
self._frozen_token = None |
|
|
self._next_refresh = None |
|
|
|
|
|
async def get_frozen_token(self): |
|
|
await self._refresh() |
|
|
return self._frozen_token |
|
|
|
|
|
async def _refresh(self): |
|
|
|
|
|
refresh_type = self._should_refresh() |
|
|
if not refresh_type: |
|
|
return None |
|
|
|
|
|
|
|
|
block_for_refresh = refresh_type == "mandatory" |
|
|
if block_for_refresh or not self._refresh_lock.locked(): |
|
|
async with self._refresh_lock: |
|
|
await self._protected_refresh() |
|
|
|
|
|
async def _protected_refresh(self): |
|
|
|
|
|
|
|
|
refresh_type = self._should_refresh() |
|
|
if not refresh_type: |
|
|
return None |
|
|
|
|
|
try: |
|
|
now = self._time_fetcher() |
|
|
self._next_refresh = now + timedelta(seconds=self._attempt_timeout) |
|
|
self._frozen_token = await self._refresh_using() |
|
|
except Exception: |
|
|
logger.warning( |
|
|
"Refreshing token failed during the %s refresh period.", |
|
|
refresh_type, |
|
|
exc_info=True, |
|
|
) |
|
|
if refresh_type == "mandatory": |
|
|
|
|
|
raise |
|
|
|
|
|
if self._is_expired(): |
|
|
|
|
|
raise TokenRetrievalError( |
|
|
provider=self.method, |
|
|
error_msg="Token has expired and refresh failed", |
|
|
) |
|
|
|
|
|
|
|
|
class AioSSOTokenProvider(SSOTokenProvider): |
|
|
async def _attempt_create_token(self, token): |
|
|
async with self._client as client: |
|
|
response = await client.create_token( |
|
|
grantType=self._GRANT_TYPE, |
|
|
clientId=token["clientId"], |
|
|
clientSecret=token["clientSecret"], |
|
|
refreshToken=token["refreshToken"], |
|
|
) |
|
|
expires_in = timedelta(seconds=response["expiresIn"]) |
|
|
new_token = { |
|
|
"startUrl": self._sso_config["sso_start_url"], |
|
|
"region": self._sso_config["sso_region"], |
|
|
"accessToken": response["accessToken"], |
|
|
"expiresAt": self._now() + expires_in, |
|
|
|
|
|
"clientId": token["clientId"], |
|
|
"clientSecret": token["clientSecret"], |
|
|
"registrationExpiresAt": token["registrationExpiresAt"], |
|
|
} |
|
|
if "refreshToken" in response: |
|
|
new_token["refreshToken"] = response["refreshToken"] |
|
|
logger.info("SSO Token refresh succeeded") |
|
|
return new_token |
|
|
|
|
|
async def _refresh_access_token(self, token): |
|
|
keys = ( |
|
|
"refreshToken", |
|
|
"clientId", |
|
|
"clientSecret", |
|
|
"registrationExpiresAt", |
|
|
) |
|
|
missing_keys = [k for k in keys if k not in token] |
|
|
if missing_keys: |
|
|
msg = f"Unable to refresh SSO token: missing keys: {missing_keys}" |
|
|
logger.info(msg) |
|
|
return None |
|
|
|
|
|
expiry = dateutil.parser.parse(token["registrationExpiresAt"]) |
|
|
if total_seconds(expiry - self._now()) <= 0: |
|
|
logger.info(f"SSO token registration expired at {expiry}") |
|
|
return None |
|
|
|
|
|
try: |
|
|
return await self._attempt_create_token(token) |
|
|
except ClientError: |
|
|
logger.warning("SSO token refresh attempt failed", exc_info=True) |
|
|
return None |
|
|
|
|
|
async def _refresher(self): |
|
|
start_url = self._sso_config["sso_start_url"] |
|
|
session_name = self._sso_config["session_name"] |
|
|
logger.info(f"Loading cached SSO token for {session_name}") |
|
|
token_dict = self._token_loader(start_url, session_name=session_name) |
|
|
expiration = dateutil.parser.parse(token_dict["expiresAt"]) |
|
|
logger.debug(f"Cached SSO token expires at {expiration}") |
|
|
|
|
|
remaining = total_seconds(expiration - self._now()) |
|
|
if remaining < self._REFRESH_WINDOW: |
|
|
new_token_dict = await self._refresh_access_token(token_dict) |
|
|
if new_token_dict is not None: |
|
|
token_dict = new_token_dict |
|
|
expiration = token_dict["expiresAt"] |
|
|
self._token_loader.save_token( |
|
|
start_url, token_dict, session_name=session_name |
|
|
) |
|
|
|
|
|
return FrozenAuthToken( |
|
|
token_dict["accessToken"], expiration=expiration |
|
|
) |
|
|
|
|
|
def load_token(self): |
|
|
if self._sso_config is None: |
|
|
return None |
|
|
|
|
|
return AioDeferredRefreshableToken( |
|
|
self.METHOD, self._refresher, time_fetcher=self._now |
|
|
) |
|
|
|