File size: 3,751 Bytes
2143587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# SPDX-License-Identifier: MIT

from __future__ import annotations

import platform
import sys

from dataclasses import dataclass
from typing import Any

from .exceptions import InvalidHashError, UnsupportedParametersError
from .low_level import Type


NoneType = type(None)


def _check_types(**kw: Any) -> str | None:
    """
    Check each ``name: (value, types)`` in *kw*.

    Returns a human-readable string of all violations or `None``.
    """
    errors = []
    for name, (value, types) in kw.items():
        if not isinstance(value, types):
            if isinstance(types, tuple):
                types = ", or ".join(t.__name__ for t in types)
            else:
                types = types.__name__
            errors.append(
                f"'{name}' must be a {types} (got {type(value).__name__})"
            )

    if errors != []:
        return ", ".join(errors) + "."

    return None


def _is_wasm() -> bool:
    return sys.platform == "emscripten" or platform.machine() in [
        "wasm32",
        "wasm64",
    ]


def _decoded_str_len(length: int) -> int:
    """
    Compute how long an encoded string of length *l* becomes.
    """
    rem = length % 4

    if rem == 3:
        last_group_len = 2
    elif rem == 2:
        last_group_len = 1
    else:
        last_group_len = 0

    return length // 4 * 3 + last_group_len


@dataclass
class Parameters:
    """
    Argon2 hash parameters.

    See :doc:`parameters` on how to pick them.

    Attributes:
        type: Hash type.

        version: Argon2 version.

        salt_len: Length of the salt in bytes.

        hash_len: Length of the hash in bytes.

        time_cost: Time cost in iterations.

        memory_cost: Memory cost in kibibytes.

        parallelism: Number of parallel threads.

    .. versionadded:: 18.2.0
    """

    type: Type
    version: int
    salt_len: int
    hash_len: int
    time_cost: int
    memory_cost: int
    parallelism: int

    __slots__ = (
        "hash_len",
        "memory_cost",
        "parallelism",
        "salt_len",
        "time_cost",
        "type",
        "version",
    )


_NAME_TO_TYPE = {"argon2id": Type.ID, "argon2i": Type.I, "argon2d": Type.D}
_REQUIRED_KEYS = sorted(("v", "m", "t", "p"))


def extract_parameters(hash: str) -> Parameters:
    """
    Extract parameters from an encoded *hash*.

    Args:
        hash: An encoded Argon2 hash string.

    Returns:
        The parameters used to create the hash.

    .. versionadded:: 18.2.0
    """
    parts = hash.split("$")

    # Backwards compatibility for Argon v1.2 hashes
    if len(parts) == 5:
        parts.insert(2, "v=18")

    if len(parts) != 6:
        raise InvalidHashError

    if parts[0]:
        raise InvalidHashError

    try:
        type = _NAME_TO_TYPE[parts[1]]

        kvs = {
            k: int(v)
            for k, v in (
                s.split("=") for s in [parts[2], *parts[3].split(",")]
            )
        }
    except Exception:  # noqa: BLE001
        raise InvalidHashError from None

    if sorted(kvs.keys()) != _REQUIRED_KEYS:
        raise InvalidHashError

    return Parameters(
        type=type,
        salt_len=_decoded_str_len(len(parts[4])),
        hash_len=_decoded_str_len(len(parts[5])),
        version=kvs["v"],
        time_cost=kvs["t"],
        memory_cost=kvs["m"],
        parallelism=kvs["p"],
    )


def validate_params_for_platform(params: Parameters) -> None:
    """
    Validate *params* against current platform.

    Args:
        params: Parameters to be validated

    Returns:
       None
    """
    if _is_wasm() and params.parallelism != 1:
        msg = "In WebAssembly environments `parallelism` must be 1."
        raise UnsupportedParametersError(msg)