|
|
import io |
|
|
|
|
|
from botocore.httpchecksum import ( |
|
|
_CHECKSUM_CLS, |
|
|
AwsChunkedWrapper, |
|
|
FlexibleChecksumError, |
|
|
_apply_request_header_checksum, |
|
|
_handle_streaming_response, |
|
|
base64, |
|
|
conditionally_calculate_md5, |
|
|
determine_content_length, |
|
|
logger, |
|
|
) |
|
|
|
|
|
from aiobotocore._helpers import resolve_awaitable |
|
|
|
|
|
|
|
|
class AioAwsChunkedWrapper(AwsChunkedWrapper): |
|
|
async def _make_chunk(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_chunk = await resolve_awaitable(self._raw.read(self._chunk_size)) |
|
|
hex_len = hex(len(raw_chunk))[2:].encode("ascii") |
|
|
self._complete = not raw_chunk |
|
|
|
|
|
if self._checksum: |
|
|
self._checksum.update(raw_chunk) |
|
|
|
|
|
if self._checksum and self._complete: |
|
|
name = self._checksum_name.encode("ascii") |
|
|
checksum = self._checksum.b64digest().encode("ascii") |
|
|
return b"0\r\n%s:%s\r\n\r\n" % (name, checksum) |
|
|
|
|
|
return b"%s\r\n%s\r\n" % (hex_len, raw_chunk) |
|
|
|
|
|
def __aiter__(self): |
|
|
return self |
|
|
|
|
|
async def __anext__(self): |
|
|
while not self._complete: |
|
|
return await self._make_chunk() |
|
|
raise StopAsyncIteration() |
|
|
|
|
|
|
|
|
async def handle_checksum_body( |
|
|
http_response, response, context, operation_model |
|
|
): |
|
|
headers = response["headers"] |
|
|
checksum_context = context.get("checksum", {}) |
|
|
algorithms = checksum_context.get("response_algorithms") |
|
|
|
|
|
if not algorithms: |
|
|
return |
|
|
|
|
|
for algorithm in algorithms: |
|
|
header_name = "x-amz-checksum-%s" % algorithm |
|
|
|
|
|
if header_name not in headers: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "-" in headers[header_name]: |
|
|
continue |
|
|
|
|
|
if operation_model.has_streaming_output: |
|
|
response["body"] = _handle_streaming_response( |
|
|
http_response, response, algorithm |
|
|
) |
|
|
else: |
|
|
response["body"] = await _handle_bytes_response( |
|
|
http_response, response, algorithm |
|
|
) |
|
|
|
|
|
|
|
|
checksum_context = response["context"].get("checksum", {}) |
|
|
checksum_context["response_algorithm"] = algorithm |
|
|
response["context"]["checksum"] = checksum_context |
|
|
return |
|
|
|
|
|
logger.info( |
|
|
f'Skipping checksum validation. Response did not contain one of the ' |
|
|
f'following algorithms: {algorithms}.' |
|
|
) |
|
|
|
|
|
|
|
|
async def _handle_bytes_response(http_response, response, algorithm): |
|
|
body = await http_response.content |
|
|
header_name = "x-amz-checksum-%s" % algorithm |
|
|
checksum_cls = _CHECKSUM_CLS.get(algorithm) |
|
|
checksum = checksum_cls() |
|
|
checksum.update(body) |
|
|
expected = response["headers"][header_name] |
|
|
if checksum.digest() != base64.b64decode(expected): |
|
|
error_msg = ( |
|
|
"Expected checksum %s did not match calculated checksum: %s" |
|
|
% ( |
|
|
expected, |
|
|
checksum.b64digest(), |
|
|
) |
|
|
) |
|
|
raise FlexibleChecksumError(error_msg=error_msg) |
|
|
return body |
|
|
|
|
|
|
|
|
def apply_request_checksum(request): |
|
|
checksum_context = request.get("context", {}).get("checksum", {}) |
|
|
algorithm = checksum_context.get("request_algorithm") |
|
|
|
|
|
if not algorithm: |
|
|
return |
|
|
|
|
|
if algorithm == "conditional-md5": |
|
|
|
|
|
conditionally_calculate_md5(request) |
|
|
elif algorithm["in"] == "header": |
|
|
_apply_request_header_checksum(request) |
|
|
elif algorithm["in"] == "trailer": |
|
|
_apply_request_trailer_checksum(request) |
|
|
else: |
|
|
raise FlexibleChecksumError( |
|
|
error_msg="Unknown checksum variant: %s" % algorithm["in"] |
|
|
) |
|
|
|
|
|
|
|
|
def _apply_request_trailer_checksum(request): |
|
|
checksum_context = request.get("context", {}).get("checksum", {}) |
|
|
algorithm = checksum_context.get("request_algorithm") |
|
|
location_name = algorithm["name"] |
|
|
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"]) |
|
|
|
|
|
headers = request["headers"] |
|
|
body = request["body"] |
|
|
|
|
|
if location_name in headers: |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
headers["Transfer-Encoding"] = "chunked" |
|
|
if "Content-Encoding" in headers: |
|
|
|
|
|
|
|
|
headers["Content-Encoding"] += ",aws-chunked" |
|
|
else: |
|
|
headers["Content-Encoding"] = "aws-chunked" |
|
|
headers["X-Amz-Trailer"] = location_name |
|
|
|
|
|
content_length = determine_content_length(body) |
|
|
if content_length is not None: |
|
|
|
|
|
|
|
|
headers["X-Amz-Decoded-Content-Length"] = str(content_length) |
|
|
|
|
|
if isinstance(body, (bytes, bytearray)): |
|
|
body = io.BytesIO(body) |
|
|
|
|
|
request["body"] = AioAwsChunkedWrapper( |
|
|
body, |
|
|
checksum_cls=checksum_cls, |
|
|
checksum_name=location_name, |
|
|
) |
|
|
|