| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import typing |
| |
|
| | from cryptography.hazmat.primitives.ciphers import Cipher |
| | from cryptography.hazmat.primitives.ciphers.algorithms import AES |
| | from cryptography.hazmat.primitives.ciphers.modes import ECB |
| | from cryptography.hazmat.primitives.constant_time import bytes_eq |
| |
|
| |
|
| | def _wrap_core( |
| | wrapping_key: bytes, |
| | a: bytes, |
| | r: list[bytes], |
| | ) -> bytes: |
| | |
| | encryptor = Cipher(AES(wrapping_key), ECB()).encryptor() |
| | n = len(r) |
| | for j in range(6): |
| | for i in range(n): |
| | |
| | |
| | |
| | b = encryptor.update(a + r[i]) |
| | a = ( |
| | int.from_bytes(b[:8], byteorder="big") ^ ((n * j) + i + 1) |
| | ).to_bytes(length=8, byteorder="big") |
| | r[i] = b[-8:] |
| |
|
| | assert encryptor.finalize() == b"" |
| |
|
| | return a + b"".join(r) |
| |
|
| |
|
| | def aes_key_wrap( |
| | wrapping_key: bytes, |
| | key_to_wrap: bytes, |
| | backend: typing.Any = None, |
| | ) -> bytes: |
| | if len(wrapping_key) not in [16, 24, 32]: |
| | raise ValueError("The wrapping key must be a valid AES key length") |
| |
|
| | if len(key_to_wrap) < 16: |
| | raise ValueError("The key to wrap must be at least 16 bytes") |
| |
|
| | if len(key_to_wrap) % 8 != 0: |
| | raise ValueError("The key to wrap must be a multiple of 8 bytes") |
| |
|
| | a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" |
| | r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)] |
| | return _wrap_core(wrapping_key, a, r) |
| |
|
| |
|
| | def _unwrap_core( |
| | wrapping_key: bytes, |
| | a: bytes, |
| | r: list[bytes], |
| | ) -> tuple[bytes, list[bytes]]: |
| | |
| | decryptor = Cipher(AES(wrapping_key), ECB()).decryptor() |
| | n = len(r) |
| | for j in reversed(range(6)): |
| | for i in reversed(range(n)): |
| | atr = ( |
| | int.from_bytes(a, byteorder="big") ^ ((n * j) + i + 1) |
| | ).to_bytes(length=8, byteorder="big") + r[i] |
| | |
| | |
| | b = decryptor.update(atr) |
| | a = b[:8] |
| | r[i] = b[-8:] |
| |
|
| | assert decryptor.finalize() == b"" |
| | return a, r |
| |
|
| |
|
| | def aes_key_wrap_with_padding( |
| | wrapping_key: bytes, |
| | key_to_wrap: bytes, |
| | backend: typing.Any = None, |
| | ) -> bytes: |
| | if len(wrapping_key) not in [16, 24, 32]: |
| | raise ValueError("The wrapping key must be a valid AES key length") |
| |
|
| | aiv = b"\xa6\x59\x59\xa6" + len(key_to_wrap).to_bytes( |
| | length=4, byteorder="big" |
| | ) |
| | |
| | pad = (8 - (len(key_to_wrap) % 8)) % 8 |
| | key_to_wrap = key_to_wrap + b"\x00" * pad |
| | if len(key_to_wrap) == 8: |
| | |
| | encryptor = Cipher(AES(wrapping_key), ECB()).encryptor() |
| | b = encryptor.update(aiv + key_to_wrap) |
| | assert encryptor.finalize() == b"" |
| | return b |
| | else: |
| | r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)] |
| | return _wrap_core(wrapping_key, aiv, r) |
| |
|
| |
|
| | def aes_key_unwrap_with_padding( |
| | wrapping_key: bytes, |
| | wrapped_key: bytes, |
| | backend: typing.Any = None, |
| | ) -> bytes: |
| | if len(wrapped_key) < 16: |
| | raise InvalidUnwrap("Must be at least 16 bytes") |
| |
|
| | if len(wrapping_key) not in [16, 24, 32]: |
| | raise ValueError("The wrapping key must be a valid AES key length") |
| |
|
| | if len(wrapped_key) == 16: |
| | |
| | decryptor = Cipher(AES(wrapping_key), ECB()).decryptor() |
| | out = decryptor.update(wrapped_key) |
| | assert decryptor.finalize() == b"" |
| | a = out[:8] |
| | data = out[8:] |
| | n = 1 |
| | else: |
| | r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)] |
| | encrypted_aiv = r.pop(0) |
| | n = len(r) |
| | a, r = _unwrap_core(wrapping_key, encrypted_aiv, r) |
| | data = b"".join(r) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | mli = int.from_bytes(a[4:], byteorder="big") |
| | b = (8 * n) - mli |
| | if ( |
| | not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") |
| | or not 8 * (n - 1) < mli <= 8 * n |
| | or (b != 0 and not bytes_eq(data[-b:], b"\x00" * b)) |
| | ): |
| | raise InvalidUnwrap() |
| |
|
| | if b == 0: |
| | return data |
| | else: |
| | return data[:-b] |
| |
|
| |
|
| | def aes_key_unwrap( |
| | wrapping_key: bytes, |
| | wrapped_key: bytes, |
| | backend: typing.Any = None, |
| | ) -> bytes: |
| | if len(wrapped_key) < 24: |
| | raise InvalidUnwrap("Must be at least 24 bytes") |
| |
|
| | if len(wrapped_key) % 8 != 0: |
| | raise InvalidUnwrap("The wrapped key must be a multiple of 8 bytes") |
| |
|
| | if len(wrapping_key) not in [16, 24, 32]: |
| | raise ValueError("The wrapping key must be a valid AES key length") |
| |
|
| | aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" |
| | r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)] |
| | a = r.pop(0) |
| | a, r = _unwrap_core(wrapping_key, a, r) |
| | if not bytes_eq(a, aiv): |
| | raise InvalidUnwrap() |
| |
|
| | return b"".join(r) |
| |
|
| |
|
| | class InvalidUnwrap(Exception): |
| | pass |
| |
|