coreutils / tests /factor /tests_for_mulredc2.c
AryaWu's picture
Upload folder using huggingface_hub
78d2150 verified
#include "../../unity/unity.h"
#include <stdlib.h>
#include <stdint.h>
#include <stdbool.h>
#include <gmp.h>
/* Unity hooks */
void setUp(void) {
/* Setup code here, or leave empty */
}
void tearDown(void) {
/* Cleanup code here, or leave empty */
}
/* Helper: build mpz from two limbs (x1,x0) => x1*B + x0 */
static void mpz_from_pair(mpz_t out, mp_limb_t x1, mp_limb_t x0) {
mpz_set_ui(out, x1);
mpz_mul_2exp(out, out, W_TYPE_SIZE);
mpz_add_ui(out, out, x0);
}
/* Helper: build mpz for modulus m = m1*B + m0 */
static void build_modulus(mpz_t m, mp_limb_t m1, mp_limb_t m0) {
mpz_from_pair(m, m1, m0);
}
/* Helper: compute B = 2^W_TYPE_SIZE as mpz */
static void build_B(mpz_t B) {
mpz_set_ui(B, 1);
mpz_mul_2exp(B, B, W_TYPE_SIZE);
}
/* Helper: compute redc representation (x * B^2 mod m); extract two limbs */
static void redc_rep_from_normal(mp_limb_t *out1, mp_limb_t *out0,
const mpz_t x_norm, const mpz_t m) {
mpz_t tmp, B, R;
mpz_inits(tmp, B, R, NULL);
build_B(B); /* B = 2^W */
mpz_mul_2exp(R, B, W_TYPE_SIZE); /* R = B^2 = 2^(2W) */
mpz_mul(tmp, x_norm, R);
mpz_mod(tmp, tmp, m);
/* Extract two limbs (low and high) */
mp_limb_t a0 = mpz_getlimbn(tmp, 0);
mp_limb_t a1 = mpz_getlimbn(tmp, 1);
*out0 = a0;
*out1 = a1;
mpz_clears(tmp, B, R, NULL);
}
/* Helper: compute expected redc product (a*b*B^2 mod m) and extract limbs */
static void expected_redc_product_pair(mp_limb_t *e1, mp_limb_t *e0,
const mpz_t a_norm, const mpz_t b_norm,
const mpz_t m) {
mpz_t B, R, tmp;
mpz_inits(B, R, tmp, NULL);
build_B(B);
mpz_mul_2exp(R, B, W_TYPE_SIZE); /* R = B^2 */
mpz_mul(tmp, a_norm, b_norm); /* a*b */
mpz_mod(tmp, tmp, m);
mpz_mul(tmp, tmp, R); /* a*b*B^2 */
mpz_mod(tmp, tmp, m);
*e0 = mpz_getlimbn(tmp, 0);
*e1 = mpz_getlimbn(tmp, 1);
mpz_clears(B, R, tmp, NULL);
}
/* Helper: combine pair to mpz for comparisons */
static void pair_to_mpz(mpz_t out, mp_limb_t x1, mp_limb_t x0) {
mpz_from_pair(out, x1, x0);
}
/* Test 1: one-limb modulus, simple values */
static void run_one_limb_case(mp_limb_t m0, mp_limb_t a_val, mp_limb_t b_val) {
TEST_ASSERT_MESSAGE((m0 & 1u) == 1u, "Modulus must be odd");
mpz_t m, a_norm, b_norm, res_mpz;
mpz_inits(m, a_norm, b_norm, res_mpz, NULL);
build_modulus(m, 0, m0);
mpz_set_ui(a_norm, a_val % m0);
mpz_set_ui(b_norm, b_val % m0);
mp_limb_t A1, A0, B1, B0;
redc_rep_from_normal(&A1, &A0, a_norm, m);
redc_rep_from_normal(&B1, &B0, b_norm, m);
mp_limb_t mi = binv_limb(m0);
mp_limb_t r1;
mp_limb_t r0 = mulredc2(&r1, A1, A0, B1, B0, 0, m0, mi);
/* Expected */
mp_limb_t E1, E0;
expected_redc_product_pair(&E1, &E0, a_norm, b_norm, m);
/* Verify */
TEST_ASSERT_EQUAL_UINT32(0, r1 == 0 ? 0u : 1u); /* for one-limb m, result < m < B => r1==0 */
TEST_ASSERT(r0 == E0);
TEST_ASSERT(r1 == E1);
/* Also check result < m */
pair_to_mpz(res_mpz, r1, r0);
TEST_ASSERT(mpz_cmp(res_mpz, m) < 0);
mpz_clears(m, a_norm, b_norm, res_mpz, NULL);
}
void test_mulredc2_one_limb_basic(void) {
/* Small odd prime modulus */
run_one_limb_case(101u, 5u, 7u);
}
void test_mulredc2_one_limb_edge_values(void) {
/* Use modulus near B but still one-limb; choose odd m0 = B-1 */
mp_limb_t Bm1 = (mp_limb_t)(~(mp_limb_t)0); /* B-1 (all ones) is odd */
/* Test a variety of normal values */
run_one_limb_case(Bm1, 0u, 123u); /* zero */
run_one_limb_case(Bm1, 1u, 1u); /* identity */
run_one_limb_case(Bm1, 2u, 3u);
run_one_limb_case(Bm1, (mp_limb_t)17, (mp_limb_t)19);
run_one_limb_case(Bm1, Bm1 - 1u, Bm1 - 3u); /* values near modulus */
}
/* Test 2: two-limb modulus to exercise two-limb path */
void test_mulredc2_two_limb_basic(void) {
/* Choose m = (m1,m0) with m1 small and m0 odd near B-1 to ensure odd modulus */
mp_limb_t m1 = 1; /* MSB of m1 is clear */
mp_limb_t m0 = (mp_limb_t)(~(mp_limb_t)0); /* B-1 (odd) */
mpz_t m, a_norm, b_norm, res_mpz;
mpz_inits(m, a_norm, b_norm, res_mpz, NULL);
build_modulus(m, m1, m0);
/* Pick a_norm ~ B + 12345, b_norm ~ B - 54321, both < m (~2B-1) */
mpz_t B;
mpz_init(B);
build_B(B);
mpz_add_ui(a_norm, B, 12345u);
mpz_t tmp;
mpz_init(tmp);
mpz_sub_ui(tmp, B, 54321u);
mpz_set(b_norm, tmp);
/* Ensure < m */
TEST_ASSERT(mpz_cmp(a_norm, m) < 0);
TEST_ASSERT(mpz_cmp(b_norm, m) < 0);
mp_limb_t A1, A0, B1, B0;
redc_rep_from_normal(&A1, &A0, a_norm, m);
redc_rep_from_normal(&B1, &B0, b_norm, m);
mp_limb_t mi = binv_limb(m0);
mp_limb_t r1;
mp_limb_t r0 = mulredc2(&r1, A1, A0, B1, B0, m1, m0, mi);
/* Expected */
mp_limb_t E1, E0;
expected_redc_product_pair(&E1, &E0, a_norm, b_norm, m);
TEST_ASSERT(r0 == E0);
TEST_ASSERT(r1 == E1);
pair_to_mpz(res_mpz, r1, r0);
TEST_ASSERT(mpz_cmp(res_mpz, m) < 0);
mpz_clears(m, a_norm, b_norm, res_mpz, B, tmp, NULL);
}
/* Test 3: identity element in redc (1 * R mod m) */
void test_mulredc2_identity_and_zero(void) {
/* Modulus: one-limb, odd */
mp_limb_t m0 = 65537u; /* odd prime */
mp_limb_t m1 = 0;
mpz_t m, a_norm, one, zero, res_mpz;
mpz_inits(m, a_norm, one, zero, res_mpz, NULL);
build_modulus(m, m1, m0);
mpz_set_ui(a_norm, 123456u % m0);
mpz_set_ui(one, 1u);
mpz_set_ui(zero, 0u);
/* Compute redc representations */
mp_limb_t A1, A0, One1, One0, Zero1, Zero0;
redc_rep_from_normal(&A1, &A0, a_norm, m);
redc_rep_from_normal(&One1, &One0, one, m);
redc_rep_from_normal(&Zero1, &Zero0, zero, m);
mp_limb_t mi = binv_limb(m0);
/* A * 1 = A */
mp_limb_t r1_id;
mp_limb_t r0_id = mulredc2(&r1_id, A1, A0, One1, One0, m1, m0, mi);
TEST_ASSERT(r0_id == A0);
TEST_ASSERT(r1_id == A1);
/* 1 * A = A */
mp_limb_t r1_id2;
mp_limb_t r0_id2 = mulredc2(&r1_id2, One1, One0, A1, A0, m1, m0, mi);
TEST_ASSERT(r0_id2 == A0);
TEST_ASSERT(r1_id2 == A1);
/* A * 0 = 0 */
mp_limb_t r1_zero1;
mp_limb_t r0_zero1 = mulredc2(&r1_zero1, A1, A0, Zero1, Zero0, m1, m0, mi);
TEST_ASSERT(r0_zero1 == 0);
TEST_ASSERT(r1_zero1 == 0);
/* 0 * A = 0 */
mp_limb_t r1_zero2;
mp_limb_t r0_zero2 = mulredc2(&r1_zero2, Zero1, Zero0, A1, A0, m1, m0, mi);
TEST_ASSERT(r0_zero2 == 0);
TEST_ASSERT(r1_zero2 == 0);
mpz_clears(m, a_norm, one, zero, res_mpz, NULL);
}
/* Test 4: commutativity A*B == B*A in redc */
void test_mulredc2_commutativity(void) {
mp_limb_t m1 = 1;
mp_limb_t m0 = (mp_limb_t)(~(mp_limb_t)0); /* B-1, odd */
mpz_t m, a_norm, b_norm;
mpz_inits(m, a_norm, b_norm, NULL);
build_modulus(m, m1, m0);
mpz_t B;
mpz_init(B);
build_B(B);
/* Choose two arbitrary normal values less than m */
mpz_add_ui(a_norm, B, 333u); /* ~B + 333 */
mpz_sub_ui(b_norm, B, 777u); /* ~B - 777 */
TEST_ASSERT(mpz_cmp(a_norm, m) < 0);
TEST_ASSERT(mpz_cmp(b_norm, m) < 0);
mp_limb_t A1, A0, B1, B0;
redc_rep_from_normal(&A1, &A0, a_norm, m);
redc_rep_from_normal(&B1, &B0, b_norm, m);
mp_limb_t mi = binv_limb(m0);
mp_limb_t r1_ab, r1_ba;
mp_limb_t r0_ab = mulredc2(&r1_ab, A1, A0, B1, B0, m1, m0, mi);
mp_limb_t r0_ba = mulredc2(&r1_ba, B1, B0, A1, A0, m1, m0, mi);
TEST_ASSERT(r0_ab == r0_ba);
TEST_ASSERT(r1_ab == r1_ba);
mpz_clears(m, a_norm, b_norm, B, NULL);
}
int main(void) {
UNITY_BEGIN();
RUN_TEST(test_mulredc2_one_limb_basic);
RUN_TEST(test_mulredc2_one_limb_edge_values);
RUN_TEST(test_mulredc2_two_limb_basic);
RUN_TEST(test_mulredc2_identity_and_zero);
RUN_TEST(test_mulredc2_commutativity);
return UNITY_END();
}