pqc / rust-engine /src /lwe.rs
wuhp's picture
Update rust-engine/src/lwe.rs
9447f63 verified
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use rand_distr::{Distribution, Uniform};
pub const N: usize = 2048;
pub type Poly = [i16; N];
/// Primitive 256-th roots of unity (zetas) for Kyber NTT
/// These are bit-reversed powers of zeta=17 mod 3329
const ZETAS: [i16; 128] = [
1, -1600, -749, -40, -687, 630, -1432, 848, 1062, -1410, 193, 797, -543, -69, 569, -1583, 296, -882, 1339, 1476, -283, 56, -1089, 1333, 1426, -1235, 535, -447, -936, -450, -1355, 821, 289, 331, -76, -1573, 1197, -1025, -1052, -1274, 650, -1352, -816, 632, -464, 33, 1320, -1414, -1010, 1435, 807, 452, 1438, -461, 1534, -927, -682, -712, 1481, 648, -855, -219, 1227, 910, 17, -568, 583, -680, 1637, 723, -1041, 1100, 1409, -667, -48, 233, 756, -1173, -314, -279, -1626, 1651, -540, -1540, -1482, 952, 1461, -642, 939, -1021, -892, -941, 733, -992, 268, 641, 1584, -1031, -1292, -109, 375, -780, -1239, 1645, 1063, 319, -556, 757, -1230, 561, -863, -735, -525, 1092, 403, 1026, 1143, -1179, -554, 886, -1607, 1212, -1455, 1029, -1219, -394, 885, -1175,
];
pub struct RingLwe {
pub n: usize,
pub q: i16,
pub std_dev: f64,
pub is_cbd: bool,
}
impl RingLwe {
pub fn new(n: usize, q: i32, std_dev: f64) -> Self {
Self { n, q: q as i16, std_dev, is_cbd: true }
}
pub fn with_type(n: usize, q: i32, noise_param: f64, is_cbd: bool) -> Self {
assert!(n <= N, "Dimension n {} exceeds max N {}", n, N);
Self { n, q: q as i16, std_dev: noise_param, is_cbd }
}
/// Generic reduction for any q
fn reduce_gen(&self, a: i32) -> i16 {
let mut res = (a % self.q as i32) as i16;
res += (res >> 15) & self.q;
res
}
/// Barrett reduction specifically for q=3329
fn f_reduce(&self, a: i32) -> i16 {
if self.q != 3329 { return self.reduce_gen(a); } // Safety
let v = 20158i32;
let t = ((v as i64 * a as i64) >> 26) as i32;
let mut res = (a - t * 3329) as i16;
res -= 3329 & (((3328 - res as i32) >> 15) as i16);
res += 3329 & ((res as i32 >> 15) as i16);
res
}
fn sample_noise(&self, rng: &mut ChaCha20Rng, poly: &mut Poly) {
if self.is_cbd {
let eta = self.std_dev.round() as usize;
for i in 0..self.n {
let mut a = 0;
let mut b = 0;
for _ in 0..eta {
if rng.gen::<bool>() { a += 1; }
if rng.gen::<bool>() { b += 1; }
}
poly[i] = self.f_reduce(a - b);
}
} else {
let dist = rand_distr::Normal::new(0.0, self.std_dev).unwrap();
for i in 0..self.n {
poly[i] = self.reduce_gen(dist.sample(rng).round() as i32);
}
}
}
pub fn sample_noise_debug(&self) -> i32 {
let mut rng = ChaCha20Rng::from_entropy();
let mut poly = [0i16; N];
self.sample_noise(&mut rng, &mut poly);
let val = poly[0] as i32;
if val > (self.q as i32) / 2 { val - self.q as i32 } else { val }
}
/// Forward NTT (7 stages)
pub fn ntt(&self, p: &mut Poly) {
if self.n != 256 || self.q != 3329 { return; }
let mut k = 1;
let mut len = 128;
while len >= 2 {
for start in (0..self.n).step_by(2 * len) {
let zeta = ZETAS[k];
k += 1;
for j in start..(start + len) {
let t = self.f_reduce(zeta as i32 * p[j + len] as i32);
let pj = p[j];
p[j + len] = self.f_reduce(pj as i32 - t as i32);
p[j] = self.f_reduce(pj as i32 + t as i32);
}
}
len >>= 1;
}
for i in 0..self.n { p[i] = self.f_reduce(p[i] as i32); }
}
/// Inverse NTT (7 stages)
pub fn intt(&self, p: &mut Poly) {
if self.n != 256 || self.q != 3329 { return; }
let mut k = 127;
let mut len = 2;
while len <= 128 {
for start in (0..self.n).step_by(2 * len) {
let zeta = ZETAS[k];
k -= 1;
for j in start..(start + len) {
let t = p[j];
p[j] = self.f_reduce(t as i32 + p[j + len] as i32);
let val = p[j + len] as i32 - t as i32;
p[j + len] = self.f_reduce(-zeta as i32 * val);
}
}
len <<= 1;
}
let n_inv: i32 = 3303; // 128^-1 mod 3329
for i in 0..self.n {
p[i] = self.f_reduce(p[i] as i32 * n_inv);
}
}
/// Base case multiplication for Kyber NTT (degree 1 fragments)
fn basemul(&self, a0: i16, a1: i16, b0: i16, b1: i16, zeta: i16) -> (i16, i16) {
let r0 = self.f_reduce(a0 as i32 * b0 as i32 + self.f_reduce(a1 as i32 * b1 as i32) as i32 * zeta as i32);
let r1 = self.f_reduce(a0 as i32 * b1 as i32 + a1 as i32 * b0 as i32);
(r0, r1)
}
fn ntt_mul(&self, a: &Poly, b: &Poly) -> Poly {
let mut c = [0i16; N];
for i in 0..64 {
let (r0, r1) = self.basemul(a[4*i], a[4*i+1], b[4*i], b[4*i+1], ZETAS[64 + i]);
let (r2, r3) = self.basemul(a[4*i+2], a[4*i+3], b[4*i+2], b[4*i+3], -ZETAS[64 + i]);
c[4*i] = r0; c[4*i+1] = r1;
c[4*i+2] = r2; c[4*i+3] = r3;
}
c
}
pub fn poly_mul(&self, a: &Poly, b: &Poly) -> Poly {
// Only use NTT for Kyber parameters
if self.n == 256 && self.q == 3329 {
return self.ntt_mul(a, b);
}
// Fallback to Schoolbook multiplication (slow but correct for research)
let mut c = [0i16; N];
for i in 0..self.n {
for j in 0..self.n {
let mut idx = i + j;
let mut sign = 1;
if idx >= self.n {
idx -= self.n;
sign = -1; // Ring is x^n + 1
}
let prod = a[i] as i32 * b[j] as i32 * sign;
c[idx] = self.reduce_gen(c[idx] as i32 + prod);
}
}
c
}
pub fn keygen_with_rng(&self, rng: &mut ChaCha20Rng) -> ([u8; 32], Poly, Poly) {
let mut seed = [0u8; 32];
rng.fill(&mut seed);
let mut a_rng = ChaCha20Rng::from_seed(seed);
let mut a_poly = [0i16; N];
let uniform = Uniform::new(0, self.q);
for i in 0..self.n { a_poly[i] = uniform.sample(&mut a_rng); }
let mut s = [0i16; N];
let mut e = [0i16; N];
self.sample_noise(rng, &mut s);
self.sample_noise(rng, &mut e);
let use_ntt = self.n == 256 && self.q == 3329;
let mut s_domain = s;
let mut e_domain = e;
let mut a_domain = a_poly;
if use_ntt {
self.ntt(&mut s_domain);
self.ntt(&mut e_domain);
self.ntt(&mut a_domain);
}
let mut t_domain = [0i16; N];
let prod = self.poly_mul(&a_domain, &s_domain);
for i in 0..self.n {
t_domain[i] = self.reduce_gen(prod[i] as i32 + e_domain[i] as i32);
}
(seed, t_domain, s_domain)
}
pub fn keygen(&self) -> ([u8; 32], Poly, Poly) {
let mut rng = ChaCha20Rng::from_entropy();
self.keygen_with_rng(&mut rng)
}
pub fn encrypt_with_rng(&self, rng: &mut ChaCha20Rng, seed: &[u8; 32], t_domain: &Poly, ptxt: &[u8; 32]) -> (Poly, Poly) {
let mut a_rng = ChaCha20Rng::from_seed(*seed);
let mut a_poly = [0i16; N];
let uniform = Uniform::new(0, self.q);
for i in 0..self.n { a_poly[i] = uniform.sample(&mut a_rng); }
let mut r = [0i16; N];
let mut e1 = [0i16; N];
let mut e2 = [0i16; N];
self.sample_noise(rng, &mut r);
self.sample_noise(rng, &mut e1);
self.sample_noise(rng, &mut e2);
let use_ntt = self.n == 256 && self.q == 3329;
let mut r_domain = r;
let mut e1_domain = e1;
let mut a_domain = a_poly;
if use_ntt {
self.ntt(&mut r_domain);
self.ntt(&mut e1_domain);
self.ntt(&mut a_domain);
}
let ar_prod = self.poly_mul(&a_domain, &r_domain);
let mut u_domain = [0i16; N];
for i in 0..self.n {
u_domain[i] = self.reduce_gen(ar_prod[i] as i32 + e1_domain[i] as i32);
}
let mut tr_prod = self.poly_mul(t_domain, &r_domain);
if use_ntt {
self.intt(&mut tr_prod);
}
let mut v = [0i16; N];
for i in 0..self.n {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (ptxt[byte_idx % 32] >> bit_idx) & 1;
let m_encoded = (bit as i16 * (self.q / 2)) as i16;
v[i] = self.reduce_gen(tr_prod[i] as i32 + e2[i] as i32 + m_encoded as i32);
}
(u_domain, v)
}
pub fn encrypt(&self, seed: &[u8; 32], t_domain: &Poly, ptxt: &[u8; 32]) -> (Poly, Poly) {
let mut rng = ChaCha20Rng::from_entropy();
self.encrypt_with_rng(&mut rng, seed, t_domain, ptxt)
}
pub fn decrypt(&self, s_domain: &Poly, u_domain: &Poly, v: &Poly) -> [u8; 32] {
let mut su_prod = self.poly_mul(s_domain, u_domain);
let use_ntt = self.n == 256 && self.q == 3329;
if use_ntt {
self.intt(&mut su_prod);
}
let mut ptxt = [0u8; 32];
for i in 0..self.n {
let mut m_approx = (v[i] as i32 - su_prod[i] as i32) % self.q as i32;
m_approx += (m_approx >> 15) & self.q as i32;
let val = (m_approx * 2 + self.q as i32 / 2) / self.q as i32;
let bit = (val & 1) as u8;
let byte_idx = i / 8;
let bit_idx = i % 8;
if byte_idx < 32 {
ptxt[byte_idx] |= bit << bit_idx;
}
}
ptxt
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ntt_roundtrip() {
let rlwe = RingLwe::new(256, 3329, 2.0);
let mut poly = [0i16; N];
for i in 0..256 {
poly[i] = (i as i16 * 17) % 3329;
}
let original = poly.clone();
rlwe.ntt(&mut poly);
rlwe.intt(&mut poly);
// Let's actually verify what INTT does. In Kyber, the result of INTT(NTT(a)) is a * 128
// Or wait, INTT scales by 128^-1, but it also uses a 128-sized NTT. Actually, NTT transforms 256 coeffs to 128 degree-1 polys.
// Wait, standard Kyber NTT/INTT roundtrip is identity. Let's just assert it is identity.
for i in 0..256 {
let mut diff = (poly[i] as i32 - original[i] as i32) % 3329;
if diff < 0 { diff += 3329; }
assert_eq!(diff, 0, "Mismatch at index {}", i);
}
}
}