Kris8an's picture
Upload folder using huggingface_hub
a06facb verified
import asyncio
import logging
from datetime import timedelta
import dateutil.parser
from botocore.compat import total_seconds
from botocore.exceptions import ClientError, TokenRetrievalError
from botocore.tokens import (
DeferredRefreshableToken,
FrozenAuthToken,
SSOTokenProvider,
TokenProviderChain,
_utc_now,
)
logger = logging.getLogger(__name__)
def create_token_resolver(session):
providers = [
AioSSOTokenProvider(session),
]
return TokenProviderChain(providers=providers)
class AioDeferredRefreshableToken(DeferredRefreshableToken):
def __init__(
self, method, refresh_using, time_fetcher=_utc_now
): # noqa: E501, lgtm [py/missing-call-to-init]
self._time_fetcher = time_fetcher
self._refresh_using = refresh_using
self.method = method
# The frozen token is protected by this lock
self._refresh_lock = asyncio.Lock()
self._frozen_token = None
self._next_refresh = None
async def get_frozen_token(self):
await self._refresh()
return self._frozen_token
async def _refresh(self):
# If we don't need to refresh just return
refresh_type = self._should_refresh()
if not refresh_type:
return None
# Block for refresh if we're in the mandatory refresh window
block_for_refresh = refresh_type == "mandatory"
if block_for_refresh or not self._refresh_lock.locked():
async with self._refresh_lock:
await self._protected_refresh()
async def _protected_refresh(self):
# This should only be called after acquiring the refresh lock
# Another task may have already refreshed, double check refresh
refresh_type = self._should_refresh()
if not refresh_type:
return None
try:
now = self._time_fetcher()
self._next_refresh = now + timedelta(seconds=self._attempt_timeout)
self._frozen_token = await self._refresh_using()
except Exception:
logger.warning(
"Refreshing token failed during the %s refresh period.",
refresh_type,
exc_info=True,
)
if refresh_type == "mandatory":
# This refresh was mandatory, error must be propagated back
raise
if self._is_expired():
# Fresh credentials should never be expired
raise TokenRetrievalError(
provider=self.method,
error_msg="Token has expired and refresh failed",
)
class AioSSOTokenProvider(SSOTokenProvider):
async def _attempt_create_token(self, token):
async with self._client as client:
response = await client.create_token(
grantType=self._GRANT_TYPE,
clientId=token["clientId"],
clientSecret=token["clientSecret"],
refreshToken=token["refreshToken"],
)
expires_in = timedelta(seconds=response["expiresIn"])
new_token = {
"startUrl": self._sso_config["sso_start_url"],
"region": self._sso_config["sso_region"],
"accessToken": response["accessToken"],
"expiresAt": self._now() + expires_in,
# Cache the registration alongside the token
"clientId": token["clientId"],
"clientSecret": token["clientSecret"],
"registrationExpiresAt": token["registrationExpiresAt"],
}
if "refreshToken" in response:
new_token["refreshToken"] = response["refreshToken"]
logger.info("SSO Token refresh succeeded")
return new_token
async def _refresh_access_token(self, token):
keys = (
"refreshToken",
"clientId",
"clientSecret",
"registrationExpiresAt",
)
missing_keys = [k for k in keys if k not in token]
if missing_keys:
msg = f"Unable to refresh SSO token: missing keys: {missing_keys}"
logger.info(msg)
return None
expiry = dateutil.parser.parse(token["registrationExpiresAt"])
if total_seconds(expiry - self._now()) <= 0:
logger.info(f"SSO token registration expired at {expiry}")
return None
try:
return await self._attempt_create_token(token)
except ClientError:
logger.warning("SSO token refresh attempt failed", exc_info=True)
return None
async def _refresher(self):
start_url = self._sso_config["sso_start_url"]
session_name = self._sso_config["session_name"]
logger.info(f"Loading cached SSO token for {session_name}")
token_dict = self._token_loader(start_url, session_name=session_name)
expiration = dateutil.parser.parse(token_dict["expiresAt"])
logger.debug(f"Cached SSO token expires at {expiration}")
remaining = total_seconds(expiration - self._now())
if remaining < self._REFRESH_WINDOW:
new_token_dict = await self._refresh_access_token(token_dict)
if new_token_dict is not None:
token_dict = new_token_dict
expiration = token_dict["expiresAt"]
self._token_loader.save_token(
start_url, token_dict, session_name=session_name
)
return FrozenAuthToken(
token_dict["accessToken"], expiration=expiration
)
def load_token(self):
if self._sso_config is None:
return None
return AioDeferredRefreshableToken(
self.METHOD, self._refresher, time_fetcher=self._now
)