#include "../../unity/unity.h" #include #include #include #include /* Unity hooks */ void setUp(void) { /* Setup code here, or leave empty */ } void tearDown(void) { /* Cleanup code here, or leave empty */ } /* Helper: build mpz from two limbs (x1,x0) => x1*B + x0 */ static void mpz_from_pair(mpz_t out, mp_limb_t x1, mp_limb_t x0) { mpz_set_ui(out, x1); mpz_mul_2exp(out, out, W_TYPE_SIZE); mpz_add_ui(out, out, x0); } /* Helper: build mpz for modulus m = m1*B + m0 */ static void build_modulus(mpz_t m, mp_limb_t m1, mp_limb_t m0) { mpz_from_pair(m, m1, m0); } /* Helper: compute B = 2^W_TYPE_SIZE as mpz */ static void build_B(mpz_t B) { mpz_set_ui(B, 1); mpz_mul_2exp(B, B, W_TYPE_SIZE); } /* Helper: compute redc representation (x * B^2 mod m); extract two limbs */ static void redc_rep_from_normal(mp_limb_t *out1, mp_limb_t *out0, const mpz_t x_norm, const mpz_t m) { mpz_t tmp, B, R; mpz_inits(tmp, B, R, NULL); build_B(B); /* B = 2^W */ mpz_mul_2exp(R, B, W_TYPE_SIZE); /* R = B^2 = 2^(2W) */ mpz_mul(tmp, x_norm, R); mpz_mod(tmp, tmp, m); /* Extract two limbs (low and high) */ mp_limb_t a0 = mpz_getlimbn(tmp, 0); mp_limb_t a1 = mpz_getlimbn(tmp, 1); *out0 = a0; *out1 = a1; mpz_clears(tmp, B, R, NULL); } /* Helper: compute expected redc product (a*b*B^2 mod m) and extract limbs */ static void expected_redc_product_pair(mp_limb_t *e1, mp_limb_t *e0, const mpz_t a_norm, const mpz_t b_norm, const mpz_t m) { mpz_t B, R, tmp; mpz_inits(B, R, tmp, NULL); build_B(B); mpz_mul_2exp(R, B, W_TYPE_SIZE); /* R = B^2 */ mpz_mul(tmp, a_norm, b_norm); /* a*b */ mpz_mod(tmp, tmp, m); mpz_mul(tmp, tmp, R); /* a*b*B^2 */ mpz_mod(tmp, tmp, m); *e0 = mpz_getlimbn(tmp, 0); *e1 = mpz_getlimbn(tmp, 1); mpz_clears(B, R, tmp, NULL); } /* Helper: combine pair to mpz for comparisons */ static void pair_to_mpz(mpz_t out, mp_limb_t x1, mp_limb_t x0) { mpz_from_pair(out, x1, x0); } /* Test 1: one-limb modulus, simple values */ static void run_one_limb_case(mp_limb_t m0, mp_limb_t a_val, mp_limb_t b_val) { TEST_ASSERT_MESSAGE((m0 & 1u) == 1u, "Modulus must be odd"); mpz_t m, a_norm, b_norm, res_mpz; mpz_inits(m, a_norm, b_norm, res_mpz, NULL); build_modulus(m, 0, m0); mpz_set_ui(a_norm, a_val % m0); mpz_set_ui(b_norm, b_val % m0); mp_limb_t A1, A0, B1, B0; redc_rep_from_normal(&A1, &A0, a_norm, m); redc_rep_from_normal(&B1, &B0, b_norm, m); mp_limb_t mi = binv_limb(m0); mp_limb_t r1; mp_limb_t r0 = mulredc2(&r1, A1, A0, B1, B0, 0, m0, mi); /* Expected */ mp_limb_t E1, E0; expected_redc_product_pair(&E1, &E0, a_norm, b_norm, m); /* Verify */ TEST_ASSERT_EQUAL_UINT32(0, r1 == 0 ? 0u : 1u); /* for one-limb m, result < m < B => r1==0 */ TEST_ASSERT(r0 == E0); TEST_ASSERT(r1 == E1); /* Also check result < m */ pair_to_mpz(res_mpz, r1, r0); TEST_ASSERT(mpz_cmp(res_mpz, m) < 0); mpz_clears(m, a_norm, b_norm, res_mpz, NULL); } void test_mulredc2_one_limb_basic(void) { /* Small odd prime modulus */ run_one_limb_case(101u, 5u, 7u); } void test_mulredc2_one_limb_edge_values(void) { /* Use modulus near B but still one-limb; choose odd m0 = B-1 */ mp_limb_t Bm1 = (mp_limb_t)(~(mp_limb_t)0); /* B-1 (all ones) is odd */ /* Test a variety of normal values */ run_one_limb_case(Bm1, 0u, 123u); /* zero */ run_one_limb_case(Bm1, 1u, 1u); /* identity */ run_one_limb_case(Bm1, 2u, 3u); run_one_limb_case(Bm1, (mp_limb_t)17, (mp_limb_t)19); run_one_limb_case(Bm1, Bm1 - 1u, Bm1 - 3u); /* values near modulus */ } /* Test 2: two-limb modulus to exercise two-limb path */ void test_mulredc2_two_limb_basic(void) { /* Choose m = (m1,m0) with m1 small and m0 odd near B-1 to ensure odd modulus */ mp_limb_t m1 = 1; /* MSB of m1 is clear */ mp_limb_t m0 = (mp_limb_t)(~(mp_limb_t)0); /* B-1 (odd) */ mpz_t m, a_norm, b_norm, res_mpz; mpz_inits(m, a_norm, b_norm, res_mpz, NULL); build_modulus(m, m1, m0); /* Pick a_norm ~ B + 12345, b_norm ~ B - 54321, both < m (~2B-1) */ mpz_t B; mpz_init(B); build_B(B); mpz_add_ui(a_norm, B, 12345u); mpz_t tmp; mpz_init(tmp); mpz_sub_ui(tmp, B, 54321u); mpz_set(b_norm, tmp); /* Ensure < m */ TEST_ASSERT(mpz_cmp(a_norm, m) < 0); TEST_ASSERT(mpz_cmp(b_norm, m) < 0); mp_limb_t A1, A0, B1, B0; redc_rep_from_normal(&A1, &A0, a_norm, m); redc_rep_from_normal(&B1, &B0, b_norm, m); mp_limb_t mi = binv_limb(m0); mp_limb_t r1; mp_limb_t r0 = mulredc2(&r1, A1, A0, B1, B0, m1, m0, mi); /* Expected */ mp_limb_t E1, E0; expected_redc_product_pair(&E1, &E0, a_norm, b_norm, m); TEST_ASSERT(r0 == E0); TEST_ASSERT(r1 == E1); pair_to_mpz(res_mpz, r1, r0); TEST_ASSERT(mpz_cmp(res_mpz, m) < 0); mpz_clears(m, a_norm, b_norm, res_mpz, B, tmp, NULL); } /* Test 3: identity element in redc (1 * R mod m) */ void test_mulredc2_identity_and_zero(void) { /* Modulus: one-limb, odd */ mp_limb_t m0 = 65537u; /* odd prime */ mp_limb_t m1 = 0; mpz_t m, a_norm, one, zero, res_mpz; mpz_inits(m, a_norm, one, zero, res_mpz, NULL); build_modulus(m, m1, m0); mpz_set_ui(a_norm, 123456u % m0); mpz_set_ui(one, 1u); mpz_set_ui(zero, 0u); /* Compute redc representations */ mp_limb_t A1, A0, One1, One0, Zero1, Zero0; redc_rep_from_normal(&A1, &A0, a_norm, m); redc_rep_from_normal(&One1, &One0, one, m); redc_rep_from_normal(&Zero1, &Zero0, zero, m); mp_limb_t mi = binv_limb(m0); /* A * 1 = A */ mp_limb_t r1_id; mp_limb_t r0_id = mulredc2(&r1_id, A1, A0, One1, One0, m1, m0, mi); TEST_ASSERT(r0_id == A0); TEST_ASSERT(r1_id == A1); /* 1 * A = A */ mp_limb_t r1_id2; mp_limb_t r0_id2 = mulredc2(&r1_id2, One1, One0, A1, A0, m1, m0, mi); TEST_ASSERT(r0_id2 == A0); TEST_ASSERT(r1_id2 == A1); /* A * 0 = 0 */ mp_limb_t r1_zero1; mp_limb_t r0_zero1 = mulredc2(&r1_zero1, A1, A0, Zero1, Zero0, m1, m0, mi); TEST_ASSERT(r0_zero1 == 0); TEST_ASSERT(r1_zero1 == 0); /* 0 * A = 0 */ mp_limb_t r1_zero2; mp_limb_t r0_zero2 = mulredc2(&r1_zero2, Zero1, Zero0, A1, A0, m1, m0, mi); TEST_ASSERT(r0_zero2 == 0); TEST_ASSERT(r1_zero2 == 0); mpz_clears(m, a_norm, one, zero, res_mpz, NULL); } /* Test 4: commutativity A*B == B*A in redc */ void test_mulredc2_commutativity(void) { mp_limb_t m1 = 1; mp_limb_t m0 = (mp_limb_t)(~(mp_limb_t)0); /* B-1, odd */ mpz_t m, a_norm, b_norm; mpz_inits(m, a_norm, b_norm, NULL); build_modulus(m, m1, m0); mpz_t B; mpz_init(B); build_B(B); /* Choose two arbitrary normal values less than m */ mpz_add_ui(a_norm, B, 333u); /* ~B + 333 */ mpz_sub_ui(b_norm, B, 777u); /* ~B - 777 */ TEST_ASSERT(mpz_cmp(a_norm, m) < 0); TEST_ASSERT(mpz_cmp(b_norm, m) < 0); mp_limb_t A1, A0, B1, B0; redc_rep_from_normal(&A1, &A0, a_norm, m); redc_rep_from_normal(&B1, &B0, b_norm, m); mp_limb_t mi = binv_limb(m0); mp_limb_t r1_ab, r1_ba; mp_limb_t r0_ab = mulredc2(&r1_ab, A1, A0, B1, B0, m1, m0, mi); mp_limb_t r0_ba = mulredc2(&r1_ba, B1, B0, A1, A0, m1, m0, mi); TEST_ASSERT(r0_ab == r0_ba); TEST_ASSERT(r1_ab == r1_ba); mpz_clears(m, a_norm, b_norm, B, NULL); } int main(void) { UNITY_BEGIN(); RUN_TEST(test_mulredc2_one_limb_basic); RUN_TEST(test_mulredc2_one_limb_edge_values); RUN_TEST(test_mulredc2_two_limb_basic); RUN_TEST(test_mulredc2_identity_and_zero); RUN_TEST(test_mulredc2_commutativity); return UNITY_END(); }