Spaces:
Paused
Paused
| from authlib.common.encoding import ( | |
| to_bytes, | |
| to_unicode, | |
| urlsafe_b64encode, | |
| json_b64encode, | |
| ) | |
| from authlib.jose.util import ( | |
| extract_header, | |
| extract_segment, ensure_dict, | |
| ) | |
| from authlib.jose.errors import ( | |
| DecodeError, | |
| MissingAlgorithmError, | |
| UnsupportedAlgorithmError, | |
| BadSignatureError, | |
| InvalidHeaderParameterNameError, | |
| ) | |
| from .models import JWSHeader, JWSObject | |
| class JsonWebSignature: | |
| #: Registered Header Parameter Names defined by Section 4.1 | |
| REGISTERED_HEADER_PARAMETER_NAMES = frozenset([ | |
| 'alg', 'jku', 'jwk', 'kid', | |
| 'x5u', 'x5c', 'x5t', 'x5t#S256', | |
| 'typ', 'cty', 'crit' | |
| ]) | |
| #: Defined available JWS algorithms in the registry | |
| ALGORITHMS_REGISTRY = {} | |
| def __init__(self, algorithms=None, private_headers=None): | |
| self._private_headers = private_headers | |
| self._algorithms = algorithms | |
| def register_algorithm(cls, algorithm): | |
| if not algorithm or algorithm.algorithm_type != 'JWS': | |
| raise ValueError( | |
| f'Invalid algorithm for JWS, {algorithm!r}') | |
| cls.ALGORITHMS_REGISTRY[algorithm.name] = algorithm | |
| def serialize_compact(self, protected, payload, key): | |
| """Generate a JWS Compact Serialization. The JWS Compact Serialization | |
| represents digitally signed or MACed content as a compact, URL-safe | |
| string, per `Section 7.1`_. | |
| .. code-block:: text | |
| BASE64URL(UTF8(JWS Protected Header)) || '.' || | |
| BASE64URL(JWS Payload) || '.' || | |
| BASE64URL(JWS Signature) | |
| :param protected: A dict of protected header | |
| :param payload: A bytes/string of payload | |
| :param key: Private key used to generate signature | |
| :return: byte | |
| """ | |
| jws_header = JWSHeader(protected, None) | |
| self._validate_private_headers(protected) | |
| algorithm, key = self._prepare_algorithm_key(protected, payload, key) | |
| protected_segment = json_b64encode(jws_header.protected) | |
| payload_segment = urlsafe_b64encode(to_bytes(payload)) | |
| # calculate signature | |
| signing_input = b'.'.join([protected_segment, payload_segment]) | |
| signature = urlsafe_b64encode(algorithm.sign(signing_input, key)) | |
| return b'.'.join([protected_segment, payload_segment, signature]) | |
| def deserialize_compact(self, s, key, decode=None): | |
| """Exact JWS Compact Serialization, and validate with the given key. | |
| If key is not provided, the returned dict will contain the signature, | |
| and signing input values. Via `Section 7.1`_. | |
| :param s: text of JWS Compact Serialization | |
| :param key: key used to verify the signature | |
| :param decode: a function to decode payload data | |
| :return: JWSObject | |
| :raise: BadSignatureError | |
| .. _`Section 7.1`: https://tools.ietf.org/html/rfc7515#section-7.1 | |
| """ | |
| try: | |
| s = to_bytes(s) | |
| signing_input, signature_segment = s.rsplit(b'.', 1) | |
| protected_segment, payload_segment = signing_input.split(b'.', 1) | |
| except ValueError: | |
| raise DecodeError('Not enough segments') | |
| protected = _extract_header(protected_segment) | |
| jws_header = JWSHeader(protected, None) | |
| payload = _extract_payload(payload_segment) | |
| if decode: | |
| payload = decode(payload) | |
| signature = _extract_signature(signature_segment) | |
| rv = JWSObject(jws_header, payload, 'compact') | |
| algorithm, key = self._prepare_algorithm_key(jws_header, payload, key) | |
| if algorithm.verify(signing_input, signature, key): | |
| return rv | |
| raise BadSignatureError(rv) | |
| def serialize_json(self, header_obj, payload, key): | |
| """Generate a JWS JSON Serialization. The JWS JSON Serialization | |
| represents digitally signed or MACed content as a JSON object, | |
| per `Section 7.2`_. | |
| :param header_obj: A dict/list of header | |
| :param payload: A string/dict of payload | |
| :param key: Private key used to generate signature | |
| :return: JWSObject | |
| Example ``header_obj`` of JWS JSON Serialization:: | |
| { | |
| "protected: {"alg": "HS256"}, | |
| "header": {"kid": "jose"} | |
| } | |
| Pass a dict to generate flattened JSON Serialization, pass a list of | |
| header dict to generate standard JSON Serialization. | |
| """ | |
| payload_segment = json_b64encode(payload) | |
| def _sign(jws_header): | |
| self._validate_private_headers(jws_header) | |
| _alg, _key = self._prepare_algorithm_key(jws_header, payload, key) | |
| protected_segment = json_b64encode(jws_header.protected) | |
| signing_input = b'.'.join([protected_segment, payload_segment]) | |
| signature = urlsafe_b64encode(_alg.sign(signing_input, _key)) | |
| rv = { | |
| 'protected': to_unicode(protected_segment), | |
| 'signature': to_unicode(signature) | |
| } | |
| if jws_header.header is not None: | |
| rv['header'] = jws_header.header | |
| return rv | |
| if isinstance(header_obj, dict): | |
| data = _sign(JWSHeader.from_dict(header_obj)) | |
| data['payload'] = to_unicode(payload_segment) | |
| return data | |
| signatures = [_sign(JWSHeader.from_dict(h)) for h in header_obj] | |
| return { | |
| 'payload': to_unicode(payload_segment), | |
| 'signatures': signatures | |
| } | |
| def deserialize_json(self, obj, key, decode=None): | |
| """Exact JWS JSON Serialization, and validate with the given key. | |
| If key is not provided, it will return a dict without signature | |
| verification. Header will still be validated. Via `Section 7.2`_. | |
| :param obj: text of JWS JSON Serialization | |
| :param key: key used to verify the signature | |
| :param decode: a function to decode payload data | |
| :return: JWSObject | |
| :raise: BadSignatureError | |
| .. _`Section 7.2`: https://tools.ietf.org/html/rfc7515#section-7.2 | |
| """ | |
| obj = ensure_dict(obj, 'JWS') | |
| payload_segment = obj.get('payload') | |
| if payload_segment is None: | |
| raise DecodeError('Missing "payload" value') | |
| payload_segment = to_bytes(payload_segment) | |
| payload = _extract_payload(payload_segment) | |
| if decode: | |
| payload = decode(payload) | |
| if 'signatures' not in obj: | |
| # flattened JSON JWS | |
| jws_header, valid = self._validate_json_jws( | |
| payload_segment, payload, obj, key) | |
| rv = JWSObject(jws_header, payload, 'flat') | |
| if valid: | |
| return rv | |
| raise BadSignatureError(rv) | |
| headers = [] | |
| is_valid = True | |
| for header_obj in obj['signatures']: | |
| jws_header, valid = self._validate_json_jws( | |
| payload_segment, payload, header_obj, key) | |
| headers.append(jws_header) | |
| if not valid: | |
| is_valid = False | |
| rv = JWSObject(headers, payload, 'json') | |
| if is_valid: | |
| return rv | |
| raise BadSignatureError(rv) | |
| def serialize(self, header, payload, key): | |
| """Generate a JWS Serialization. It will automatically generate a | |
| Compact or JSON Serialization depending on the given header. If a | |
| header is in a JSON header format, it will call | |
| :meth:`serialize_json`, otherwise it will call | |
| :meth:`serialize_compact`. | |
| :param header: A dict/list of header | |
| :param payload: A string/dict of payload | |
| :param key: Private key used to generate signature | |
| :return: byte/dict | |
| """ | |
| if isinstance(header, (list, tuple)): | |
| return self.serialize_json(header, payload, key) | |
| if 'protected' in header: | |
| return self.serialize_json(header, payload, key) | |
| return self.serialize_compact(header, payload, key) | |
| def deserialize(self, s, key, decode=None): | |
| """Deserialize JWS Serialization, both compact and JSON format. | |
| It will automatically deserialize depending on the given JWS. | |
| :param s: text of JWS Compact/JSON Serialization | |
| :param key: key used to verify the signature | |
| :param decode: a function to decode payload data | |
| :return: dict | |
| :raise: BadSignatureError | |
| If key is not provided, it will still deserialize the serialization | |
| without verification. | |
| """ | |
| if isinstance(s, dict): | |
| return self.deserialize_json(s, key, decode) | |
| s = to_bytes(s) | |
| if s.startswith(b'{') and s.endswith(b'}'): | |
| return self.deserialize_json(s, key, decode) | |
| return self.deserialize_compact(s, key, decode) | |
| def _prepare_algorithm_key(self, header, payload, key): | |
| if 'alg' not in header: | |
| raise MissingAlgorithmError() | |
| alg = header['alg'] | |
| if self._algorithms is not None and alg not in self._algorithms: | |
| raise UnsupportedAlgorithmError() | |
| if alg not in self.ALGORITHMS_REGISTRY: | |
| raise UnsupportedAlgorithmError() | |
| algorithm = self.ALGORITHMS_REGISTRY[alg] | |
| if callable(key): | |
| key = key(header, payload) | |
| elif key is None and 'jwk' in header: | |
| key = header['jwk'] | |
| key = algorithm.prepare_key(key) | |
| return algorithm, key | |
| def _validate_private_headers(self, header): | |
| # only validate private headers when developers set | |
| # private headers explicitly | |
| if self._private_headers is not None: | |
| names = self.REGISTERED_HEADER_PARAMETER_NAMES.copy() | |
| names = names.union(self._private_headers) | |
| for k in header: | |
| if k not in names: | |
| raise InvalidHeaderParameterNameError(k) | |
| def _validate_json_jws(self, payload_segment, payload, header_obj, key): | |
| protected_segment = header_obj.get('protected') | |
| if not protected_segment: | |
| raise DecodeError('Missing "protected" value') | |
| signature_segment = header_obj.get('signature') | |
| if not signature_segment: | |
| raise DecodeError('Missing "signature" value') | |
| protected_segment = to_bytes(protected_segment) | |
| protected = _extract_header(protected_segment) | |
| header = header_obj.get('header') | |
| if header and not isinstance(header, dict): | |
| raise DecodeError('Invalid "header" value') | |
| jws_header = JWSHeader(protected, header) | |
| algorithm, key = self._prepare_algorithm_key(jws_header, payload, key) | |
| signing_input = b'.'.join([protected_segment, payload_segment]) | |
| signature = _extract_signature(to_bytes(signature_segment)) | |
| if algorithm.verify(signing_input, signature, key): | |
| return jws_header, True | |
| return jws_header, False | |
| def _extract_header(header_segment): | |
| return extract_header(header_segment, DecodeError) | |
| def _extract_signature(signature_segment): | |
| return extract_segment(signature_segment, DecodeError, 'signature') | |
| def _extract_payload(payload_segment): | |
| return extract_segment(payload_segment, DecodeError, 'payload') | |