|
|
#include "../../unity/unity.h" |
|
|
#include <stdlib.h> |
|
|
#include <stdint.h> |
|
|
#include <string.h> |
|
|
#include <limits.h> |
|
|
#include <gmp.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void mpz_set_limb(mpz_t z, mp_limb_t x) { |
|
|
mpz_import(z, 1, -1, sizeof(mp_limb_t), 0, 0, &x); |
|
|
} |
|
|
|
|
|
|
|
|
static mp_limb_t mont_R_mod_m(mp_limb_t m) { |
|
|
mp_limb_t R; |
|
|
redcify(R, 1, m); |
|
|
return R; |
|
|
} |
|
|
|
|
|
|
|
|
static mp_limb_t mont_redcify_value(mp_limb_t x, mp_limb_t m) { |
|
|
mp_limb_t xR; |
|
|
redcify(xR, x, m); |
|
|
return xR; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static mp_limb_t expected_redc_mul(mp_limb_t a, mp_limb_t b, mp_limb_t m) { |
|
|
mpz_t A, B, M, T, Rz; |
|
|
mpz_inits(A, B, M, T, Rz, NULL); |
|
|
|
|
|
mpz_set_limb(A, a); |
|
|
mpz_set_limb(B, b); |
|
|
mpz_set_limb(M, m); |
|
|
|
|
|
mpz_mul(T, A, B); |
|
|
mpz_mod(T, T, M); |
|
|
|
|
|
mp_limb_t R = mont_R_mod_m(m); |
|
|
mpz_set_limb(Rz, R); |
|
|
mpz_mul(T, T, Rz); |
|
|
mpz_mod(T, T, M); |
|
|
|
|
|
mp_limb_t out = mpz_getlimbn(T, 0); |
|
|
|
|
|
mpz_clears(A, B, M, T, Rz, NULL); |
|
|
return out; |
|
|
} |
|
|
|
|
|
|
|
|
static void check_case(mp_limb_t m, mp_limb_t a, mp_limb_t b) { |
|
|
TEST_ASSERT_TRUE_MESSAGE((m & 1u) == 1u, "Modulus must be odd"); |
|
|
TEST_ASSERT_TRUE(a < m); |
|
|
TEST_ASSERT_TRUE(b < m); |
|
|
|
|
|
mp_limb_t mi = binv_limb(m); |
|
|
|
|
|
mp_limb_t aR = mont_redcify_value(a, m); |
|
|
mp_limb_t bR = mont_redcify_value(b, m); |
|
|
|
|
|
mp_limb_t got = mulredc(aR, bR, m, mi); |
|
|
mp_limb_t exp = expected_redc_mul(a, b, m); |
|
|
|
|
|
|
|
|
TEST_ASSERT_TRUE(got < m); |
|
|
|
|
|
|
|
|
TEST_ASSERT_EQUAL_HEX64_MESSAGE((uint64_t)exp, (uint64_t)got, "mulredc result mismatch"); |
|
|
} |
|
|
|
|
|
void setUp(void) { |
|
|
|
|
|
} |
|
|
|
|
|
void tearDown(void) { |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
static void test_mulredc_small_moduli_fixed(void) { |
|
|
const mp_limb_t moduli[] = {3, 5, 7, 9, 11, 15, 21, 25, 27, 33, 35}; |
|
|
for (size_t i = 0; i < sizeof(moduli)/sizeof(moduli[0]); i++) { |
|
|
mp_limb_t m = moduli[i]; |
|
|
|
|
|
mp_limb_t vals[] = {0, 1, (m > 2 ? 2 : 1), m - 2, m - 1}; |
|
|
size_t nvals = sizeof(vals)/sizeof(vals[0]); |
|
|
for (size_t ia = 0; ia < nvals; ia++) { |
|
|
for (size_t ib = 0; ib < nvals; ib++) { |
|
|
mp_limb_t a = vals[ia] % m; |
|
|
mp_limb_t b = vals[ib] % m; |
|
|
check_case(m, a, b); |
|
|
|
|
|
|
|
|
mp_limb_t mi = binv_limb(m); |
|
|
mp_limb_t aR = mont_redcify_value(a, m); |
|
|
mp_limb_t bR = mont_redcify_value(b, m); |
|
|
mp_limb_t r1 = mulredc(aR, bR, m, mi); |
|
|
mp_limb_t r2 = mulredc(bR, aR, m, mi); |
|
|
TEST_ASSERT_EQUAL_HEX64((uint64_t)r1, (uint64_t)r2); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static void test_mulredc_identities(void) { |
|
|
|
|
|
mp_limb_t m = 1000003u; |
|
|
mp_limb_t mi = binv_limb(m); |
|
|
|
|
|
|
|
|
mp_limb_t R = mont_R_mod_m(m); |
|
|
|
|
|
|
|
|
for (mp_limb_t x = 0; x < 20; x++) { |
|
|
mp_limb_t xR = mont_redcify_value(x, m); |
|
|
mp_limb_t zeroR = mont_redcify_value(0, m); |
|
|
mp_limb_t got = mulredc(zeroR, xR, m, mi); |
|
|
TEST_ASSERT_EQUAL_HEX64(0u, (uint64_t)got); |
|
|
got = mulredc(xR, zeroR, m, mi); |
|
|
TEST_ASSERT_EQUAL_HEX64(0u, (uint64_t)got); |
|
|
} |
|
|
|
|
|
|
|
|
for (mp_limb_t x = 0; x < 20; x++) { |
|
|
mp_limb_t xR = mont_redcify_value(x, m); |
|
|
|
|
|
mp_limb_t got = mulredc(R, xR, m, mi); |
|
|
TEST_ASSERT_EQUAL_HEX64((uint64_t)xR, (uint64_t)got); |
|
|
got = mulredc(xR, R, m, mi); |
|
|
TEST_ASSERT_EQUAL_HEX64((uint64_t)xR, (uint64_t)got); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static void test_mulredc_random_values(void) { |
|
|
mp_limb_t m = 4294967291u; |
|
|
mp_limb_t mi = binv_limb(m); |
|
|
|
|
|
|
|
|
uint64_t seed = 0x123456789abcdef0ULL; |
|
|
for (int i = 0; i < 200; i++) { |
|
|
seed = seed * 6364136223846793005ULL + 1ULL; |
|
|
mp_limb_t a = (mp_limb_t)(seed % m); |
|
|
seed = seed * 6364136223846793005ULL + 1ULL; |
|
|
mp_limb_t b = (mp_limb_t)(seed % m); |
|
|
check_case(m, a, b); |
|
|
|
|
|
|
|
|
mp_limb_t aR = mont_redcify_value(a, m); |
|
|
mp_limb_t bR = mont_redcify_value(b, m); |
|
|
mp_limb_t r1 = mulredc(aR, bR, m, mi); |
|
|
mp_limb_t r2 = mulredc(bR, aR, m, mi); |
|
|
TEST_ASSERT_EQUAL_HEX64((uint64_t)r1, (uint64_t)r2); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static void test_mulredc_large_modulus_edge(void) { |
|
|
mp_limb_t m = MP_LIMB_MAX - 123; |
|
|
if ((m & 1u) == 0) m -= 1; |
|
|
|
|
|
mp_limb_t a = m - 2; |
|
|
mp_limb_t b = m - 3; |
|
|
check_case(m, a, b); |
|
|
|
|
|
|
|
|
check_case(m, m - 1, m - 1); |
|
|
check_case(m, m - 5, m - 7); |
|
|
check_case(m, m / 2, (m / 2) - 1); |
|
|
} |
|
|
|
|
|
|
|
|
static void test_mulredc_repeated_squaring(void) { |
|
|
mp_limb_t m = 1000003u; |
|
|
mp_limb_t mi = binv_limb(m); |
|
|
|
|
|
mp_limb_t a = 123456u % m; |
|
|
mp_limb_t aR = mont_redcify_value(a, m); |
|
|
|
|
|
mp_limb_t X = aR; |
|
|
for (unsigned k = 2; k <= 16; k++) { |
|
|
X = mulredc(X, aR, m, mi); |
|
|
|
|
|
|
|
|
mpz_t A, M, T, Rz; mpz_inits(A, M, T, Rz, NULL); |
|
|
mpz_set_limb(A, a); |
|
|
mpz_set_limb(M, m); |
|
|
mpz_pow_ui(T, A, k); |
|
|
mpz_mod(T, T, M); |
|
|
mp_limb_t R = mont_R_mod_m(m); |
|
|
mpz_set_limb(Rz, R); |
|
|
mpz_mul(T, T, Rz); |
|
|
mpz_mod(T, T, M); |
|
|
mp_limb_t exp = mpz_getlimbn(T, 0); |
|
|
mpz_clears(A, M, T, Rz, NULL); |
|
|
|
|
|
TEST_ASSERT_TRUE(X < m); |
|
|
TEST_ASSERT_EQUAL_HEX64((uint64_t)exp, (uint64_t)X); |
|
|
} |
|
|
} |
|
|
|
|
|
int main(void) { |
|
|
UNITY_BEGIN(); |
|
|
RUN_TEST(test_mulredc_small_moduli_fixed); |
|
|
RUN_TEST(test_mulredc_identities); |
|
|
RUN_TEST(test_mulredc_random_values); |
|
|
RUN_TEST(test_mulredc_large_modulus_edge); |
|
|
RUN_TEST(test_mulredc_repeated_squaring); |
|
|
return UNITY_END(); |
|
|
} |