davidtran999 commited on
Commit
d25a889
·
verified ·
1 Parent(s): 7b20742

Upload backend/venv/lib/python3.10/site-packages/jwt/algorithms.py with huggingface_hub

Browse files
backend/venv/lib/python3.10/site-packages/jwt/algorithms.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import hmac
5
+ import json
6
+ from abc import ABC, abstractmethod
7
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
8
+
9
+ from .exceptions import InvalidKeyError
10
+ from .types import HashlibHash, JWKDict
11
+ from .utils import (
12
+ base64url_decode,
13
+ base64url_encode,
14
+ der_to_raw_signature,
15
+ force_bytes,
16
+ from_base64url_uint,
17
+ is_pem_format,
18
+ is_ssh_key,
19
+ raw_to_der_signature,
20
+ to_base64url_uint,
21
+ )
22
+
23
+ try:
24
+ from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
25
+ from cryptography.hazmat.backends import default_backend
26
+ from cryptography.hazmat.primitives import hashes
27
+ from cryptography.hazmat.primitives.asymmetric import padding
28
+ from cryptography.hazmat.primitives.asymmetric.ec import (
29
+ ECDSA,
30
+ SECP256K1,
31
+ SECP256R1,
32
+ SECP384R1,
33
+ SECP521R1,
34
+ EllipticCurve,
35
+ EllipticCurvePrivateKey,
36
+ EllipticCurvePrivateNumbers,
37
+ EllipticCurvePublicKey,
38
+ EllipticCurvePublicNumbers,
39
+ )
40
+ from cryptography.hazmat.primitives.asymmetric.ed448 import (
41
+ Ed448PrivateKey,
42
+ Ed448PublicKey,
43
+ )
44
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import (
45
+ Ed25519PrivateKey,
46
+ Ed25519PublicKey,
47
+ )
48
+ from cryptography.hazmat.primitives.asymmetric.rsa import (
49
+ RSAPrivateKey,
50
+ RSAPrivateNumbers,
51
+ RSAPublicKey,
52
+ RSAPublicNumbers,
53
+ rsa_crt_dmp1,
54
+ rsa_crt_dmq1,
55
+ rsa_crt_iqmp,
56
+ rsa_recover_prime_factors,
57
+ )
58
+ from cryptography.hazmat.primitives.serialization import (
59
+ Encoding,
60
+ NoEncryption,
61
+ PrivateFormat,
62
+ PublicFormat,
63
+ load_pem_private_key,
64
+ load_pem_public_key,
65
+ load_ssh_public_key,
66
+ )
67
+
68
+ has_crypto = True
69
+ except ModuleNotFoundError:
70
+ has_crypto = False
71
+
72
+
73
+ if TYPE_CHECKING:
74
+ # Type aliases for convenience in algorithms method signatures
75
+ AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
76
+ AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
77
+ AllowedOKPKeys = (
78
+ Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
79
+ )
80
+ AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
81
+ AllowedPrivateKeys = (
82
+ RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
83
+ )
84
+ AllowedPublicKeys = (
85
+ RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
86
+ )
87
+
88
+
89
+ requires_cryptography = {
90
+ "RS256",
91
+ "RS384",
92
+ "RS512",
93
+ "ES256",
94
+ "ES256K",
95
+ "ES384",
96
+ "ES521",
97
+ "ES512",
98
+ "PS256",
99
+ "PS384",
100
+ "PS512",
101
+ "EdDSA",
102
+ }
103
+
104
+
105
+ def get_default_algorithms() -> dict[str, Algorithm]:
106
+ """
107
+ Returns the algorithms that are implemented by the library.
108
+ """
109
+ default_algorithms = {
110
+ "none": NoneAlgorithm(),
111
+ "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
112
+ "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
113
+ "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
114
+ }
115
+
116
+ if has_crypto:
117
+ default_algorithms.update(
118
+ {
119
+ "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
120
+ "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
121
+ "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
122
+ "ES256": ECAlgorithm(ECAlgorithm.SHA256),
123
+ "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
124
+ "ES384": ECAlgorithm(ECAlgorithm.SHA384),
125
+ "ES521": ECAlgorithm(ECAlgorithm.SHA512),
126
+ "ES512": ECAlgorithm(
127
+ ECAlgorithm.SHA512
128
+ ), # Backward compat for #219 fix
129
+ "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
130
+ "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
131
+ "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
132
+ "EdDSA": OKPAlgorithm(),
133
+ }
134
+ )
135
+
136
+ return default_algorithms
137
+
138
+
139
+ class Algorithm(ABC):
140
+ """
141
+ The interface for an algorithm used to sign and verify tokens.
142
+ """
143
+
144
+ def compute_hash_digest(self, bytestr: bytes) -> bytes:
145
+ """
146
+ Compute a hash digest using the specified algorithm's hash algorithm.
147
+
148
+ If there is no hash algorithm, raises a NotImplementedError.
149
+ """
150
+ # lookup self.hash_alg if defined in a way that mypy can understand
151
+ hash_alg = getattr(self, "hash_alg", None)
152
+ if hash_alg is None:
153
+ raise NotImplementedError
154
+
155
+ if (
156
+ has_crypto
157
+ and isinstance(hash_alg, type)
158
+ and issubclass(hash_alg, hashes.HashAlgorithm)
159
+ ):
160
+ digest = hashes.Hash(hash_alg(), backend=default_backend())
161
+ digest.update(bytestr)
162
+ return bytes(digest.finalize())
163
+ else:
164
+ return bytes(hash_alg(bytestr).digest())
165
+
166
+ @abstractmethod
167
+ def prepare_key(self, key: Any) -> Any:
168
+ """
169
+ Performs necessary validation and conversions on the key and returns
170
+ the key value in the proper format for sign() and verify().
171
+ """
172
+
173
+ @abstractmethod
174
+ def sign(self, msg: bytes, key: Any) -> bytes:
175
+ """
176
+ Returns a digital signature for the specified message
177
+ using the specified key value.
178
+ """
179
+
180
+ @abstractmethod
181
+ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
182
+ """
183
+ Verifies that the specified digital signature is valid
184
+ for the specified message and key values.
185
+ """
186
+
187
+ @overload
188
+ @staticmethod
189
+ @abstractmethod
190
+ def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
191
+
192
+ @overload
193
+ @staticmethod
194
+ @abstractmethod
195
+ def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
196
+
197
+ @staticmethod
198
+ @abstractmethod
199
+ def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
200
+ """
201
+ Serializes a given key into a JWK
202
+ """
203
+
204
+ @staticmethod
205
+ @abstractmethod
206
+ def from_jwk(jwk: str | JWKDict) -> Any:
207
+ """
208
+ Deserializes a given key from JWK back into a key object
209
+ """
210
+
211
+
212
+ class NoneAlgorithm(Algorithm):
213
+ """
214
+ Placeholder for use when no signing or verification
215
+ operations are required.
216
+ """
217
+
218
+ def prepare_key(self, key: str | None) -> None:
219
+ if key == "":
220
+ key = None
221
+
222
+ if key is not None:
223
+ raise InvalidKeyError('When alg = "none", key value must be None.')
224
+
225
+ return key
226
+
227
+ def sign(self, msg: bytes, key: None) -> bytes:
228
+ return b""
229
+
230
+ def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
231
+ return False
232
+
233
+ @staticmethod
234
+ def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
235
+ raise NotImplementedError()
236
+
237
+ @staticmethod
238
+ def from_jwk(jwk: str | JWKDict) -> NoReturn:
239
+ raise NotImplementedError()
240
+
241
+
242
+ class HMACAlgorithm(Algorithm):
243
+ """
244
+ Performs signing and verification operations using HMAC
245
+ and the specified hash function.
246
+ """
247
+
248
+ SHA256: ClassVar[HashlibHash] = hashlib.sha256
249
+ SHA384: ClassVar[HashlibHash] = hashlib.sha384
250
+ SHA512: ClassVar[HashlibHash] = hashlib.sha512
251
+
252
+ def __init__(self, hash_alg: HashlibHash) -> None:
253
+ self.hash_alg = hash_alg
254
+
255
+ def prepare_key(self, key: str | bytes) -> bytes:
256
+ key_bytes = force_bytes(key)
257
+
258
+ if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
259
+ raise InvalidKeyError(
260
+ "The specified key is an asymmetric key or x509 certificate and"
261
+ " should not be used as an HMAC secret."
262
+ )
263
+
264
+ return key_bytes
265
+
266
+ @overload
267
+ @staticmethod
268
+ def to_jwk(
269
+ key_obj: str | bytes, as_dict: Literal[True]
270
+ ) -> JWKDict: ... # pragma: no cover
271
+
272
+ @overload
273
+ @staticmethod
274
+ def to_jwk(
275
+ key_obj: str | bytes, as_dict: Literal[False] = False
276
+ ) -> str: ... # pragma: no cover
277
+
278
+ @staticmethod
279
+ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
280
+ jwk = {
281
+ "k": base64url_encode(force_bytes(key_obj)).decode(),
282
+ "kty": "oct",
283
+ }
284
+
285
+ if as_dict:
286
+ return jwk
287
+ else:
288
+ return json.dumps(jwk)
289
+
290
+ @staticmethod
291
+ def from_jwk(jwk: str | JWKDict) -> bytes:
292
+ try:
293
+ if isinstance(jwk, str):
294
+ obj: JWKDict = json.loads(jwk)
295
+ elif isinstance(jwk, dict):
296
+ obj = jwk
297
+ else:
298
+ raise ValueError
299
+ except ValueError:
300
+ raise InvalidKeyError("Key is not valid JSON") from None
301
+
302
+ if obj.get("kty") != "oct":
303
+ raise InvalidKeyError("Not an HMAC key")
304
+
305
+ return base64url_decode(obj["k"])
306
+
307
+ def sign(self, msg: bytes, key: bytes) -> bytes:
308
+ return hmac.new(key, msg, self.hash_alg).digest()
309
+
310
+ def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
311
+ return hmac.compare_digest(sig, self.sign(msg, key))
312
+
313
+
314
+ if has_crypto:
315
+
316
+ class RSAAlgorithm(Algorithm):
317
+ """
318
+ Performs signing and verification operations using
319
+ RSASSA-PKCS-v1_5 and the specified hash function.
320
+ """
321
+
322
+ SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
323
+ SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
324
+ SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
325
+
326
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
327
+ self.hash_alg = hash_alg
328
+
329
+ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
330
+ if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
331
+ return key
332
+
333
+ if not isinstance(key, (bytes, str)):
334
+ raise TypeError("Expecting a PEM-formatted key.")
335
+
336
+ key_bytes = force_bytes(key)
337
+
338
+ try:
339
+ if key_bytes.startswith(b"ssh-rsa"):
340
+ return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
341
+ else:
342
+ return cast(
343
+ RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
344
+ )
345
+ except ValueError:
346
+ try:
347
+ return cast(RSAPublicKey, load_pem_public_key(key_bytes))
348
+ except (ValueError, UnsupportedAlgorithm):
349
+ raise InvalidKeyError(
350
+ "Could not parse the provided public key."
351
+ ) from None
352
+
353
+ @overload
354
+ @staticmethod
355
+ def to_jwk(
356
+ key_obj: AllowedRSAKeys, as_dict: Literal[True]
357
+ ) -> JWKDict: ... # pragma: no cover
358
+
359
+ @overload
360
+ @staticmethod
361
+ def to_jwk(
362
+ key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
363
+ ) -> str: ... # pragma: no cover
364
+
365
+ @staticmethod
366
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
367
+ obj: dict[str, Any] | None = None
368
+
369
+ if hasattr(key_obj, "private_numbers"):
370
+ # Private key
371
+ numbers = key_obj.private_numbers()
372
+
373
+ obj = {
374
+ "kty": "RSA",
375
+ "key_ops": ["sign"],
376
+ "n": to_base64url_uint(numbers.public_numbers.n).decode(),
377
+ "e": to_base64url_uint(numbers.public_numbers.e).decode(),
378
+ "d": to_base64url_uint(numbers.d).decode(),
379
+ "p": to_base64url_uint(numbers.p).decode(),
380
+ "q": to_base64url_uint(numbers.q).decode(),
381
+ "dp": to_base64url_uint(numbers.dmp1).decode(),
382
+ "dq": to_base64url_uint(numbers.dmq1).decode(),
383
+ "qi": to_base64url_uint(numbers.iqmp).decode(),
384
+ }
385
+
386
+ elif hasattr(key_obj, "verify"):
387
+ # Public key
388
+ numbers = key_obj.public_numbers()
389
+
390
+ obj = {
391
+ "kty": "RSA",
392
+ "key_ops": ["verify"],
393
+ "n": to_base64url_uint(numbers.n).decode(),
394
+ "e": to_base64url_uint(numbers.e).decode(),
395
+ }
396
+ else:
397
+ raise InvalidKeyError("Not a public or private key")
398
+
399
+ if as_dict:
400
+ return obj
401
+ else:
402
+ return json.dumps(obj)
403
+
404
+ @staticmethod
405
+ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
406
+ try:
407
+ if isinstance(jwk, str):
408
+ obj = json.loads(jwk)
409
+ elif isinstance(jwk, dict):
410
+ obj = jwk
411
+ else:
412
+ raise ValueError
413
+ except ValueError:
414
+ raise InvalidKeyError("Key is not valid JSON") from None
415
+
416
+ if obj.get("kty") != "RSA":
417
+ raise InvalidKeyError("Not an RSA key") from None
418
+
419
+ if "d" in obj and "e" in obj and "n" in obj:
420
+ # Private key
421
+ if "oth" in obj:
422
+ raise InvalidKeyError(
423
+ "Unsupported RSA private key: > 2 primes not supported"
424
+ )
425
+
426
+ other_props = ["p", "q", "dp", "dq", "qi"]
427
+ props_found = [prop in obj for prop in other_props]
428
+ any_props_found = any(props_found)
429
+
430
+ if any_props_found and not all(props_found):
431
+ raise InvalidKeyError(
432
+ "RSA key must include all parameters if any are present besides d"
433
+ ) from None
434
+
435
+ public_numbers = RSAPublicNumbers(
436
+ from_base64url_uint(obj["e"]),
437
+ from_base64url_uint(obj["n"]),
438
+ )
439
+
440
+ if any_props_found:
441
+ numbers = RSAPrivateNumbers(
442
+ d=from_base64url_uint(obj["d"]),
443
+ p=from_base64url_uint(obj["p"]),
444
+ q=from_base64url_uint(obj["q"]),
445
+ dmp1=from_base64url_uint(obj["dp"]),
446
+ dmq1=from_base64url_uint(obj["dq"]),
447
+ iqmp=from_base64url_uint(obj["qi"]),
448
+ public_numbers=public_numbers,
449
+ )
450
+ else:
451
+ d = from_base64url_uint(obj["d"])
452
+ p, q = rsa_recover_prime_factors(
453
+ public_numbers.n, d, public_numbers.e
454
+ )
455
+
456
+ numbers = RSAPrivateNumbers(
457
+ d=d,
458
+ p=p,
459
+ q=q,
460
+ dmp1=rsa_crt_dmp1(d, p),
461
+ dmq1=rsa_crt_dmq1(d, q),
462
+ iqmp=rsa_crt_iqmp(p, q),
463
+ public_numbers=public_numbers,
464
+ )
465
+
466
+ return numbers.private_key()
467
+ elif "n" in obj and "e" in obj:
468
+ # Public key
469
+ return RSAPublicNumbers(
470
+ from_base64url_uint(obj["e"]),
471
+ from_base64url_uint(obj["n"]),
472
+ ).public_key()
473
+ else:
474
+ raise InvalidKeyError("Not a public or private key")
475
+
476
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
477
+ return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
478
+
479
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
480
+ try:
481
+ key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
482
+ return True
483
+ except InvalidSignature:
484
+ return False
485
+
486
+ class ECAlgorithm(Algorithm):
487
+ """
488
+ Performs signing and verification operations using
489
+ ECDSA and the specified hash function
490
+ """
491
+
492
+ SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
493
+ SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
494
+ SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
495
+
496
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
497
+ self.hash_alg = hash_alg
498
+
499
+ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
500
+ if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
501
+ return key
502
+
503
+ if not isinstance(key, (bytes, str)):
504
+ raise TypeError("Expecting a PEM-formatted key.")
505
+
506
+ key_bytes = force_bytes(key)
507
+
508
+ # Attempt to load key. We don't know if it's
509
+ # a Signing Key or a Verifying Key, so we try
510
+ # the Verifying Key first.
511
+ try:
512
+ if key_bytes.startswith(b"ecdsa-sha2-"):
513
+ crypto_key = load_ssh_public_key(key_bytes)
514
+ else:
515
+ crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
516
+ except ValueError:
517
+ crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
518
+
519
+ # Explicit check the key to prevent confusing errors from cryptography
520
+ if not isinstance(
521
+ crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
522
+ ):
523
+ raise InvalidKeyError(
524
+ "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
525
+ ) from None
526
+
527
+ return crypto_key
528
+
529
+ def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
530
+ der_sig = key.sign(msg, ECDSA(self.hash_alg()))
531
+
532
+ return der_to_raw_signature(der_sig, key.curve)
533
+
534
+ def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
535
+ try:
536
+ der_sig = raw_to_der_signature(sig, key.curve)
537
+ except ValueError:
538
+ return False
539
+
540
+ try:
541
+ public_key = (
542
+ key.public_key()
543
+ if isinstance(key, EllipticCurvePrivateKey)
544
+ else key
545
+ )
546
+ public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
547
+ return True
548
+ except InvalidSignature:
549
+ return False
550
+
551
+ @overload
552
+ @staticmethod
553
+ def to_jwk(
554
+ key_obj: AllowedECKeys, as_dict: Literal[True]
555
+ ) -> JWKDict: ... # pragma: no cover
556
+
557
+ @overload
558
+ @staticmethod
559
+ def to_jwk(
560
+ key_obj: AllowedECKeys, as_dict: Literal[False] = False
561
+ ) -> str: ... # pragma: no cover
562
+
563
+ @staticmethod
564
+ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
565
+ if isinstance(key_obj, EllipticCurvePrivateKey):
566
+ public_numbers = key_obj.public_key().public_numbers()
567
+ elif isinstance(key_obj, EllipticCurvePublicKey):
568
+ public_numbers = key_obj.public_numbers()
569
+ else:
570
+ raise InvalidKeyError("Not a public or private key")
571
+
572
+ if isinstance(key_obj.curve, SECP256R1):
573
+ crv = "P-256"
574
+ elif isinstance(key_obj.curve, SECP384R1):
575
+ crv = "P-384"
576
+ elif isinstance(key_obj.curve, SECP521R1):
577
+ crv = "P-521"
578
+ elif isinstance(key_obj.curve, SECP256K1):
579
+ crv = "secp256k1"
580
+ else:
581
+ raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
582
+
583
+ obj: dict[str, Any] = {
584
+ "kty": "EC",
585
+ "crv": crv,
586
+ "x": to_base64url_uint(
587
+ public_numbers.x,
588
+ bit_length=key_obj.curve.key_size,
589
+ ).decode(),
590
+ "y": to_base64url_uint(
591
+ public_numbers.y,
592
+ bit_length=key_obj.curve.key_size,
593
+ ).decode(),
594
+ }
595
+
596
+ if isinstance(key_obj, EllipticCurvePrivateKey):
597
+ obj["d"] = to_base64url_uint(
598
+ key_obj.private_numbers().private_value,
599
+ bit_length=key_obj.curve.key_size,
600
+ ).decode()
601
+
602
+ if as_dict:
603
+ return obj
604
+ else:
605
+ return json.dumps(obj)
606
+
607
+ @staticmethod
608
+ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
609
+ try:
610
+ if isinstance(jwk, str):
611
+ obj = json.loads(jwk)
612
+ elif isinstance(jwk, dict):
613
+ obj = jwk
614
+ else:
615
+ raise ValueError
616
+ except ValueError:
617
+ raise InvalidKeyError("Key is not valid JSON") from None
618
+
619
+ if obj.get("kty") != "EC":
620
+ raise InvalidKeyError("Not an Elliptic curve key") from None
621
+
622
+ if "x" not in obj or "y" not in obj:
623
+ raise InvalidKeyError("Not an Elliptic curve key") from None
624
+
625
+ x = base64url_decode(obj.get("x"))
626
+ y = base64url_decode(obj.get("y"))
627
+
628
+ curve = obj.get("crv")
629
+ curve_obj: EllipticCurve
630
+
631
+ if curve == "P-256":
632
+ if len(x) == len(y) == 32:
633
+ curve_obj = SECP256R1()
634
+ else:
635
+ raise InvalidKeyError(
636
+ "Coords should be 32 bytes for curve P-256"
637
+ ) from None
638
+ elif curve == "P-384":
639
+ if len(x) == len(y) == 48:
640
+ curve_obj = SECP384R1()
641
+ else:
642
+ raise InvalidKeyError(
643
+ "Coords should be 48 bytes for curve P-384"
644
+ ) from None
645
+ elif curve == "P-521":
646
+ if len(x) == len(y) == 66:
647
+ curve_obj = SECP521R1()
648
+ else:
649
+ raise InvalidKeyError(
650
+ "Coords should be 66 bytes for curve P-521"
651
+ ) from None
652
+ elif curve == "secp256k1":
653
+ if len(x) == len(y) == 32:
654
+ curve_obj = SECP256K1()
655
+ else:
656
+ raise InvalidKeyError(
657
+ "Coords should be 32 bytes for curve secp256k1"
658
+ )
659
+ else:
660
+ raise InvalidKeyError(f"Invalid curve: {curve}")
661
+
662
+ public_numbers = EllipticCurvePublicNumbers(
663
+ x=int.from_bytes(x, byteorder="big"),
664
+ y=int.from_bytes(y, byteorder="big"),
665
+ curve=curve_obj,
666
+ )
667
+
668
+ if "d" not in obj:
669
+ return public_numbers.public_key()
670
+
671
+ d = base64url_decode(obj.get("d"))
672
+ if len(d) != len(x):
673
+ raise InvalidKeyError(
674
+ "D should be {} bytes for curve {}", len(x), curve
675
+ )
676
+
677
+ return EllipticCurvePrivateNumbers(
678
+ int.from_bytes(d, byteorder="big"), public_numbers
679
+ ).private_key()
680
+
681
+ class RSAPSSAlgorithm(RSAAlgorithm):
682
+ """
683
+ Performs a signature using RSASSA-PSS with MGF1
684
+ """
685
+
686
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
687
+ return key.sign(
688
+ msg,
689
+ padding.PSS(
690
+ mgf=padding.MGF1(self.hash_alg()),
691
+ salt_length=self.hash_alg().digest_size,
692
+ ),
693
+ self.hash_alg(),
694
+ )
695
+
696
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
697
+ try:
698
+ key.verify(
699
+ sig,
700
+ msg,
701
+ padding.PSS(
702
+ mgf=padding.MGF1(self.hash_alg()),
703
+ salt_length=self.hash_alg().digest_size,
704
+ ),
705
+ self.hash_alg(),
706
+ )
707
+ return True
708
+ except InvalidSignature:
709
+ return False
710
+
711
+ class OKPAlgorithm(Algorithm):
712
+ """
713
+ Performs signing and verification operations using EdDSA
714
+
715
+ This class requires ``cryptography>=2.6`` to be installed.
716
+ """
717
+
718
+ def __init__(self, **kwargs: Any) -> None:
719
+ pass
720
+
721
+ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
722
+ if isinstance(key, (bytes, str)):
723
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
724
+ key_bytes = key.encode("utf-8") if isinstance(key, str) else key
725
+
726
+ if "-----BEGIN PUBLIC" in key_str:
727
+ key = load_pem_public_key(key_bytes) # type: ignore[assignment]
728
+ elif "-----BEGIN PRIVATE" in key_str:
729
+ key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
730
+ elif key_str[0:4] == "ssh-":
731
+ key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
732
+
733
+ # Explicit check the key to prevent confusing errors from cryptography
734
+ if not isinstance(
735
+ key,
736
+ (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
737
+ ):
738
+ raise InvalidKeyError(
739
+ "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
740
+ )
741
+
742
+ return key
743
+
744
+ def sign(
745
+ self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
746
+ ) -> bytes:
747
+ """
748
+ Sign a message ``msg`` using the EdDSA private key ``key``
749
+ :param str|bytes msg: Message to sign
750
+ :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
751
+ or :class:`.Ed448PrivateKey` isinstance
752
+ :return bytes signature: The signature, as bytes
753
+ """
754
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
755
+ return key.sign(msg_bytes)
756
+
757
+ def verify(
758
+ self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
759
+ ) -> bool:
760
+ """
761
+ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
762
+
763
+ :param str|bytes sig: EdDSA signature to check ``msg`` against
764
+ :param str|bytes msg: Message to sign
765
+ :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
766
+ A private or public EdDSA key instance
767
+ :return bool verified: True if signature is valid, False if not.
768
+ """
769
+ try:
770
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
771
+ sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
772
+
773
+ public_key = (
774
+ key.public_key()
775
+ if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
776
+ else key
777
+ )
778
+ public_key.verify(sig_bytes, msg_bytes)
779
+ return True # If no exception was raised, the signature is valid.
780
+ except InvalidSignature:
781
+ return False
782
+
783
+ @overload
784
+ @staticmethod
785
+ def to_jwk(
786
+ key: AllowedOKPKeys, as_dict: Literal[True]
787
+ ) -> JWKDict: ... # pragma: no cover
788
+
789
+ @overload
790
+ @staticmethod
791
+ def to_jwk(
792
+ key: AllowedOKPKeys, as_dict: Literal[False] = False
793
+ ) -> str: ... # pragma: no cover
794
+
795
+ @staticmethod
796
+ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
797
+ if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
798
+ x = key.public_bytes(
799
+ encoding=Encoding.Raw,
800
+ format=PublicFormat.Raw,
801
+ )
802
+ crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
803
+
804
+ obj = {
805
+ "x": base64url_encode(force_bytes(x)).decode(),
806
+ "kty": "OKP",
807
+ "crv": crv,
808
+ }
809
+
810
+ if as_dict:
811
+ return obj
812
+ else:
813
+ return json.dumps(obj)
814
+
815
+ if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
816
+ d = key.private_bytes(
817
+ encoding=Encoding.Raw,
818
+ format=PrivateFormat.Raw,
819
+ encryption_algorithm=NoEncryption(),
820
+ )
821
+
822
+ x = key.public_key().public_bytes(
823
+ encoding=Encoding.Raw,
824
+ format=PublicFormat.Raw,
825
+ )
826
+
827
+ crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
828
+ obj = {
829
+ "x": base64url_encode(force_bytes(x)).decode(),
830
+ "d": base64url_encode(force_bytes(d)).decode(),
831
+ "kty": "OKP",
832
+ "crv": crv,
833
+ }
834
+
835
+ if as_dict:
836
+ return obj
837
+ else:
838
+ return json.dumps(obj)
839
+
840
+ raise InvalidKeyError("Not a public or private key")
841
+
842
+ @staticmethod
843
+ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
844
+ try:
845
+ if isinstance(jwk, str):
846
+ obj = json.loads(jwk)
847
+ elif isinstance(jwk, dict):
848
+ obj = jwk
849
+ else:
850
+ raise ValueError
851
+ except ValueError:
852
+ raise InvalidKeyError("Key is not valid JSON") from None
853
+
854
+ if obj.get("kty") != "OKP":
855
+ raise InvalidKeyError("Not an Octet Key Pair")
856
+
857
+ curve = obj.get("crv")
858
+ if curve != "Ed25519" and curve != "Ed448":
859
+ raise InvalidKeyError(f"Invalid curve: {curve}")
860
+
861
+ if "x" not in obj:
862
+ raise InvalidKeyError('OKP should have "x" parameter')
863
+ x = base64url_decode(obj.get("x"))
864
+
865
+ try:
866
+ if "d" not in obj:
867
+ if curve == "Ed25519":
868
+ return Ed25519PublicKey.from_public_bytes(x)
869
+ return Ed448PublicKey.from_public_bytes(x)
870
+ d = base64url_decode(obj.get("d"))
871
+ if curve == "Ed25519":
872
+ return Ed25519PrivateKey.from_private_bytes(d)
873
+ return Ed448PrivateKey.from_private_bytes(d)
874
+ except ValueError as err:
875
+ raise InvalidKeyError("Invalid key parameter") from err