Kris8an's picture
Upload folder using huggingface_hub
a06facb verified
from botocore import UNSIGNED, translate
from botocore.exceptions import PartialCredentialsError
from botocore.session import (
EVENT_ALIASES,
ServiceModel,
Session,
UnknownServiceError,
copy,
)
from . import retryhandler
from .client import AioBaseClient, AioClientCreator
from .configprovider import AioSmartDefaultsConfigStoreFactory
from .credentials import AioCredentials, create_credential_resolver
from .hooks import AioHierarchicalEmitter
from .parsers import AioResponseParserFactory
from .tokens import create_token_resolver
from .utils import AioIMDSRegionProvider
class ClientCreatorContext:
def __init__(self, coro):
self._coro = coro
self._client = None
async def __aenter__(self) -> AioBaseClient:
self._client = await self._coro
return await self._client.__aenter__()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._client.__aexit__(exc_type, exc_val, exc_tb)
class AioSession(Session):
# noinspection PyMissingConstructor
def __init__(
self,
session_vars=None,
event_hooks=None,
include_builtin_handlers=True,
profile=None,
):
if event_hooks is None:
event_hooks = AioHierarchicalEmitter()
super().__init__(
session_vars, event_hooks, include_builtin_handlers, profile
)
def _create_token_resolver(self):
return create_token_resolver(self)
def _create_credential_resolver(self):
return create_credential_resolver(
self, region_name=self._last_client_region_used
)
def _register_smart_defaults_factory(self):
def create_smart_defaults_factory():
default_config_resolver = self._get_internal_component(
'default_config_resolver'
)
imds_region_provider = AioIMDSRegionProvider(session=self)
return AioSmartDefaultsConfigStoreFactory(
default_config_resolver, imds_region_provider
)
self._internal_components.lazy_register_component(
'smart_defaults_factory', create_smart_defaults_factory
)
def _register_response_parser_factory(self):
self._components.register_component(
'response_parser_factory', AioResponseParserFactory()
)
def set_credentials(self, access_key, secret_key, token=None):
self._credentials = AioCredentials(access_key, secret_key, token)
async def get_credentials(self):
if self._credentials is None:
self._credentials = await (
self._components.get_component(
'credential_provider'
).load_credentials()
)
return self._credentials
async def get_service_model(self, service_name, api_version=None):
service_description = await self.get_service_data(
service_name, api_version
)
return ServiceModel(service_description, service_name=service_name)
async def get_service_data(self, service_name, api_version=None):
"""
Retrieve the fully merged data associated with a service.
"""
data_path = service_name
service_data = self.get_component('data_loader').load_service_model(
data_path, type_name='service-2', api_version=api_version
)
service_id = EVENT_ALIASES.get(service_name, service_name)
await self._events.emit(
'service-data-loaded.%s' % service_id,
service_data=service_data,
service_name=service_name,
session=self,
)
return service_data
def create_client(self, *args, **kwargs):
return ClientCreatorContext(self._create_client(*args, **kwargs))
async def _create_client(
self,
service_name,
region_name=None,
api_version=None,
use_ssl=True,
verify=None,
endpoint_url=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
config=None,
):
default_client_config = self.get_default_client_config()
# If a config is provided and a default config is set, then
# use the config resulting from merging the two.
if config is not None and default_client_config is not None:
config = default_client_config.merge(config)
# If a config was not provided then use the default
# client config from the session
elif default_client_config is not None:
config = default_client_config
region_name = self._resolve_region_name(region_name, config)
# Figure out the verify value base on the various
# configuration options.
if verify is None:
verify = self.get_config_variable('ca_bundle')
if api_version is None:
api_version = self.get_config_variable('api_versions').get(
service_name, None
)
loader = self.get_component('data_loader')
event_emitter = self.get_component('event_emitter')
response_parser_factory = self.get_component('response_parser_factory')
if config is not None and config.signature_version is UNSIGNED:
credentials = None
elif (
aws_access_key_id is not None and aws_secret_access_key is not None
):
credentials = AioCredentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
elif self._missing_cred_vars(aws_access_key_id, aws_secret_access_key):
raise PartialCredentialsError(
provider='explicit',
cred_var=self._missing_cred_vars(
aws_access_key_id, aws_secret_access_key
),
)
else:
credentials = await self.get_credentials()
auth_token = self.get_auth_token()
endpoint_resolver = self._get_internal_component('endpoint_resolver')
exceptions_factory = self._get_internal_component('exceptions_factory')
config_store = copy.copy(self.get_component('config_store'))
user_agent_creator = self.get_component('user_agent_creator')
# Session configuration values for the user agent string are applied
# just before each client creation because they may have been modified
# at any time between session creation and client creation.
user_agent_creator.set_session_config(
session_user_agent_name=self.user_agent_name,
session_user_agent_version=self.user_agent_version,
session_user_agent_extra=self.user_agent_extra,
)
defaults_mode = self._resolve_defaults_mode(config, config_store)
if defaults_mode != 'legacy':
smart_defaults_factory = self._get_internal_component(
'smart_defaults_factory'
)
await smart_defaults_factory.merge_smart_defaults(
config_store, defaults_mode, region_name
)
self._add_configured_endpoint_provider(
client_name=service_name,
config_store=config_store,
)
client_creator = AioClientCreator(
loader,
endpoint_resolver,
self.user_agent(),
event_emitter,
retryhandler,
translate,
response_parser_factory,
exceptions_factory,
config_store,
user_agent_creator=user_agent_creator,
)
client = await client_creator.create_client(
service_name=service_name,
region_name=region_name,
is_secure=use_ssl,
endpoint_url=endpoint_url,
verify=verify,
credentials=credentials,
scoped_config=self.get_scoped_config(),
client_config=config,
api_version=api_version,
auth_token=auth_token,
)
monitor = self._get_internal_component('monitor')
if monitor is not None:
monitor.register(client.meta.events)
return client
async def get_available_regions(
self, service_name, partition_name='aws', allow_non_regional=False
):
resolver = self._get_internal_component('endpoint_resolver')
results = []
try:
service_data = await self.get_service_data(service_name)
endpoint_prefix = service_data['metadata'].get(
'endpointPrefix', service_name
)
results = resolver.get_available_endpoints(
endpoint_prefix, partition_name, allow_non_regional
)
except UnknownServiceError:
pass
return results
def get_session(env_vars=None):
"""
Return a new session object.
"""
return AioSession(env_vars)