Spaces:
Sleeping
Sleeping
Upload backend/venv/lib/python3.10/site-packages/jwt/api_jws.py with huggingface_hub
ac4ad5e
verified
| from __future__ import annotations | |
| import binascii | |
| import json | |
| import warnings | |
| from collections.abc import Sequence | |
| from typing import TYPE_CHECKING, Any | |
| from .algorithms import ( | |
| Algorithm, | |
| get_default_algorithms, | |
| has_crypto, | |
| requires_cryptography, | |
| ) | |
| from .api_jwk import PyJWK | |
| from .exceptions import ( | |
| DecodeError, | |
| InvalidAlgorithmError, | |
| InvalidSignatureError, | |
| InvalidTokenError, | |
| ) | |
| from .utils import base64url_decode, base64url_encode | |
| from .warnings import RemovedInPyjwt3Warning | |
| if TYPE_CHECKING: | |
| from .algorithms import AllowedPrivateKeys, AllowedPublicKeys | |
| class PyJWS: | |
| header_typ = "JWT" | |
| def __init__( | |
| self, | |
| algorithms: Sequence[str] | None = None, | |
| options: dict[str, Any] | None = None, | |
| ) -> None: | |
| self._algorithms = get_default_algorithms() | |
| self._valid_algs = ( | |
| set(algorithms) if algorithms is not None else set(self._algorithms) | |
| ) | |
| # Remove algorithms that aren't on the whitelist | |
| for key in list(self._algorithms.keys()): | |
| if key not in self._valid_algs: | |
| del self._algorithms[key] | |
| if options is None: | |
| options = {} | |
| self.options = {**self._get_default_options(), **options} | |
| def _get_default_options() -> dict[str, bool]: | |
| return {"verify_signature": True} | |
| def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None: | |
| """ | |
| Registers a new Algorithm for use when creating and verifying tokens. | |
| """ | |
| if alg_id in self._algorithms: | |
| raise ValueError("Algorithm already has a handler.") | |
| if not isinstance(alg_obj, Algorithm): | |
| raise TypeError("Object is not of type `Algorithm`") | |
| self._algorithms[alg_id] = alg_obj | |
| self._valid_algs.add(alg_id) | |
| def unregister_algorithm(self, alg_id: str) -> None: | |
| """ | |
| Unregisters an Algorithm for use when creating and verifying tokens | |
| Throws KeyError if algorithm is not registered. | |
| """ | |
| if alg_id not in self._algorithms: | |
| raise KeyError( | |
| "The specified algorithm could not be removed" | |
| " because it is not registered." | |
| ) | |
| del self._algorithms[alg_id] | |
| self._valid_algs.remove(alg_id) | |
| def get_algorithms(self) -> list[str]: | |
| """ | |
| Returns a list of supported values for the 'alg' parameter. | |
| """ | |
| return list(self._valid_algs) | |
| def get_algorithm_by_name(self, alg_name: str) -> Algorithm: | |
| """ | |
| For a given string name, return the matching Algorithm object. | |
| Example usage: | |
| >>> jws_obj.get_algorithm_by_name("RS256") | |
| """ | |
| try: | |
| return self._algorithms[alg_name] | |
| except KeyError as e: | |
| if not has_crypto and alg_name in requires_cryptography: | |
| raise NotImplementedError( | |
| f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?" | |
| ) from e | |
| raise NotImplementedError("Algorithm not supported") from e | |
| def encode( | |
| self, | |
| payload: bytes, | |
| key: AllowedPrivateKeys | PyJWK | str | bytes, | |
| algorithm: str | None = None, | |
| headers: dict[str, Any] | None = None, | |
| json_encoder: type[json.JSONEncoder] | None = None, | |
| is_payload_detached: bool = False, | |
| sort_headers: bool = True, | |
| ) -> str: | |
| segments = [] | |
| # declare a new var to narrow the type for type checkers | |
| if algorithm is None: | |
| if isinstance(key, PyJWK): | |
| algorithm_ = key.algorithm_name | |
| else: | |
| algorithm_ = "HS256" | |
| else: | |
| algorithm_ = algorithm | |
| # Prefer headers values if present to function parameters. | |
| if headers: | |
| headers_alg = headers.get("alg") | |
| if headers_alg: | |
| algorithm_ = headers["alg"] | |
| headers_b64 = headers.get("b64") | |
| if headers_b64 is False: | |
| is_payload_detached = True | |
| # Header | |
| header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_} | |
| if headers: | |
| self._validate_headers(headers) | |
| header.update(headers) | |
| if not header["typ"]: | |
| del header["typ"] | |
| if is_payload_detached: | |
| header["b64"] = False | |
| elif "b64" in header: | |
| # True is the standard value for b64, so no need for it | |
| del header["b64"] | |
| json_header = json.dumps( | |
| header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers | |
| ).encode() | |
| segments.append(base64url_encode(json_header)) | |
| if is_payload_detached: | |
| msg_payload = payload | |
| else: | |
| msg_payload = base64url_encode(payload) | |
| segments.append(msg_payload) | |
| # Segments | |
| signing_input = b".".join(segments) | |
| alg_obj = self.get_algorithm_by_name(algorithm_) | |
| if isinstance(key, PyJWK): | |
| key = key.key | |
| key = alg_obj.prepare_key(key) | |
| signature = alg_obj.sign(signing_input, key) | |
| segments.append(base64url_encode(signature)) | |
| # Don't put the payload content inside the encoded token when detached | |
| if is_payload_detached: | |
| segments[1] = b"" | |
| encoded_string = b".".join(segments) | |
| return encoded_string.decode("utf-8") | |
| def decode_complete( | |
| self, | |
| jwt: str | bytes, | |
| key: AllowedPublicKeys | PyJWK | str | bytes = "", | |
| algorithms: Sequence[str] | None = None, | |
| options: dict[str, Any] | None = None, | |
| detached_payload: bytes | None = None, | |
| **kwargs, | |
| ) -> dict[str, Any]: | |
| if kwargs: | |
| warnings.warn( | |
| "passing additional kwargs to decode_complete() is deprecated " | |
| "and will be removed in pyjwt version 3. " | |
| f"Unsupported kwargs: {tuple(kwargs.keys())}", | |
| RemovedInPyjwt3Warning, | |
| stacklevel=2, | |
| ) | |
| if options is None: | |
| options = {} | |
| merged_options = {**self.options, **options} | |
| verify_signature = merged_options["verify_signature"] | |
| if verify_signature and not algorithms and not isinstance(key, PyJWK): | |
| raise DecodeError( | |
| 'It is required that you pass in a value for the "algorithms" argument when calling decode().' | |
| ) | |
| payload, signing_input, header, signature = self._load(jwt) | |
| if header.get("b64", True) is False: | |
| if detached_payload is None: | |
| raise DecodeError( | |
| 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.' | |
| ) | |
| payload = detached_payload | |
| signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload]) | |
| if verify_signature: | |
| self._verify_signature(signing_input, header, signature, key, algorithms) | |
| return { | |
| "payload": payload, | |
| "header": header, | |
| "signature": signature, | |
| } | |
| def decode( | |
| self, | |
| jwt: str | bytes, | |
| key: AllowedPublicKeys | PyJWK | str | bytes = "", | |
| algorithms: Sequence[str] | None = None, | |
| options: dict[str, Any] | None = None, | |
| detached_payload: bytes | None = None, | |
| **kwargs, | |
| ) -> Any: | |
| if kwargs: | |
| warnings.warn( | |
| "passing additional kwargs to decode() is deprecated " | |
| "and will be removed in pyjwt version 3. " | |
| f"Unsupported kwargs: {tuple(kwargs.keys())}", | |
| RemovedInPyjwt3Warning, | |
| stacklevel=2, | |
| ) | |
| decoded = self.decode_complete( | |
| jwt, key, algorithms, options, detached_payload=detached_payload | |
| ) | |
| return decoded["payload"] | |
| def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: | |
| """Returns back the JWT header parameters as a dict() | |
| Note: The signature is not verified so the header parameters | |
| should not be fully trusted until signature verification is complete | |
| """ | |
| headers = self._load(jwt)[2] | |
| self._validate_headers(headers) | |
| return headers | |
| def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: | |
| if isinstance(jwt, str): | |
| jwt = jwt.encode("utf-8") | |
| if not isinstance(jwt, bytes): | |
| raise DecodeError(f"Invalid token type. Token must be a {bytes}") | |
| try: | |
| signing_input, crypto_segment = jwt.rsplit(b".", 1) | |
| header_segment, payload_segment = signing_input.split(b".", 1) | |
| except ValueError as err: | |
| raise DecodeError("Not enough segments") from err | |
| try: | |
| header_data = base64url_decode(header_segment) | |
| except (TypeError, binascii.Error) as err: | |
| raise DecodeError("Invalid header padding") from err | |
| try: | |
| header = json.loads(header_data) | |
| except ValueError as e: | |
| raise DecodeError(f"Invalid header string: {e}") from e | |
| if not isinstance(header, dict): | |
| raise DecodeError("Invalid header string: must be a json object") | |
| try: | |
| payload = base64url_decode(payload_segment) | |
| except (TypeError, binascii.Error) as err: | |
| raise DecodeError("Invalid payload padding") from err | |
| try: | |
| signature = base64url_decode(crypto_segment) | |
| except (TypeError, binascii.Error) as err: | |
| raise DecodeError("Invalid crypto padding") from err | |
| return (payload, signing_input, header, signature) | |
| def _verify_signature( | |
| self, | |
| signing_input: bytes, | |
| header: dict[str, Any], | |
| signature: bytes, | |
| key: AllowedPublicKeys | PyJWK | str | bytes = "", | |
| algorithms: Sequence[str] | None = None, | |
| ) -> None: | |
| if algorithms is None and isinstance(key, PyJWK): | |
| algorithms = [key.algorithm_name] | |
| try: | |
| alg = header["alg"] | |
| except KeyError: | |
| raise InvalidAlgorithmError("Algorithm not specified") from None | |
| if not alg or (algorithms is not None and alg not in algorithms): | |
| raise InvalidAlgorithmError("The specified alg value is not allowed") | |
| if isinstance(key, PyJWK): | |
| alg_obj = key.Algorithm | |
| prepared_key = key.key | |
| else: | |
| try: | |
| alg_obj = self.get_algorithm_by_name(alg) | |
| except NotImplementedError as e: | |
| raise InvalidAlgorithmError("Algorithm not supported") from e | |
| prepared_key = alg_obj.prepare_key(key) | |
| if not alg_obj.verify(signing_input, prepared_key, signature): | |
| raise InvalidSignatureError("Signature verification failed") | |
| def _validate_headers(self, headers: dict[str, Any]) -> None: | |
| if "kid" in headers: | |
| self._validate_kid(headers["kid"]) | |
| def _validate_kid(self, kid: Any) -> None: | |
| if not isinstance(kid, str): | |
| raise InvalidTokenError("Key ID header parameter must be a string") | |
| _jws_global_obj = PyJWS() | |
| encode = _jws_global_obj.encode | |
| decode_complete = _jws_global_obj.decode_complete | |
| decode = _jws_global_obj.decode | |
| register_algorithm = _jws_global_obj.register_algorithm | |
| unregister_algorithm = _jws_global_obj.unregister_algorithm | |
| get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name | |
| get_unverified_header = _jws_global_obj.get_unverified_header | |