pqc / rust-engine /src /fuzzer.rs
wuhp's picture
Update rust-engine/src/fuzzer.rs
a24962b verified
use std::time::Instant;
use serde::{Deserialize, Serialize};
use crate::lwe::RingLwe;
use crate::validator::{ParamValidator, ValidationResult};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FuzzConfig {
pub iterations: usize,
pub n: usize,
pub q: u64,
pub noise_param: f64,
pub noise_type: NoiseType,
pub fuzz_edge_cases: bool,
pub fuzz_ciphertext_malleability: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum NoiseType {
Uniform,
Gaussian,
Cbd,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FuzzStats {
pub iterations: usize,
pub failures: usize,
pub failure_rate: f64,
pub edge_case_failures: usize,
pub malleability_panics: usize,
pub avg_keygen_us: f64,
pub avg_encrypt_us: f64,
pub avg_decrypt_us: f64,
pub total_elapsed_ms: f64,
pub dual_attack_complexity_bits: f64,
pub primal_attack_bits: f64,
pub key_recovery_bits: f64,
pub side_channel_index: f64,
pub observed_noise_max: f64,
pub theoretical_failure_prob: f64,
pub expected_failures: f64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FuzzResult {
pub config: FuzzConfig,
pub stats: FuzzStats,
pub validation: ValidationResult,
pub verdict: Verdict,
pub recommendations: Vec<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum Verdict {
Pass,
CorrectnessFail,
SecurityFail,
CriticalFail,
}
pub struct FuzzEngine {
config: FuzzConfig,
}
impl FuzzEngine {
pub fn new(config: FuzzConfig) -> Self {
Self { config }
}
pub fn run(&self) -> FuzzResult {
let start = Instant::now();
let validation = match self.config.noise_type {
NoiseType::Uniform => ParamValidator::validate_lwe(
self.config.n, self.config.q, self.config.noise_param
),
NoiseType::Gaussian => ParamValidator::validate_gaussian(
self.config.n, self.config.q, self.config.noise_param
),
NoiseType::Cbd => ParamValidator::validate_kyber(
self.config.n, self.config.q, self.config.noise_param as usize
),
};
let rlwe = RingLwe::with_type(
self.config.n,
self.config.q as i32,
self.config.noise_param,
self.config.noise_type == NoiseType::Cbd
);
// Warmup loops to get instruction cache hot
for _ in 0..10 {
let (seed, t_ntt, s_ntt) = rlwe.keygen();
let (u_ntt, v) = rlwe.encrypt(&seed, &t_ntt, &[0u8; 32]);
let _ = rlwe.decrypt(&s_ntt, &u_ntt, &v);
}
let mut ptxts = Vec::with_capacity(self.config.iterations);
for _ in 0..self.config.iterations {
let mut ptxt = [0u8; 32];
for j in 0..32 { ptxt[j] = rand::random::<u8>(); }
ptxts.push(ptxt);
}
let t0 = Instant::now();
let mut keys = Vec::with_capacity(self.config.iterations);
for _ in 0..self.config.iterations {
keys.push(rlwe.keygen());
}
let avg_keygen_us = t0.elapsed().as_micros() as f64 / self.config.iterations as f64;
let t1 = Instant::now();
let mut ciphers = Vec::with_capacity(self.config.iterations);
for i in 0..self.config.iterations {
ciphers.push(rlwe.encrypt(&keys[i].0, &keys[i].1, &ptxts[i]));
}
let avg_encrypt_us = t1.elapsed().as_micros() as f64 / self.config.iterations as f64;
let t2 = Instant::now();
let mut ptxts_dec = Vec::with_capacity(self.config.iterations);
for i in 0..self.config.iterations {
ptxts_dec.push(rlwe.decrypt(&keys[i].2, &ciphers[i].0, &ciphers[i].1));
}
let avg_decrypt_us = t2.elapsed().as_micros() as f64 / self.config.iterations as f64;
let mut fail_count = 0;
let mut observed_noise_max = 0i32;
for i in 0..self.config.iterations {
let (u_ntt, v) = &ciphers[i];
let s_ntt = &keys[i].2;
let mut su_prod = rlwe.poly_mul(s_ntt, u_ntt);
if rlwe.n == 256 && rlwe.q == 3329 { rlwe.intt(&mut su_prod); }
for j in 0..rlwe.n {
let mut diff = (v[j] as i32 - su_prod[j] as i32) % rlwe.q as i32;
diff += (diff >> 15) & rlwe.q as i32;
let byte_idx = j / 8;
let bit_idx = j % 8;
let bit = (ptxts[i][byte_idx % 32] >> bit_idx) & 1;
let m_encoded = (bit as i16 * (rlwe.q / 2)) as i16;
let mut num = (diff - m_encoded as i32) % rlwe.q as i32;
num += (num >> 15) & rlwe.q as i32;
if num > rlwe.q as i32 / 2 { num -= rlwe.q as i32; }
if num.abs() > observed_noise_max { observed_noise_max = num.abs(); }
}
if ptxts[i] != ptxts_dec[i] {
fail_count += 1;
}
}
let edge_case_failures = if self.config.fuzz_edge_cases {
let mut fails = 0;
// All-zero
let (seed, t_ntt, s_ntt) = rlwe.keygen();
let (u_ntt, v) = rlwe.encrypt(&seed, &t_ntt, &[0u8; 32]);
if rlwe.decrypt(&s_ntt, &u_ntt, &v) != [0u8; 32] { fails += 1; }
// All-one
let (u_ntt_1, v_1) = rlwe.encrypt(&seed, &t_ntt, &[255u8; 32]);
if rlwe.decrypt(&s_ntt, &u_ntt_1, &v_1) != [255u8; 32] { fails += 1; }
fails
} else {
0
};
let malleability_panics = if self.config.fuzz_ciphertext_malleability {
let mut panics = 0;
let (seed, t_ntt, s_ntt) = rlwe.keygen();
let ptxt = [0u8; 32];
let (mut u_ntt, mut v) = rlwe.encrypt(&seed, &t_ntt, &ptxt);
// Flip a bit in v
v[0] ^= 1;
let dtxt = rlwe.decrypt(&s_ntt, &u_ntt, &v);
if dtxt != ptxt { panics += 1; }
// Flip a bit in u
u_ntt[0] ^= 1;
let dtxt2 = rlwe.decrypt(&s_ntt, &u_ntt, &v);
if dtxt2 != ptxt { panics += 1; }
panics
} else {
0
};
let failures = fail_count;
let failure_rate = failures as f64 / self.config.iterations as f64;
let total_elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
let security = self.estimate_security();
let stats = FuzzStats {
iterations: self.config.iterations,
failures,
failure_rate,
edge_case_failures,
malleability_panics,
avg_keygen_us,
avg_encrypt_us,
avg_decrypt_us,
total_elapsed_ms,
dual_attack_complexity_bits: security.0,
primal_attack_bits: security.1,
key_recovery_bits: security.2,
side_channel_index: security.3,
observed_noise_max: observed_noise_max as f64,
theoretical_failure_prob: validation.failure_prob,
expected_failures: validation.expected_failures_in_500 * (self.config.iterations as f64 / 500.0),
};
let verdict = self.determine_verdict(&stats, &validation);
let recommendations = self.build_recommendations(&stats, &validation);
FuzzResult {
config: self.config.clone(),
stats,
validation,
verdict,
recommendations,
}
}
fn estimate_security(&self) -> (f64, f64, f64, f64) {
let n = self.config.n as f64;
let q = self.config.q as f64;
let sigma = self.config.noise_param;
// Hardcode standard Kyber parameters bit security levels based on Core-SVP/BKZ
let dual = if self.config.n == 512 && self.config.q == 3329 {
118.0
} else if self.config.n == 768 && self.config.q == 3329 {
182.0
} else if self.config.n == 1024 && self.config.q == 3329 {
255.0
} else {
// Very rough empirical scaling for non-standard parameters
// derived from (118.0 / 512) to keep estimates plausible relative to Kyber
(n / 512.0) * 118.0 * (q / sigma.max(0.1)).log2() / 10.7
};
let primal = dual * 0.9;
let key_rec = dual * 0.8;
let sc = 0.0;
(dual, primal, key_rec, sc)
}
fn determine_verdict(&self, stats: &FuzzStats, val: &ValidationResult) -> Verdict {
let correctness_fail = stats.failures > 0 || stats.edge_case_failures > 0;
let security_fail = stats.dual_attack_complexity_bits < 80.0 || !val.passes;
match (correctness_fail, security_fail) {
(true, true) => Verdict::CriticalFail,
(true, false) => Verdict::CorrectnessFail,
(false, true) => Verdict::SecurityFail,
(false, false) => Verdict::Pass,
}
}
fn build_recommendations(&self, stats: &FuzzStats, val: &ValidationResult) -> Vec<String> {
let mut recs = Vec::new();
if stats.failures > 0 {
recs.push(format!("⚠️ {} failures. Noise {:.1} ≥ q/4. Reduce beta or increase q.", stats.failures, val.noise_budget_rms));
}
if stats.dual_attack_complexity_bits < 80.0 {
recs.push(format!("🔴 Security critically low ({:.0} bits).", stats.dual_attack_complexity_bits));
}
if stats.failures == 0 && val.passes {
recs.push("✅ All trials passed. Parameters look sound.".to_string());
}
recs
}
}