use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use rand_distr::{Distribution, Uniform}; pub const N: usize = 2048; pub type Poly = [i16; N]; /// Primitive 256-th roots of unity (zetas) for Kyber NTT /// These are bit-reversed powers of zeta=17 mod 3329 const ZETAS: [i16; 128] = [ 1, -1600, -749, -40, -687, 630, -1432, 848, 1062, -1410, 193, 797, -543, -69, 569, -1583, 296, -882, 1339, 1476, -283, 56, -1089, 1333, 1426, -1235, 535, -447, -936, -450, -1355, 821, 289, 331, -76, -1573, 1197, -1025, -1052, -1274, 650, -1352, -816, 632, -464, 33, 1320, -1414, -1010, 1435, 807, 452, 1438, -461, 1534, -927, -682, -712, 1481, 648, -855, -219, 1227, 910, 17, -568, 583, -680, 1637, 723, -1041, 1100, 1409, -667, -48, 233, 756, -1173, -314, -279, -1626, 1651, -540, -1540, -1482, 952, 1461, -642, 939, -1021, -892, -941, 733, -992, 268, 641, 1584, -1031, -1292, -109, 375, -780, -1239, 1645, 1063, 319, -556, 757, -1230, 561, -863, -735, -525, 1092, 403, 1026, 1143, -1179, -554, 886, -1607, 1212, -1455, 1029, -1219, -394, 885, -1175, ]; pub struct RingLwe { pub n: usize, pub q: i16, pub std_dev: f64, pub is_cbd: bool, } impl RingLwe { pub fn new(n: usize, q: i32, std_dev: f64) -> Self { Self { n, q: q as i16, std_dev, is_cbd: true } } pub fn with_type(n: usize, q: i32, noise_param: f64, is_cbd: bool) -> Self { assert!(n <= N, "Dimension n {} exceeds max N {}", n, N); Self { n, q: q as i16, std_dev: noise_param, is_cbd } } /// Generic reduction for any q fn reduce_gen(&self, a: i32) -> i16 { let mut res = (a % self.q as i32) as i16; res += (res >> 15) & self.q; res } /// Barrett reduction specifically for q=3329 fn f_reduce(&self, a: i32) -> i16 { if self.q != 3329 { return self.reduce_gen(a); } // Safety let v = 20158i32; let t = ((v as i64 * a as i64) >> 26) as i32; let mut res = (a - t * 3329) as i16; res -= 3329 & (((3328 - res as i32) >> 15) as i16); res += 3329 & ((res as i32 >> 15) as i16); res } fn sample_noise(&self, rng: &mut ChaCha20Rng, poly: &mut Poly) { if self.is_cbd { let eta = self.std_dev.round() as usize; for i in 0..self.n { let mut a = 0; let mut b = 0; for _ in 0..eta { if rng.gen::() { a += 1; } if rng.gen::() { b += 1; } } poly[i] = self.f_reduce(a - b); } } else { let dist = rand_distr::Normal::new(0.0, self.std_dev).unwrap(); for i in 0..self.n { poly[i] = self.reduce_gen(dist.sample(rng).round() as i32); } } } pub fn sample_noise_debug(&self) -> i32 { let mut rng = ChaCha20Rng::from_entropy(); let mut poly = [0i16; N]; self.sample_noise(&mut rng, &mut poly); let val = poly[0] as i32; if val > (self.q as i32) / 2 { val - self.q as i32 } else { val } } /// Forward NTT (7 stages) pub fn ntt(&self, p: &mut Poly) { if self.n != 256 || self.q != 3329 { return; } let mut k = 1; let mut len = 128; while len >= 2 { for start in (0..self.n).step_by(2 * len) { let zeta = ZETAS[k]; k += 1; for j in start..(start + len) { let t = self.f_reduce(zeta as i32 * p[j + len] as i32); let pj = p[j]; p[j + len] = self.f_reduce(pj as i32 - t as i32); p[j] = self.f_reduce(pj as i32 + t as i32); } } len >>= 1; } for i in 0..self.n { p[i] = self.f_reduce(p[i] as i32); } } /// Inverse NTT (7 stages) pub fn intt(&self, p: &mut Poly) { if self.n != 256 || self.q != 3329 { return; } let mut k = 127; let mut len = 2; while len <= 128 { for start in (0..self.n).step_by(2 * len) { let zeta = ZETAS[k]; k -= 1; for j in start..(start + len) { let t = p[j]; p[j] = self.f_reduce(t as i32 + p[j + len] as i32); let val = p[j + len] as i32 - t as i32; p[j + len] = self.f_reduce(-zeta as i32 * val); } } len <<= 1; } let n_inv: i32 = 3303; // 128^-1 mod 3329 for i in 0..self.n { p[i] = self.f_reduce(p[i] as i32 * n_inv); } } /// Base case multiplication for Kyber NTT (degree 1 fragments) fn basemul(&self, a0: i16, a1: i16, b0: i16, b1: i16, zeta: i16) -> (i16, i16) { let r0 = self.f_reduce(a0 as i32 * b0 as i32 + self.f_reduce(a1 as i32 * b1 as i32) as i32 * zeta as i32); let r1 = self.f_reduce(a0 as i32 * b1 as i32 + a1 as i32 * b0 as i32); (r0, r1) } fn ntt_mul(&self, a: &Poly, b: &Poly) -> Poly { let mut c = [0i16; N]; for i in 0..64 { let (r0, r1) = self.basemul(a[4*i], a[4*i+1], b[4*i], b[4*i+1], ZETAS[64 + i]); let (r2, r3) = self.basemul(a[4*i+2], a[4*i+3], b[4*i+2], b[4*i+3], -ZETAS[64 + i]); c[4*i] = r0; c[4*i+1] = r1; c[4*i+2] = r2; c[4*i+3] = r3; } c } pub fn poly_mul(&self, a: &Poly, b: &Poly) -> Poly { // Only use NTT for Kyber parameters if self.n == 256 && self.q == 3329 { return self.ntt_mul(a, b); } // Fallback to Schoolbook multiplication (slow but correct for research) let mut c = [0i16; N]; for i in 0..self.n { for j in 0..self.n { let mut idx = i + j; let mut sign = 1; if idx >= self.n { idx -= self.n; sign = -1; // Ring is x^n + 1 } let prod = a[i] as i32 * b[j] as i32 * sign; c[idx] = self.reduce_gen(c[idx] as i32 + prod); } } c } pub fn keygen_with_rng(&self, rng: &mut ChaCha20Rng) -> ([u8; 32], Poly, Poly) { let mut seed = [0u8; 32]; rng.fill(&mut seed); let mut a_rng = ChaCha20Rng::from_seed(seed); let mut a_poly = [0i16; N]; let uniform = Uniform::new(0, self.q); for i in 0..self.n { a_poly[i] = uniform.sample(&mut a_rng); } let mut s = [0i16; N]; let mut e = [0i16; N]; self.sample_noise(rng, &mut s); self.sample_noise(rng, &mut e); let use_ntt = self.n == 256 && self.q == 3329; let mut s_domain = s; let mut e_domain = e; let mut a_domain = a_poly; if use_ntt { self.ntt(&mut s_domain); self.ntt(&mut e_domain); self.ntt(&mut a_domain); } let mut t_domain = [0i16; N]; let prod = self.poly_mul(&a_domain, &s_domain); for i in 0..self.n { t_domain[i] = self.reduce_gen(prod[i] as i32 + e_domain[i] as i32); } (seed, t_domain, s_domain) } pub fn keygen(&self) -> ([u8; 32], Poly, Poly) { let mut rng = ChaCha20Rng::from_entropy(); self.keygen_with_rng(&mut rng) } pub fn encrypt_with_rng(&self, rng: &mut ChaCha20Rng, seed: &[u8; 32], t_domain: &Poly, ptxt: &[u8; 32]) -> (Poly, Poly) { let mut a_rng = ChaCha20Rng::from_seed(*seed); let mut a_poly = [0i16; N]; let uniform = Uniform::new(0, self.q); for i in 0..self.n { a_poly[i] = uniform.sample(&mut a_rng); } let mut r = [0i16; N]; let mut e1 = [0i16; N]; let mut e2 = [0i16; N]; self.sample_noise(rng, &mut r); self.sample_noise(rng, &mut e1); self.sample_noise(rng, &mut e2); let use_ntt = self.n == 256 && self.q == 3329; let mut r_domain = r; let mut e1_domain = e1; let mut a_domain = a_poly; if use_ntt { self.ntt(&mut r_domain); self.ntt(&mut e1_domain); self.ntt(&mut a_domain); } let ar_prod = self.poly_mul(&a_domain, &r_domain); let mut u_domain = [0i16; N]; for i in 0..self.n { u_domain[i] = self.reduce_gen(ar_prod[i] as i32 + e1_domain[i] as i32); } let mut tr_prod = self.poly_mul(t_domain, &r_domain); if use_ntt { self.intt(&mut tr_prod); } let mut v = [0i16; N]; for i in 0..self.n { let byte_idx = i / 8; let bit_idx = i % 8; let bit = (ptxt[byte_idx % 32] >> bit_idx) & 1; let m_encoded = (bit as i16 * (self.q / 2)) as i16; v[i] = self.reduce_gen(tr_prod[i] as i32 + e2[i] as i32 + m_encoded as i32); } (u_domain, v) } pub fn encrypt(&self, seed: &[u8; 32], t_domain: &Poly, ptxt: &[u8; 32]) -> (Poly, Poly) { let mut rng = ChaCha20Rng::from_entropy(); self.encrypt_with_rng(&mut rng, seed, t_domain, ptxt) } pub fn decrypt(&self, s_domain: &Poly, u_domain: &Poly, v: &Poly) -> [u8; 32] { let mut su_prod = self.poly_mul(s_domain, u_domain); let use_ntt = self.n == 256 && self.q == 3329; if use_ntt { self.intt(&mut su_prod); } let mut ptxt = [0u8; 32]; for i in 0..self.n { let mut m_approx = (v[i] as i32 - su_prod[i] as i32) % self.q as i32; m_approx += (m_approx >> 15) & self.q as i32; let val = (m_approx * 2 + self.q as i32 / 2) / self.q as i32; let bit = (val & 1) as u8; let byte_idx = i / 8; let bit_idx = i % 8; if byte_idx < 32 { ptxt[byte_idx] |= bit << bit_idx; } } ptxt } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ntt_roundtrip() { let rlwe = RingLwe::new(256, 3329, 2.0); let mut poly = [0i16; N]; for i in 0..256 { poly[i] = (i as i16 * 17) % 3329; } let original = poly.clone(); rlwe.ntt(&mut poly); rlwe.intt(&mut poly); // Let's actually verify what INTT does. In Kyber, the result of INTT(NTT(a)) is a * 128 // Or wait, INTT scales by 128^-1, but it also uses a 128-sized NTT. Actually, NTT transforms 256 coeffs to 128 degree-1 polys. // Wait, standard Kyber NTT/INTT roundtrip is identity. Let's just assert it is identity. for i in 0..256 { let mut diff = (poly[i] as i32 - original[i] as i32) % 3329; if diff < 0 { diff += 3329; } assert_eq!(diff, 0, "Mismatch at index {}", i); } } }