Kris8an's picture
Upload folder using huggingface_hub
a06facb verified
import io
from botocore.httpchecksum import (
_CHECKSUM_CLS,
AwsChunkedWrapper,
FlexibleChecksumError,
_apply_request_header_checksum,
_handle_streaming_response,
base64,
conditionally_calculate_md5,
determine_content_length,
logger,
)
from aiobotocore._helpers import resolve_awaitable
class AioAwsChunkedWrapper(AwsChunkedWrapper):
async def _make_chunk(self):
# NOTE: Chunk size is not deterministic as read could return less. This
# means we cannot know the content length of the encoded aws-chunked
# stream ahead of time without ensuring a consistent chunk size
raw_chunk = await resolve_awaitable(self._raw.read(self._chunk_size))
hex_len = hex(len(raw_chunk))[2:].encode("ascii")
self._complete = not raw_chunk
if self._checksum:
self._checksum.update(raw_chunk)
if self._checksum and self._complete:
name = self._checksum_name.encode("ascii")
checksum = self._checksum.b64digest().encode("ascii")
return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)
return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)
def __aiter__(self):
return self
async def __anext__(self):
while not self._complete:
return await self._make_chunk()
raise StopAsyncIteration()
async def handle_checksum_body(
http_response, response, context, operation_model
):
headers = response["headers"]
checksum_context = context.get("checksum", {})
algorithms = checksum_context.get("response_algorithms")
if not algorithms:
return
for algorithm in algorithms:
header_name = "x-amz-checksum-%s" % algorithm
# If the header is not found, check the next algorithm
if header_name not in headers:
continue
# If a - is in the checksum this is not valid Base64. S3 returns
# checksums that include a -# suffix to indicate a checksum derived
# from the hash of all part checksums. We cannot wrap this response
if "-" in headers[header_name]:
continue
if operation_model.has_streaming_output:
response["body"] = _handle_streaming_response(
http_response, response, algorithm
)
else:
response["body"] = await _handle_bytes_response(
http_response, response, algorithm
)
# Expose metadata that the checksum check actually occurred
checksum_context = response["context"].get("checksum", {})
checksum_context["response_algorithm"] = algorithm
response["context"]["checksum"] = checksum_context
return
logger.info(
f'Skipping checksum validation. Response did not contain one of the '
f'following algorithms: {algorithms}.'
)
async def _handle_bytes_response(http_response, response, algorithm):
body = await http_response.content
header_name = "x-amz-checksum-%s" % algorithm
checksum_cls = _CHECKSUM_CLS.get(algorithm)
checksum = checksum_cls()
checksum.update(body)
expected = response["headers"][header_name]
if checksum.digest() != base64.b64decode(expected):
error_msg = (
"Expected checksum %s did not match calculated checksum: %s"
% (
expected,
checksum.b64digest(),
)
)
raise FlexibleChecksumError(error_msg=error_msg)
return body
def apply_request_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")
if not algorithm:
return
if algorithm == "conditional-md5":
# Special case to handle the http checksum required trait
conditionally_calculate_md5(request)
elif algorithm["in"] == "header":
_apply_request_header_checksum(request)
elif algorithm["in"] == "trailer":
_apply_request_trailer_checksum(request)
else:
raise FlexibleChecksumError(
error_msg="Unknown checksum variant: %s" % algorithm["in"]
)
def _apply_request_trailer_checksum(request):
checksum_context = request.get("context", {}).get("checksum", {})
algorithm = checksum_context.get("request_algorithm")
location_name = algorithm["name"]
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
headers = request["headers"]
body = request["body"]
if location_name in headers:
# If the header is already set by the customer, skip calculation
return
# Cannot set this as aiohttp complains
headers["Transfer-Encoding"] = "chunked"
if "Content-Encoding" in headers:
# We need to preserve the existing content encoding and add
# aws-chunked as a new content encoding.
headers["Content-Encoding"] += ",aws-chunked"
else:
headers["Content-Encoding"] = "aws-chunked"
headers["X-Amz-Trailer"] = location_name
content_length = determine_content_length(body)
if content_length is not None:
# Send the decoded content length if we can determine it. Some
# services such as S3 may require the decoded content length
headers["X-Amz-Decoded-Content-Length"] = str(content_length)
if isinstance(body, (bytes, bytearray)):
body = io.BytesIO(body)
request["body"] = AioAwsChunkedWrapper(
body,
checksum_cls=checksum_cls,
checksum_name=location_name,
)