#include "../../unity/unity.h" #include #include #include #include #include /* Helpers specific to testing mulredc. We rely on internal functions/macros from the factor implementation: binv_limb and redcify. */ /* Convert a single limb to an mpz_t using mpz_import, to avoid width issues. */ static void mpz_set_limb(mpz_t z, mp_limb_t x) { mpz_import(z, 1, -1, sizeof(mp_limb_t), 0, 0, &x); } /* Compute R = B mod m using redcify macro: redcify(R, 1, m). */ static mp_limb_t mont_R_mod_m(mp_limb_t m) { mp_limb_t R; redcify(R, 1, m); return R; } /* Compute xR = x * B mod m using redcify macro. Requires x < m and m odd. */ static mp_limb_t mont_redcify_value(mp_limb_t x, mp_limb_t m) { mp_limb_t xR; redcify(xR, x, m); return xR; } /* Compute expected redc result for mulredc(aR, bR, m, mi), which should be (a*b mod m) * R mod m, where a,b are the corresponding standard reps. a and b must be < m. */ static mp_limb_t expected_redc_mul(mp_limb_t a, mp_limb_t b, mp_limb_t m) { mpz_t A, B, M, T, Rz; mpz_inits(A, B, M, T, Rz, NULL); mpz_set_limb(A, a); mpz_set_limb(B, b); mpz_set_limb(M, m); mpz_mul(T, A, B); mpz_mod(T, T, M); mp_limb_t R = mont_R_mod_m(m); mpz_set_limb(Rz, R); mpz_mul(T, T, Rz); mpz_mod(T, T, M); mp_limb_t out = mpz_getlimbn(T, 0); mpz_clears(A, B, M, T, Rz, NULL); return out; } /* Check one test case for given m, a, b. Requires m odd, a < m, b < m. */ static void check_case(mp_limb_t m, mp_limb_t a, mp_limb_t b) { TEST_ASSERT_TRUE_MESSAGE((m & 1u) == 1u, "Modulus must be odd"); TEST_ASSERT_TRUE(a < m); TEST_ASSERT_TRUE(b < m); mp_limb_t mi = binv_limb(m); /* mi = m^{-1} mod B */ mp_limb_t aR = mont_redcify_value(a, m); mp_limb_t bR = mont_redcify_value(b, m); mp_limb_t got = mulredc(aR, bR, m, mi); mp_limb_t exp = expected_redc_mul(a, b, m); /* Range check */ TEST_ASSERT_TRUE(got < m); /* Value check */ TEST_ASSERT_EQUAL_HEX64_MESSAGE((uint64_t)exp, (uint64_t)got, "mulredc result mismatch"); } void setUp(void) { /* Nothing to set up */ } void tearDown(void) { /* Nothing to tear down */ } /* Test small odd moduli with fixed values, including composite moduli. */ static void test_mulredc_small_moduli_fixed(void) { const mp_limb_t moduli[] = {3, 5, 7, 9, 11, 15, 21, 25, 27, 33, 35}; for (size_t i = 0; i < sizeof(moduli)/sizeof(moduli[0]); i++) { mp_limb_t m = moduli[i]; /* Try various a,b pairs including edges */ mp_limb_t vals[] = {0, 1, (m > 2 ? 2 : 1), m - 2, m - 1}; size_t nvals = sizeof(vals)/sizeof(vals[0]); for (size_t ia = 0; ia < nvals; ia++) { for (size_t ib = 0; ib < nvals; ib++) { mp_limb_t a = vals[ia] % m; mp_limb_t b = vals[ib] % m; check_case(m, a, b); /* Commutativity */ mp_limb_t mi = binv_limb(m); mp_limb_t aR = mont_redcify_value(a, m); mp_limb_t bR = mont_redcify_value(b, m); mp_limb_t r1 = mulredc(aR, bR, m, mi); mp_limb_t r2 = mulredc(bR, aR, m, mi); TEST_ASSERT_EQUAL_HEX64((uint64_t)r1, (uint64_t)r2); } } } } /* Test identities: multiplication by 0 and by 1 (in Montgomery form). */ static void test_mulredc_identities(void) { /* Choose an odd modulus not too small */ mp_limb_t m = 1000003u; /* Prime, odd, < 2^32 and < 2^64 */ mp_limb_t mi = binv_limb(m); /* R = B mod m, and 1R = R, 0R = 0 */ mp_limb_t R = mont_R_mod_m(m); /* Test 0 * x = 0 */ for (mp_limb_t x = 0; x < 20; x++) { mp_limb_t xR = mont_redcify_value(x, m); mp_limb_t zeroR = mont_redcify_value(0, m); mp_limb_t got = mulredc(zeroR, xR, m, mi); TEST_ASSERT_EQUAL_HEX64(0u, (uint64_t)got); got = mulredc(xR, zeroR, m, mi); TEST_ASSERT_EQUAL_HEX64(0u, (uint64_t)got); } /* Test 1 * x = x in redc form */ for (mp_limb_t x = 0; x < 20; x++) { mp_limb_t xR = mont_redcify_value(x, m); /* "1" in redc form is R. */ mp_limb_t got = mulredc(R, xR, m, mi); TEST_ASSERT_EQUAL_HEX64((uint64_t)xR, (uint64_t)got); got = mulredc(xR, R, m, mi); TEST_ASSERT_EQUAL_HEX64((uint64_t)xR, (uint64_t)got); } } /* Randomized tests under a fixed modulus. */ static void test_mulredc_random_values(void) { mp_limb_t m = 4294967291u; /* large 32-bit prime, odd; OK on 64-bit too */ mp_limb_t mi = binv_limb(m); /* Simple LCG for deterministic pseudo-randoms */ uint64_t seed = 0x123456789abcdef0ULL; for (int i = 0; i < 200; i++) { seed = seed * 6364136223846793005ULL + 1ULL; mp_limb_t a = (mp_limb_t)(seed % m); seed = seed * 6364136223846793005ULL + 1ULL; mp_limb_t b = (mp_limb_t)(seed % m); check_case(m, a, b); /* Also directly check commutativity */ mp_limb_t aR = mont_redcify_value(a, m); mp_limb_t bR = mont_redcify_value(b, m); mp_limb_t r1 = mulredc(aR, bR, m, mi); mp_limb_t r2 = mulredc(bR, aR, m, mi); TEST_ASSERT_EQUAL_HEX64((uint64_t)r1, (uint64_t)r2); } } /* Test a modulus near the maximum limb to exercise carries and subtraction. */ static void test_mulredc_large_modulus_edge(void) { mp_limb_t m = MP_LIMB_MAX - 123; /* ensure odd */ if ((m & 1u) == 0) m -= 1; /* Keep a,b just below m to exercise large products */ mp_limb_t a = m - 2; mp_limb_t b = m - 3; check_case(m, a, b); /* A few more around the edge */ check_case(m, m - 1, m - 1); check_case(m, m - 5, m - 7); check_case(m, m / 2, (m / 2) - 1); } /* Test repeated squaring: (a^k) in redc form via repeated mulredc. */ static void test_mulredc_repeated_squaring(void) { mp_limb_t m = 1000003u; /* prime */ mp_limb_t mi = binv_limb(m); mp_limb_t a = 123456u % m; mp_limb_t aR = mont_redcify_value(a, m); mp_limb_t X = aR; /* redc(a^1) */ for (unsigned k = 2; k <= 16; k++) { X = mulredc(X, aR, m, mi); /* redc(a^k) */ /* Compute expected redc(a^k) */ /* Using GMP: (a^k mod m) * R mod m */ mpz_t A, M, T, Rz; mpz_inits(A, M, T, Rz, NULL); mpz_set_limb(A, a); mpz_set_limb(M, m); mpz_pow_ui(T, A, k); mpz_mod(T, T, M); mp_limb_t R = mont_R_mod_m(m); mpz_set_limb(Rz, R); mpz_mul(T, T, Rz); mpz_mod(T, T, M); mp_limb_t exp = mpz_getlimbn(T, 0); mpz_clears(A, M, T, Rz, NULL); TEST_ASSERT_TRUE(X < m); TEST_ASSERT_EQUAL_HEX64((uint64_t)exp, (uint64_t)X); } } int main(void) { UNITY_BEGIN(); RUN_TEST(test_mulredc_small_moduli_fixed); RUN_TEST(test_mulredc_identities); RUN_TEST(test_mulredc_random_values); RUN_TEST(test_mulredc_large_modulus_edge); RUN_TEST(test_mulredc_repeated_squaring); return UNITY_END(); }