Spaces:
Paused
Paused
| from cryptography.hazmat.primitives.asymmetric import rsa | |
| from cryptography.hazmat.primitives.asymmetric.rsa import ( | |
| RSAPublicKey, RSAPrivateKeyWithSerialization, | |
| RSAPrivateNumbers, RSAPublicNumbers, | |
| rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp | |
| ) | |
| from cryptography.hazmat.backends import default_backend | |
| from authlib.common.encoding import base64_to_int, int_to_base64 | |
| from ..rfc7517 import AsymmetricKey | |
| class RSAKey(AsymmetricKey): | |
| """Key class of the ``RSA`` key type.""" | |
| kty = 'RSA' | |
| PUBLIC_KEY_CLS = RSAPublicKey | |
| PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization | |
| PUBLIC_KEY_FIELDS = ['e', 'n'] | |
| PRIVATE_KEY_FIELDS = ['d', 'dp', 'dq', 'e', 'n', 'p', 'q', 'qi'] | |
| REQUIRED_JSON_FIELDS = ['e', 'n'] | |
| SSH_PUBLIC_PREFIX = b'ssh-rsa' | |
| def dumps_private_key(self): | |
| numbers = self.private_key.private_numbers() | |
| return { | |
| 'n': int_to_base64(numbers.public_numbers.n), | |
| 'e': int_to_base64(numbers.public_numbers.e), | |
| 'd': int_to_base64(numbers.d), | |
| 'p': int_to_base64(numbers.p), | |
| 'q': int_to_base64(numbers.q), | |
| 'dp': int_to_base64(numbers.dmp1), | |
| 'dq': int_to_base64(numbers.dmq1), | |
| 'qi': int_to_base64(numbers.iqmp) | |
| } | |
| def dumps_public_key(self): | |
| numbers = self.public_key.public_numbers() | |
| return { | |
| 'n': int_to_base64(numbers.n), | |
| 'e': int_to_base64(numbers.e) | |
| } | |
| def load_private_key(self): | |
| obj = self._dict_data | |
| if 'oth' in obj: # pragma: no cover | |
| # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 | |
| raise ValueError('"oth" is not supported yet') | |
| public_numbers = RSAPublicNumbers( | |
| base64_to_int(obj['e']), base64_to_int(obj['n'])) | |
| if has_all_prime_factors(obj): | |
| numbers = RSAPrivateNumbers( | |
| d=base64_to_int(obj['d']), | |
| p=base64_to_int(obj['p']), | |
| q=base64_to_int(obj['q']), | |
| dmp1=base64_to_int(obj['dp']), | |
| dmq1=base64_to_int(obj['dq']), | |
| iqmp=base64_to_int(obj['qi']), | |
| public_numbers=public_numbers) | |
| else: | |
| d = base64_to_int(obj['d']) | |
| p, q = rsa_recover_prime_factors( | |
| public_numbers.n, d, public_numbers.e) | |
| numbers = RSAPrivateNumbers( | |
| d=d, | |
| p=p, | |
| q=q, | |
| dmp1=rsa_crt_dmp1(d, p), | |
| dmq1=rsa_crt_dmq1(d, q), | |
| iqmp=rsa_crt_iqmp(p, q), | |
| public_numbers=public_numbers) | |
| return numbers.private_key(default_backend()) | |
| def load_public_key(self): | |
| numbers = RSAPublicNumbers( | |
| base64_to_int(self._dict_data['e']), | |
| base64_to_int(self._dict_data['n']) | |
| ) | |
| return numbers.public_key(default_backend()) | |
| def generate_key(cls, key_size=2048, options=None, is_private=False) -> 'RSAKey': | |
| if key_size < 512: | |
| raise ValueError('key_size must not be less than 512') | |
| if key_size % 8 != 0: | |
| raise ValueError('Invalid key_size for RSAKey') | |
| raw_key = rsa.generate_private_key( | |
| public_exponent=65537, | |
| key_size=key_size, | |
| backend=default_backend(), | |
| ) | |
| if not is_private: | |
| raw_key = raw_key.public_key() | |
| return cls.import_key(raw_key, options=options) | |
| def import_dict_key(cls, raw, options=None): | |
| cls.check_required_fields(raw) | |
| key = cls(options=options) | |
| key._dict_data = raw | |
| if 'd' in raw and not has_all_prime_factors(raw): | |
| # reload dict key | |
| key.load_raw_key() | |
| key.load_dict_key() | |
| return key | |
| def has_all_prime_factors(obj): | |
| props = ['p', 'q', 'dp', 'dq', 'qi'] | |
| props_found = [prop in obj for prop in props] | |
| if all(props_found): | |
| return True | |
| if any(props_found): | |
| raise ValueError( | |
| 'RSA key must include all parameters ' | |
| 'if any are present besides d') | |
| return False | |