| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import binascii |
| | import enum |
| | import os |
| | import re |
| | import typing |
| | import warnings |
| | from base64 import encodebytes as _base64_encode |
| | from dataclasses import dataclass |
| |
|
| | from cryptography import utils |
| | from cryptography.exceptions import UnsupportedAlgorithm |
| | from cryptography.hazmat.primitives import hashes |
| | from cryptography.hazmat.primitives.asymmetric import ( |
| | dsa, |
| | ec, |
| | ed25519, |
| | padding, |
| | rsa, |
| | ) |
| | from cryptography.hazmat.primitives.asymmetric import utils as asym_utils |
| | from cryptography.hazmat.primitives.ciphers import ( |
| | AEADDecryptionContext, |
| | Cipher, |
| | algorithms, |
| | modes, |
| | ) |
| | from cryptography.hazmat.primitives.serialization import ( |
| | Encoding, |
| | KeySerializationEncryption, |
| | NoEncryption, |
| | PrivateFormat, |
| | PublicFormat, |
| | _KeySerializationEncryption, |
| | ) |
| |
|
| | try: |
| | from bcrypt import kdf as _bcrypt_kdf |
| |
|
| | _bcrypt_supported = True |
| | except ImportError: |
| | _bcrypt_supported = False |
| |
|
| | def _bcrypt_kdf( |
| | password: bytes, |
| | salt: bytes, |
| | desired_key_bytes: int, |
| | rounds: int, |
| | ignore_few_rounds: bool = False, |
| | ) -> bytes: |
| | raise UnsupportedAlgorithm("Need bcrypt module") |
| |
|
| |
|
| | _SSH_ED25519 = b"ssh-ed25519" |
| | _SSH_RSA = b"ssh-rsa" |
| | _SSH_DSA = b"ssh-dss" |
| | _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" |
| | _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" |
| | _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" |
| | _CERT_SUFFIX = b"-cert-v01@openssh.com" |
| |
|
| | |
| | _SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com" |
| | _SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com" |
| |
|
| | |
| | |
| | _SSH_RSA_SHA256 = b"rsa-sha2-256" |
| | _SSH_RSA_SHA512 = b"rsa-sha2-512" |
| |
|
| | _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") |
| | _SK_MAGIC = b"openssh-key-v1\0" |
| | _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" |
| | _SK_END = b"-----END OPENSSH PRIVATE KEY-----" |
| | _BCRYPT = b"bcrypt" |
| | _NONE = b"none" |
| | _DEFAULT_CIPHER = b"aes256-ctr" |
| | _DEFAULT_ROUNDS = 16 |
| |
|
| | |
| | _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) |
| |
|
| | |
| | _PADDING = memoryview(bytearray(range(1, 1 + 16))) |
| |
|
| |
|
| | @dataclass |
| | class _SSHCipher: |
| | alg: type[algorithms.AES] |
| | key_len: int |
| | mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] |
| | block_len: int |
| | iv_len: int |
| | tag_len: int | None |
| | is_aead: bool |
| |
|
| |
|
| | |
| | _SSH_CIPHERS: dict[bytes, _SSHCipher] = { |
| | b"aes256-ctr": _SSHCipher( |
| | alg=algorithms.AES, |
| | key_len=32, |
| | mode=modes.CTR, |
| | block_len=16, |
| | iv_len=16, |
| | tag_len=None, |
| | is_aead=False, |
| | ), |
| | b"aes256-cbc": _SSHCipher( |
| | alg=algorithms.AES, |
| | key_len=32, |
| | mode=modes.CBC, |
| | block_len=16, |
| | iv_len=16, |
| | tag_len=None, |
| | is_aead=False, |
| | ), |
| | b"aes256-gcm@openssh.com": _SSHCipher( |
| | alg=algorithms.AES, |
| | key_len=32, |
| | mode=modes.GCM, |
| | block_len=16, |
| | iv_len=12, |
| | tag_len=16, |
| | is_aead=True, |
| | ), |
| | } |
| |
|
| | |
| | _ECDSA_KEY_TYPE = { |
| | "secp256r1": _ECDSA_NISTP256, |
| | "secp384r1": _ECDSA_NISTP384, |
| | "secp521r1": _ECDSA_NISTP521, |
| | } |
| |
|
| |
|
| | def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: |
| | if isinstance(key, ec.EllipticCurvePrivateKey): |
| | key_type = _ecdsa_key_type(key.public_key()) |
| | elif isinstance(key, ec.EllipticCurvePublicKey): |
| | key_type = _ecdsa_key_type(key) |
| | elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): |
| | key_type = _SSH_RSA |
| | elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): |
| | key_type = _SSH_DSA |
| | elif isinstance( |
| | key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) |
| | ): |
| | key_type = _SSH_ED25519 |
| | else: |
| | raise ValueError("Unsupported key type") |
| |
|
| | return key_type |
| |
|
| |
|
| | def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: |
| | """Return SSH key_type and curve_name for private key.""" |
| | curve = public_key.curve |
| | if curve.name not in _ECDSA_KEY_TYPE: |
| | raise ValueError( |
| | f"Unsupported curve for ssh private key: {curve.name!r}" |
| | ) |
| | return _ECDSA_KEY_TYPE[curve.name] |
| |
|
| |
|
| | def _ssh_pem_encode( |
| | data: utils.Buffer, |
| | prefix: bytes = _SK_START + b"\n", |
| | suffix: bytes = _SK_END + b"\n", |
| | ) -> bytes: |
| | return b"".join([prefix, _base64_encode(data), suffix]) |
| |
|
| |
|
| | def _check_block_size(data: utils.Buffer, block_len: int) -> None: |
| | """Require data to be full blocks""" |
| | if not data or len(data) % block_len != 0: |
| | raise ValueError("Corrupt data: missing padding") |
| |
|
| |
|
| | def _check_empty(data: utils.Buffer) -> None: |
| | """All data should have been parsed.""" |
| | if data: |
| | raise ValueError("Corrupt data: unparsed data") |
| |
|
| |
|
| | def _init_cipher( |
| | ciphername: bytes, |
| | password: bytes | None, |
| | salt: bytes, |
| | rounds: int, |
| | ) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: |
| | """Generate key + iv and return cipher.""" |
| | if not password: |
| | raise TypeError( |
| | "Key is password-protected, but password was not provided." |
| | ) |
| |
|
| | ciph = _SSH_CIPHERS[ciphername] |
| | seed = _bcrypt_kdf( |
| | password, salt, ciph.key_len + ciph.iv_len, rounds, True |
| | ) |
| | return Cipher( |
| | ciph.alg(seed[: ciph.key_len]), |
| | ciph.mode(seed[ciph.key_len :]), |
| | ) |
| |
|
| |
|
| | def _get_u32(data: memoryview) -> tuple[int, memoryview]: |
| | """Uint32""" |
| | if len(data) < 4: |
| | raise ValueError("Invalid data") |
| | return int.from_bytes(data[:4], byteorder="big"), data[4:] |
| |
|
| |
|
| | def _get_u64(data: memoryview) -> tuple[int, memoryview]: |
| | """Uint64""" |
| | if len(data) < 8: |
| | raise ValueError("Invalid data") |
| | return int.from_bytes(data[:8], byteorder="big"), data[8:] |
| |
|
| |
|
| | def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: |
| | """Bytes with u32 length prefix""" |
| | n, data = _get_u32(data) |
| | if n > len(data): |
| | raise ValueError("Invalid data") |
| | return data[:n], data[n:] |
| |
|
| |
|
| | def _get_mpint(data: memoryview) -> tuple[int, memoryview]: |
| | """Big integer.""" |
| | val, data = _get_sshstr(data) |
| | if val and val[0] > 0x7F: |
| | raise ValueError("Invalid data") |
| | return int.from_bytes(val, "big"), data |
| |
|
| |
|
| | def _to_mpint(val: int) -> bytes: |
| | """Storage format for signed bigint.""" |
| | if val < 0: |
| | raise ValueError("negative mpint not allowed") |
| | if not val: |
| | return b"" |
| | nbytes = (val.bit_length() + 8) // 8 |
| | return utils.int_to_bytes(val, nbytes) |
| |
|
| |
|
| | class _FragList: |
| | """Build recursive structure without data copy.""" |
| |
|
| | flist: list[utils.Buffer] |
| |
|
| | def __init__(self, init: list[utils.Buffer] | None = None) -> None: |
| | self.flist = [] |
| | if init: |
| | self.flist.extend(init) |
| |
|
| | def put_raw(self, val: utils.Buffer) -> None: |
| | """Add plain bytes""" |
| | self.flist.append(val) |
| |
|
| | def put_u32(self, val: int) -> None: |
| | """Big-endian uint32""" |
| | self.flist.append(val.to_bytes(length=4, byteorder="big")) |
| |
|
| | def put_u64(self, val: int) -> None: |
| | """Big-endian uint64""" |
| | self.flist.append(val.to_bytes(length=8, byteorder="big")) |
| |
|
| | def put_sshstr(self, val: bytes | _FragList) -> None: |
| | """Bytes prefixed with u32 length""" |
| | if isinstance(val, (bytes, memoryview, bytearray)): |
| | self.put_u32(len(val)) |
| | self.flist.append(val) |
| | else: |
| | self.put_u32(val.size()) |
| | self.flist.extend(val.flist) |
| |
|
| | def put_mpint(self, val: int) -> None: |
| | """Big-endian bigint prefixed with u32 length""" |
| | self.put_sshstr(_to_mpint(val)) |
| |
|
| | def size(self) -> int: |
| | """Current number of bytes""" |
| | return sum(map(len, self.flist)) |
| |
|
| | def render(self, dstbuf: memoryview, pos: int = 0) -> int: |
| | """Write into bytearray""" |
| | for frag in self.flist: |
| | flen = len(frag) |
| | start, pos = pos, pos + flen |
| | dstbuf[start:pos] = frag |
| | return pos |
| |
|
| | def tobytes(self) -> bytes: |
| | """Return as bytes""" |
| | buf = memoryview(bytearray(self.size())) |
| | self.render(buf) |
| | return buf.tobytes() |
| |
|
| |
|
| | class _SSHFormatRSA: |
| | """Format for RSA keys. |
| | |
| | Public: |
| | mpint e, n |
| | Private: |
| | mpint n, e, d, iqmp, p, q |
| | """ |
| |
|
| | def get_public( |
| | self, data: memoryview |
| | ) -> tuple[tuple[int, int], memoryview]: |
| | """RSA public fields""" |
| | e, data = _get_mpint(data) |
| | n, data = _get_mpint(data) |
| | return (e, n), data |
| |
|
| | def load_public( |
| | self, data: memoryview |
| | ) -> tuple[rsa.RSAPublicKey, memoryview]: |
| | """Make RSA public key from data.""" |
| | (e, n), data = self.get_public(data) |
| | public_numbers = rsa.RSAPublicNumbers(e, n) |
| | public_key = public_numbers.public_key() |
| | return public_key, data |
| |
|
| | def load_private( |
| | self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool |
| | ) -> tuple[rsa.RSAPrivateKey, memoryview]: |
| | """Make RSA private key from data.""" |
| | n, data = _get_mpint(data) |
| | e, data = _get_mpint(data) |
| | d, data = _get_mpint(data) |
| | iqmp, data = _get_mpint(data) |
| | p, data = _get_mpint(data) |
| | q, data = _get_mpint(data) |
| |
|
| | if (e, n) != pubfields: |
| | raise ValueError("Corrupt data: rsa field mismatch") |
| | dmp1 = rsa.rsa_crt_dmp1(d, p) |
| | dmq1 = rsa.rsa_crt_dmq1(d, q) |
| | public_numbers = rsa.RSAPublicNumbers(e, n) |
| | private_numbers = rsa.RSAPrivateNumbers( |
| | p, q, d, dmp1, dmq1, iqmp, public_numbers |
| | ) |
| | private_key = private_numbers.private_key( |
| | unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation |
| | ) |
| | return private_key, data |
| |
|
| | def encode_public( |
| | self, public_key: rsa.RSAPublicKey, f_pub: _FragList |
| | ) -> None: |
| | """Write RSA public key""" |
| | pubn = public_key.public_numbers() |
| | f_pub.put_mpint(pubn.e) |
| | f_pub.put_mpint(pubn.n) |
| |
|
| | def encode_private( |
| | self, private_key: rsa.RSAPrivateKey, f_priv: _FragList |
| | ) -> None: |
| | """Write RSA private key""" |
| | private_numbers = private_key.private_numbers() |
| | public_numbers = private_numbers.public_numbers |
| |
|
| | f_priv.put_mpint(public_numbers.n) |
| | f_priv.put_mpint(public_numbers.e) |
| |
|
| | f_priv.put_mpint(private_numbers.d) |
| | f_priv.put_mpint(private_numbers.iqmp) |
| | f_priv.put_mpint(private_numbers.p) |
| | f_priv.put_mpint(private_numbers.q) |
| |
|
| |
|
| | class _SSHFormatDSA: |
| | """Format for DSA keys. |
| | |
| | Public: |
| | mpint p, q, g, y |
| | Private: |
| | mpint p, q, g, y, x |
| | """ |
| |
|
| | def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: |
| | """DSA public fields""" |
| | p, data = _get_mpint(data) |
| | q, data = _get_mpint(data) |
| | g, data = _get_mpint(data) |
| | y, data = _get_mpint(data) |
| | return (p, q, g, y), data |
| |
|
| | def load_public( |
| | self, data: memoryview |
| | ) -> tuple[dsa.DSAPublicKey, memoryview]: |
| | """Make DSA public key from data.""" |
| | (p, q, g, y), data = self.get_public(data) |
| | parameter_numbers = dsa.DSAParameterNumbers(p, q, g) |
| | public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) |
| | self._validate(public_numbers) |
| | public_key = public_numbers.public_key() |
| | return public_key, data |
| |
|
| | def load_private( |
| | self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool |
| | ) -> tuple[dsa.DSAPrivateKey, memoryview]: |
| | """Make DSA private key from data.""" |
| | (p, q, g, y), data = self.get_public(data) |
| | x, data = _get_mpint(data) |
| |
|
| | if (p, q, g, y) != pubfields: |
| | raise ValueError("Corrupt data: dsa field mismatch") |
| | parameter_numbers = dsa.DSAParameterNumbers(p, q, g) |
| | public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) |
| | self._validate(public_numbers) |
| | private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) |
| | private_key = private_numbers.private_key() |
| | return private_key, data |
| |
|
| | def encode_public( |
| | self, public_key: dsa.DSAPublicKey, f_pub: _FragList |
| | ) -> None: |
| | """Write DSA public key""" |
| | public_numbers = public_key.public_numbers() |
| | parameter_numbers = public_numbers.parameter_numbers |
| | self._validate(public_numbers) |
| |
|
| | f_pub.put_mpint(parameter_numbers.p) |
| | f_pub.put_mpint(parameter_numbers.q) |
| | f_pub.put_mpint(parameter_numbers.g) |
| | f_pub.put_mpint(public_numbers.y) |
| |
|
| | def encode_private( |
| | self, private_key: dsa.DSAPrivateKey, f_priv: _FragList |
| | ) -> None: |
| | """Write DSA private key""" |
| | self.encode_public(private_key.public_key(), f_priv) |
| | f_priv.put_mpint(private_key.private_numbers().x) |
| |
|
| | def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: |
| | parameter_numbers = public_numbers.parameter_numbers |
| | if parameter_numbers.p.bit_length() != 1024: |
| | raise ValueError("SSH supports only 1024 bit DSA keys") |
| |
|
| |
|
| | class _SSHFormatECDSA: |
| | """Format for ECDSA keys. |
| | |
| | Public: |
| | str curve |
| | bytes point |
| | Private: |
| | str curve |
| | bytes point |
| | mpint secret |
| | """ |
| |
|
| | def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): |
| | self.ssh_curve_name = ssh_curve_name |
| | self.curve = curve |
| |
|
| | def get_public( |
| | self, data: memoryview |
| | ) -> tuple[tuple[memoryview, memoryview], memoryview]: |
| | """ECDSA public fields""" |
| | curve, data = _get_sshstr(data) |
| | point, data = _get_sshstr(data) |
| | if curve != self.ssh_curve_name: |
| | raise ValueError("Curve name mismatch") |
| | if point[0] != 4: |
| | raise NotImplementedError("Need uncompressed point") |
| | return (curve, point), data |
| |
|
| | def load_public( |
| | self, data: memoryview |
| | ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: |
| | """Make ECDSA public key from data.""" |
| | (_, point), data = self.get_public(data) |
| | public_key = ec.EllipticCurvePublicKey.from_encoded_point( |
| | self.curve, point.tobytes() |
| | ) |
| | return public_key, data |
| |
|
| | def load_private( |
| | self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool |
| | ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: |
| | """Make ECDSA private key from data.""" |
| | (curve_name, point), data = self.get_public(data) |
| | secret, data = _get_mpint(data) |
| |
|
| | if (curve_name, point) != pubfields: |
| | raise ValueError("Corrupt data: ecdsa field mismatch") |
| | private_key = ec.derive_private_key(secret, self.curve) |
| | return private_key, data |
| |
|
| | def encode_public( |
| | self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList |
| | ) -> None: |
| | """Write ECDSA public key""" |
| | point = public_key.public_bytes( |
| | Encoding.X962, PublicFormat.UncompressedPoint |
| | ) |
| | f_pub.put_sshstr(self.ssh_curve_name) |
| | f_pub.put_sshstr(point) |
| |
|
| | def encode_private( |
| | self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList |
| | ) -> None: |
| | """Write ECDSA private key""" |
| | public_key = private_key.public_key() |
| | private_numbers = private_key.private_numbers() |
| |
|
| | self.encode_public(public_key, f_priv) |
| | f_priv.put_mpint(private_numbers.private_value) |
| |
|
| |
|
| | class _SSHFormatEd25519: |
| | """Format for Ed25519 keys. |
| | |
| | Public: |
| | bytes point |
| | Private: |
| | bytes point |
| | bytes secret_and_point |
| | """ |
| |
|
| | def get_public( |
| | self, data: memoryview |
| | ) -> tuple[tuple[memoryview], memoryview]: |
| | """Ed25519 public fields""" |
| | point, data = _get_sshstr(data) |
| | return (point,), data |
| |
|
| | def load_public( |
| | self, data: memoryview |
| | ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: |
| | """Make Ed25519 public key from data.""" |
| | (point,), data = self.get_public(data) |
| | public_key = ed25519.Ed25519PublicKey.from_public_bytes( |
| | point.tobytes() |
| | ) |
| | return public_key, data |
| |
|
| | def load_private( |
| | self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool |
| | ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: |
| | """Make Ed25519 private key from data.""" |
| | (point,), data = self.get_public(data) |
| | keypair, data = _get_sshstr(data) |
| |
|
| | secret = keypair[:32] |
| | point2 = keypair[32:] |
| | if point != point2 or (point,) != pubfields: |
| | raise ValueError("Corrupt data: ed25519 field mismatch") |
| | private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) |
| | return private_key, data |
| |
|
| | def encode_public( |
| | self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList |
| | ) -> None: |
| | """Write Ed25519 public key""" |
| | raw_public_key = public_key.public_bytes( |
| | Encoding.Raw, PublicFormat.Raw |
| | ) |
| | f_pub.put_sshstr(raw_public_key) |
| |
|
| | def encode_private( |
| | self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList |
| | ) -> None: |
| | """Write Ed25519 private key""" |
| | public_key = private_key.public_key() |
| | raw_private_key = private_key.private_bytes( |
| | Encoding.Raw, PrivateFormat.Raw, NoEncryption() |
| | ) |
| | raw_public_key = public_key.public_bytes( |
| | Encoding.Raw, PublicFormat.Raw |
| | ) |
| | f_keypair = _FragList([raw_private_key, raw_public_key]) |
| |
|
| | self.encode_public(public_key, f_priv) |
| | f_priv.put_sshstr(f_keypair) |
| |
|
| |
|
| | def load_application(data) -> tuple[memoryview, memoryview]: |
| | """ |
| | U2F application strings |
| | """ |
| | application, data = _get_sshstr(data) |
| | if not application.tobytes().startswith(b"ssh:"): |
| | raise ValueError( |
| | "U2F application string does not start with b'ssh:' " |
| | f"({application})" |
| | ) |
| | return application, data |
| |
|
| |
|
| | class _SSHFormatSKEd25519: |
| | """ |
| | The format of a sk-ssh-ed25519@openssh.com public key is: |
| | |
| | string "sk-ssh-ed25519@openssh.com" |
| | string public key |
| | string application (user-specified, but typically "ssh:") |
| | """ |
| |
|
| | def load_public( |
| | self, data: memoryview |
| | ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: |
| | """Make Ed25519 public key from data.""" |
| | public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data) |
| | _, data = load_application(data) |
| | return public_key, data |
| |
|
| | def get_public(self, data: memoryview) -> typing.NoReturn: |
| | |
| | |
| | raise UnsupportedAlgorithm( |
| | "sk-ssh-ed25519 private keys cannot be loaded" |
| | ) |
| |
|
| |
|
| | class _SSHFormatSKECDSA: |
| | """ |
| | The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is: |
| | |
| | string "sk-ecdsa-sha2-nistp256@openssh.com" |
| | string curve name |
| | ec_point Q |
| | string application (user-specified, but typically "ssh:") |
| | """ |
| |
|
| | def load_public( |
| | self, data: memoryview |
| | ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: |
| | """Make ECDSA public key from data.""" |
| | public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data) |
| | _, data = load_application(data) |
| | return public_key, data |
| |
|
| | def get_public(self, data: memoryview) -> typing.NoReturn: |
| | |
| | |
| | raise UnsupportedAlgorithm( |
| | "sk-ecdsa-sha2-nistp256 private keys cannot be loaded" |
| | ) |
| |
|
| |
|
| | _KEY_FORMATS = { |
| | _SSH_RSA: _SSHFormatRSA(), |
| | _SSH_DSA: _SSHFormatDSA(), |
| | _SSH_ED25519: _SSHFormatEd25519(), |
| | _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), |
| | _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), |
| | _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), |
| | _SK_SSH_ED25519: _SSHFormatSKEd25519(), |
| | _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(), |
| | } |
| |
|
| |
|
| | def _lookup_kformat(key_type: utils.Buffer): |
| | """Return valid format or throw error""" |
| | if not isinstance(key_type, bytes): |
| | key_type = memoryview(key_type).tobytes() |
| | if key_type in _KEY_FORMATS: |
| | return _KEY_FORMATS[key_type] |
| | raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") |
| |
|
| |
|
| | SSHPrivateKeyTypes = typing.Union[ |
| | ec.EllipticCurvePrivateKey, |
| | rsa.RSAPrivateKey, |
| | dsa.DSAPrivateKey, |
| | ed25519.Ed25519PrivateKey, |
| | ] |
| |
|
| |
|
| | def load_ssh_private_key( |
| | data: utils.Buffer, |
| | password: bytes | None, |
| | backend: typing.Any = None, |
| | *, |
| | unsafe_skip_rsa_key_validation: bool = False, |
| | ) -> SSHPrivateKeyTypes: |
| | """Load private key from OpenSSH custom encoding.""" |
| | utils._check_byteslike("data", data) |
| | if password is not None: |
| | utils._check_bytes("password", password) |
| |
|
| | m = _PEM_RC.search(data) |
| | if not m: |
| | raise ValueError("Not OpenSSH private key format") |
| | p1 = m.start(1) |
| | p2 = m.end(1) |
| | data = binascii.a2b_base64(memoryview(data)[p1:p2]) |
| | if not data.startswith(_SK_MAGIC): |
| | raise ValueError("Not OpenSSH private key format") |
| | data = memoryview(data)[len(_SK_MAGIC) :] |
| |
|
| | |
| | ciphername, data = _get_sshstr(data) |
| | kdfname, data = _get_sshstr(data) |
| | kdfoptions, data = _get_sshstr(data) |
| | nkeys, data = _get_u32(data) |
| | if nkeys != 1: |
| | raise ValueError("Only one key supported") |
| |
|
| | |
| | pubdata, data = _get_sshstr(data) |
| | pub_key_type, pubdata = _get_sshstr(pubdata) |
| | kformat = _lookup_kformat(pub_key_type) |
| | pubfields, pubdata = kformat.get_public(pubdata) |
| | _check_empty(pubdata) |
| |
|
| | if ciphername != _NONE or kdfname != _NONE: |
| | ciphername_bytes = ciphername.tobytes() |
| | if ciphername_bytes not in _SSH_CIPHERS: |
| | raise UnsupportedAlgorithm( |
| | f"Unsupported cipher: {ciphername_bytes!r}" |
| | ) |
| | if kdfname != _BCRYPT: |
| | raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") |
| | blklen = _SSH_CIPHERS[ciphername_bytes].block_len |
| | tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len |
| | |
| | edata, data = _get_sshstr(data) |
| | |
| | |
| | if _SSH_CIPHERS[ciphername_bytes].is_aead: |
| | tag = bytes(data) |
| | if len(tag) != tag_len: |
| | raise ValueError("Corrupt data: invalid tag length for cipher") |
| | else: |
| | _check_empty(data) |
| | _check_block_size(edata, blklen) |
| | salt, kbuf = _get_sshstr(kdfoptions) |
| | rounds, kbuf = _get_u32(kbuf) |
| | _check_empty(kbuf) |
| | ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) |
| | dec = ciph.decryptor() |
| | edata = memoryview(dec.update(edata)) |
| | if _SSH_CIPHERS[ciphername_bytes].is_aead: |
| | assert isinstance(dec, AEADDecryptionContext) |
| | _check_empty(dec.finalize_with_tag(tag)) |
| | else: |
| | |
| | |
| | _check_empty(dec.finalize()) |
| | else: |
| | if password: |
| | raise TypeError( |
| | "Password was given but private key is not encrypted." |
| | ) |
| | |
| | edata, data = _get_sshstr(data) |
| | _check_empty(data) |
| | blklen = 8 |
| | _check_block_size(edata, blklen) |
| | ck1, edata = _get_u32(edata) |
| | ck2, edata = _get_u32(edata) |
| | if ck1 != ck2: |
| | raise ValueError("Corrupt data: broken checksum") |
| |
|
| | |
| | key_type, edata = _get_sshstr(edata) |
| | if key_type != pub_key_type: |
| | raise ValueError("Corrupt data: key type mismatch") |
| | private_key, edata = kformat.load_private( |
| | edata, |
| | pubfields, |
| | unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation, |
| | ) |
| | |
| | _, edata = _get_sshstr(edata) |
| |
|
| | |
| | |
| | if edata != _PADDING[: len(edata)]: |
| | raise ValueError("Corrupt data: invalid padding") |
| |
|
| | if isinstance(private_key, dsa.DSAPrivateKey): |
| | warnings.warn( |
| | "SSH DSA keys are deprecated and will be removed in a future " |
| | "release.", |
| | utils.DeprecatedIn40, |
| | stacklevel=2, |
| | ) |
| |
|
| | return private_key |
| |
|
| |
|
| | def _serialize_ssh_private_key( |
| | private_key: SSHPrivateKeyTypes, |
| | password: bytes, |
| | encryption_algorithm: KeySerializationEncryption, |
| | ) -> bytes: |
| | """Serialize private key with OpenSSH custom encoding.""" |
| | utils._check_bytes("password", password) |
| | if isinstance(private_key, dsa.DSAPrivateKey): |
| | warnings.warn( |
| | "SSH DSA key support is deprecated and will be " |
| | "removed in a future release", |
| | utils.DeprecatedIn40, |
| | stacklevel=4, |
| | ) |
| |
|
| | key_type = _get_ssh_key_type(private_key) |
| | kformat = _lookup_kformat(key_type) |
| |
|
| | |
| | f_kdfoptions = _FragList() |
| | if password: |
| | ciphername = _DEFAULT_CIPHER |
| | blklen = _SSH_CIPHERS[ciphername].block_len |
| | kdfname = _BCRYPT |
| | rounds = _DEFAULT_ROUNDS |
| | if ( |
| | isinstance(encryption_algorithm, _KeySerializationEncryption) |
| | and encryption_algorithm._kdf_rounds is not None |
| | ): |
| | rounds = encryption_algorithm._kdf_rounds |
| | salt = os.urandom(16) |
| | f_kdfoptions.put_sshstr(salt) |
| | f_kdfoptions.put_u32(rounds) |
| | ciph = _init_cipher(ciphername, password, salt, rounds) |
| | else: |
| | ciphername = kdfname = _NONE |
| | blklen = 8 |
| | ciph = None |
| | nkeys = 1 |
| | checkval = os.urandom(4) |
| | comment = b"" |
| |
|
| | |
| | f_public_key = _FragList() |
| | f_public_key.put_sshstr(key_type) |
| | kformat.encode_public(private_key.public_key(), f_public_key) |
| |
|
| | f_secrets = _FragList([checkval, checkval]) |
| | f_secrets.put_sshstr(key_type) |
| | kformat.encode_private(private_key, f_secrets) |
| | f_secrets.put_sshstr(comment) |
| | f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) |
| |
|
| | |
| | f_main = _FragList() |
| | f_main.put_raw(_SK_MAGIC) |
| | f_main.put_sshstr(ciphername) |
| | f_main.put_sshstr(kdfname) |
| | f_main.put_sshstr(f_kdfoptions) |
| | f_main.put_u32(nkeys) |
| | f_main.put_sshstr(f_public_key) |
| | f_main.put_sshstr(f_secrets) |
| |
|
| | |
| | slen = f_secrets.size() |
| | mlen = f_main.size() |
| | buf = memoryview(bytearray(mlen + blklen)) |
| | f_main.render(buf) |
| | ofs = mlen - slen |
| |
|
| | |
| | if ciph is not None: |
| | ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) |
| |
|
| | return _ssh_pem_encode(buf[:mlen]) |
| |
|
| |
|
| | SSHPublicKeyTypes = typing.Union[ |
| | ec.EllipticCurvePublicKey, |
| | rsa.RSAPublicKey, |
| | dsa.DSAPublicKey, |
| | ed25519.Ed25519PublicKey, |
| | ] |
| |
|
| | SSHCertPublicKeyTypes = typing.Union[ |
| | ec.EllipticCurvePublicKey, |
| | rsa.RSAPublicKey, |
| | ed25519.Ed25519PublicKey, |
| | ] |
| |
|
| |
|
| | class SSHCertificateType(enum.Enum): |
| | USER = 1 |
| | HOST = 2 |
| |
|
| |
|
| | class SSHCertificate: |
| | def __init__( |
| | self, |
| | _nonce: memoryview, |
| | _public_key: SSHPublicKeyTypes, |
| | _serial: int, |
| | _cctype: int, |
| | _key_id: memoryview, |
| | _valid_principals: list[bytes], |
| | _valid_after: int, |
| | _valid_before: int, |
| | _critical_options: dict[bytes, bytes], |
| | _extensions: dict[bytes, bytes], |
| | _sig_type: memoryview, |
| | _sig_key: memoryview, |
| | _inner_sig_type: memoryview, |
| | _signature: memoryview, |
| | _tbs_cert_body: memoryview, |
| | _cert_key_type: bytes, |
| | _cert_body: memoryview, |
| | ): |
| | self._nonce = _nonce |
| | self._public_key = _public_key |
| | self._serial = _serial |
| | try: |
| | self._type = SSHCertificateType(_cctype) |
| | except ValueError: |
| | raise ValueError("Invalid certificate type") |
| | self._key_id = _key_id |
| | self._valid_principals = _valid_principals |
| | self._valid_after = _valid_after |
| | self._valid_before = _valid_before |
| | self._critical_options = _critical_options |
| | self._extensions = _extensions |
| | self._sig_type = _sig_type |
| | self._sig_key = _sig_key |
| | self._inner_sig_type = _inner_sig_type |
| | self._signature = _signature |
| | self._cert_key_type = _cert_key_type |
| | self._cert_body = _cert_body |
| | self._tbs_cert_body = _tbs_cert_body |
| |
|
| | @property |
| | def nonce(self) -> bytes: |
| | return bytes(self._nonce) |
| |
|
| | def public_key(self) -> SSHCertPublicKeyTypes: |
| | |
| | |
| | return typing.cast(SSHCertPublicKeyTypes, self._public_key) |
| |
|
| | @property |
| | def serial(self) -> int: |
| | return self._serial |
| |
|
| | @property |
| | def type(self) -> SSHCertificateType: |
| | return self._type |
| |
|
| | @property |
| | def key_id(self) -> bytes: |
| | return bytes(self._key_id) |
| |
|
| | @property |
| | def valid_principals(self) -> list[bytes]: |
| | return self._valid_principals |
| |
|
| | @property |
| | def valid_before(self) -> int: |
| | return self._valid_before |
| |
|
| | @property |
| | def valid_after(self) -> int: |
| | return self._valid_after |
| |
|
| | @property |
| | def critical_options(self) -> dict[bytes, bytes]: |
| | return self._critical_options |
| |
|
| | @property |
| | def extensions(self) -> dict[bytes, bytes]: |
| | return self._extensions |
| |
|
| | def signature_key(self) -> SSHCertPublicKeyTypes: |
| | sigformat = _lookup_kformat(self._sig_type) |
| | signature_key, sigkey_rest = sigformat.load_public(self._sig_key) |
| | _check_empty(sigkey_rest) |
| | return signature_key |
| |
|
| | def public_bytes(self) -> bytes: |
| | return ( |
| | bytes(self._cert_key_type) |
| | + b" " |
| | + binascii.b2a_base64(bytes(self._cert_body), newline=False) |
| | ) |
| |
|
| | def verify_cert_signature(self) -> None: |
| | signature_key = self.signature_key() |
| | if isinstance(signature_key, ed25519.Ed25519PublicKey): |
| | signature_key.verify( |
| | bytes(self._signature), bytes(self._tbs_cert_body) |
| | ) |
| | elif isinstance(signature_key, ec.EllipticCurvePublicKey): |
| | |
| | r, data = _get_mpint(self._signature) |
| | s, data = _get_mpint(data) |
| | _check_empty(data) |
| | computed_sig = asym_utils.encode_dss_signature(r, s) |
| | hash_alg = _get_ec_hash_alg(signature_key.curve) |
| | signature_key.verify( |
| | computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) |
| | ) |
| | else: |
| | assert isinstance(signature_key, rsa.RSAPublicKey) |
| | if self._inner_sig_type == _SSH_RSA: |
| | hash_alg = hashes.SHA1() |
| | elif self._inner_sig_type == _SSH_RSA_SHA256: |
| | hash_alg = hashes.SHA256() |
| | else: |
| | assert self._inner_sig_type == _SSH_RSA_SHA512 |
| | hash_alg = hashes.SHA512() |
| | signature_key.verify( |
| | bytes(self._signature), |
| | bytes(self._tbs_cert_body), |
| | padding.PKCS1v15(), |
| | hash_alg, |
| | ) |
| |
|
| |
|
| | def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: |
| | if isinstance(curve, ec.SECP256R1): |
| | return hashes.SHA256() |
| | elif isinstance(curve, ec.SECP384R1): |
| | return hashes.SHA384() |
| | else: |
| | assert isinstance(curve, ec.SECP521R1) |
| | return hashes.SHA512() |
| |
|
| |
|
| | def _load_ssh_public_identity( |
| | data: utils.Buffer, |
| | _legacy_dsa_allowed=False, |
| | ) -> SSHCertificate | SSHPublicKeyTypes: |
| | utils._check_byteslike("data", data) |
| |
|
| | m = _SSH_PUBKEY_RC.match(data) |
| | if not m: |
| | raise ValueError("Invalid line format") |
| | key_type = orig_key_type = m.group(1) |
| | key_body = m.group(2) |
| | with_cert = False |
| | if key_type.endswith(_CERT_SUFFIX): |
| | with_cert = True |
| | key_type = key_type[: -len(_CERT_SUFFIX)] |
| | if key_type == _SSH_DSA and not _legacy_dsa_allowed: |
| | raise UnsupportedAlgorithm( |
| | "DSA keys aren't supported in SSH certificates" |
| | ) |
| | kformat = _lookup_kformat(key_type) |
| |
|
| | try: |
| | rest = memoryview(binascii.a2b_base64(key_body)) |
| | except (TypeError, binascii.Error): |
| | raise ValueError("Invalid format") |
| |
|
| | if with_cert: |
| | cert_body = rest |
| | inner_key_type, rest = _get_sshstr(rest) |
| | if inner_key_type != orig_key_type: |
| | raise ValueError("Invalid key format") |
| | if with_cert: |
| | nonce, rest = _get_sshstr(rest) |
| | public_key, rest = kformat.load_public(rest) |
| | if with_cert: |
| | serial, rest = _get_u64(rest) |
| | cctype, rest = _get_u32(rest) |
| | key_id, rest = _get_sshstr(rest) |
| | principals, rest = _get_sshstr(rest) |
| | valid_principals = [] |
| | while principals: |
| | principal, principals = _get_sshstr(principals) |
| | valid_principals.append(bytes(principal)) |
| | valid_after, rest = _get_u64(rest) |
| | valid_before, rest = _get_u64(rest) |
| | crit_options, rest = _get_sshstr(rest) |
| | critical_options = _parse_exts_opts(crit_options) |
| | exts, rest = _get_sshstr(rest) |
| | extensions = _parse_exts_opts(exts) |
| | |
| | _, rest = _get_sshstr(rest) |
| | sig_key_raw, rest = _get_sshstr(rest) |
| | sig_type, sig_key = _get_sshstr(sig_key_raw) |
| | if sig_type == _SSH_DSA and not _legacy_dsa_allowed: |
| | raise UnsupportedAlgorithm( |
| | "DSA signatures aren't supported in SSH certificates" |
| | ) |
| | |
| | tbs_cert_body = cert_body[: -len(rest)] |
| | signature_raw, rest = _get_sshstr(rest) |
| | _check_empty(rest) |
| | inner_sig_type, sig_rest = _get_sshstr(signature_raw) |
| | |
| | if ( |
| | sig_type == _SSH_RSA |
| | and inner_sig_type |
| | not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] |
| | ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): |
| | raise ValueError("Signature key type does not match") |
| | signature, sig_rest = _get_sshstr(sig_rest) |
| | _check_empty(sig_rest) |
| | return SSHCertificate( |
| | nonce, |
| | public_key, |
| | serial, |
| | cctype, |
| | key_id, |
| | valid_principals, |
| | valid_after, |
| | valid_before, |
| | critical_options, |
| | extensions, |
| | sig_type, |
| | sig_key, |
| | inner_sig_type, |
| | signature, |
| | tbs_cert_body, |
| | orig_key_type, |
| | cert_body, |
| | ) |
| | else: |
| | _check_empty(rest) |
| | return public_key |
| |
|
| |
|
| | def load_ssh_public_identity( |
| | data: utils.Buffer, |
| | ) -> SSHCertificate | SSHPublicKeyTypes: |
| | return _load_ssh_public_identity(data) |
| |
|
| |
|
| | def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: |
| | result: dict[bytes, bytes] = {} |
| | last_name = None |
| | while exts_opts: |
| | name, exts_opts = _get_sshstr(exts_opts) |
| | bname: bytes = bytes(name) |
| | if bname in result: |
| | raise ValueError("Duplicate name") |
| | if last_name is not None and bname < last_name: |
| | raise ValueError("Fields not lexically sorted") |
| | value, exts_opts = _get_sshstr(exts_opts) |
| | if len(value) > 0: |
| | value, extra = _get_sshstr(value) |
| | if len(extra) > 0: |
| | raise ValueError("Unexpected extra data after value") |
| | result[bname] = bytes(value) |
| | last_name = bname |
| | return result |
| |
|
| |
|
| | def ssh_key_fingerprint( |
| | key: SSHPublicKeyTypes, |
| | hash_algorithm: hashes.MD5 | hashes.SHA256, |
| | ) -> bytes: |
| | if not isinstance(hash_algorithm, (hashes.MD5, hashes.SHA256)): |
| | raise TypeError("hash_algorithm must be either MD5 or SHA256") |
| |
|
| | key_type = _get_ssh_key_type(key) |
| | kformat = _lookup_kformat(key_type) |
| |
|
| | f_pub = _FragList() |
| | f_pub.put_sshstr(key_type) |
| | kformat.encode_public(key, f_pub) |
| |
|
| | ssh_binary_data = f_pub.tobytes() |
| |
|
| | |
| | hash_obj = hashes.Hash(hash_algorithm) |
| | hash_obj.update(ssh_binary_data) |
| | return hash_obj.finalize() |
| |
|
| |
|
| | def load_ssh_public_key( |
| | data: utils.Buffer, backend: typing.Any = None |
| | ) -> SSHPublicKeyTypes: |
| | cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) |
| | public_key: SSHPublicKeyTypes |
| | if isinstance(cert_or_key, SSHCertificate): |
| | public_key = cert_or_key.public_key() |
| | else: |
| | public_key = cert_or_key |
| |
|
| | if isinstance(public_key, dsa.DSAPublicKey): |
| | warnings.warn( |
| | "SSH DSA keys are deprecated and will be removed in a future " |
| | "release.", |
| | utils.DeprecatedIn40, |
| | stacklevel=2, |
| | ) |
| | return public_key |
| |
|
| |
|
| | def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: |
| | """One-line public key format for OpenSSH""" |
| | if isinstance(public_key, dsa.DSAPublicKey): |
| | warnings.warn( |
| | "SSH DSA key support is deprecated and will be " |
| | "removed in a future release", |
| | utils.DeprecatedIn40, |
| | stacklevel=4, |
| | ) |
| | key_type = _get_ssh_key_type(public_key) |
| | kformat = _lookup_kformat(key_type) |
| |
|
| | f_pub = _FragList() |
| | f_pub.put_sshstr(key_type) |
| | kformat.encode_public(public_key, f_pub) |
| |
|
| | pub = binascii.b2a_base64(f_pub.tobytes()).strip() |
| | return b"".join([key_type, b" ", pub]) |
| |
|
| |
|
| | SSHCertPrivateKeyTypes = typing.Union[ |
| | ec.EllipticCurvePrivateKey, |
| | rsa.RSAPrivateKey, |
| | ed25519.Ed25519PrivateKey, |
| | ] |
| |
|
| |
|
| | |
| | |
| | _SSHKEY_CERT_MAX_PRINCIPALS = 256 |
| |
|
| |
|
| | class SSHCertificateBuilder: |
| | def __init__( |
| | self, |
| | _public_key: SSHCertPublicKeyTypes | None = None, |
| | _serial: int | None = None, |
| | _type: SSHCertificateType | None = None, |
| | _key_id: bytes | None = None, |
| | _valid_principals: list[bytes] = [], |
| | _valid_for_all_principals: bool = False, |
| | _valid_before: int | None = None, |
| | _valid_after: int | None = None, |
| | _critical_options: list[tuple[bytes, bytes]] = [], |
| | _extensions: list[tuple[bytes, bytes]] = [], |
| | ): |
| | self._public_key = _public_key |
| | self._serial = _serial |
| | self._type = _type |
| | self._key_id = _key_id |
| | self._valid_principals = _valid_principals |
| | self._valid_for_all_principals = _valid_for_all_principals |
| | self._valid_before = _valid_before |
| | self._valid_after = _valid_after |
| | self._critical_options = _critical_options |
| | self._extensions = _extensions |
| |
|
| | def public_key( |
| | self, public_key: SSHCertPublicKeyTypes |
| | ) -> SSHCertificateBuilder: |
| | if not isinstance( |
| | public_key, |
| | ( |
| | ec.EllipticCurvePublicKey, |
| | rsa.RSAPublicKey, |
| | ed25519.Ed25519PublicKey, |
| | ), |
| | ): |
| | raise TypeError("Unsupported key type") |
| | if self._public_key is not None: |
| | raise ValueError("public_key already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def serial(self, serial: int) -> SSHCertificateBuilder: |
| | if not isinstance(serial, int): |
| | raise TypeError("serial must be an integer") |
| | if not 0 <= serial < 2**64: |
| | raise ValueError("serial must be between 0 and 2**64") |
| | if self._serial is not None: |
| | raise ValueError("serial already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: |
| | if not isinstance(type, SSHCertificateType): |
| | raise TypeError("type must be an SSHCertificateType") |
| | if self._type is not None: |
| | raise ValueError("type already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def key_id(self, key_id: bytes) -> SSHCertificateBuilder: |
| | if not isinstance(key_id, bytes): |
| | raise TypeError("key_id must be bytes") |
| | if self._key_id is not None: |
| | raise ValueError("key_id already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def valid_principals( |
| | self, valid_principals: list[bytes] |
| | ) -> SSHCertificateBuilder: |
| | if self._valid_for_all_principals: |
| | raise ValueError( |
| | "Principals can't be set because the cert is valid " |
| | "for all principals" |
| | ) |
| | if ( |
| | not all(isinstance(x, bytes) for x in valid_principals) |
| | or not valid_principals |
| | ): |
| | raise TypeError( |
| | "principals must be a list of bytes and can't be empty" |
| | ) |
| | if self._valid_principals: |
| | raise ValueError("valid_principals already set") |
| |
|
| | if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: |
| | raise ValueError( |
| | "Reached or exceeded the maximum number of valid_principals" |
| | ) |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def valid_for_all_principals(self): |
| | if self._valid_principals: |
| | raise ValueError( |
| | "valid_principals already set, can't set " |
| | "valid_for_all_principals" |
| | ) |
| | if self._valid_for_all_principals: |
| | raise ValueError("valid_for_all_principals already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=True, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: |
| | if not isinstance(valid_before, (int, float)): |
| | raise TypeError("valid_before must be an int or float") |
| | valid_before = int(valid_before) |
| | if valid_before < 0 or valid_before >= 2**64: |
| | raise ValueError("valid_before must [0, 2**64)") |
| | if self._valid_before is not None: |
| | raise ValueError("valid_before already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: |
| | if not isinstance(valid_after, (int, float)): |
| | raise TypeError("valid_after must be an int or float") |
| | valid_after = int(valid_after) |
| | if valid_after < 0 or valid_after >= 2**64: |
| | raise ValueError("valid_after must [0, 2**64)") |
| | if self._valid_after is not None: |
| | raise ValueError("valid_after already set") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def add_critical_option( |
| | self, name: bytes, value: bytes |
| | ) -> SSHCertificateBuilder: |
| | if not isinstance(name, bytes) or not isinstance(value, bytes): |
| | raise TypeError("name and value must be bytes") |
| | |
| | if name in [name for name, _ in self._critical_options]: |
| | raise ValueError("Duplicate critical option name") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=[*self._critical_options, (name, value)], |
| | _extensions=self._extensions, |
| | ) |
| |
|
| | def add_extension( |
| | self, name: bytes, value: bytes |
| | ) -> SSHCertificateBuilder: |
| | if not isinstance(name, bytes) or not isinstance(value, bytes): |
| | raise TypeError("name and value must be bytes") |
| | |
| | if name in [name for name, _ in self._extensions]: |
| | raise ValueError("Duplicate extension name") |
| |
|
| | return SSHCertificateBuilder( |
| | _public_key=self._public_key, |
| | _serial=self._serial, |
| | _type=self._type, |
| | _key_id=self._key_id, |
| | _valid_principals=self._valid_principals, |
| | _valid_for_all_principals=self._valid_for_all_principals, |
| | _valid_before=self._valid_before, |
| | _valid_after=self._valid_after, |
| | _critical_options=self._critical_options, |
| | _extensions=[*self._extensions, (name, value)], |
| | ) |
| |
|
| | def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: |
| | if not isinstance( |
| | private_key, |
| | ( |
| | ec.EllipticCurvePrivateKey, |
| | rsa.RSAPrivateKey, |
| | ed25519.Ed25519PrivateKey, |
| | ), |
| | ): |
| | raise TypeError("Unsupported private key type") |
| |
|
| | if self._public_key is None: |
| | raise ValueError("public_key must be set") |
| |
|
| | |
| | serial = 0 if self._serial is None else self._serial |
| |
|
| | if self._type is None: |
| | raise ValueError("type must be set") |
| |
|
| | |
| | key_id = b"" if self._key_id is None else self._key_id |
| |
|
| | |
| | |
| | |
| | |
| | if not self._valid_principals and not self._valid_for_all_principals: |
| | raise ValueError( |
| | "valid_principals must be set if valid_for_all_principals " |
| | "is False" |
| | ) |
| |
|
| | if self._valid_before is None: |
| | raise ValueError("valid_before must be set") |
| |
|
| | if self._valid_after is None: |
| | raise ValueError("valid_after must be set") |
| |
|
| | if self._valid_after > self._valid_before: |
| | raise ValueError("valid_after must be earlier than valid_before") |
| |
|
| | |
| | self._critical_options.sort(key=lambda x: x[0]) |
| | self._extensions.sort(key=lambda x: x[0]) |
| |
|
| | key_type = _get_ssh_key_type(self._public_key) |
| | cert_prefix = key_type + _CERT_SUFFIX |
| |
|
| | |
| | nonce = os.urandom(32) |
| | kformat = _lookup_kformat(key_type) |
| | f = _FragList() |
| | f.put_sshstr(cert_prefix) |
| | f.put_sshstr(nonce) |
| | kformat.encode_public(self._public_key, f) |
| | f.put_u64(serial) |
| | f.put_u32(self._type.value) |
| | f.put_sshstr(key_id) |
| | fprincipals = _FragList() |
| | for p in self._valid_principals: |
| | fprincipals.put_sshstr(p) |
| | f.put_sshstr(fprincipals.tobytes()) |
| | f.put_u64(self._valid_after) |
| | f.put_u64(self._valid_before) |
| | fcrit = _FragList() |
| | for name, value in self._critical_options: |
| | fcrit.put_sshstr(name) |
| | if len(value) > 0: |
| | foptval = _FragList() |
| | foptval.put_sshstr(value) |
| | fcrit.put_sshstr(foptval.tobytes()) |
| | else: |
| | fcrit.put_sshstr(value) |
| | f.put_sshstr(fcrit.tobytes()) |
| | fext = _FragList() |
| | for name, value in self._extensions: |
| | fext.put_sshstr(name) |
| | if len(value) > 0: |
| | fextval = _FragList() |
| | fextval.put_sshstr(value) |
| | fext.put_sshstr(fextval.tobytes()) |
| | else: |
| | fext.put_sshstr(value) |
| | f.put_sshstr(fext.tobytes()) |
| | f.put_sshstr(b"") |
| | |
| | ca_type = _get_ssh_key_type(private_key) |
| | caformat = _lookup_kformat(ca_type) |
| | caf = _FragList() |
| | caf.put_sshstr(ca_type) |
| | caformat.encode_public(private_key.public_key(), caf) |
| | f.put_sshstr(caf.tobytes()) |
| | |
| | |
| | |
| | if isinstance(private_key, ed25519.Ed25519PrivateKey): |
| | signature = private_key.sign(f.tobytes()) |
| | fsig = _FragList() |
| | fsig.put_sshstr(ca_type) |
| | fsig.put_sshstr(signature) |
| | f.put_sshstr(fsig.tobytes()) |
| | elif isinstance(private_key, ec.EllipticCurvePrivateKey): |
| | hash_alg = _get_ec_hash_alg(private_key.curve) |
| | signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) |
| | r, s = asym_utils.decode_dss_signature(signature) |
| | fsig = _FragList() |
| | fsig.put_sshstr(ca_type) |
| | fsigblob = _FragList() |
| | fsigblob.put_mpint(r) |
| | fsigblob.put_mpint(s) |
| | fsig.put_sshstr(fsigblob.tobytes()) |
| | f.put_sshstr(fsig.tobytes()) |
| |
|
| | else: |
| | assert isinstance(private_key, rsa.RSAPrivateKey) |
| | |
| | |
| | |
| | |
| | fsig = _FragList() |
| | fsig.put_sshstr(_SSH_RSA_SHA512) |
| | signature = private_key.sign( |
| | f.tobytes(), padding.PKCS1v15(), hashes.SHA512() |
| | ) |
| | fsig.put_sshstr(signature) |
| | f.put_sshstr(fsig.tobytes()) |
| |
|
| | cert_data = binascii.b2a_base64(f.tobytes()).strip() |
| | |
| | |
| | |
| | return typing.cast( |
| | SSHCertificate, |
| | load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), |
| | ) |
| |
|