| from typing import Callable |
| from Decipher.cmac import CMAC, xor_bytes, BLOCK_SIZE |
| from Decipher.ctr import CTR |
|
|
|
|
| def _omac_with_prefix(cmac: CMAC, prefix: int, data: bytes) -> bytes: |
| |
| P = b'\x00' * (BLOCK_SIZE - 1) + bytes([prefix]) |
| return cmac.digest(P + data) |
|
|
|
|
| class EAX: |
| def __init__(self, encrypt_block: Callable[[bytes], bytes]): |
| self.encrypt_block = encrypt_block |
| self.cmac = CMAC(encrypt_block) |
|
|
| def encrypt(self, nonce: bytes, plaintext: bytes, aad: bytes = b''): |
| |
| n_tag = _omac_with_prefix(self.cmac, 0x00, nonce) |
|
|
| |
| h_tag = _omac_with_prefix(self.cmac, 0x01, aad) |
|
|
| |
| ctr = CTR(self.encrypt_block, n_tag) |
| ciphertext = ctr.process(plaintext) |
|
|
| |
| c_tag = _omac_with_prefix(self.cmac, 0x02, ciphertext) |
|
|
| |
| tag = xor_bytes(xor_bytes(n_tag, h_tag), c_tag) |
|
|
| return ciphertext, tag |
|
|
| def decrypt(self, nonce: bytes, ciphertext: bytes, tag: bytes, aad: bytes = b''): |
| |
| n_tag = _omac_with_prefix(self.cmac, 0x00, nonce) |
|
|
| |
| ctr = CTR(self.encrypt_block, n_tag) |
| plaintext = ctr.process(ciphertext) |
|
|
| |
| h_tag = _omac_with_prefix(self.cmac, 0x01, aad) |
| c_tag = _omac_with_prefix(self.cmac, 0x02, ciphertext) |
|
|
| |
| expected_tag = xor_bytes(xor_bytes(n_tag, h_tag), c_tag) |
| if expected_tag != tag: |
| raise ValueError("EAX authentication failed") |
|
|
| return plaintext |
|
|